1use std::{array, borrow::BorrowMut, ops::Range};
2
3use openvm_circuit_primitives::{
4 bitwise_op_lookup::BitwiseOperationLookupChip, encoder::Encoder,
5 utils::next_power_of_two_or_zero,
6};
7use openvm_stark_backend::{
8 p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*,
9};
10use sha2::{compress256, digest::generic_array::GenericArray};
11
12use super::{
13 big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, get_flag_pt_array,
14 maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, SHA256_DIGEST_WIDTH,
15 SHA256_HASH_WORDS, SHA256_ROUND_WIDTH,
16};
17use crate::{
18 big_sig0, big_sig1, ch, columns::Sha256DigestCols, limbs_into_u32, maj, small_sig0, small_sig1,
19 u32_into_bits_field, u32_into_u16s, SHA256_BLOCK_U8S, SHA256_H, SHA256_INVALID_CARRY_A,
20 SHA256_INVALID_CARRY_E, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROWS_PER_BLOCK,
21 SHA256_WORD_U16S, SHA256_WORD_U8S,
22};
23
24pub struct Sha256FillerHelper {
27 pub row_idx_encoder: Encoder,
28}
29
30impl Default for Sha256FillerHelper {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl Sha256FillerHelper {
41 pub fn new() -> Self {
42 Self {
43 row_idx_encoder: Encoder::new(18, 2, false),
44 }
45 }
46 pub fn get_block_hash(
49 prev_hash: &[u32; SHA256_HASH_WORDS],
50 input: [u8; SHA256_BLOCK_U8S],
51 ) -> [u32; SHA256_HASH_WORDS] {
52 let mut new_hash = *prev_hash;
53 let input_array = [GenericArray::from(input)];
54 compress256(&mut new_hash, &input_array);
55 new_hash
56 }
57
58 #[allow(clippy::too_many_arguments)]
66 pub fn generate_block_trace<F: PrimeField32>(
67 &self,
68 trace: &mut [F],
69 trace_width: usize,
70 trace_start_col: usize,
71 input: &[u32; SHA256_BLOCK_WORDS],
72 bitwise_lookup_chip: &BitwiseOperationLookupChip<8>,
73 prev_hash: &[u32; SHA256_HASH_WORDS],
74 is_last_block: bool,
75 global_block_idx: u32,
76 local_block_idx: u32,
77 ) {
78 #[cfg(debug_assertions)]
79 {
80 assert!(trace.len() == trace_width * SHA256_ROWS_PER_BLOCK);
81 assert!(trace_start_col + super::SHA256_WIDTH <= trace_width);
82 if local_block_idx == 0 {
83 assert!(*prev_hash == SHA256_H);
84 }
85 }
86 let get_range = |start: usize, len: usize| -> Range<usize> { start..start + len };
87 let mut message_schedule = [0u32; 64];
88 message_schedule[..input.len()].copy_from_slice(input);
89 let mut work_vars = *prev_hash;
90 for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() {
91 if i < 16 {
93 let cols: &mut Sha256RoundCols<F> =
94 row[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut();
95 cols.flags.is_round_row = F::ONE;
96 cols.flags.is_first_4_rows = if i < 4 { F::ONE } else { F::ZERO };
97 cols.flags.is_digest_row = F::ZERO;
98 cols.flags.is_last_block = F::from_bool(is_last_block);
99 cols.flags.row_idx =
100 get_flag_pt_array(&self.row_idx_encoder, i).map(F::from_canonical_u32);
101 cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx);
102 cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx);
103
104 if i < 4 {
106 for j in 0..SHA256_ROUNDS_PER_ROW {
107 cols.message_schedule.w[j] =
108 u32_into_bits_field::<F>(input[i * SHA256_ROUNDS_PER_ROW + j]);
109 }
110 }
111 else {
113 for j in 0..SHA256_ROUNDS_PER_ROW {
114 let idx = i * SHA256_ROUNDS_PER_ROW + j;
115 let nums: [u32; 4] = [
116 small_sig1(message_schedule[idx - 2]),
117 message_schedule[idx - 7],
118 small_sig0(message_schedule[idx - 15]),
119 message_schedule[idx - 16],
120 ];
121 let w: u32 = nums.iter().fold(0, |acc, &num| acc.wrapping_add(num));
122 cols.message_schedule.w[j] = u32_into_bits_field::<F>(w);
123
124 let nums_limbs = nums.map(u32_into_u16s);
125 let w_limbs = u32_into_u16s(w);
126
127 for k in 0..SHA256_WORD_U16S {
129 let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]);
130 if k > 0 {
131 sum += (cols.message_schedule.carry_or_buffer[j][k * 2 - 2]
132 + F::TWO * cols.message_schedule.carry_or_buffer[j][k * 2 - 1])
133 .as_canonical_u32();
134 }
135 let carry = (sum - w_limbs[k]) >> 16;
136 cols.message_schedule.carry_or_buffer[j][k * 2] =
137 F::from_canonical_u32(carry & 1);
138 cols.message_schedule.carry_or_buffer[j][k * 2 + 1] =
139 F::from_canonical_u32(carry >> 1);
140 }
141 message_schedule[idx] = w;
143 }
144 }
145 for j in 0..SHA256_ROUNDS_PER_ROW {
147 let t1 = [
149 work_vars[7],
150 big_sig1(work_vars[4]),
151 ch(work_vars[4], work_vars[5], work_vars[6]),
152 SHA256_K[i * SHA256_ROUNDS_PER_ROW + j],
153 limbs_into_u32(cols.message_schedule.w[j].map(|f| f.as_canonical_u32())),
154 ];
155 let t1_sum: u32 = t1.iter().fold(0, |acc, &num| acc.wrapping_add(num));
156
157 let t2 = [
159 big_sig0(work_vars[0]),
160 maj(work_vars[0], work_vars[1], work_vars[2]),
161 ];
162
163 let t2_sum: u32 = t2.iter().fold(0, |acc, &num| acc.wrapping_add(num));
164
165 let e = work_vars[3].wrapping_add(t1_sum);
167 cols.work_vars.e[j] = u32_into_bits_field::<F>(e);
168 let e_limbs = u32_into_u16s(e);
169 let a = t1_sum.wrapping_add(t2_sum);
171 cols.work_vars.a[j] = u32_into_bits_field::<F>(a);
172 let a_limbs = u32_into_u16s(a);
173 for k in 0..SHA256_WORD_U16S {
175 let t1_limb = t1.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]);
176 let t2_limb = t2.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]);
177
178 let mut e_limb = t1_limb + u32_into_u16s(work_vars[3])[k];
179 let mut a_limb = t1_limb + t2_limb;
180 if k > 0 {
181 a_limb += cols.work_vars.carry_a[j][k - 1].as_canonical_u32();
182 e_limb += cols.work_vars.carry_e[j][k - 1].as_canonical_u32();
183 }
184 let carry_a = (a_limb - a_limbs[k]) >> 16;
185 let carry_e = (e_limb - e_limbs[k]) >> 16;
186 cols.work_vars.carry_a[j][k] = F::from_canonical_u32(carry_a);
187 cols.work_vars.carry_e[j][k] = F::from_canonical_u32(carry_e);
188 bitwise_lookup_chip.request_range(carry_a, carry_e);
189 }
190
191 work_vars[7] = work_vars[6];
193 work_vars[6] = work_vars[5];
194 work_vars[5] = work_vars[4];
195 work_vars[4] = e;
196 work_vars[3] = work_vars[2];
197 work_vars[2] = work_vars[1];
198 work_vars[1] = work_vars[0];
199 work_vars[0] = a;
200 }
201
202 if i > 0 {
204 for j in 0..SHA256_ROUNDS_PER_ROW {
205 let idx = i * SHA256_ROUNDS_PER_ROW + j;
206 let w_4 = u32_into_u16s(message_schedule[idx - 4]);
207 let sig_0_w_3 = u32_into_u16s(small_sig0(message_schedule[idx - 3]));
208 cols.schedule_helper.intermed_4[j] =
209 array::from_fn(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k]));
210 if j < SHA256_ROUNDS_PER_ROW - 1 {
211 let w_3 = message_schedule[idx - 3];
212 cols.schedule_helper.w_3[j] =
213 u32_into_u16s(w_3).map(F::from_canonical_u32);
214 }
215 }
216 }
217 }
218 else {
220 let cols: &mut Sha256DigestCols<F> =
221 row[get_range(trace_start_col, SHA256_DIGEST_WIDTH)].borrow_mut();
222 for j in 0..SHA256_ROUNDS_PER_ROW - 1 {
223 let w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3];
224 cols.schedule_helper.w_3[j] = u32_into_u16s(w_3).map(F::from_canonical_u32);
225 }
226 cols.flags.is_round_row = F::ZERO;
227 cols.flags.is_first_4_rows = F::ZERO;
228 cols.flags.is_digest_row = F::ONE;
229 cols.flags.is_last_block = F::from_bool(is_last_block);
230 cols.flags.row_idx =
231 get_flag_pt_array(&self.row_idx_encoder, 16).map(F::from_canonical_u32);
232 cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx);
233
234 cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx);
235 let final_hash: [u32; SHA256_HASH_WORDS] =
236 array::from_fn(|i| work_vars[i].wrapping_add(prev_hash[i]));
237 let final_hash_limbs: [[u8; SHA256_WORD_U8S]; SHA256_HASH_WORDS] =
238 array::from_fn(|i| final_hash[i].to_le_bytes());
239 for word in final_hash_limbs.iter() {
243 for chunk in word.chunks(2) {
244 bitwise_lookup_chip.request_range(chunk[0] as u32, chunk[1] as u32);
245 }
246 }
247 cols.final_hash = array::from_fn(|i| {
248 array::from_fn(|j| F::from_canonical_u8(final_hash_limbs[i][j]))
249 });
250 cols.prev_hash = prev_hash.map(|f| u32_into_u16s(f).map(F::from_canonical_u32));
251 let hash = if is_last_block {
252 SHA256_H.map(u32_into_bits_field::<F>)
253 } else {
254 cols.final_hash
255 .map(|f| u32::from_le_bytes(f.map(|x| x.as_canonical_u32() as u8)))
256 .map(u32_into_bits_field::<F>)
257 };
258
259 for i in 0..SHA256_ROUNDS_PER_ROW {
260 cols.hash.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1];
261 cols.hash.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3];
262 }
263 }
264 }
265
266 for i in 0..SHA256_ROWS_PER_BLOCK - 1 {
267 let rows = &mut trace[i * trace_width..(i + 2) * trace_width];
268 let (local, next) = rows.split_at_mut(trace_width);
269 let local_cols: &mut Sha256RoundCols<F> =
270 local[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut();
271 let next_cols: &mut Sha256RoundCols<F> =
272 next[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut();
273 if i > 0 {
274 for j in 0..SHA256_ROUNDS_PER_ROW {
275 next_cols.schedule_helper.intermed_8[j] =
276 local_cols.schedule_helper.intermed_4[j];
277 if (2..SHA256_ROWS_PER_BLOCK - 3).contains(&i) {
278 next_cols.schedule_helper.intermed_12[j] =
279 local_cols.schedule_helper.intermed_8[j];
280 }
281 }
282 }
283 if i == SHA256_ROWS_PER_BLOCK - 2 {
284 Self::generate_carry_ae(local_cols, next_cols);
288 Self::generate_intermed_4(local_cols, next_cols);
291 }
292 if i <= 2 {
293 Self::generate_intermed_12(local_cols, next_cols);
297 }
298 }
299 }
300
301 pub fn generate_missing_cells<F: PrimeField32>(
308 &self,
309 trace: &mut [F],
310 trace_width: usize,
311 trace_start_col: usize,
312 ) {
313 let rows_15_17 = &mut trace[14 * trace_width..17 * trace_width];
315 let (row_15, row_16_17) = rows_15_17.split_at_mut(trace_width);
316 let (row_16, row_17) = row_16_17.split_at_mut(trace_width);
317 let cols_15: &mut Sha256RoundCols<F> =
318 row_15[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut();
319 let cols_16: &mut Sha256RoundCols<F> =
320 row_16[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut();
321 let cols_17: &mut Sha256RoundCols<F> =
322 row_17[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut();
323 Self::generate_intermed_12(cols_15, cols_16);
326 Self::generate_intermed_12(cols_16, cols_17);
329 Self::generate_intermed_4(cols_16, cols_17);
332 }
333
334 pub fn generate_default_row<F: PrimeField32>(
337 self: &Sha256FillerHelper,
338 cols: &mut Sha256RoundCols<F>,
339 ) {
340 cols.flags.row_idx =
341 get_flag_pt_array(&self.row_idx_encoder, 17).map(F::from_canonical_u32);
342
343 let hash = SHA256_H.map(u32_into_bits_field::<F>);
344
345 for i in 0..SHA256_ROUNDS_PER_ROW {
346 cols.work_vars.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1];
347 cols.work_vars.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3];
348 }
349
350 cols.work_vars.carry_a = array::from_fn(|i| {
351 array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_A[i][j]))
352 });
353 cols.work_vars.carry_e = array::from_fn(|i| {
354 array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_E[i][j]))
355 });
356 }
357
358 fn generate_carry_ae<F: PrimeField32>(
362 local_cols: &Sha256RoundCols<F>,
363 next_cols: &mut Sha256RoundCols<F>,
364 ) {
365 let a = [local_cols.work_vars.a, next_cols.work_vars.a].concat();
366 let e = [local_cols.work_vars.e, next_cols.work_vars.e].concat();
367 for i in 0..SHA256_ROUNDS_PER_ROW {
368 let cur_a = a[i + 4];
369 let sig_a = big_sig0_field::<F>(&a[i + 3]);
370 let maj_abc = maj_field::<F>(&a[i + 3], &a[i + 2], &a[i + 1]);
371 let d = a[i];
372 let cur_e = e[i + 4];
373 let sig_e = big_sig1_field::<F>(&e[i + 3]);
374 let ch_efg = ch_field::<F>(&e[i + 3], &e[i + 2], &e[i + 1]);
375 let h = e[i];
376
377 let t1 = [h, sig_e, ch_efg];
378 let t2 = [sig_a, maj_abc];
379 for j in 0..SHA256_WORD_U16S {
380 let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| {
381 acc + compose::<F>(&x[j * 16..(j + 1) * 16], 1)
382 });
383 let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| {
384 acc + compose::<F>(&x[j * 16..(j + 1) * 16], 1)
385 });
386 let d_limb = compose::<F>(&d[j * 16..(j + 1) * 16], 1);
387 let cur_a_limb = compose::<F>(&cur_a[j * 16..(j + 1) * 16], 1);
388 let cur_e_limb = compose::<F>(&cur_e[j * 16..(j + 1) * 16], 1);
389 let sum = d_limb
390 + t1_limb_sum
391 + if j == 0 {
392 F::ZERO
393 } else {
394 next_cols.work_vars.carry_e[i][j - 1]
395 }
396 - cur_e_limb;
397 let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse());
398
399 let sum = t1_limb_sum
400 + t2_limb_sum
401 + if j == 0 {
402 F::ZERO
403 } else {
404 next_cols.work_vars.carry_a[i][j - 1]
405 }
406 - cur_a_limb;
407 let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse());
408 next_cols.work_vars.carry_e[i][j] = carry_e;
409 next_cols.work_vars.carry_a[i][j] = carry_a;
410 }
411 }
412 }
413
414 fn generate_intermed_4<F: PrimeField32>(
416 local_cols: &Sha256RoundCols<F>,
417 next_cols: &mut Sha256RoundCols<F>,
418 ) {
419 let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat();
420 let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w
421 .iter()
422 .map(|x| array::from_fn(|i| compose::<F>(&x[i * 16..(i + 1) * 16], 1)))
423 .collect();
424 for i in 0..SHA256_ROUNDS_PER_ROW {
425 let sig_w = small_sig0_field::<F>(&w[i + 1]);
426 let sig_w_limbs: [F; SHA256_WORD_U16S] =
427 array::from_fn(|j| compose::<F>(&sig_w[j * 16..(j + 1) * 16], 1));
428 for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() {
429 next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb;
430 }
431 }
432 }
433
434 fn generate_intermed_12<F: PrimeField32>(
436 local_cols: &mut Sha256RoundCols<F>,
437 next_cols: &Sha256RoundCols<F>,
438 ) {
439 let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat();
440 let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w
441 .iter()
442 .map(|x| array::from_fn(|i| compose::<F>(&x[i * 16..(i + 1) * 16], 1)))
443 .collect();
444 for i in 0..SHA256_ROUNDS_PER_ROW {
445 let sig_w_2: [F; SHA256_WORD_U16S] = array::from_fn(|j| {
447 compose::<F>(&small_sig1_field::<F>(&w[i + 2])[j * 16..(j + 1) * 16], 1)
448 });
449 let w_7 = if i < 3 {
451 local_cols.schedule_helper.w_3[i]
452 } else {
453 w_limbs[i - 3]
454 };
455 let w_cur = w_limbs[i + 4];
457 for j in 0..SHA256_WORD_U16S {
458 let carry = next_cols.message_schedule.carry_or_buffer[i][j * 2]
459 + F::TWO * next_cols.message_schedule.carry_or_buffer[i][j * 2 + 1];
460 let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j]
461 + if j > 0 {
462 next_cols.message_schedule.carry_or_buffer[i][j * 2 - 2]
463 + F::from_canonical_u32(2)
464 * next_cols.message_schedule.carry_or_buffer[i][j * 2 - 1]
465 } else {
466 F::ZERO
467 };
468 local_cols.schedule_helper.intermed_12[i][j] = -sum;
469 }
470 }
471 }
472}
473
474pub fn generate_trace<F: PrimeField32>(
477 step: &Sha256FillerHelper,
478 bitwise_lookup_chip: &BitwiseOperationLookupChip<8>,
479 width: usize,
480 records: Vec<([u8; SHA256_BLOCK_U8S], bool)>,
481) -> RowMajorMatrix<F> {
482 let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK;
483 let height = next_power_of_two_or_zero(non_padded_height);
484 let mut values = F::zero_vec(height * width);
485
486 struct BlockContext {
487 prev_hash: [u32; 8],
488 local_block_idx: u32,
489 global_block_idx: u32,
490 input: [u8; SHA256_BLOCK_U8S],
491 is_last_block: bool,
492 }
493 let mut block_ctx: Vec<BlockContext> = Vec::with_capacity(records.len());
494 let mut prev_hash = SHA256_H;
495 let mut local_block_idx = 0;
496 let mut global_block_idx = 1;
497 for (input, is_last_block) in records {
498 block_ctx.push(BlockContext {
499 prev_hash,
500 local_block_idx,
501 global_block_idx,
502 input,
503 is_last_block,
504 });
505 global_block_idx += 1;
506 if is_last_block {
507 local_block_idx = 0;
508 prev_hash = SHA256_H;
509 } else {
510 local_block_idx += 1;
511 prev_hash = Sha256FillerHelper::get_block_hash(&prev_hash, input);
512 }
513 }
514 values
516 .par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK)
517 .zip(block_ctx)
518 .for_each(|(block, ctx)| {
519 let BlockContext {
520 prev_hash,
521 local_block_idx,
522 global_block_idx,
523 input,
524 is_last_block,
525 } = ctx;
526 let input_words = array::from_fn(|i| {
527 limbs_into_u32::<SHA256_WORD_U8S>(array::from_fn(|j| {
528 input[(i + 1) * SHA256_WORD_U8S - j - 1] as u32
529 }))
530 });
531 step.generate_block_trace(
532 block,
533 width,
534 0,
535 &input_words,
536 bitwise_lookup_chip,
537 &prev_hash,
538 is_last_block,
539 global_block_idx,
540 local_block_idx,
541 );
542 });
543 values[width * non_padded_height..]
545 .par_chunks_mut(width)
546 .for_each(|row| {
547 let cols: &mut Sha256RoundCols<F> = row.borrow_mut();
548 step.generate_default_row(cols);
549 });
550 values[width..]
552 .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK)
553 .take(non_padded_height / SHA256_ROWS_PER_BLOCK)
554 .for_each(|chunk| {
555 step.generate_missing_cells(chunk, width, 0);
556 });
557 RowMajorMatrix::new(values, width)
558}