Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
process_buckets.cpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Planned, auditors: [Sergei], commit: }
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#include "process_buckets.hpp"
8
9#include <array>
10
12
13// NOLINTNEXTLINE(misc-no-recursion) recursion is fine here, max depth is 4 (32-bit bucket index / 8 bits per call)
15 const size_t num_entries,
16 const uint32_t shift,
17 size_t& num_zero_entries,
18 const uint32_t bucket_index_bits,
19 const uint64_t* top_level_keys) noexcept
20{
21 constexpr size_t NUM_RADIX_BUCKETS = 1UL << RADIX_BITS;
22 constexpr uint32_t RADIX_MASK = static_cast<uint32_t>(NUM_RADIX_BUCKETS) - 1U;
23
24 // Step 1: Count entries in each radix bucket
26 for (size_t i = 0; i < num_entries; ++i) {
27 bucket_counts[(keys[i] >> shift) & RADIX_MASK]++;
28 }
29
30 // Step 2: Convert counts to cumulative offsets (prefix sum)
33 offsets[0] = 0;
34 for (size_t i = 0; i < NUM_RADIX_BUCKETS - 1; ++i) {
35 bucket_counts[i + 1] += bucket_counts[i];
36 }
37
38 // Count zero entries only at the final recursion level (shift == 0) and only for the full array
39 if ((shift == 0) && (keys == top_level_keys)) {
40 num_zero_entries = bucket_counts[0];
41 }
42
43 for (size_t i = 1; i < NUM_RADIX_BUCKETS + 1; ++i) {
44 offsets[i] = bucket_counts[i - 1];
45 }
46 for (size_t i = 0; i < NUM_RADIX_BUCKETS + 1; ++i) {
47 offsets_copy[i] = offsets[i];
48 }
49
50 // Step 3: In-place permutation using cycle sort
51 // For each radix bucket, repeatedly swap elements to their correct positions until all elements
52 // in that bucket's range belong there. The offsets array tracks the next write position for each bucket.
53 uint64_t* start = &keys[0];
54 for (size_t i = 0; i < NUM_RADIX_BUCKETS; ++i) {
55 uint64_t* bucket_start = &keys[offsets[i]];
56 const uint64_t* bucket_end = &keys[offsets_copy[i + 1]];
57 while (bucket_start != bucket_end) {
58 for (uint64_t* it = bucket_start; it < bucket_end; ++it) {
59 const size_t value = (*it >> shift) & RADIX_MASK;
60 const uint64_t offset = offsets[value]++;
61 std::iter_swap(it, start + offset);
62 }
63 bucket_start = &keys[offsets[i]];
64 }
65 }
66
67 // Step 4: Recursively sort each bucket by the next less-significant byte
68 if (shift > 0) {
69 for (size_t i = 0; i < NUM_RADIX_BUCKETS; ++i) {
70 const size_t bucket_size = offsets_copy[i + 1] - offsets_copy[i];
71 if (bucket_size > 1) {
73 &keys[offsets_copy[i]], bucket_size, shift - RADIX_BITS, num_zero_entries, bucket_index_bits, keys);
74 }
75 }
76 }
77}
78
79size_t sort_point_schedule_and_count_zero_buckets(uint64_t* point_schedule,
80 const size_t num_entries,
81 const uint32_t bucket_index_bits) noexcept
82{
83 if (num_entries == 0) {
84 return 0;
85 }
86
87 // Round bucket_index_bits up to next multiple of RADIX_BITS for proper MSD radix sort alignment.
88 // E.g., if bucket_index_bits=10, we need to start sorting from bit 16 (2 bytes) not bit 10.
89 const uint32_t remainder = bucket_index_bits % RADIX_BITS;
90 const uint32_t padded_bits = (remainder == 0) ? bucket_index_bits : bucket_index_bits - remainder + RADIX_BITS;
91 const uint32_t initial_shift = padded_bits - RADIX_BITS;
92
93 size_t num_zero_entries = 0;
95 point_schedule, num_entries, initial_shift, num_zero_entries, bucket_index_bits, point_schedule);
96
97 // The radix sort counts entries where the least significant BYTE is zero, but we need entries where
98 // the entire bucket_index (lower 32 bits) is zero. Verify the first entry after sorting.
99 if ((point_schedule[0] & BUCKET_INDEX_MASK) != 0) {
100 num_zero_entries = 0;
101 }
102
103 return num_zero_entries;
104}
105
106} // namespace bb::scalar_multiplication
ssize_t offset
Definition engine.cpp:52
constexpr uint32_t RADIX_BITS
size_t sort_point_schedule_and_count_zero_buckets(uint64_t *point_schedule, const size_t num_entries, const uint32_t bucket_index_bits) noexcept
Sort point schedule by bucket index and count zero-bucket entries.
void radix_sort_count_zero_entries(uint64_t *keys, const size_t num_entries, const uint32_t shift, size_t &num_zero_entries, const uint32_t bucket_index_bits, const uint64_t *top_level_keys) noexcept
Recursive MSD radix sort that also counts entries with zero bucket index.
constexpr uint64_t BUCKET_INDEX_MASK
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13