Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
univariate_coefficient_basis.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Complete, auditors: [Nishat], commit: 94f596f8b3bbbc216f9ad7dc33253256141156b2 }
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#pragma once
11#include <span>
12
13namespace bb {
14
39template <class Fr, size_t domain_end, bool has_a0_plus_a1> class UnivariateCoefficientBasis {
40 public:
41 static constexpr size_t LENGTH = domain_end;
42 static_assert(LENGTH == 2 || LENGTH == 3);
43 using value_type = Fr; // used to get the type of the elements consistently with std::array
44
52 std::array<Fr, 3> coefficients;
53
55
57 requires(!has_a0_plus_a1)
58 {
59 coefficients[0] = other.coefficients[0];
60 coefficients[1] = other.coefficients[1];
61 if constexpr (domain_end == 3) {
62 coefficients[2] = other.coefficients[2];
63 }
64 }
65
71
72 template <size_t other_domain_end, bool other_has_a0_plus_a1 = true>
74 requires(domain_end > other_domain_end)
75 {
76 coefficients[0] = other.coefficients[0];
77 coefficients[1] = other.coefficients[1];
78 if constexpr (domain_end == 3) {
79 coefficients[2] = 0;
80 }
81 };
82
83 // Operations between UnivariateCoefficientBasis and other UnivariateCoefficientBasis
84 bool operator==(const UnivariateCoefficientBasis& other) const = default;
85
86 template <size_t other_domain_end, bool other_has_a0_plus_a1>
89 {
90 // if both operands are degree-1, then we do not update coefficients[2], which represents `a1 + a0`
91 // the output object therefore must have `other_has_a0_plus_a1` set to false.
92 // i.e. the input also requires `other_has_a0_plus_a1`, otherwise use `operator+
93 coefficients[0] += other.coefficients[0];
94 coefficients[1] += other.coefficients[1];
95 if constexpr (other_domain_end == 3 && domain_end == 3) {
96 coefficients[2] += other.coefficients[2];
97 }
98 return *this;
99 }
100
101 template <size_t other_domain_end, bool other_has_a0_plus_a1>
104 {
105 // if both operands are degree-1, then we do not update coefficients[2], which represents `a1 + a0`
106 // the output object therefore must have `other_has_a0_plus_a1` set to false.
107 // i.e. the input also requires `other_has_a0_plus_a1`, otherwise use `operator+
108 coefficients[0] -= other.coefficients[0];
109 coefficients[1] -= other.coefficients[1];
110 if constexpr (other_domain_end == 3 && domain_end == 3) {
111 coefficients[2] -= other.coefficients[2];
112 }
113 return *this;
114 }
115
116 template <bool other_has_a0_plus_a1>
119 requires(LENGTH == 2)
120 {
122 // result.coefficients[0] = a0 * a0;
123 // result.coefficients[1] = a1 * a1
124 result.coefficients[0] = coefficients[0] * other.coefficients[0];
125 result.coefficients[2] = coefficients[1] * other.coefficients[1];
126
127 // the reason we've been tracking this variable all this time.
128 // coefficients[1] = sum of X^2 and X coefficients
129 // (a0 + a1X) * (b0 + b1X) = a0b0 + (a0b1 + a1b0)X + a1b1XX
130 // coefficients[1] = a0b1 + a1b0 + a1b1
131 // which represented as (a0 + a1) * (b0 + b1) - a0b0
132 // if we have a1_plus_a0
133 if constexpr (has_a0_plus_a1 && other_has_a0_plus_a1) {
134 result.coefficients[1] = (coefficients[2] * other.coefficients[2] - result.coefficients[0]);
135 } else if constexpr (has_a0_plus_a1 && !other_has_a0_plus_a1) {
136 result.coefficients[1] =
137 coefficients[2] * (other.coefficients[0] + other.coefficients[1]) - result.coefficients[0];
138 } else if constexpr (!has_a0_plus_a1 && other_has_a0_plus_a1) {
139 result.coefficients[1] =
140 (coefficients[0] + coefficients[1]) * other.coefficients[2] - result.coefficients[0];
141 } else {
142 result.coefficients[1] =
143 (coefficients[0] + coefficients[1]) * (other.coefficients[0] + other.coefficients[1]) -
144 result.coefficients[0];
145 }
146 return result;
147 }
148
149 template <size_t other_domain_end, bool other_has_a0_plus_a1>
152 {
154 // if both operands are degree-1, then we do not update coefficients[2], which represents `a1 + a0`
155 // the output object therefore must have `other_has_a0_plus_a1` set to false.
156 // i.e. the input also requires `other_has_a0_plus_a1`, otherwise use `operator+
157 res.coefficients[0] += other.coefficients[0];
158 res.coefficients[1] += other.coefficients[1];
159 if constexpr (other_domain_end == 3 && domain_end == 3) {
160 res.coefficients[2] += other.coefficients[2];
161 }
162 return res;
163 }
164
165 template <size_t other_domain_end, bool other_has_a0_plus_a1>
168 {
170 // if both operands are degree-1, then we do not update coefficients[2], which represents `a1 + a0`
171 // the output object therefore must have `other_has_a0_plus_a1` set to false.
172 // i.e. the input also requires `other_has_a0_plus_a1`, otherwise use `operator+
173 res.coefficients[0] -= other.coefficients[0];
174 res.coefficients[1] -= other.coefficients[1];
175 if constexpr (other_domain_end == 3 && domain_end == 3) {
176 res.coefficients[2] -= other.coefficients[2];
177 }
178 return res;
179 }
180
182 {
184 res.coefficients[0] = -coefficients[0];
185 res.coefficients[1] = -coefficients[1];
186 if constexpr (domain_end == 3) {
187 res.coefficients[2] = -coefficients[2];
188 }
189
190 return res;
191 }
192
194 requires(LENGTH == 2)
195 {
197 result.coefficients[0] = coefficients[0].sqr();
198 result.coefficients[2] = coefficients[1].sqr();
199
200 // (a0 + a1.X)^2 = a0a0 + 2a0a1.X + a1a1.XX
201 // coefficients[0] = a0a0
202 // coefficients[1] = 2a0a1 + a1a1 = (a0 + a0 + a1).a1
203 // coefficients[2] = a1a1
204 // a0a0 a1a1 a0a1a1a0
205 if constexpr (has_a0_plus_a1) {
206 result.coefficients[1] = (coefficients[2] + coefficients[0]) * coefficients[1];
207 } else {
208 result.coefficients[1] = coefficients[0] * coefficients[1];
209 result.coefficients[1] += result.coefficients[1];
210 result.coefficients[1] += result.coefficients[2];
211 }
212 return result;
213 }
214
215 // Operations between Univariate and scalar
217 requires(!has_a0_plus_a1)
218 {
219 coefficients[0] += scalar;
220 return *this;
221 }
222
224 requires(!has_a0_plus_a1)
225 {
226 coefficients[0] -= scalar;
227 return *this;
228 }
230 requires(!has_a0_plus_a1)
231 {
232 coefficients[0] *= scalar;
233 coefficients[1] *= scalar;
234 if constexpr (domain_end == 3) {
235 coefficients[2] *= scalar;
236 }
237 return *this;
238 }
239
241 {
243 res += scalar;
244 return res;
245 }
246
248 {
250 res -= scalar;
251 return res;
252 }
253
255 {
257 res.coefficients[0] *= scalar;
258 res.coefficients[1] *= scalar;
259 if constexpr (domain_end == 3) {
260 res.coefficients[2] *= scalar;
261 }
262 return res;
263 }
264
265 // Output is immediately parsable as a list of integers by Python.
266 friend std::ostream& operator<<(std::ostream& os, const UnivariateCoefficientBasis& u)
267 {
268 os << "[";
269 os << u.coefficients[0] << "," << std::endl;
270 for (size_t i = 1; i < u.coefficients.size(); i++) {
271 os << " " << u.coefficients[i];
272 if (i + 1 < u.coefficients.size()) {
273 os << "," << std::endl;
274 } else {
275 os << "]";
276 };
277 }
278 return os;
279 }
280};
281
282template <typename B, class Fr, size_t domain_end, bool has_a0_plus_a1>
284{
285 using serialize::read;
286 read(it, univariate.coefficients);
287}
288
289template <typename B, class Fr, size_t domain_end, bool has_a0_plus_a1>
291{
292 using serialize::write;
293 write(it, univariate.coefficients);
294}
295
296} // namespace bb
297
298namespace std {
299template <typename T, size_t N, bool X>
300struct tuple_size<bb::UnivariateCoefficientBasis<T, N, X>> : std::integral_constant<std::size_t, N> {};
301
302} // namespace std
A view of a univariate, also used to truncate univariates.
bool operator==(const UnivariateCoefficientBasis &other) const =default
friend std::ostream & operator<<(std::ostream &os, const UnivariateCoefficientBasis &u)
UnivariateCoefficientBasis & operator=(const UnivariateCoefficientBasis &other)=default
UnivariateCoefficientBasis< Fr, domain_end, false > operator-(const Fr &scalar) const
UnivariateCoefficientBasis(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other)
UnivariateCoefficientBasis< Fr, domain_end, false > & operator*=(const Fr &scalar)
UnivariateCoefficientBasis(UnivariateCoefficientBasis &&other) noexcept=default
UnivariateCoefficientBasis< Fr, 3, false > sqr() const
UnivariateCoefficientBasis(const UnivariateCoefficientBasis< Fr, domain_end, true > &other)
UnivariateCoefficientBasis & operator+=(const Fr &scalar)
UnivariateCoefficientBasis< Fr, domain_end, false > operator+(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other) const
std::array< Fr, 3 > coefficients
coefficients is a length-3 array with the following representation:
UnivariateCoefficientBasis< Fr, domain_end, false > & operator+=(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other)
UnivariateCoefficientBasis< Fr, domain_end, false > operator-() const
UnivariateCoefficientBasis & operator-=(const Fr &scalar)
UnivariateCoefficientBasis & operator=(UnivariateCoefficientBasis &&other) noexcept=default
UnivariateCoefficientBasis(const UnivariateCoefficientBasis &other)=default
UnivariateCoefficientBasis< Fr, domain_end, false > & operator-=(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other)
UnivariateCoefficientBasis< Fr, 3, false > operator*(const UnivariateCoefficientBasis< Fr, domain_end, other_has_a0_plus_a1 > &other) const
UnivariateCoefficientBasis< Fr, domain_end, false > operator*(const Fr &scalar) const
UnivariateCoefficientBasis< Fr, domain_end, false > operator-(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other) const
UnivariateCoefficientBasis< Fr, domain_end, false > operator+(const Fr &scalar) const
Entry point for Barretenberg command-line interface.
Definition api.hpp:5
void read(B &it, field2< base_field, Params > &value)
void write(B &buf, field2< base_field, Params > const &value)
void read(auto &it, msgpack_concepts::HasMsgPack auto &obj)
Automatically derived read for any object that defines .msgpack() (implicitly defined by MSGPACK_FIEL...
void write(auto &buf, const msgpack_concepts::HasMsgPack auto &obj)
Automatically derived write for any object that defines .msgpack() (implicitly defined by MSGPACK_FIE...
STL namespace.
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
Curve::ScalarField Fr