23 static size_t current_size = 0;
24 if (num_elements > current_size) {
26 current_size = num_elements;
28 return working_memory;
35 x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1));
36 x = (((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2));
37 x = (((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4));
38 x = (((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8));
39 return (((x >> 16) | (x << 16))) >> (32 - bit_length);
44 return x && !(x & (x - 1));
51 const Fr& generator_start,
52 const Fr& generator_shift,
53 const size_t generator_size)
56 Fr thread_shift = generator_shift.pow(static_cast<uint64_t>(j * (generator_size / domain.num_threads)));
57 Fr work_generator = generator_start * thread_shift;
58 const size_t offset = j * (generator_size / domain.num_threads);
59 const size_t end = offset + (generator_size / domain.num_threads);
60 for (size_t i = offset; i < end; ++i) {
61 target[i] = coeffs[i] * work_generator;
62 work_generator *= generator_shift;
68 requires SupportsFFT<Fr>
75 for (size_t i = (j * domain.thread_size); i < ((j + 1) * domain.thread_size); i += 2) {
76 uint32_t next_index_1 = (uint32_t)reverse_bits((uint32_t)i + 2, (uint32_t)domain.log2_size);
77 uint32_t next_index_2 = (uint32_t)reverse_bits((uint32_t)i + 3, (uint32_t)domain.log2_size);
78 __builtin_prefetch(&coeffs[next_index_1]);
79 __builtin_prefetch(&coeffs[next_index_2]);
81 uint32_t swap_index_1 = (uint32_t)reverse_bits((uint32_t)i, (uint32_t)domain.log2_size);
82 uint32_t swap_index_2 = (uint32_t)reverse_bits((uint32_t)i + 1, (uint32_t)domain.log2_size);
84 Fr::__copy(coeffs[swap_index_1], temp_1);
85 Fr::__copy(coeffs[swap_index_2], temp_2);
86 target[i + 1] = temp_1 - temp_2;
87 target[i] = temp_1 + temp_2;
92 for (
size_t m = 2; m < (domain.size); m <<= 1) {
103 const size_t start = j * (domain.thread_size >> 1);
104 const size_t end = (j + 1) * (domain.thread_size >> 1);
124 const size_t block_mask = m - 1;
134 const size_t index_mask = ~block_mask;
139 const Fr* round_roots = root_table[static_cast<size_t>(numeric::get_msb(m)) - 1];
144 for (size_t i = start; i < end; ++i) {
145 size_t k1 = (i & index_mask) << 1;
146 size_t j1 = i & block_mask;
147 temp = round_roots[j1] * target[k1 + j1 + m];
148 target[k1 + j1 + m] = target[k1 + j1] - temp;
149 target[k1 + j1] += temp;
155template <
typename Fr>
156 requires SupportsFFT<Fr>
162 const size_t start = j * domain.thread_size;
163 const size_t end = (j + 1) * domain.thread_size;
164 for (size_t i = start; i < end; ++i) {
165 target[i] *= domain.domain_inverse;
173 std::vector<Fr> evaluations(num_threads,
Fr::zero());
177 auto range = chunk.
range(n);
181 size_t start = *range.begin();
182 Fr z_acc = z.
pow(
static_cast<uint64_t
>(start));
183 for (
size_t i : range) {
184 Fr work_var = z_acc * coeffs[i];
191 for (
const auto& eval : evaluations) {
197template <
typename Fr>
Fr evaluate(
const std::vector<Fr*> coeffs,
const Fr& z,
const size_t large_n)
199 const size_t num_polys = coeffs.size();
200 const size_t poly_size = large_n / num_polys;
202 const size_t log2_poly_size = (size_t)numeric::get_msb(poly_size);
204 std::vector<Fr> evaluations(num_threads,
Fr::zero());
208 auto range = chunk.
range(large_n);
212 size_t start = *range.begin();
213 Fr z_acc = z.
pow(
static_cast<uint64_t
>(start));
214 for (
size_t i : range) {
215 Fr work_var = z_acc * coeffs[i >> log2_poly_size][i & (poly_size - 1)];
222 for (
const auto& eval : evaluations) {
232 for (
size_t i = 0; i < n; ++i) {
242 auto scratch_space_ptr = get_scratch_space<Fr>(n);
243 auto scratch_space = scratch_space_ptr.get();
244 memcpy((
void*)scratch_space, (
void*)roots, n *
sizeof(
Fr));
251 for (
size_t i = 0; i < n - 1; ++i) {
253 for (
size_t j = 0; j < n - 1 - i; ++j) {
254 scratch_space[j] = roots[j] *
compute_sum(&scratch_space[j + 1], n - 1 - i - j);
255 temp += scratch_space[j];
257 dest[n - 2 - i] = temp * constant;
265 for (
size_t i = 0; i < n; ++i) {
266 result *= (z - roots[i]);
271template <
typename Fr>
308 std::vector<Fr> numerator_polynomial(n + 1);
309 polynomial_arithmetic::compute_linear_polynomial_product(evaluation_points, numerator_polynomial.
data(), n);
311 std::vector<Fr> roots_and_denominators(2 * n);
312 std::vector<Fr> temp_src(n);
313 for (
size_t i = 0; i < n; ++i) {
314 roots_and_denominators[i] = -evaluation_points[i];
315 temp_src[i] = src[i];
318 roots_and_denominators[n + i] = 1;
319 for (
size_t j = 0; j < n; ++j) {
323 roots_and_denominators[n + i] *= (evaluation_points[i] - evaluation_points[j]);
331 std::vector<Fr> temp_dest(n);
333 bool interpolation_domain_contains_zero =
false;
336 if (numerator_polynomial[0] ==
Fr(0)) {
337 for (
size_t i = 0; i < n; ++i) {
338 if (evaluation_points[i] ==
Fr(0)) {
340 interpolation_domain_contains_zero =
true;
346 if (!interpolation_domain_contains_zero) {
347 for (
size_t i = 0; i < n; ++i) {
349 z = roots_and_denominators[i];
351 multiplier = temp_src[i] * roots_and_denominators[n + i];
352 temp_dest[0] = multiplier * numerator_polynomial[0];
354 dest[0] += temp_dest[0];
355 for (
size_t j = 1; j < n; ++j) {
356 temp_dest[j] = multiplier * numerator_polynomial[j] - temp_dest[j - 1];
358 dest[j] += temp_dest[j];
362 for (
size_t i = 0; i < n; ++i) {
368 z = roots_and_denominators[i];
370 multiplier = temp_src[i] * roots_and_denominators[n + i];
372 temp_dest[1] = multiplier * numerator_polynomial[1];
376 dest[1] += temp_dest[1];
378 for (
size_t j = 2; j < n; ++j) {
379 temp_dest[j] = multiplier * numerator_polynomial[j] - temp_dest[j - 1];
381 dest[j] += temp_dest[j];
385 for (
size_t i = 0; i < n; ++i) {
386 dest[i] += temp_src[idx_zero] * roots_and_denominators[n + idx_zero] * numerator_polynomial[i + 1];
391template fr evaluate<fr>(
const fr*,
const fr&,
const size_t);
395template fr compute_sum<fr>(
const fr*,
const size_t);
396template void compute_linear_polynomial_product<fr>(
const fr*,
fr*,
const size_t);
397template void compute_efficient_interpolation<fr>(
const fr*,
fr*,
const fr*,
const size_t);
403template void compute_efficient_interpolation<grumpkin::fr>(
const grumpkin::fr*,
#define BB_ASSERT(expression,...)
#define BB_ASSERT_EQ(actual, expected,...)
const std::vector< FF * > & get_inverse_round_roots() const
constexpr bool is_power_of_two(uint64_t x)
Fr compute_linear_polynomial_product_evaluation(const Fr *roots, const Fr z, const size_t n)
uint32_t reverse_bits(uint32_t x, uint32_t bit_length)
void ifft(Fr *coeffs, Fr *target, const EvaluationDomain< Fr > &domain)
void compute_linear_polynomial_product(const Fr *roots, Fr *dest, const size_t n)
Fr evaluate(const Fr *coeffs, const Fr &z, const size_t n)
bool is_power_of_two(uint64_t x)
void fft_inner_parallel(Fr *coeffs, Fr *target, const EvaluationDomain< Fr > &domain, const Fr &, const std::vector< Fr * > &root_table)
void compute_efficient_interpolation(const Fr *src, Fr *dest, const Fr *evaluation_points, const size_t n)
void scale_by_generator(Fr *coeffs, Fr *target, const EvaluationDomain< Fr > &domain, const Fr &generator_start, const Fr &generator_shift, const size_t generator_size)
Fr compute_sum(const Fr *src, const size_t n)
void parallel_for(size_t num_iterations, const std::function< void(size_t)> &func)
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
auto range(size_t size, size_t offset=0) const
static constexpr field neg_one()
BB_INLINE constexpr field pow(const uint256_t &exponent) const noexcept
static void batch_invert(C &coeffs) noexcept
Batch invert a collection of field elements using Montgomery's trick.
static constexpr field zero()