1use std::{array, borrow::Borrow, cmp::max, iter::once};
2
3use openvm_circuit_primitives::{
4 bitwise_op_lookup::BitwiseOperationLookupBus,
5 encoder::Encoder,
6 utils::{not, select},
7 SubAir,
8};
9use openvm_stark_backend::{
10 interaction::{BusIndex, InteractionBuilder, PermutationCheckBus},
11 p3_air::{AirBuilder, BaseAir},
12 p3_field::{Field, FieldAlgebra},
13 p3_matrix::Matrix,
14};
15
16use super::{
17 big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field,
18 small_sig1_field, u32_into_limbs, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH,
19 SHA256_H, SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH,
20 SHA256_WORD_BITS, SHA256_WORD_U16S, SHA256_WORD_U8S,
21};
22use crate::constraint_word_addition;
23
24#[derive(Clone, Debug)]
26pub struct Sha256Air {
27 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
28 pub row_idx_encoder: Encoder,
29 bus: PermutationCheckBus,
31}
32
33impl Sha256Air {
34 pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: BusIndex) -> Self {
35 Self {
36 bitwise_lookup_bus,
37 row_idx_encoder: Encoder::new(18, 2, false),
38 bus: PermutationCheckBus::new(self_bus_idx),
39 }
40 }
41}
42
43impl<F> BaseAir<F> for Sha256Air {
44 fn width(&self) -> usize {
45 max(
46 Sha256RoundCols::<F>::width(),
47 Sha256DigestCols::<F>::width(),
48 )
49 }
50}
51
52impl<AB: InteractionBuilder> SubAir<AB> for Sha256Air {
53 type AirContext<'a>
55 = usize
56 where
57 Self: 'a,
58 AB: 'a,
59 <AB as AirBuilder>::Var: 'a,
60 <AB as AirBuilder>::Expr: 'a;
61
62 fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>)
63 where
64 <AB as AirBuilder>::Var: 'a,
65 <AB as AirBuilder>::Expr: 'a,
66 {
67 self.eval_row(builder, start_col);
68 self.eval_transitions(builder, start_col);
69 }
70}
71
72impl Sha256Air {
73 fn eval_row<AB: InteractionBuilder>(&self, builder: &mut AB, start_col: usize) {
76 let main = builder.main();
77 let local = main.row_slice(0);
78
79 let local_cols: &Sha256DigestCols<AB::Var> =
82 local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow();
83 let flags = &local_cols.flags;
84 builder.assert_bool(flags.is_round_row);
85 builder.assert_bool(flags.is_first_4_rows);
86 builder.assert_bool(flags.is_digest_row);
87 builder.assert_bool(flags.is_round_row + flags.is_digest_row);
88 builder.assert_bool(flags.is_last_block);
89
90 self.row_idx_encoder
91 .eval(builder, &local_cols.flags.row_idx);
92 builder.assert_one(
93 self.row_idx_encoder
94 .contains_flag_range::<AB>(&local_cols.flags.row_idx, 0..=17),
95 );
96 builder.assert_eq(
97 self.row_idx_encoder
98 .contains_flag_range::<AB>(&local_cols.flags.row_idx, 0..=3),
99 flags.is_first_4_rows,
100 );
101 builder.assert_eq(
102 self.row_idx_encoder
103 .contains_flag_range::<AB>(&local_cols.flags.row_idx, 0..=15),
104 flags.is_round_row,
105 );
106 builder.assert_eq(
107 self.row_idx_encoder
108 .contains_flag::<AB>(&local_cols.flags.row_idx, &[16]),
109 flags.is_digest_row,
110 );
111 builder.assert_eq(
113 self.row_idx_encoder
114 .contains_flag::<AB>(&local_cols.flags.row_idx, &[17]),
115 flags.is_padding_row(),
116 );
117
118 for i in 0..SHA256_ROUNDS_PER_ROW {
121 for j in 0..SHA256_WORD_BITS {
122 builder.assert_bool(local_cols.hash.a[i][j]);
123 builder.assert_bool(local_cols.hash.e[i][j]);
124 }
125 }
126 }
127
128 fn eval_digest_row<AB: InteractionBuilder>(
133 &self,
134 builder: &mut AB,
135 local: &Sha256RoundCols<AB::Var>,
136 next: &Sha256DigestCols<AB::Var>,
137 ) {
138 for i in 0..SHA256_ROUNDS_PER_ROW {
141 let a = next.hash.a[i].map(|x| x.into());
142 let e = next.hash.e[i].map(|x| x.into());
143 for j in 0..SHA256_WORD_U16S {
144 let a_limb = compose::<AB::Expr>(&a[j * 16..(j + 1) * 16], 1);
145 let e_limb = compose::<AB::Expr>(&e[j * 16..(j + 1) * 16], 1);
146
147 builder
150 .when(
151 next.flags.is_padding_row()
152 + next.flags.is_last_block * next.flags.is_digest_row,
153 )
154 .assert_eq(
155 a_limb,
156 AB::Expr::from_canonical_u32(
157 u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j],
158 ),
159 );
160
161 builder
162 .when(
163 next.flags.is_padding_row()
164 + next.flags.is_last_block * next.flags.is_digest_row,
165 )
166 .assert_eq(
167 e_limb,
168 AB::Expr::from_canonical_u32(
169 u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j],
170 ),
171 );
172 }
173 }
174
175 for i in 0..SHA256_ROUNDS_PER_ROW {
178 let prev_a = next.hash.a[i].map(|x| x.into());
179 let prev_e = next.hash.e[i].map(|x| x.into());
180 let cur_a = next.final_hash[SHA256_ROUNDS_PER_ROW - i - 1].map(|x| x.into());
181
182 let cur_e = next.final_hash[SHA256_ROUNDS_PER_ROW - i + 3].map(|x| x.into());
183 for j in 0..SHA256_WORD_U8S {
184 let prev_a_limb = compose::<AB::Expr>(&prev_a[j * 8..(j + 1) * 8], 1);
185 let prev_e_limb = compose::<AB::Expr>(&prev_e[j * 8..(j + 1) * 8], 1);
186
187 builder
188 .when(not(next.flags.is_last_block) * next.flags.is_digest_row)
189 .assert_eq(prev_a_limb, cur_a[j].clone());
190
191 builder
192 .when(not(next.flags.is_last_block) * next.flags.is_digest_row)
193 .assert_eq(prev_e_limb, cur_e[j].clone());
194 }
195 }
196
197 for i in 0..SHA256_HASH_WORDS {
201 let mut carry = AB::Expr::ZERO;
202 for j in 0..SHA256_WORD_U16S {
203 let work_var_limb = if i < SHA256_ROUNDS_PER_ROW {
204 compose::<AB::Expr>(
205 &local.work_vars.a[SHA256_ROUNDS_PER_ROW - 1 - i][j * 16..(j + 1) * 16],
206 1,
207 )
208 } else {
209 compose::<AB::Expr>(
210 &local.work_vars.e[SHA256_ROUNDS_PER_ROW + 3 - i][j * 16..(j + 1) * 16],
211 1,
212 )
213 };
214 let final_hash_limb =
215 compose::<AB::Expr>(&next.final_hash[i][j * 2..(j + 1) * 2], 8);
216
217 carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse())
218 * (next.prev_hash[i][j] + work_var_limb + carry - final_hash_limb);
219 builder
220 .when(next.flags.is_digest_row)
221 .assert_bool(carry.clone());
222 }
223 for chunk in next.final_hash[i].chunks(2) {
226 self.bitwise_lookup_bus
227 .send_range(chunk[0], chunk[1])
228 .eval(builder, next.flags.is_digest_row);
229 }
230 }
231 }
232
233 fn eval_transitions<AB: InteractionBuilder>(&self, builder: &mut AB, start_col: usize) {
234 let main = builder.main();
235 let local = main.row_slice(0);
236 let next = main.row_slice(1);
237
238 let local_cols: &Sha256RoundCols<AB::Var> =
240 local[start_col..start_col + SHA256_ROUND_WIDTH].borrow();
241 let next_cols: &Sha256RoundCols<AB::Var> =
242 next[start_col..start_col + SHA256_ROUND_WIDTH].borrow();
243
244 let local_is_padding_row = local_cols.flags.is_padding_row();
245 let next_is_padding_row = next_cols.flags.is_padding_row();
249
250 builder
254 .when(next_is_padding_row.clone())
255 .when(local_cols.flags.is_digest_row)
256 .assert_one(local_cols.flags.is_last_block);
257 builder
259 .when(local_cols.flags.is_round_row)
260 .assert_zero(next_is_padding_row.clone());
261 builder
263 .when_first_row()
264 .assert_one(local_cols.flags.is_round_row);
265 builder
267 .when_transition()
268 .when(local_is_padding_row.clone())
269 .assert_one(next_is_padding_row.clone());
270 builder
272 .when(local_cols.flags.is_digest_row)
273 .assert_zero(next_cols.flags.is_digest_row);
274 let delta = local_cols.flags.is_round_row * AB::Expr::ONE
282 + local_cols.flags.is_digest_row
283 * next_cols.flags.is_round_row
284 * AB::Expr::from_canonical_u32(16)
285 * AB::Expr::NEG_ONE
286 + local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE;
287
288 let local_row_idx = self.row_idx_encoder.flag_with_val::<AB>(
289 &local_cols.flags.row_idx,
290 &(0..18).map(|i| (i, i)).collect::<Vec<_>>(),
291 );
292 let next_row_idx = self.row_idx_encoder.flag_with_val::<AB>(
293 &next_cols.flags.row_idx,
294 &(0..18).map(|i| (i, i)).collect::<Vec<_>>(),
295 );
296
297 builder
298 .when_transition()
299 .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone());
300 builder.when_first_row().assert_zero(local_row_idx);
301
302 builder
308 .when_first_row()
309 .assert_one(local_cols.flags.global_block_idx);
310
311 builder.when(local_cols.flags.is_round_row).assert_eq(
313 local_cols.flags.global_block_idx,
314 next_cols.flags.global_block_idx,
315 );
316 builder
318 .when_transition()
319 .when(local_cols.flags.is_digest_row)
320 .when(next_cols.flags.is_round_row)
321 .assert_eq(
322 local_cols.flags.global_block_idx + AB::Expr::ONE,
323 next_cols.flags.global_block_idx,
324 );
325 builder
327 .when(local_is_padding_row.clone())
328 .assert_zero(local_cols.flags.global_block_idx);
329
330 builder.when(not(local_cols.flags.is_digest_row)).assert_eq(
336 local_cols.flags.local_block_idx,
337 next_cols.flags.local_block_idx,
338 );
339 builder
341 .when(local_cols.flags.is_digest_row)
342 .when(not(local_cols.flags.is_last_block))
343 .assert_eq(
344 local_cols.flags.local_block_idx + AB::Expr::ONE,
345 next_cols.flags.local_block_idx,
346 );
347 builder
350 .when(local_cols.flags.is_digest_row)
351 .when(local_cols.flags.is_last_block)
352 .assert_zero(next_cols.flags.local_block_idx);
353
354 self.eval_message_schedule::<AB>(builder, local_cols, next_cols);
355 self.eval_work_vars::<AB>(builder, local_cols, next_cols);
356 let next_cols: &Sha256DigestCols<AB::Var> =
357 next[start_col..start_col + SHA256_DIGEST_WIDTH].borrow();
358 self.eval_digest_row(builder, local_cols, next_cols);
359 let local_cols: &Sha256DigestCols<AB::Var> =
360 local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow();
361 self.eval_prev_hash::<AB>(builder, local_cols, next_is_padding_row);
362 }
363
364 fn eval_prev_hash<AB: InteractionBuilder>(
367 &self,
368 builder: &mut AB,
369 local: &Sha256DigestCols<AB::Var>,
370 is_last_block_of_trace: AB::Expr, ) {
373 let composed_hash: [[<AB as AirBuilder>::Expr; SHA256_WORD_U16S]; SHA256_HASH_WORDS] =
375 array::from_fn(|i| {
376 let hash_bits = if i < SHA256_ROUNDS_PER_ROW {
377 local.hash.a[SHA256_ROUNDS_PER_ROW - 1 - i].map(|x| x.into())
378 } else {
379 local.hash.e[SHA256_ROUNDS_PER_ROW + 3 - i].map(|x| x.into())
380 };
381 array::from_fn(|j| compose::<AB::Expr>(&hash_bits[j * 16..(j + 1) * 16], 1))
382 });
383 let next_global_block_idx = select(
385 is_last_block_of_trace,
386 AB::Expr::ONE,
387 local.flags.global_block_idx + AB::Expr::ONE,
388 );
389 self.bus.send(
391 builder,
392 composed_hash
393 .into_iter()
394 .flatten()
395 .chain(once(next_global_block_idx)),
396 local.flags.is_digest_row,
397 );
398
399 self.bus.receive(
400 builder,
401 local
402 .prev_hash
403 .into_iter()
404 .flatten()
405 .map(|x| x.into())
406 .chain(once(local.flags.global_block_idx.into())),
407 local.flags.is_digest_row,
408 );
409 }
410
411 fn eval_message_schedule<AB: InteractionBuilder>(
416 &self,
417 builder: &mut AB,
418 local: &Sha256RoundCols<AB::Var>,
419 next: &Sha256RoundCols<AB::Var>,
420 ) {
421 let w = [local.message_schedule.w, next.message_schedule.w].concat();
423
424 for i in 0..SHA256_ROUNDS_PER_ROW - 1 {
426 let w_3 = w[i + 1].map(|x| x.into());
429 let expected_w_3 = next.schedule_helper.w_3[i];
430 for j in 0..SHA256_WORD_U16S {
431 let w_3_limb = compose::<AB::Expr>(&w_3[j * 16..(j + 1) * 16], 1);
432 builder
433 .when(local.flags.is_round_row)
434 .assert_eq(w_3_limb, expected_w_3[j].into());
435 }
436 }
437
438 let is_row_3_14 = self
443 .row_idx_encoder
444 .contains_flag_range::<AB>(&next.flags.row_idx, 3..=14);
445 let is_row_2_13 = self
448 .row_idx_encoder
449 .contains_flag_range::<AB>(&next.flags.row_idx, 2..=13);
450 for i in 0..SHA256_ROUNDS_PER_ROW {
451 let w_idx = w[i].map(|x| x.into());
453 let sig_w = small_sig0_field::<AB::Expr>(&w[i + 1]);
455 for j in 0..SHA256_WORD_U16S {
456 let w_idx_limb = compose::<AB::Expr>(&w_idx[j * 16..(j + 1) * 16], 1);
457 let sig_w_limb = compose::<AB::Expr>(&sig_w[j * 16..(j + 1) * 16], 1);
458
459 builder.when_transition().assert_eq(
464 next.schedule_helper.intermed_4[i][j],
465 w_idx_limb + sig_w_limb,
466 );
467
468 builder.when(is_row_2_13.clone()).assert_eq(
469 next.schedule_helper.intermed_8[i][j],
470 local.schedule_helper.intermed_4[i][j],
471 );
472
473 builder.when(is_row_3_14.clone()).assert_eq(
474 next.schedule_helper.intermed_12[i][j],
475 local.schedule_helper.intermed_8[i][j],
476 );
477 }
478 }
479
480 for i in 0..SHA256_ROUNDS_PER_ROW {
482 let w_7 = if i < 3 {
485 local.schedule_helper.w_3[i].map(|x| x.into())
486 } else {
487 let w_3 = w[i - 3].map(|x| x.into());
488 array::from_fn(|j| compose::<AB::Expr>(&w_3[j * 16..(j + 1) * 16], 1))
489 };
490 let intermed_16 = local.schedule_helper.intermed_12[i].map(|x| x.into());
492
493 let carries = array::from_fn(|j| {
494 next.message_schedule.carry_or_buffer[i][j * 2]
495 + AB::Expr::TWO * next.message_schedule.carry_or_buffer[i][j * 2 + 1]
496 });
497
498 constraint_word_addition(
505 &mut builder.when_transition(),
508 &[&small_sig1_field::<AB::Expr>(&w[i + 2])],
509 &[&w_7, &intermed_16],
510 &w[i + 4],
511 &carries,
512 );
513
514 for j in 0..SHA256_WORD_U16S {
515 let is_row_4_15 = next.flags.is_round_row - next.flags.is_first_4_rows;
517 builder
518 .when(is_row_4_15.clone())
519 .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2]);
520 builder
521 .when(is_row_4_15)
522 .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2 + 1]);
523 }
524 for j in 0..SHA256_WORD_BITS {
526 builder
527 .when(next.flags.is_round_row)
528 .assert_bool(next.message_schedule.w[i][j]);
529 }
530 }
531 }
532
533 fn eval_work_vars<AB: InteractionBuilder>(
536 &self,
537 builder: &mut AB,
538 local: &Sha256RoundCols<AB::Var>,
539 next: &Sha256RoundCols<AB::Var>,
540 ) {
541 let a = [local.work_vars.a, next.work_vars.a].concat();
542 let e = [local.work_vars.e, next.work_vars.e].concat();
543 for i in 0..SHA256_ROUNDS_PER_ROW {
544 for j in 0..SHA256_WORD_U16S {
545 self.bitwise_lookup_bus
549 .send_range(local.work_vars.carry_a[i][j], local.work_vars.carry_e[i][j])
550 .eval(builder, local.flags.is_round_row);
551 }
552
553 let w_limbs = array::from_fn(|j| {
554 compose::<AB::Expr>(&next.message_schedule.w[i][j * 16..(j + 1) * 16], 1)
555 * next.flags.is_round_row
556 });
557 let k_limbs = array::from_fn(|j| {
558 self.row_idx_encoder.flag_with_val::<AB>(
559 &next.flags.row_idx,
560 &(0..16)
561 .map(|rw_idx| {
562 (
563 rw_idx,
564 u32_into_limbs::<SHA256_WORD_U16S>(
565 SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i],
566 )[j] as usize,
567 )
568 })
569 .collect::<Vec<_>>(),
570 )
571 });
572
573 constraint_word_addition(
578 builder,
579 &[
580 &e[i].map(|x| x.into()), &big_sig1_field::<AB::Expr>(&e[i + 3]), &ch_field::<AB::Expr>(&e[i + 3], &e[i + 2], &e[i + 1]), &big_sig0_field::<AB::Expr>(&a[i + 3]), &maj_field::<AB::Expr>(&a[i + 3], &a[i + 2], &a[i + 1]), ],
588 &[&w_limbs, &k_limbs], &a[i + 4], &next.work_vars.carry_a[i], );
592
593 constraint_word_addition(
598 builder,
599 &[
600 &a[i].map(|x| x.into()), &e[i].map(|x| x.into()), &big_sig1_field::<AB::Expr>(&e[i + 3]), &ch_field::<AB::Expr>(&e[i + 3], &e[i + 2], &e[i + 1]), ],
607 &[&w_limbs, &k_limbs], &e[i + 4], &next.work_vars.carry_e[i], );
611 }
612 }
613}