1use std::{array, borrow::Borrow, cmp::min};
2
3use openvm_circuit::{
4 arch::ExecutionBridge,
5 system::{
6 memory::{offline_checker::MemoryBridge, MemoryAddress},
7 SystemPort,
8 },
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir,
12};
13use openvm_instructions::{
14 riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
15 LocalOpcode,
16};
17use openvm_sha256_air::{
18 compose, Sha256Air, SHA256_BLOCK_U8S, SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW,
19 SHA256_WORD_U16S, SHA256_WORD_U8S,
20};
21use openvm_sha256_transpiler::Rv32Sha256Opcode;
22use openvm_stark_backend::{
23 interaction::{BusIndex, InteractionBuilder},
24 p3_air::{Air, AirBuilder, BaseAir},
25 p3_field::{Field, PrimeCharacteristicRing},
26 p3_matrix::Matrix,
27 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
28};
29
30use super::{
31 Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH,
32 SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE,
33};
34
35#[derive(Clone, Debug)]
38pub struct Sha256VmAir {
39 pub execution_bridge: ExecutionBridge,
40 pub memory_bridge: MemoryBridge,
41 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
43 pub ptr_max_bits: usize,
46 pub(super) sha256_subair: Sha256Air,
47 pub(super) padding_encoder: Encoder,
48}
49
50impl Sha256VmAir {
51 pub fn new(
52 SystemPort {
53 execution_bus,
54 program_bus,
55 memory_bridge,
56 }: SystemPort,
57 bitwise_lookup_bus: BitwiseOperationLookupBus,
58 ptr_max_bits: usize,
59 self_bus_idx: BusIndex,
60 ) -> Self {
61 Self {
62 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
63 memory_bridge,
64 bitwise_lookup_bus,
65 ptr_max_bits,
66 sha256_subair: Sha256Air::new(bitwise_lookup_bus, self_bus_idx),
67 padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false),
68 }
69 }
70}
71
72impl<F: Field> BaseAirWithPublicValues<F> for Sha256VmAir {}
73impl<F: Field> PartitionedBaseAir<F> for Sha256VmAir {}
74impl<F: Field> BaseAir<F> for Sha256VmAir {
75 fn width(&self) -> usize {
76 SHA256VM_WIDTH
77 }
78}
79
80impl<AB: InteractionBuilder> Air<AB> for Sha256VmAir {
81 fn eval(&self, builder: &mut AB) {
82 self.eval_padding(builder);
83 self.eval_transitions(builder);
84 self.eval_reads(builder);
85 self.eval_last_row(builder);
86
87 self.sha256_subair.eval(builder, SHA256VM_CONTROL_WIDTH);
88 }
89}
90
91#[allow(dead_code, non_camel_case_types)]
92pub(super) enum PaddingFlags {
93 NotConsidered,
95 NotPadding,
97 FirstPadding0,
99 FirstPadding1,
100 FirstPadding2,
101 FirstPadding3,
102 FirstPadding4,
103 FirstPadding5,
104 FirstPadding6,
105 FirstPadding7,
106 FirstPadding8,
107 FirstPadding9,
108 FirstPadding10,
109 FirstPadding11,
110 FirstPadding12,
111 FirstPadding13,
112 FirstPadding14,
113 FirstPadding15,
114 FirstPadding0_LastRow,
119 FirstPadding1_LastRow,
120 FirstPadding2_LastRow,
121 FirstPadding3_LastRow,
122 FirstPadding4_LastRow,
123 FirstPadding5_LastRow,
124 FirstPadding6_LastRow,
125 FirstPadding7_LastRow,
126 EntirePaddingLastRow,
129 EntirePadding,
131}
132
133impl PaddingFlags {
134 pub const COUNT: usize = EntirePadding as usize + 1;
136}
137
138use PaddingFlags::*;
139impl Sha256VmAir {
140 fn eval_padding<AB: InteractionBuilder>(&self, builder: &mut AB) {
142 let main = builder.main();
143 let (local, next) = (
144 main.row_slice(0).expect("window should have two elements"),
145 main.row_slice(1).expect("window should have two elements"),
146 );
147 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
148 let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
149
150 self.padding_encoder
152 .eval(builder, &local_cols.control.pad_flags);
153
154 builder.assert_one(self.padding_encoder.contains_flag_range::<AB>(
155 &local_cols.control.pad_flags,
156 NotConsidered as usize..=EntirePadding as usize,
157 ));
158
159 Self::eval_padding_transitions(self, builder, local_cols, next_cols);
160 Self::eval_padding_row(self, builder, local_cols);
161 }
162
163 fn eval_padding_transitions<AB: InteractionBuilder>(
164 &self,
165 builder: &mut AB,
166 local: &Sha256VmRoundCols<AB::Var>,
167 next: &Sha256VmRoundCols<AB::Var>,
168 ) {
169 let next_is_last_row = next.inner.flags.is_digest_row * next.inner.flags.is_last_block;
170
171 builder.assert_bool(local.control.padding_occurred);
176 builder.assert_bool(local.control.padding_spills);
177 builder
180 .when(next_is_last_row.clone())
181 .assert_one(local.control.padding_occurred);
182
183 builder
185 .when(next_is_last_row.clone())
186 .assert_zero(next.control.padding_occurred);
187
188 builder
191 .when(local.control.padding_occurred - next_is_last_row.clone())
192 .assert_one(next.control.padding_occurred);
193
194 builder
198 .when_transition()
199 .when(not(next.inner.flags.is_first_4_rows) - next_is_last_row)
200 .assert_eq(
201 next.control.padding_occurred,
202 local.control.padding_occurred,
203 );
204
205 let next_is_first_padding_row =
207 next.control.padding_occurred - local.control.padding_occurred;
208 let next_row_idx = self.sha256_subair.row_idx_encoder.flag_with_val::<AB>(
210 &next.inner.flags.row_idx,
211 &(0..4).map(|x| (x, x)).collect::<Vec<_>>(),
212 );
213 let next_padding_offset = self.padding_encoder.flag_with_val::<AB>(
216 &next.control.pad_flags,
217 &(0..16)
218 .map(|i| (FirstPadding0 as usize + i, i))
219 .collect::<Vec<_>>(),
220 ) + self.padding_encoder.flag_with_val::<AB>(
221 &next.control.pad_flags,
222 &(0..8)
223 .map(|i| (FirstPadding0_LastRow as usize + i, i))
224 .collect::<Vec<_>>(),
225 );
226
227 let expected_len = next.inner.flags.local_block_idx
232 * next.control.padding_occurred
233 * AB::Expr::from_usize(SHA256_BLOCK_U8S)
234 + next_row_idx * AB::Expr::from_usize(SHA256_READ_SIZE)
235 + next_padding_offset;
236
237 builder.when(next_is_first_padding_row).assert_eq(
241 expected_len,
242 next.control.len * next.control.padding_occurred,
243 );
244
245 let is_next_first_padding = self.padding_encoder.contains_flag_range::<AB>(
247 &next.control.pad_flags,
248 FirstPadding0 as usize..=FirstPadding7_LastRow as usize,
249 );
250
251 let is_next_last_padding = self.padding_encoder.contains_flag_range::<AB>(
252 &next.control.pad_flags,
253 FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
254 );
255
256 let is_next_entire_padding = self.padding_encoder.contains_flag_range::<AB>(
257 &next.control.pad_flags,
258 EntirePaddingLastRow as usize..=EntirePadding as usize,
259 );
260
261 let is_next_not_considered = self
262 .padding_encoder
263 .contains_flag::<AB>(&next.control.pad_flags, &[NotConsidered as usize]);
264
265 let is_next_not_padding = self
266 .padding_encoder
267 .contains_flag::<AB>(&next.control.pad_flags, &[NotPadding as usize]);
268
269 let is_next_4th_row = self
270 .sha256_subair
271 .row_idx_encoder
272 .contains_flag::<AB>(&next.inner.flags.row_idx, &[3]);
273
274 let is_next_fp8_15 = self.padding_encoder.contains_flag_range::<AB>(
275 &next.control.pad_flags,
276 FirstPadding8 as usize..=FirstPadding15 as usize,
277 );
278
279 builder.assert_eq(
281 not(next.inner.flags.is_first_4_rows),
282 is_next_not_considered,
283 );
284
285 builder.when(next.inner.flags.is_first_4_rows).assert_eq(
287 local.control.padding_occurred * next.control.padding_occurred,
288 is_next_entire_padding,
289 );
290
291 builder.when(next.inner.flags.is_first_4_rows).assert_eq(
294 not(local.control.padding_occurred) * next.control.padding_occurred,
295 is_next_first_padding.clone(),
296 );
297
298 builder
300 .when(next.inner.flags.is_first_4_rows)
301 .assert_eq(not(next.control.padding_occurred), is_next_not_padding);
302
303 builder.assert_zero(
306 next.control.padding_spills * (AB::Expr::ONE - is_next_first_padding.clone()),
307 );
308 builder
309 .assert_zero(next.control.padding_spills * (AB::Expr::ONE - is_next_4th_row.clone()));
310 builder.assert_zero(next.control.padding_spills * (AB::Expr::ONE - is_next_fp8_15));
311
312 builder.when(is_next_first_padding.clone()).assert_eq(
314 next.inner.flags.is_last_block + next.control.padding_spills,
315 AB::Expr::ONE,
316 );
317
318 builder
321 .when(local.inner.flags.is_digest_row * next.control.padding_occurred)
322 .assert_one(next.inner.flags.is_last_block);
323
324 builder.assert_eq(
326 is_next_4th_row * next.inner.flags.is_last_block,
327 is_next_last_padding,
328 );
329 }
330
331 fn eval_padding_row<AB: InteractionBuilder>(
332 &self,
333 builder: &mut AB,
334 local: &Sha256VmRoundCols<AB::Var>,
335 ) {
336 let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
337 local.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U8S)]
338 [i % (SHA256_WORD_U8S)]
339 });
340
341 let get_ith_byte = |i: usize| {
342 let word_idx = i / SHA256_ROUNDS_PER_ROW;
343 let word = local.inner.message_schedule.w[word_idx].map(|x| x.into());
344 let byte_idx = 4 - i % 4 - 1;
346 compose::<AB::Expr>(&word[byte_idx * 8..(byte_idx + 1) * 8], 1)
347 };
348
349 let is_not_padding = self
350 .padding_encoder
351 .contains_flag::<AB>(&local.control.pad_flags, &[NotPadding as usize]);
352
353 for (i, message_byte) in message.iter().enumerate() {
355 let w = get_ith_byte(i);
356 let should_be_message = is_not_padding.clone()
357 + if i < 15 {
358 self.padding_encoder.contains_flag_range::<AB>(
359 &local.control.pad_flags,
360 FirstPadding0 as usize + i + 1..=FirstPadding15 as usize,
361 )
362 } else {
363 AB::Expr::ZERO
364 }
365 + if i < 7 {
366 self.padding_encoder.contains_flag_range::<AB>(
367 &local.control.pad_flags,
368 FirstPadding0_LastRow as usize + i + 1..=FirstPadding7_LastRow as usize,
369 )
370 } else {
371 AB::Expr::ZERO
372 };
373 builder
374 .when(should_be_message)
375 .assert_eq(w.clone(), *message_byte);
376
377 let should_be_zero = self
378 .padding_encoder
379 .contains_flag::<AB>(&local.control.pad_flags, &[EntirePadding as usize])
380 + if i < 12 {
381 self.padding_encoder.contains_flag::<AB>(
382 &local.control.pad_flags,
383 &[EntirePaddingLastRow as usize],
384 ) + if i > 0 {
385 self.padding_encoder.contains_flag_range::<AB>(
386 &local.control.pad_flags,
387 FirstPadding0_LastRow as usize
388 ..=min(
389 FirstPadding0_LastRow as usize + i - 1,
390 FirstPadding7_LastRow as usize,
391 ),
392 )
393 } else {
394 AB::Expr::ZERO
395 }
396 } else {
397 AB::Expr::ZERO
398 }
399 + if i > 0 {
400 self.padding_encoder.contains_flag_range::<AB>(
401 &local.control.pad_flags,
402 FirstPadding0 as usize..=FirstPadding0 as usize + i - 1,
403 )
404 } else {
405 AB::Expr::ZERO
406 };
407 builder.when(should_be_zero).assert_zero(w.clone());
408
409 let should_be_128 = self
412 .padding_encoder
413 .contains_flag::<AB>(&local.control.pad_flags, &[FirstPadding0 as usize + i])
414 + if i < 8 {
415 self.padding_encoder.contains_flag::<AB>(
416 &local.control.pad_flags,
417 &[FirstPadding0_LastRow as usize + i],
418 )
419 } else {
420 AB::Expr::ZERO
421 };
422
423 builder
424 .when(should_be_128)
425 .assert_eq(AB::Expr::from_u32(1 << 7), w);
426
427 }
429 let appended_len = compose::<AB::Expr>(
430 &[
431 get_ith_byte(15),
432 get_ith_byte(14),
433 get_ith_byte(13),
434 get_ith_byte(12),
435 ],
436 RV32_CELL_BITS,
437 );
438
439 let actual_len = local.control.len;
440
441 let is_last_padding_row = self.padding_encoder.contains_flag_range::<AB>(
442 &local.control.pad_flags,
443 FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
444 );
445
446 builder.when(is_last_padding_row.clone()).assert_eq(
447 appended_len * AB::F::from_usize(RV32_CELL_BITS).inverse(), actual_len,
449 );
450
451 builder.when(is_last_padding_row.clone()).assert_zero(
453 local.inner.message_schedule.w[3][0]
454 + local.inner.message_schedule.w[3][1]
455 + local.inner.message_schedule.w[3][2],
456 );
457
458 for i in 8..12 {
462 builder
463 .when(is_last_padding_row.clone())
464 .assert_zero(get_ith_byte(i));
465 }
466 }
467 fn eval_transitions<AB: InteractionBuilder>(&self, builder: &mut AB) {
469 let main = builder.main();
470 let (local, next) = (
471 main.row_slice(0).expect("window should have two elements"),
472 main.row_slice(1).expect("window should have two elements"),
473 );
474 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
475 let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
476
477 let is_last_row =
478 local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
479
480 builder
482 .when_transition()
483 .when(not::<AB::Expr>(is_last_row.clone()))
484 .assert_eq(next_cols.control.len, local_cols.control.len);
485
486 let read_ptr_delta =
489 local_cols.inner.flags.is_first_4_rows * AB::Expr::from_usize(SHA256_READ_SIZE);
490 builder
491 .when_transition()
492 .when(not::<AB::Expr>(is_last_row.clone()))
493 .assert_eq(
494 next_cols.control.read_ptr,
495 local_cols.control.read_ptr + read_ptr_delta,
496 );
497
498 let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE;
500 builder
501 .when_transition()
502 .when(not::<AB::Expr>(is_last_row.clone()))
503 .assert_eq(
504 next_cols.control.cur_timestamp,
505 local_cols.control.cur_timestamp + timestamp_delta,
506 );
507 }
508
509 fn eval_reads<AB: InteractionBuilder>(&self, builder: &mut AB) {
511 let main = builder.main();
512 let local = main.row_slice(0).expect("window should have two elements");
513 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
514
515 let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
516 local_cols.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U16S * 2)]
517 [i % (SHA256_WORD_U16S * 2)]
518 });
519
520 self.memory_bridge
521 .read(
522 MemoryAddress::new(
523 AB::Expr::from_u32(RV32_MEMORY_AS),
524 local_cols.control.read_ptr,
525 ),
526 message,
527 local_cols.control.cur_timestamp,
528 &local_cols.read_aux,
529 )
530 .eval(builder, local_cols.inner.flags.is_first_4_rows);
531 }
532 fn eval_last_row<AB: InteractionBuilder>(&self, builder: &mut AB) {
534 let main = builder.main();
535 let local = main.row_slice(0).expect("window should have two elements");
536 let local_cols: &Sha256VmDigestCols<AB::Var> = local[..SHA256VM_DIGEST_WIDTH].borrow();
537
538 let timestamp: AB::Var = local_cols.from_state.timestamp;
539 let mut timestamp_delta: usize = 0;
540 let mut timestamp_pp = || {
541 timestamp_delta += 1;
542 timestamp + AB::Expr::from_usize(timestamp_delta - 1)
543 };
544
545 let is_last_row =
546 local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
547
548 self.memory_bridge
549 .read(
550 MemoryAddress::new(AB::Expr::from_u32(RV32_REGISTER_AS), local_cols.rd_ptr),
551 local_cols.dst_ptr,
552 timestamp_pp(),
553 &local_cols.register_reads_aux[0],
554 )
555 .eval(builder, is_last_row.clone());
556
557 self.memory_bridge
558 .read(
559 MemoryAddress::new(AB::Expr::from_u32(RV32_REGISTER_AS), local_cols.rs1_ptr),
560 local_cols.src_ptr,
561 timestamp_pp(),
562 &local_cols.register_reads_aux[1],
563 )
564 .eval(builder, is_last_row.clone());
565
566 self.memory_bridge
567 .read(
568 MemoryAddress::new(AB::Expr::from_u32(RV32_REGISTER_AS), local_cols.rs2_ptr),
569 local_cols.len_data,
570 timestamp_pp(),
571 &local_cols.register_reads_aux[2],
572 )
573 .eval(builder, is_last_row.clone());
574
575 let shift = AB::Expr::from_usize(
577 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits),
578 );
579 self.bitwise_lookup_bus
581 .send_range(
582 local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
585 local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
586 )
587 .eval(builder, is_last_row.clone());
588
589 self.bitwise_lookup_bus
597 .send_range(
598 local_cols.len_data[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
599 AB::Expr::ZERO,
600 )
601 .eval(builder, is_last_row.clone());
602
603 let time_delta =
605 (local_cols.inner.flags.local_block_idx + AB::Expr::ONE) * AB::Expr::from_usize(4);
606 let read_ptr_delta = time_delta.clone() * AB::Expr::from_usize(SHA256_READ_SIZE);
608
609 let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| {
610 local_cols.inner.final_hash[i / SHA256_WORD_U8S]
612 [SHA256_WORD_U8S - i % SHA256_WORD_U8S - 1]
613 });
614
615 let dst_ptr_val =
616 compose::<AB::Expr>(&local_cols.dst_ptr.map(|x| x.into()), RV32_CELL_BITS);
617
618 self.memory_bridge
622 .write(
623 MemoryAddress::new(AB::Expr::from_u32(RV32_MEMORY_AS), dst_ptr_val),
624 result,
625 timestamp_pp() + time_delta.clone(),
626 &local_cols.writes_aux,
627 )
628 .eval(builder, is_last_row.clone());
629
630 self.execution_bridge
631 .execute_and_increment_pc(
632 AB::Expr::from_usize(Rv32Sha256Opcode::SHA256.global_opcode().as_usize()),
633 [
634 local_cols.rd_ptr.into(),
635 local_cols.rs1_ptr.into(),
636 local_cols.rs2_ptr.into(),
637 AB::Expr::from_u32(RV32_REGISTER_AS),
638 AB::Expr::from_u32(RV32_MEMORY_AS),
639 ],
640 local_cols.from_state,
641 AB::Expr::from_usize(timestamp_delta) + time_delta.clone(),
642 )
643 .eval(builder, is_last_row.clone());
644
645 let len_val = compose::<AB::Expr>(&local_cols.len_data.map(|x| x.into()), RV32_CELL_BITS);
647 builder
648 .when(is_last_row.clone())
649 .assert_eq(local_cols.control.len, len_val);
650 let src_val = compose::<AB::Expr>(&local_cols.src_ptr.map(|x| x.into()), RV32_CELL_BITS);
652 builder
653 .when(is_last_row.clone())
654 .assert_eq(local_cols.control.read_ptr, src_val + read_ptr_delta);
655 builder.when(is_last_row.clone()).assert_eq(
657 local_cols.control.cur_timestamp,
658 local_cols.from_state.timestamp + AB::Expr::from_u32(3) + time_delta,
659 );
660 }
661}