Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
aes128.cpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Complete, auditors: [Khashayar], commit: 21476601b111f046f023474465598843e4cfd8ac}
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#include "./aes128.hpp"
8
12
15
16#include <span>
17
18using namespace bb::crypto;
19
21
22template <typename Builder> using byte_pair = std::pair<field_t<Builder>, field_t<Builder>>;
23template <typename Builder> using state_span = std::span<byte_pair<Builder>, BLOCK_SIZE>;
24template <typename Builder> using column_span = std::span<byte_pair<Builder>, COLUMN_SIZE>;
25template <typename Builder> using key_span = std::span<field_t<Builder>, EXTENDED_KEY_LENGTH>;
26template <typename Builder> using block_span = std::span<field_t<Builder>, BLOCK_SIZE>;
27using namespace bb::plookup;
28
30{
32 return result;
33}
34
39
40template <typename Builder>
42{
43 std::array<field_t<Builder>, 16> sparse_bytes;
44 auto block_data_copy = block_data;
45 if (block_data.is_constant()) {
46 // The algorithm expects that the sparse bytes are witnesses, so the block_data_copy must be a witness
47 block_data_copy.convert_constant_to_fixed_witness(ctx);
48 }
49 // Convert block data into sparse bytes using the AES_INPUT lookup table
50 auto lookup = plookup_read<Builder>::get_lookup_accumulators(AES_INPUT, block_data_copy);
51 for (size_t i = 0; i < 16; ++i) {
52 sparse_bytes[15 - i] = lookup[ColumnIdx::C2][i];
53 }
54 return sparse_bytes;
55}
56
57template <typename Builder> field_t<Builder> convert_from_sparse_bytes(Builder* ctx, block_span<Builder> sparse_bytes)
58{
59 uint256_t accumulator = 0;
60 for (size_t i = 0; i < BLOCK_SIZE; ++i) {
61 uint64_t sparse_byte = uint256_t(sparse_bytes[i].get_value()).data[0];
62 uint256_t byte = numeric::map_from_sparse_form<AES128_BASE>(sparse_byte);
63 accumulator <<= 8;
64 accumulator += (byte);
65 }
66
67 field_t<Builder> result = witness_t(ctx, fr(accumulator));
68
70
71 for (size_t i = 0; i < BLOCK_SIZE; ++i) {
72 sparse_bytes[BLOCK_SIZE - 1 - i].assert_equal(lookup[ColumnIdx::C2][i]);
73 }
74
75 return result;
76}
77
108template <typename Builder>
110{
111 // Round constants (Rcon) from FIPS 197. Index 0 is a placeholder (never used);
112 // indices 1-10 are Rcon[1] through Rcon[10] = {0x01, 0x02, 0x04, ..., 0x36}.
113 // These are powers of 2 in GF(2^8): Rcon[i] = 2^(i-1) mod P(x).
114 constexpr std::array<uint8_t, 11> round_constants = { 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10,
115 0x20, 0x40, 0x80, 0x1b, 0x36 };
116 const auto sparse_round_constants = [&]() {
117 std::array<field_t<Builder>, 11> result;
118 for (size_t i = 0; i < 11; ++i) {
119 result[i] = field_t<Builder>(ctx, fr(numeric::map_into_sparse_form<AES128_BASE>(round_constants[i])));
120 }
121 return result;
122 }();
123
125 const auto sparse_key = convert_into_sparse_bytes(ctx, key);
126
128 std::array<uint64_t, 4> temp_add_counts{};
129 // Track the number of additions in each byte to normalize to prevent overflow in the sparse representation
131 for (size_t i = 0; i < EXTENDED_KEY_LENGTH; ++i) {
132 add_counts[i] = 1;
133 }
134
135 // For the first round (first 16 bytes of the expanded key), the round key is the same as the original key
136 for (size_t i = 0; i < 16; ++i) {
137 round_key[i] = sparse_key[i];
138 }
139
140 // Ittereate over the 40 words (4 words per round for 10 rounds)
141 for (size_t i = 4; i < 44; ++i) {
142 size_t k = (i - 1) * 4;
143 // Each word is 4 bytes, hence all the operations are done on 4 bytes at a time
144 temp_add_counts[0] = add_counts[k + 0];
145 temp_add_counts[1] = add_counts[k + 1];
146 temp_add_counts[2] = add_counts[k + 2];
147 temp_add_counts[3] = add_counts[k + 3];
148
149 temp[0] = round_key[k];
150 temp[1] = round_key[k + 1];
151 temp[2] = round_key[k + 2];
152 temp[3] = round_key[k + 3];
153
154 // If the word index is a multiple of 4, then we need to apply the RotWord and SubWord operations
155 if ((i & 0x03) == 0) {
156 // Apply the RotWord operation to the 4 bytes
157 const auto t = temp[0];
158 temp[0] = temp[1];
159 temp[1] = temp[2];
160 temp[2] = temp[3];
161 temp[3] = t;
162
163 // Apply the SubWord operation to the 4 bytes by looking up the S-box value in the AES_SBOX lookup table
164 temp[0] = apply_aes_sbox_map(ctx, temp[0]).first;
165 temp[1] = apply_aes_sbox_map(ctx, temp[1]).first;
166 temp[2] = apply_aes_sbox_map(ctx, temp[2]).first;
167 temp[3] = apply_aes_sbox_map(ctx, temp[3]).first;
168
169 // Add the round constant to the word. Since the round constants are 1 byte long we can just add them to the
170 // first byte of the word
171 temp[0] = temp[0] + sparse_round_constants[i >> 2];
172 ++temp_add_counts[0];
173 }
174
175 // The index of the expanded key bytes that need to be updated
176 size_t j = i * 4;
177 // The index if the key bytes corresponding to the previous word
178 k = (i - 4) * 4;
179 round_key[j] = round_key[k] + temp[0];
180 round_key[j + 1] = round_key[k + 1] + temp[1];
181 round_key[j + 2] = round_key[k + 2] + temp[2];
182 round_key[j + 3] = round_key[k + 3] + temp[3];
183
184 add_counts[j] = add_counts[k] + temp_add_counts[0];
185 add_counts[j + 1] = add_counts[k + 1] + temp_add_counts[1];
186 add_counts[j + 2] = add_counts[k + 2] + temp_add_counts[2];
187 add_counts[j + 3] = add_counts[k + 3] + temp_add_counts[3];
188
189 // Number of additions before we need to normalize the sparse form
190 constexpr uint64_t target = 3;
191 for (size_t k = 0; k < 4; ++k) {
192 // If the number of additions exceeds the target or the byte corresponds to a word index that is a multiple
193 // of 4 (i.e. the byte is used as input to the S-box) we normalize the sparse form
194 size_t byte_index = j + k;
195 if (add_counts[byte_index] > target || (add_counts[byte_index] > 1 && (byte_index & 12) == 12)) {
196 round_key[byte_index] = normalize_sparse_form(ctx, round_key[byte_index]);
197 // Reset the addition counter
198 add_counts[byte_index] = 1;
199 }
200 }
201 }
202
203 return round_key;
204}
205
213template <typename Builder> void shift_rows(state_span<Builder> state)
214{
215 byte_pair<Builder> temp = state[1];
216 state[1] = state[5];
217 state[5] = state[9];
218 state[9] = state[13];
219 state[13] = temp;
220
221 temp = state[2];
222 state[2] = state[10];
223 state[10] = temp;
224 temp = state[6];
225 state[6] = state[14];
226 state[14] = temp;
227
228 temp = state[3];
229 state[3] = state[15];
230 state[15] = state[11];
231 state[11] = state[7];
232 state[7] = temp;
233}
234
265template <typename Builder>
267 key_span<Builder> round_key,
268 size_t round,
269 size_t column)
270{
271 // Intermediate values to reduce the number of additions (optimization)
272 // t0 = s0 + s3 + 3·s1
273 auto t0 = column_pairs[0].first.add_two(column_pairs[3].first, column_pairs[1].second);
274 // t1 = s1 + s2 + 3·s3
275 auto t1 = column_pairs[1].first.add_two(column_pairs[2].first, column_pairs[3].second);
276
277 // r0 = 2·s0 ⊕ 3·s1 ⊕ s2 ⊕ s3 = t0 + s2 + 3·s0 = (s0 + 3·s0) + 3·s1 + s2 + s3
278 auto r0 = t0.add_two(column_pairs[2].first, column_pairs[0].second);
279 // r1 = s0 ⊕ 2·s1 ⊕ 3·s2 ⊕ s3 = t0 + s1 + 3·s2 = s0 + (s1 + 3·s1) + 3·s2 + s3
280 auto r1 = t0.add_two(column_pairs[1].first, column_pairs[2].second);
281 // r2 = s0 ⊕ s1 ⊕ 2·s2 ⊕ 3·s3 = t1 + s0 + 3·s2 = s0 + s1 + (s2 + 3·s2) + 3·s3
282 auto r2 = t1.add_two(column_pairs[0].first, column_pairs[2].second);
283 // r3 = 3·s0 ⊕ s1 ⊕ s2 ⊕ 2·s3 = t1 + 3·s0 + s3 = 3·s0 + s1 + s2 + (s3 + 3·s3)
284 auto r3 = t1.add_two(column_pairs[0].second, column_pairs[3].first);
285
286 // Round key offset: round * 16 (bytes per round) + column * 4 (bytes per column)
287 const size_t key_offset = round * BLOCK_SIZE + column * COLUMN_SIZE;
288
289 // Add round key and store result back (only .first is updated; .second will be recomputed by next SubBytes)
290 column_pairs[0].first = r0 + round_key[key_offset];
291 column_pairs[1].first = r1 + round_key[key_offset + 1];
292 column_pairs[2].first = r2 + round_key[key_offset + 2];
293 column_pairs[3].first = r3 + round_key[key_offset + 3];
294}
295
296template <typename Builder>
298{
299 mix_column_and_add_round_key<Builder>(state_pairs.template subspan<0, COLUMN_SIZE>(), round_key, round, 0);
300 mix_column_and_add_round_key<Builder>(state_pairs.template subspan<4, COLUMN_SIZE>(), round_key, round, 1);
301 mix_column_and_add_round_key<Builder>(state_pairs.template subspan<8, COLUMN_SIZE>(), round_key, round, 2);
302 mix_column_and_add_round_key<Builder>(state_pairs.template subspan<12, COLUMN_SIZE>(), round_key, round, 3);
303}
304
305template <typename Builder> void sub_bytes(Builder* ctx, state_span<Builder> state_pairs)
306{
307 for (size_t i = 0; i < BLOCK_SIZE; ++i) {
308 state_pairs[i] = apply_aes_sbox_map(ctx, state_pairs[i].first);
309 }
310}
311
312template <typename Builder>
313void add_round_key(state_span<Builder> sparse_state, key_span<Builder> sparse_round_key, size_t round)
314{
315 const size_t key_offset = round * BLOCK_SIZE;
316 for (size_t i = 0; i < BLOCK_SIZE; i += COLUMN_SIZE) {
317 for (size_t j = 0; j < COLUMN_SIZE; ++j) {
318 sparse_state[i + j].first += sparse_round_key[key_offset + i + j];
319 }
320 }
321}
322
323template <typename Builder> void xor_with_iv(state_span<Builder> state, block_span<Builder> iv)
324{
325 for (size_t i = 0; i < BLOCK_SIZE; ++i) {
326 state[i].first += iv[i];
327 }
328}
329
330template <typename Builder>
332{
333 add_round_key<Builder>(state, sparse_round_key, 0);
334 for (size_t i = 0; i < BLOCK_SIZE; ++i) {
335 state[i].first = normalize_sparse_form(ctx, state[i].first);
336 }
337
338 for (size_t round = 1; round < NUM_ROUNDS; ++round) {
339 sub_bytes(ctx, state);
340 shift_rows<Builder>(state);
341 mix_columns_and_add_round_key<Builder>(state, sparse_round_key, round);
342 for (size_t i = 0; i < BLOCK_SIZE; ++i) {
343 state[i].first = normalize_sparse_form(ctx, state[i].first);
344 }
345 }
346
347 sub_bytes(ctx, state);
348 shift_rows<Builder>(state);
349 add_round_key<Builder>(state, sparse_round_key, NUM_ROUNDS);
350}
351
352template <typename Builder>
354 const field_t<Builder>& iv,
355 const field_t<Builder>& key)
356{
357 // Check if all inputs are constants
358 bool all_constants = key.is_constant() && iv.is_constant();
359 for (const auto& input_block : input) {
360 if (!input_block.is_constant()) {
361 all_constants = false;
362 break;
363 }
364 }
365
366 if (all_constants) {
367 // Compute result directly using native crypto implementation
369 std::vector<uint8_t> key_bytes(16);
370 std::vector<uint8_t> iv_bytes(16);
371 std::vector<uint8_t> input_bytes(input.size() * 16);
372
373 // Convert key to bytes
374 uint256_t key_value = key.get_value();
375 for (size_t i = 0; i < 16; ++i) {
376 key_bytes[15 - i] = static_cast<uint8_t>((key_value >> (i * 8)) & 0xFF);
377 }
378
379 // Convert IV to bytes
380 uint256_t iv_value = iv.get_value();
381 for (size_t i = 0; i < 16; ++i) {
382 iv_bytes[15 - i] = static_cast<uint8_t>((iv_value >> (i * 8)) & 0xFF);
383 }
384
385 // Convert input blocks to bytes
386 for (size_t block_idx = 0; block_idx < input.size(); ++block_idx) {
387 uint256_t block_value = input[block_idx].get_value();
388 for (size_t i = 0; i < 16; ++i) {
389 input_bytes[block_idx * 16 + 15 - i] = static_cast<uint8_t>((block_value >> (i * 8)) & 0xFF);
390 }
391 }
392
393 // Run native AES encryption
394 crypto::aes128_encrypt_buffer_cbc(input_bytes.data(), iv_bytes.data(), key_bytes.data(), input_bytes.size());
395
396 // Convert result back to field elements
397 for (size_t block_idx = 0; block_idx < input.size(); ++block_idx) {
398 uint256_t result_value = 0;
399 for (size_t i = 0; i < 16; ++i) {
400 result_value <<= 8;
401 result_value += input_bytes[block_idx * 16 + i];
402 }
403 result.push_back(field_t<Builder>(result_value));
404 }
405
406 return result;
407 }
408
409 // Find a valid context from any of the inputs
410 Builder* ctx = nullptr;
411 if (!key.is_constant()) {
412 ctx = key.get_context();
413 } else if (!iv.is_constant()) {
414 ctx = iv.get_context();
415 } else {
416 for (const auto& input_block : input) {
417 if (!input_block.is_constant()) {
418 ctx = input_block.get_context();
419 break;
420 }
421 }
422 }
423
424 BB_ASSERT(ctx);
425
426 auto round_key = expand_key(ctx, key);
427 key_span<Builder> round_key_span{ round_key };
428
429 const size_t num_blocks = input.size();
430
431 std::vector<byte_pair<Builder>> sparse_state;
432 for (size_t i = 0; i < num_blocks; ++i) {
433 auto bytes = convert_into_sparse_bytes(ctx, input[i]);
434 for (const auto& byte : bytes) {
435 sparse_state.push_back({ byte, field_t(ctx, fr(0)) });
436 }
437 }
438
439 auto sparse_iv = convert_into_sparse_bytes(ctx, iv);
440 block_span<Builder> sparse_iv_span{ sparse_iv };
441
442 for (size_t i = 0; i < num_blocks; ++i) {
443 state_span<Builder> round_state{ &sparse_state[i * BLOCK_SIZE], BLOCK_SIZE };
444 xor_with_iv<Builder>(round_state, sparse_iv_span);
445 aes128_cipher(ctx, round_state, round_key_span);
446
447 for (size_t j = 0; j < BLOCK_SIZE; ++j) {
448 sparse_iv[j] = round_state[j].first;
449 }
450 }
451
452 std::vector<field_t<Builder>> sparse_output;
453 for (auto& element : sparse_state) {
454 sparse_output.push_back(normalize_sparse_form(ctx, element.first));
455 }
456
458 for (size_t i = 0; i < num_blocks; ++i) {
459 block_span<Builder> output_span{ &sparse_output[i * BLOCK_SIZE], BLOCK_SIZE };
460 output.push_back(convert_from_sparse_bytes(ctx, output_span));
461 }
462 return output;
463}
464// Explicit template instantiations
465#define INSTANTIATE_AES128_TEMPLATES(Builder) \
466 template std::vector<field_t<Builder>> encrypt_buffer_cbc<Builder>( \
467 const std::vector<field_t<Builder>>&, const field_t<Builder>&, const field_t<Builder>&); \
468 template std::array<field_t<Builder>, BLOCK_SIZE> convert_into_sparse_bytes<Builder>(Builder*, \
469 const field_t<Builder>&); \
470 template field_t<Builder> convert_from_sparse_bytes<Builder>(Builder*, std::span<field_t<Builder>, BLOCK_SIZE>)
471
474
475} // namespace bb::stdlib::aes128
#define BB_ASSERT(expression,...)
Definition assert.hpp:70
Builder * get_context() const
Definition field.hpp:431
bb::fr get_value() const
Given a := *this, compute its value given by a.v * a.mul + a.add.
Definition field.cpp:836
void convert_constant_to_fixed_witness(Builder *ctx)
Definition field.hpp:456
bool is_constant() const
Definition field.hpp:441
static plookup::ReadData< field_pt > get_lookup_accumulators(const plookup::MultiTableId id, const field_pt &key_a, const field_pt &key_b=0, const bool is_2_to_1_lookup=false)
Definition plookup.cpp:19
static field_pt read_from_1_to_2_table(const plookup::MultiTableId id, const field_pt &key_a)
Definition plookup.cpp:89
static std::pair< field_pt, field_pt > read_pair_from_table(const plookup::MultiTableId id, const field_pt &key)
Definition plookup.cpp:70
void aes128_encrypt_buffer_cbc(uint8_t *buffer, uint8_t *iv, const uint8_t *key, const size_t length)
Definition aes128.cpp:232
@ AES_NORMALIZE
Definition types.hpp:98
byte_pair< Builder > apply_aes_sbox_map(Builder *, field_t< Builder > &input)
Definition aes128.cpp:35
std::span< byte_pair< Builder >, COLUMN_SIZE > column_span
Definition aes128.cpp:24
constexpr size_t NUM_ROUNDS
Definition aes128.hpp:23
field_t< Builder > normalize_sparse_form(Builder *, field_t< Builder > &byte)
Definition aes128.cpp:29
void sub_bytes(Builder *ctx, state_span< Builder > state_pairs)
Definition aes128.cpp:305
constexpr size_t BLOCK_SIZE
Definition aes128.hpp:21
std::array< field_t< Builder >, 16 > convert_into_sparse_bytes(Builder *ctx, const field_t< Builder > &block_data)
Converts a 128-bit block into 16 sparse-form bytes via AES_INPUT plookup table.
Definition aes128.cpp:41
field_t< Builder > convert_from_sparse_bytes(Builder *ctx, block_span< Builder > sparse_bytes)
Definition aes128.cpp:57
std::pair< field_t< Builder >, field_t< Builder > > byte_pair
Definition aes128.cpp:22
constexpr size_t COLUMN_SIZE
Definition aes128.hpp:24
std::span< field_t< Builder >, EXTENDED_KEY_LENGTH > key_span
Definition aes128.cpp:25
std::array< field_t< Builder >, EXTENDED_KEY_LENGTH > expand_key(Builder *ctx, const field_t< Builder > &key)
Expands a 128-bit AES key into the full key schedule (EXTENDED_KEY_LENGTH bytes / 11 round keys).
Definition aes128.cpp:109
void mix_column_and_add_round_key(column_span< Builder > column_pairs, key_span< Builder > round_key, size_t round, size_t column)
Performs MixColumns on a single column and adds the round key (FIPS 197, Sections 5....
Definition aes128.cpp:266
constexpr size_t EXTENDED_KEY_LENGTH
Definition aes128.hpp:22
void shift_rows(state_span< Builder > state)
The SHIFTROW() operation as in FIPS 197, Section 5.1.2.
Definition aes128.cpp:213
std::span< field_t< Builder >, BLOCK_SIZE > block_span
Definition aes128.cpp:26
void add_round_key(state_span< Builder > sparse_state, key_span< Builder > sparse_round_key, size_t round)
Definition aes128.cpp:313
void xor_with_iv(state_span< Builder > state, block_span< Builder > iv)
Definition aes128.cpp:323
std::vector< field_t< Builder > > encrypt_buffer_cbc(const std::vector< field_t< Builder > > &input, const field_t< Builder > &iv, const field_t< Builder > &key)
Main public interface: AES-128 CBC encryption.
Definition aes128.cpp:353
void aes128_cipher(Builder *ctx, state_span< Builder > state, key_span< Builder > sparse_round_key)
Definition aes128.cpp:331
void mix_columns_and_add_round_key(state_span< Builder > state_pairs, key_span< Builder > round_key, size_t round)
Definition aes128.cpp:297
std::span< byte_pair< Builder >, BLOCK_SIZE > state_span
Definition aes128.cpp:23
std::conditional_t< IsGoblinBigGroup< C, Fq, Fr, G >, element_goblin::goblin_element< C, goblin_field< C >, Fr, G >, element_default::element< C, Fq, Fr, G > > element
element wraps either element_default::element or element_goblin::goblin_element depending on parametr...
field< Bn254FrParams > fr
Definition fr.hpp:155
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
#define INSTANTIATE_AES128_TEMPLATES(Builder)
Definition aes128.cpp:465