Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
biggroup_nafs.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Complete, auditors: [Suyash], commit: 553c5eb82901955c638b943065acd3e47fc918c0}
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#pragma once
11
13
14template <typename C, class Fq, class Fr, class G>
15template <size_t wnaf_size>
17 const uint64_t stagger,
18 bool is_negative,
19 bool wnaf_skew)
20{
21 BB_ASSERT_LT(stagger, 32ULL, "biggroup_nafs: stagger value ≥ 32");
22
23 // If there is no stagger then there is no need to change anything
24 if (stagger == 0) {
25 return std::make_pair(0, wnaf_skew);
26 }
27
28 // Sanity check input fragment
29 BB_ASSERT_LT(fragment_u64, (1ULL << stagger), "biggroup_nafs: fragment value ≥ 2^{stagger}");
30
31 // Convert the fragment to signed int for easier manipulation
32 int fragment = static_cast<int>(fragment_u64);
33
34 // Inverse the fragment if it's negative
35 if (is_negative) {
36 fragment = -fragment;
37 }
38 // If the value is positive and there is a skew in wnaf, subtract 2^{stagger}.
39 if (!is_negative && wnaf_skew) {
40 fragment -= (1 << stagger);
41 }
42
43 // If the value is negative and there is a skew in wnaf, add 2^{stagger}.
44 if (is_negative && wnaf_skew) {
45 fragment += (1 << stagger);
46 }
47
48 // If the lowest bit is zero, then set final skew to 1 and
49 // (i) add 1 to the absolute value of the fragment if it's positive
50 // (ii) subtract 1 from the absolute value of the fragment if it's negative
51 bool output_skew = (fragment_u64 & 1) == 0;
52 if (!is_negative && output_skew) {
53 fragment += 1;
54 } else if (is_negative && output_skew) {
55 fragment -= 1;
56 }
57
58 // Compute raw wnaf value: w = 2e + 1 => e = (w - 1) / 2 => e = ⌊w / 2⌋
59 const int signed_wnaf_value = (fragment / 2);
60 constexpr int wnaf_window_size = (1ULL << (wnaf_size - 1));
61 uint64_t output_fragment = 0;
62 if (fragment < 0) {
63 output_fragment = static_cast<uint64_t>(wnaf_window_size + signed_wnaf_value - 1);
64 } else {
65 output_fragment = static_cast<uint64_t>(wnaf_window_size + signed_wnaf_value);
66 }
67
68 return std::make_pair(output_fragment, output_skew);
69}
70
71template <typename C, class Fq, class Fr, class G>
72template <size_t wnaf_size>
74 C* builder, const uint64_t* wnaf_values, bool is_negative, size_t rounds, const bool range_constrain_wnaf)
75{
76 constexpr uint64_t wnaf_window_size = (1ULL << (wnaf_size - 1));
77
78 std::vector<field_ct> wnaf_entries;
79 for (size_t i = 0; i < rounds; ++i) {
80 // Predicate == sign of current wnaf value
81 const bool predicate = (wnaf_values[i] >> 31U) & 1U; // sign bit (32nd bit)
82 const uint64_t wnaf_magnitude = (wnaf_values[i] & 0x7fffffffU); // 31-bit magnitude
83
84 // If the signs of current entry and the whole scalar are the same, then add the magnitude of the
85 // wnaf value to the windows size to form an entry. Otherwise, subract the magnitude along with 1.
86 // The extra 1 is needed to get a uniform representation of (2e' + 1) as explained in the README.
87 uint64_t offset_wnaf_entry = 0;
88 if ((!predicate && !is_negative) || (predicate && is_negative)) {
89 offset_wnaf_entry = wnaf_window_size + wnaf_magnitude;
90 } else {
91 offset_wnaf_entry = wnaf_window_size - wnaf_magnitude - 1;
92 }
93 field_ct wnaf_entry(witness_ct(builder, offset_wnaf_entry));
94
95 // In some cases we may want to skip range constraining the wnaf entries. For example when we use these
96 // entries to lookup in a ROM or regular table, it implicitly enforces the range constraint.
97 if (range_constrain_wnaf) {
98 wnaf_entry.create_range_constraint(wnaf_size, "biggroup_nafs: wnaf_entry is not in range");
99 }
100 wnaf_entries.emplace_back(wnaf_entry);
101 }
102 return wnaf_entries;
103}
104
105template <typename C, class Fq, class Fr, class G>
106template <size_t wnaf_size>
108 const std::vector<field_t<Builder>>& wnaf,
109 const bool_ct& positive_skew,
110 const bool_ct& negative_skew,
111 const field_t<Builder>& stagger_fragment,
112 const size_t stagger,
113 const size_t rounds)
114{
115 // Collect positive wnaf entries for accumulation
116 std::vector<field_ct> accumulator;
117 for (size_t i = 0; i < rounds; ++i) {
118 field_ct entry = wnaf[rounds - 1 - i];
119 entry *= field_ct(uint256_t(1) << (i * wnaf_size));
120 accumulator.emplace_back(entry);
121 }
122
123 // Accumulate entries, shift by stagger and add the stagger itself
124 field_ct sum = field_ct::accumulate(accumulator);
125 sum = sum * field_ct(bb::fr(1ULL << stagger));
126 sum += (stagger_fragment);
127 sum = sum.normalize();
128
129 // Convert this value to bigfield element
130 Fr reconstructed_positive_part =
131 Fr(sum, field_ct::from_witness_index(builder, builder->zero_idx()), /*can_overflow*/ false);
132
133 // Double the final value and add the positive skew
134 reconstructed_positive_part =
135 (reconstructed_positive_part + reconstructed_positive_part)
136 .add_to_lower_limb(field_t<Builder>(positive_skew), /*other_maximum_value*/ uint256_t(1));
137
138 // Start reconstructing the negative part: start with wnaf constant 0xff...ff
139 // See the README for explanation of this constant
140 constexpr uint64_t wnaf_window_size = (1ULL << (wnaf_size - 1));
141 uint256_t negative_constant_wnaf_offset(0);
142 for (size_t i = 0; i < rounds; ++i) {
143 negative_constant_wnaf_offset += uint256_t((wnaf_window_size * 2) - 1) * (uint256_t(1) << (i * wnaf_size));
144 }
145
146 // Shift by stagger
147 negative_constant_wnaf_offset = negative_constant_wnaf_offset << stagger;
148
149 // Add for stagger (if any)
150 if (stagger > 0) {
151 negative_constant_wnaf_offset += ((1ULL << wnaf_size) - 1ULL); // from stagger fragment
152 }
153
154 // Add the negative skew to the bigfield constant
155 Fr reconstructed_negative_part =
156 Fr(nullptr, negative_constant_wnaf_offset).add_to_lower_limb(field_t<Builder>(negative_skew), uint256_t(1));
157
158 // output = x_pos - x_neg (x_pos and x_neg are both non-negative)
159 Fr reconstructed = reconstructed_positive_part - reconstructed_negative_part;
160
161 return reconstructed;
162}
163
164template <typename C, class Fq, class Fr, class G>
165template <size_t num_bits, size_t wnaf_size, size_t lo_stagger, size_t hi_stagger>
167 C* builder,
168 const secp256k1::fr& scalar,
169 size_t stagger,
170 bool is_negative,
171 const bool range_constrain_wnaf,
172 bool is_lo)
173{
174 // The number of rounds is the minimal required to cover the whole scalar with wnaf_size windows
175 constexpr size_t num_rounds = ((num_bits + wnaf_size - 1) / wnaf_size);
176
177 // Stagger mask is needed to retrieve the lowest bits that will not be used in montgomery ladder directly
178 const uint64_t stagger_mask = (1ULL << stagger) - 1;
179
180 // Stagger scalar represents the lower "staggered" bits that are not used in the ladder
181 const uint64_t stagger_scalar = scalar.data[0] & stagger_mask;
182
183 std::array<uint64_t, num_rounds> wnaf_values = { 0 };
184 bool skew_without_stagger = false;
185 uint256_t k_u256{ scalar.data[0], scalar.data[1], scalar.data[2], scalar.data[3] };
186 k_u256 = k_u256 >> stagger;
187 if (is_lo) {
188 bb::wnaf::fixed_wnaf<num_bits - lo_stagger, 1, wnaf_size>(
189 &k_u256.data[0], &wnaf_values[0], skew_without_stagger, 0);
190 } else {
191 bb::wnaf::fixed_wnaf<num_bits - hi_stagger, 1, wnaf_size>(
192 &k_u256.data[0], &wnaf_values[0], skew_without_stagger, 0);
193 }
194
195 // Number of rounds that are needed to reconstruct the scalar without staggered bits
196 const size_t num_rounds_excluding_stagger_bits = ((num_bits + wnaf_size - 1 - stagger) / wnaf_size);
197
198 // Compute the stagger-related fragment and the final skew due to the same
199 const auto [first_fragment, skew] =
200 get_staggered_wnaf_fragment_value<wnaf_size>(stagger_scalar, stagger, is_negative, skew_without_stagger);
201
202 // Get wnaf witnesses
203 // Note that we only range constrain the wnaf entries if range_constrain_wnaf is set to true.
204 std::vector<field_ct> wnaf = convert_wnaf_values_to_witnesses<wnaf_size>(
205 builder, &wnaf_values[0], is_negative, num_rounds_excluding_stagger_bits, range_constrain_wnaf);
206
207 // Compute and constrain skews
208 bool_ct negative_skew(witness_ct(builder, is_negative ? 0 : skew), /*use_range_constraint*/ true);
209 bool_ct positive_skew(witness_ct(builder, is_negative ? skew : 0), /*use_range_constraint*/ true);
210
211 // Enforce that both positive_skew, negative_skew are not set at the same time
212 bool_ct both_skews_cannot_be_one = !(positive_skew & negative_skew);
213 both_skews_cannot_be_one.assert_equal(
214 bool_ct(builder, true), "biggroup_nafs: both positive and negative skews cannot be set at the same time");
215
216 // Initialize stagger witness
217 field_ct stagger_fragment = witness_ct(builder, first_fragment);
218
219 // We only range constrain the stagger fragment if range_constrain_wnaf is set. This is because in some cases
220 // we may use the stagger fragment to lookup in a ROM/regular table, which implicitly enforces the range constraint.
221 if (range_constrain_wnaf) {
222 stagger_fragment.create_range_constraint(wnaf_size, "biggroup_nafs: stagger fragment is not in range");
223 }
224
225 // Reconstruct the bigfield scalar from (wnaf + stagger) representation
226 Fr reconstructed = reconstruct_bigfield_from_wnaf<wnaf_size>(
227 builder, wnaf, positive_skew, negative_skew, stagger_fragment, stagger, num_rounds_excluding_stagger_bits);
228
229 secp256k1_wnaf wnaf_out{ .wnaf = wnaf,
230 .positive_skew = positive_skew,
231 .negative_skew = negative_skew,
232 .least_significant_wnaf_fragment = stagger_fragment,
233 .has_wnaf_fragment = (stagger > 0) };
234
235 return std::make_pair(reconstructed, wnaf_out);
236}
237
325template <typename C, class Fq, class Fr, class G>
326template <size_t wnaf_size, size_t lo_stagger, size_t hi_stagger>
328 const Fr& scalar, const bool range_constrain_wnaf)
329{
356 C* builder = scalar.get_context();
357
358 constexpr size_t num_bits = 129;
359
360 // Decomposes the scalar k into two 129-bit scalars klo, khi such that
361 // k = klo + ζ * khi (mod n)
362 // = klo - λ * khi (mod n)
363 // where ζ is the primitive sixth root of unity mod n, and λ is the primitive cube root of unity mod n
364 // (note that ζ = -λ). We know that for any scalar k, such a decomposition exists and klo and khi are 128-bits long.
365 secp256k1::fr k(uint256_t(scalar.get_value() % Fr::modulus_u512));
366 secp256k1::fr klo(0);
367 secp256k1::fr khi(0);
368 bool klo_negative = false;
369 bool khi_negative = false;
371
372 // The low and high scalars must be less than 2^129 in absolute value. In some cases, the klo or khi value
373 // is returned as negative, in which case we negate it and set a flag to indicate this.
374 if (khi.uint256_t_no_montgomery_conversion().get_msb() >= 129) {
375 khi_negative = true;
376 khi = -khi;
377 }
378 if (klo.uint256_t_no_montgomery_conversion().get_msb() >= 129) {
379 klo_negative = true;
380 klo = -klo;
381 }
382
383 BB_ASSERT_LT(klo.uint256_t_no_montgomery_conversion().get_msb(), 129ULL, "biggroup_nafs: klo > 129 bits");
384 BB_ASSERT_LT(khi.uint256_t_no_montgomery_conversion().get_msb(), 129ULL, "biggroup_nafs: khi > 129 bits");
385
386 const auto [klo_reconstructed, klo_out] =
388 builder, klo, lo_stagger, klo_negative, range_constrain_wnaf, true);
389
390 const auto [khi_reconstructed, khi_out] =
392 builder, khi, hi_stagger, khi_negative, range_constrain_wnaf, false);
393
394 uint256_t minus_lambda_val(-secp256k1::fr::cube_root_of_unity());
395 Fr minus_lambda(bb::fr(minus_lambda_val.slice(0, 136)), bb::fr(minus_lambda_val.slice(136, 256)), false);
396
397 Fr reconstructed_scalar = khi_reconstructed.madd(minus_lambda, { klo_reconstructed });
398
399 // Constant scalars are always reduced mod n by design (scalar < n), however
400 // the reconstructed_scalar may be larger than n as it's a witness. So we need to
401 // reduce the reconstructed_scalar mod n explicitly to match the original scalar.
402 // This is necessary for assert_equal to pass.
403 if (scalar.is_constant()) {
404 reconstructed_scalar.self_reduce();
405 }
406
407 // Validate that the reconstructed scalar matches the original scalar in circuit
408 scalar.assert_equal(reconstructed_scalar, "biggroup_nafs: reconstructed scalar does not match reduced input");
409
410 return { .klo = klo_out, .khi = khi_out };
411}
412
413template <typename C, class Fq, class Fr, class G>
414std::vector<bool_t<C>> element<C, Fq, Fr, G>::compute_naf(const Fr& scalar, const size_t max_num_bits)
415{
416 // Get the circuit builder
417 C* builder = scalar.get_context();
418
419 // To compute the NAF representation, we first reduce the scalar modulo r (the scalar field modulus).
420 uint512_t scalar_multiplier_512 = uint512_t(scalar.get_value()) % uint512_t(Fr::modulus);
421 uint256_t scalar_multiplier = scalar_multiplier_512.lo;
422
423 // Number of rounds is either the max_num_bits provided, or the full size of the scalar field modulus.
424 // If the scalar is zero, we use the full size of the scalar field modulus as we use scalar = r in this case.
425 const size_t num_rounds = (max_num_bits == 0 || scalar_multiplier == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits;
426
427 // NAF can't handle 0 so we set scalar = r in this case.
428 if (scalar_multiplier == 0) {
429 scalar_multiplier = Fr::modulus;
430 }
431
432 // NAF representation consists of num_rounds bits and a skew bit.
433 // Given a scalar k, we compute the NAF representation as follows:
434 //
435 // k = -skew + ₀∑ⁿ⁻¹ (1 - 2 * naf_i) * 2^i
436 //
437 // where naf_i = (1 - k_{i + 1}) ∈ {0, 1} and k_{i + 1} is the (i + 1)-th bit of the scalar k.
438 // If naf_i = 0, then the i-th NAF entry is +1, otherwise it is -1. See the README for more details.
439 //
440 std::vector<bool_ct> naf_entries(num_rounds + 1);
441
442 // If the scalar is even, we set the skew flag to true and add 1 to the scalar.
443 // Sidenote: we apply range constraints to the boolean witnesses instead of full 1-bit range gates.
444 const bool skew_value = !scalar_multiplier.get_bit(0);
445 scalar_multiplier += uint256_t(static_cast<uint64_t>(skew_value));
446 naf_entries[num_rounds] = bool_ct(witness_ct(builder, skew_value), /*use_range_constraint*/ true);
447
448 // We need to manually propagate the origin tag
449 naf_entries[num_rounds].set_origin_tag(scalar.get_origin_tag());
450
451 for (size_t i = 0; i < num_rounds - 1; ++i) {
452 // If the next entry is false, we need to flip the sign of the current entry (naf_entry := (1 - next_bit)).
453 // Apply a basic range constraint per bool, and not a full 1-bit range gate. Results in ~`num_rounds`/4 gates
454 // per scalar.
455 const bool next_entry = scalar_multiplier.get_bit(i + 1);
456 naf_entries[num_rounds - i - 1] = bool_ct(witness_ct(builder, !next_entry), /*use_range_constraint*/ true);
457
458 // We need to manually propagate the origin tag
459 naf_entries[num_rounds - i - 1].set_origin_tag(scalar.get_origin_tag());
460 }
461
462 // The most significant NAF entry is always (+1) as we are working with scalars < 2^{max_num_bits}.
463 // Recall that true represents (-1) and false represents (+1).
464 naf_entries[0] = bool_ct(witness_ct(builder, false), /*use_range_constraint*/ true);
465 naf_entries[0].set_origin_tag(scalar.get_origin_tag());
466
467 // validate correctness of NAF
468 if constexpr (!Fr::is_composite) {
469 std::vector<Fr> accumulators;
470 for (size_t i = 0; i < num_rounds; ++i) {
471 // bit = 1 - 2 * naf
472 Fr entry(naf_entries[num_rounds - i - 1]);
473 entry *= -2;
474 entry += 1;
475 entry *= static_cast<Fr>(uint256_t(1) << (i));
476 accumulators.emplace_back(entry);
477 }
478 accumulators.emplace_back(-Fr(naf_entries[num_rounds])); // -skew
479 Fr accumulator_result = Fr::accumulate(accumulators);
480 scalar.assert_equal(accumulator_result);
481 } else {
482 const auto reconstruct_half_naf = [](bool_ct* nafs, const size_t half_round_length) {
483 field_ct negative_accumulator(0);
484 field_ct positive_accumulator(0);
485 for (size_t i = 0; i < half_round_length; ++i) {
486 negative_accumulator = negative_accumulator + negative_accumulator + field_ct(nafs[i]);
487 positive_accumulator = positive_accumulator + positive_accumulator + field_ct(1) - field_ct(nafs[i]);
488 }
489 return std::make_pair(positive_accumulator, negative_accumulator);
490 };
491
492 std::pair<field_ct, field_ct> hi_accumulators;
493 std::pair<field_ct, field_ct> lo_accumulators;
494
495 if (num_rounds > Fr::NUM_LIMB_BITS * 2) {
496 const size_t midpoint = num_rounds - (Fr::NUM_LIMB_BITS * 2);
497 hi_accumulators = reconstruct_half_naf(&naf_entries[0], midpoint);
498 lo_accumulators = reconstruct_half_naf(&naf_entries[midpoint], num_rounds - midpoint);
499 } else {
500 // If the number of rounds is ≤ (2 * Fr::NUM_LIMB_BITS), the high bits of the resulting Fr element are 0.
502 // The zero_idx is a constant zero, so set the CONSTANT tag to allow merging with origin-tagged elements
503 auto const_tag = OriginTag::constant();
504 zero.set_origin_tag(const_tag);
505 lo_accumulators = reconstruct_half_naf(&naf_entries[0], num_rounds);
506 hi_accumulators = std::make_pair(zero, zero);
507 }
508
509 // Add the skew bit to the low accumulator's negative part.
510 // This addition can produce exactly 2^136 if negative accumulator is 2^136-1 and skew is 1.
511 // When this happens, we need to carry the overflow to the high bits.
512 field_ct lo_neg_with_skew = lo_accumulators.second + field_ct(naf_entries[num_rounds]);
513
514 // Detect if we hit exactly 2^136 (the only overflow case possible)
515 const uint256_t two_pow_136 = uint256_t(1) << (Fr::NUM_LIMB_BITS * 2);
516 field_ct overflow_check = lo_neg_with_skew - field_ct(two_pow_136);
517 bool_ct has_overflow = overflow_check.is_zero();
518
519 // If overflow: set lo to 0, carry 1 to hi_neg
520 // If no overflow: keep lo_neg_with_skew, hi_neg unchanged
521 lo_accumulators.second = lo_neg_with_skew * field_ct(!has_overflow);
522 hi_accumulators.second = hi_accumulators.second + field_ct(has_overflow);
523
524 Fr reconstructed_positive = Fr(lo_accumulators.first, hi_accumulators.first);
525 Fr reconstructed_negative = Fr(lo_accumulators.second, hi_accumulators.second);
526 Fr accumulator = reconstructed_positive - reconstructed_negative;
527
528 // Constant scalars are always reduced mod n by design (scalar < n), however
529 // the reconstructed accumulator may be larger than n as its a witness. So we need to
530 // reduce the reconstructed accumulator mod n explicitly to match the original scalar.
531 // This is necessary for assert_equal to pass.
532 if (scalar.is_constant()) {
533 accumulator.self_reduce();
534 }
535
536 // Validate that the reconstructed scalar matches the original scalar in circuit
537 accumulator.assert_equal(scalar);
538 }
539
540 // Propagate tags to naf
541 const auto original_tag = scalar.get_origin_tag();
542 for (auto& naf_entry : naf_entries) {
543 naf_entry.set_origin_tag(original_tag);
544 }
545 return naf_entries;
546}
547} // namespace bb::stdlib::element_default
#define BB_ASSERT_LT(left, right,...)
Definition assert.hpp:143
constexpr bool get_bit(uint64_t bit_index) const
constexpr uint256_t slice(uint64_t start, uint64_t end) const
constexpr uint64_t get_msb() const
Implements boolean logic in-circuit.
Definition bool.hpp:60
void assert_equal(const bool_t &rhs, std::string const &msg="bool_t::assert_equal") const
Implements copy constraint for bool_t elements.
Definition bool.cpp:433
static field_t from_witness_index(Builder *ctx, uint32_t witness_index)
Definition field.cpp:67
static field_t accumulate(const std::vector< field_t > &input)
Efficiently compute the sum of vector entries. Using big_add_gate we reduce the number of gates neede...
Definition field.cpp:1178
void create_range_constraint(size_t num_bits, std::string const &msg="field_t::range_constraint") const
Let x = *this.normalize(), constrain x.v < 2^{num_bits}.
Definition field.cpp:919
bool_t< Builder > is_zero() const
Validate whether a field_t element is zero.
Definition field.cpp:783
void set_origin_tag(const OriginTag &new_tag) const
Definition field.hpp:357
AluTraceBuilder builder
Definition alu.test.cpp:124
stdlib::witness_t< Builder > witness_ct
stdlib::field_t< Builder > field_ct
constexpr T get_msb(const T in)
Definition get_msb.hpp:49
uintx< uint256_t > uint512_t
Definition uintx.hpp:306
void fixed_wnaf(const uint64_t *scalar, uint64_t *wnaf, bool &skew_map, const uint64_t point_index, const uint64_t num_points, const size_t wnaf_bits) noexcept
Performs fixed-window non-adjacent form (WNAF) computation for scalar multiplication.
Definition wnaf.hpp:117
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
Curve::ScalarField Fr
static OriginTag constant()
static constexpr field cube_root_of_unity()
static constexpr uint256_t modulus
static void split_into_endomorphism_scalars(const field &k, field &k1, field &k2)
Full-width endomorphism decomposition: k ≡ k1 - k2·λ (mod r). Modifies the field elements k1 and k2.
constexpr uint256_t uint256_t_no_montgomery_conversion() const noexcept
BB_INLINE constexpr field from_montgomery_form() const noexcept