1use std::{array, borrow::BorrowMut, ops::Range};
2
3use openvm_circuit_primitives::{
4 bitwise_op_lookup::BitwiseOperationLookupChip, encoder::Encoder,
5 utils::next_power_of_two_or_zero,
6};
7use openvm_stark_backend::{
8 p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*,
9};
10use sha2::{compress256, digest::generic_array::GenericArray};
11
12use super::{
13 big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, get_flag_pt_array,
14 maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, SHA256_DIGEST_WIDTH,
15 SHA256_HASH_WORDS, SHA256_ROUND_WIDTH,
16};
17use crate::{
18 big_sig0, big_sig1, ch, columns::Sha256DigestCols, limbs_into_u32, maj, small_sig0, small_sig1,
19 u32_into_bits_field, u32_into_u16s, SHA256_BLOCK_U8S, SHA256_H, SHA256_INVALID_CARRY_A,
20 SHA256_INVALID_CARRY_E, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROWS_PER_BLOCK,
21 SHA256_WORD_U16S, SHA256_WORD_U8S,
22};
23
24pub struct Sha256FillerHelper {
27 pub row_idx_encoder: Encoder,
28}
29
30impl Default for Sha256FillerHelper {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl Sha256FillerHelper {
41 pub fn new() -> Self {
42 Self {
43 row_idx_encoder: Encoder::new(18, 2, false),
44 }
45 }
46 pub fn get_block_hash(
49 prev_hash: &[u32; SHA256_HASH_WORDS],
50 input: [u8; SHA256_BLOCK_U8S],
51 ) -> [u32; SHA256_HASH_WORDS] {
52 let mut new_hash = *prev_hash;
53 let input_array = [GenericArray::from(input)];
54 compress256(&mut new_hash, &input_array);
55 new_hash
56 }
57
58 #[allow(clippy::too_many_arguments)]
66 pub fn generate_block_trace<F: PrimeField32>(
67 &self,
68 trace: &mut [F],
69 trace_width: usize,
70 trace_start_col: usize,
71 input: &[u32; SHA256_BLOCK_WORDS],
72 bitwise_lookup_chip: &BitwiseOperationLookupChip<8>,
73 prev_hash: &[u32; SHA256_HASH_WORDS],
74 is_last_block: bool,
75 global_block_idx: u32,
76 local_block_idx: u32,
77 ) {
78 #[cfg(debug_assertions)]
79 {
80 assert!(trace.len() == trace_width * SHA256_ROWS_PER_BLOCK);
81 assert!(trace_start_col + super::SHA256_WIDTH <= trace_width);
82 if local_block_idx == 0 {
83 assert!(*prev_hash == SHA256_H);
84 }
85 }
86 let get_range = |start: usize, len: usize| -> Range<usize> { start..start + len };
87 let mut message_schedule = [0u32; 64];
88 message_schedule[..input.len()].copy_from_slice(input);
89 let mut work_vars = *prev_hash;
90 for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() {
91 if i < 16 {
93 let cols: &mut Sha256RoundCols<F> =
94 row[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut();
95 cols.flags.is_round_row = F::ONE;
96 cols.flags.is_first_4_rows = if i < 4 { F::ONE } else { F::ZERO };
97 cols.flags.is_digest_row = F::ZERO;
98 cols.flags.is_last_block = F::from_bool(is_last_block);
99 cols.flags.row_idx = get_flag_pt_array(&self.row_idx_encoder, i).map(F::from_u32);
100 cols.flags.global_block_idx = F::from_u32(global_block_idx);
101 cols.flags.local_block_idx = F::from_u32(local_block_idx);
102
103 if i < 4 {
105 for j in 0..SHA256_ROUNDS_PER_ROW {
106 cols.message_schedule.w[j] =
107 u32_into_bits_field::<F>(input[i * SHA256_ROUNDS_PER_ROW + j]);
108 }
109 }
110 else {
112 for j in 0..SHA256_ROUNDS_PER_ROW {
113 let idx = i * SHA256_ROUNDS_PER_ROW + j;
114 let nums: [u32; 4] = [
115 small_sig1(message_schedule[idx - 2]),
116 message_schedule[idx - 7],
117 small_sig0(message_schedule[idx - 15]),
118 message_schedule[idx - 16],
119 ];
120 let w: u32 = nums.iter().fold(0, |acc, &num| acc.wrapping_add(num));
121 cols.message_schedule.w[j] = u32_into_bits_field::<F>(w);
122
123 let nums_limbs = nums.map(u32_into_u16s);
124 let w_limbs = u32_into_u16s(w);
125
126 for k in 0..SHA256_WORD_U16S {
128 let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]);
129 if k > 0 {
130 sum += (cols.message_schedule.carry_or_buffer[j][k * 2 - 2]
131 + F::TWO * cols.message_schedule.carry_or_buffer[j][k * 2 - 1])
132 .as_canonical_u32();
133 }
134 let carry = (sum - w_limbs[k]) >> 16;
135 cols.message_schedule.carry_or_buffer[j][k * 2] =
136 F::from_u32(carry & 1);
137 cols.message_schedule.carry_or_buffer[j][k * 2 + 1] =
138 F::from_u32(carry >> 1);
139 }
140 message_schedule[idx] = w;
142 }
143 }
144 for j in 0..SHA256_ROUNDS_PER_ROW {
146 let t1 = [
148 work_vars[7],
149 big_sig1(work_vars[4]),
150 ch(work_vars[4], work_vars[5], work_vars[6]),
151 SHA256_K[i * SHA256_ROUNDS_PER_ROW + j],
152 limbs_into_u32(cols.message_schedule.w[j].map(|f| f.as_canonical_u32())),
153 ];
154 let t1_sum: u32 = t1.iter().fold(0, |acc, &num| acc.wrapping_add(num));
155
156 let t2 = [
158 big_sig0(work_vars[0]),
159 maj(work_vars[0], work_vars[1], work_vars[2]),
160 ];
161
162 let t2_sum: u32 = t2.iter().fold(0, |acc, &num| acc.wrapping_add(num));
163
164 let e = work_vars[3].wrapping_add(t1_sum);
166 cols.work_vars.e[j] = u32_into_bits_field::<F>(e);
167 let e_limbs = u32_into_u16s(e);
168 let a = t1_sum.wrapping_add(t2_sum);
170 cols.work_vars.a[j] = u32_into_bits_field::<F>(a);
171 let a_limbs = u32_into_u16s(a);
172 for k in 0..SHA256_WORD_U16S {
174 let t1_limb = t1.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]);
175 let t2_limb = t2.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]);
176
177 let mut e_limb = t1_limb + u32_into_u16s(work_vars[3])[k];
178 let mut a_limb = t1_limb + t2_limb;
179 if k > 0 {
180 a_limb += cols.work_vars.carry_a[j][k - 1].as_canonical_u32();
181 e_limb += cols.work_vars.carry_e[j][k - 1].as_canonical_u32();
182 }
183 let carry_a = (a_limb - a_limbs[k]) >> 16;
184 let carry_e = (e_limb - e_limbs[k]) >> 16;
185 cols.work_vars.carry_a[j][k] = F::from_u32(carry_a);
186 cols.work_vars.carry_e[j][k] = F::from_u32(carry_e);
187 bitwise_lookup_chip.request_range(carry_a, carry_e);
188 }
189
190 work_vars[7] = work_vars[6];
192 work_vars[6] = work_vars[5];
193 work_vars[5] = work_vars[4];
194 work_vars[4] = e;
195 work_vars[3] = work_vars[2];
196 work_vars[2] = work_vars[1];
197 work_vars[1] = work_vars[0];
198 work_vars[0] = a;
199 }
200
201 if i > 0 {
203 for j in 0..SHA256_ROUNDS_PER_ROW {
204 let idx = i * SHA256_ROUNDS_PER_ROW + j;
205 let w_4 = u32_into_u16s(message_schedule[idx - 4]);
206 let sig_0_w_3 = u32_into_u16s(small_sig0(message_schedule[idx - 3]));
207 cols.schedule_helper.intermed_4[j] =
208 array::from_fn(|k| F::from_u32(w_4[k] + sig_0_w_3[k]));
209 if j < SHA256_ROUNDS_PER_ROW - 1 {
210 let w_3 = message_schedule[idx - 3];
211 cols.schedule_helper.w_3[j] = u32_into_u16s(w_3).map(F::from_u32);
212 }
213 }
214 }
215 }
216 else {
218 let cols: &mut Sha256DigestCols<F> =
219 row[get_range(trace_start_col, SHA256_DIGEST_WIDTH)].borrow_mut();
220 for j in 0..SHA256_ROUNDS_PER_ROW - 1 {
221 let w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3];
222 cols.schedule_helper.w_3[j] = u32_into_u16s(w_3).map(F::from_u32);
223 }
224 cols.flags.is_round_row = F::ZERO;
225 cols.flags.is_first_4_rows = F::ZERO;
226 cols.flags.is_digest_row = F::ONE;
227 cols.flags.is_last_block = F::from_bool(is_last_block);
228 cols.flags.row_idx = get_flag_pt_array(&self.row_idx_encoder, 16).map(F::from_u32);
229 cols.flags.global_block_idx = F::from_u32(global_block_idx);
230
231 cols.flags.local_block_idx = F::from_u32(local_block_idx);
232 let final_hash: [u32; SHA256_HASH_WORDS] =
233 array::from_fn(|i| work_vars[i].wrapping_add(prev_hash[i]));
234 let final_hash_limbs: [[u8; SHA256_WORD_U8S]; SHA256_HASH_WORDS] =
235 array::from_fn(|i| final_hash[i].to_le_bytes());
236 for word in final_hash_limbs.iter() {
240 for chunk in word.chunks(2) {
241 bitwise_lookup_chip.request_range(chunk[0] as u32, chunk[1] as u32);
242 }
243 }
244 cols.final_hash =
245 array::from_fn(|i| array::from_fn(|j| F::from_u8(final_hash_limbs[i][j])));
246 cols.prev_hash = prev_hash.map(|f| u32_into_u16s(f).map(F::from_u32));
247 let hash = if is_last_block {
248 SHA256_H.map(u32_into_bits_field::<F>)
249 } else {
250 cols.final_hash
251 .map(|f| u32::from_le_bytes(f.map(|x| x.as_canonical_u32() as u8)))
252 .map(u32_into_bits_field::<F>)
253 };
254
255 for i in 0..SHA256_ROUNDS_PER_ROW {
256 cols.hash.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1];
257 cols.hash.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3];
258 }
259 }
260 }
261
262 for i in 0..SHA256_ROWS_PER_BLOCK - 1 {
263 let rows = &mut trace[i * trace_width..(i + 2) * trace_width];
264 let (local, next) = rows.split_at_mut(trace_width);
265 let local_cols: &mut Sha256RoundCols<F> =
266 local[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut();
267 let next_cols: &mut Sha256RoundCols<F> =
268 next[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut();
269 if i > 0 {
270 for j in 0..SHA256_ROUNDS_PER_ROW {
271 next_cols.schedule_helper.intermed_8[j] =
272 local_cols.schedule_helper.intermed_4[j];
273 if (2..SHA256_ROWS_PER_BLOCK - 3).contains(&i) {
274 next_cols.schedule_helper.intermed_12[j] =
275 local_cols.schedule_helper.intermed_8[j];
276 }
277 }
278 }
279 if i == SHA256_ROWS_PER_BLOCK - 2 {
280 Self::generate_carry_ae(local_cols, next_cols);
284 Self::generate_intermed_4(local_cols, next_cols);
287 }
288 if i <= 2 {
289 Self::generate_intermed_12(local_cols, next_cols);
293 }
294 }
295 }
296
297 pub fn generate_missing_cells<F: PrimeField32>(
304 &self,
305 trace: &mut [F],
306 trace_width: usize,
307 trace_start_col: usize,
308 ) {
309 let rows_15_17 = &mut trace[14 * trace_width..17 * trace_width];
311 let (row_15, row_16_17) = rows_15_17.split_at_mut(trace_width);
312 let (row_16, row_17) = row_16_17.split_at_mut(trace_width);
313 let cols_15: &mut Sha256RoundCols<F> =
314 row_15[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut();
315 let cols_16: &mut Sha256RoundCols<F> =
316 row_16[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut();
317 let cols_17: &mut Sha256RoundCols<F> =
318 row_17[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut();
319 Self::generate_intermed_12(cols_15, cols_16);
322 Self::generate_intermed_12(cols_16, cols_17);
325 Self::generate_intermed_4(cols_16, cols_17);
328 }
329
330 pub fn generate_default_row<F: PrimeField32>(
333 self: &Sha256FillerHelper,
334 cols: &mut Sha256RoundCols<F>,
335 ) {
336 cols.flags.row_idx = get_flag_pt_array(&self.row_idx_encoder, 17).map(F::from_u32);
337
338 let hash = SHA256_H.map(u32_into_bits_field::<F>);
339
340 for i in 0..SHA256_ROUNDS_PER_ROW {
341 cols.work_vars.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1];
342 cols.work_vars.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3];
343 }
344
345 cols.work_vars.carry_a =
346 array::from_fn(|i| array::from_fn(|j| F::from_u32(SHA256_INVALID_CARRY_A[i][j])));
347 cols.work_vars.carry_e =
348 array::from_fn(|i| array::from_fn(|j| F::from_u32(SHA256_INVALID_CARRY_E[i][j])));
349 }
350
351 fn generate_carry_ae<F: PrimeField32>(
355 local_cols: &Sha256RoundCols<F>,
356 next_cols: &mut Sha256RoundCols<F>,
357 ) {
358 let a = [local_cols.work_vars.a, next_cols.work_vars.a].concat();
359 let e = [local_cols.work_vars.e, next_cols.work_vars.e].concat();
360 for i in 0..SHA256_ROUNDS_PER_ROW {
361 let cur_a = a[i + 4];
362 let sig_a = big_sig0_field::<F>(&a[i + 3]);
363 let maj_abc = maj_field::<F>(&a[i + 3], &a[i + 2], &a[i + 1]);
364 let d = a[i];
365 let cur_e = e[i + 4];
366 let sig_e = big_sig1_field::<F>(&e[i + 3]);
367 let ch_efg = ch_field::<F>(&e[i + 3], &e[i + 2], &e[i + 1]);
368 let h = e[i];
369
370 let t1 = [h, sig_e, ch_efg];
371 let t2 = [sig_a, maj_abc];
372 for j in 0..SHA256_WORD_U16S {
373 let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| {
374 acc + compose::<F>(&x[j * 16..(j + 1) * 16], 1)
375 });
376 let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| {
377 acc + compose::<F>(&x[j * 16..(j + 1) * 16], 1)
378 });
379 let d_limb = compose::<F>(&d[j * 16..(j + 1) * 16], 1);
380 let cur_a_limb = compose::<F>(&cur_a[j * 16..(j + 1) * 16], 1);
381 let cur_e_limb = compose::<F>(&cur_e[j * 16..(j + 1) * 16], 1);
382 let sum = d_limb
383 + t1_limb_sum
384 + if j == 0 {
385 F::ZERO
386 } else {
387 next_cols.work_vars.carry_e[i][j - 1]
388 }
389 - cur_e_limb;
390 let carry_e = sum * (F::from_u32(1 << 16).inverse());
391
392 let sum = t1_limb_sum
393 + t2_limb_sum
394 + if j == 0 {
395 F::ZERO
396 } else {
397 next_cols.work_vars.carry_a[i][j - 1]
398 }
399 - cur_a_limb;
400 let carry_a = sum * (F::from_u32(1 << 16).inverse());
401 next_cols.work_vars.carry_e[i][j] = carry_e;
402 next_cols.work_vars.carry_a[i][j] = carry_a;
403 }
404 }
405 }
406
407 fn generate_intermed_4<F: PrimeField32>(
409 local_cols: &Sha256RoundCols<F>,
410 next_cols: &mut Sha256RoundCols<F>,
411 ) {
412 let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat();
413 let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w
414 .iter()
415 .map(|x| array::from_fn(|i| compose::<F>(&x[i * 16..(i + 1) * 16], 1)))
416 .collect();
417 for i in 0..SHA256_ROUNDS_PER_ROW {
418 let sig_w = small_sig0_field::<F>(&w[i + 1]);
419 let sig_w_limbs: [F; SHA256_WORD_U16S] =
420 array::from_fn(|j| compose::<F>(&sig_w[j * 16..(j + 1) * 16], 1));
421 for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() {
422 next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb;
423 }
424 }
425 }
426
427 fn generate_intermed_12<F: PrimeField32>(
429 local_cols: &mut Sha256RoundCols<F>,
430 next_cols: &Sha256RoundCols<F>,
431 ) {
432 let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat();
433 let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w
434 .iter()
435 .map(|x| array::from_fn(|i| compose::<F>(&x[i * 16..(i + 1) * 16], 1)))
436 .collect();
437 for i in 0..SHA256_ROUNDS_PER_ROW {
438 let sig_w_2: [F; SHA256_WORD_U16S] = array::from_fn(|j| {
440 compose::<F>(&small_sig1_field::<F>(&w[i + 2])[j * 16..(j + 1) * 16], 1)
441 });
442 let w_7 = if i < 3 {
444 local_cols.schedule_helper.w_3[i]
445 } else {
446 w_limbs[i - 3]
447 };
448 let w_cur = w_limbs[i + 4];
450 for j in 0..SHA256_WORD_U16S {
451 let carry = next_cols.message_schedule.carry_or_buffer[i][j * 2]
452 + F::TWO * next_cols.message_schedule.carry_or_buffer[i][j * 2 + 1];
453 let sum = sig_w_2[j] + w_7[j] - carry * F::from_u32(1 << 16) - w_cur[j]
454 + if j > 0 {
455 next_cols.message_schedule.carry_or_buffer[i][j * 2 - 2]
456 + F::from_u32(2)
457 * next_cols.message_schedule.carry_or_buffer[i][j * 2 - 1]
458 } else {
459 F::ZERO
460 };
461 local_cols.schedule_helper.intermed_12[i][j] = -sum;
462 }
463 }
464 }
465}
466
467pub fn generate_trace<F: PrimeField32>(
470 step: &Sha256FillerHelper,
471 bitwise_lookup_chip: &BitwiseOperationLookupChip<8>,
472 width: usize,
473 records: Vec<([u8; SHA256_BLOCK_U8S], bool)>,
474) -> RowMajorMatrix<F> {
475 let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK;
476 let height = next_power_of_two_or_zero(non_padded_height);
477 let mut values = F::zero_vec(height * width);
478
479 struct BlockContext {
480 prev_hash: [u32; 8],
481 local_block_idx: u32,
482 global_block_idx: u32,
483 input: [u8; SHA256_BLOCK_U8S],
484 is_last_block: bool,
485 }
486 let mut block_ctx: Vec<BlockContext> = Vec::with_capacity(records.len());
487 let mut prev_hash = SHA256_H;
488 let mut local_block_idx = 0;
489 let mut global_block_idx = 1;
490 for (input, is_last_block) in records {
491 block_ctx.push(BlockContext {
492 prev_hash,
493 local_block_idx,
494 global_block_idx,
495 input,
496 is_last_block,
497 });
498 global_block_idx += 1;
499 if is_last_block {
500 local_block_idx = 0;
501 prev_hash = SHA256_H;
502 } else {
503 local_block_idx += 1;
504 prev_hash = Sha256FillerHelper::get_block_hash(&prev_hash, input);
505 }
506 }
507 values
509 .par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK)
510 .zip(block_ctx)
511 .for_each(|(block, ctx)| {
512 let BlockContext {
513 prev_hash,
514 local_block_idx,
515 global_block_idx,
516 input,
517 is_last_block,
518 } = ctx;
519 let input_words = array::from_fn(|i| {
520 limbs_into_u32::<SHA256_WORD_U8S>(array::from_fn(|j| {
521 input[(i + 1) * SHA256_WORD_U8S - j - 1] as u32
522 }))
523 });
524 step.generate_block_trace(
525 block,
526 width,
527 0,
528 &input_words,
529 bitwise_lookup_chip,
530 &prev_hash,
531 is_last_block,
532 global_block_idx,
533 local_block_idx,
534 );
535 });
536 values[width * non_padded_height..]
538 .par_chunks_mut(width)
539 .for_each(|row| {
540 let cols: &mut Sha256RoundCols<F> = row.borrow_mut();
541 step.generate_default_row(cols);
542 });
543 values[width..]
545 .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK)
546 .take(non_padded_height / SHA256_ROWS_PER_BLOCK)
547 .for_each(|chunk| {
548 step.generate_missing_cells(chunk, width, 0);
549 });
550 RowMajorMatrix::new(values, width)
551}