1use std::{array, borrow::Borrow, cmp::min};
2
3use openvm_circuit::{
4 arch::ExecutionBridge,
5 system::{
6 memory::{offline_checker::MemoryBridge, MemoryAddress},
7 SystemPort,
8 },
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir,
12};
13use openvm_instructions::{
14 riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
15 LocalOpcode,
16};
17use openvm_sha256_air::{
18 compose, Sha256Air, SHA256_BLOCK_U8S, SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW,
19 SHA256_WORD_U16S, SHA256_WORD_U8S,
20};
21use openvm_sha256_transpiler::Rv32Sha256Opcode;
22use openvm_stark_backend::{
23 interaction::{BusIndex, InteractionBuilder},
24 p3_air::{Air, AirBuilder, BaseAir},
25 p3_field::{Field, PrimeCharacteristicRing},
26 p3_matrix::Matrix,
27 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
28};
29
30use super::{
31 Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH,
32 SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE,
33};
34
35#[derive(Clone, Debug)]
38pub struct Sha256VmAir {
39 pub execution_bridge: ExecutionBridge,
40 pub memory_bridge: MemoryBridge,
41 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
43 pub ptr_max_bits: usize,
46 pub(super) sha256_subair: Sha256Air,
47 pub(super) padding_encoder: Encoder,
48}
49
50impl Sha256VmAir {
51 pub fn new(
52 SystemPort {
53 execution_bus,
54 program_bus,
55 memory_bridge,
56 }: SystemPort,
57 bitwise_lookup_bus: BitwiseOperationLookupBus,
58 ptr_max_bits: usize,
59 self_bus_idx: BusIndex,
60 ) -> Self {
61 Self {
62 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
63 memory_bridge,
64 bitwise_lookup_bus,
65 ptr_max_bits,
66 sha256_subair: Sha256Air::new(bitwise_lookup_bus, self_bus_idx),
67 padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false),
68 }
69 }
70}
71
72impl<F: Field> BaseAirWithPublicValues<F> for Sha256VmAir {}
73impl<F: Field> PartitionedBaseAir<F> for Sha256VmAir {}
74impl<F: Field> BaseAir<F> for Sha256VmAir {
75 fn width(&self) -> usize {
76 SHA256VM_WIDTH
77 }
78}
79
80impl<AB: InteractionBuilder> Air<AB> for Sha256VmAir {
81 fn eval(&self, builder: &mut AB) {
82 self.eval_padding(builder);
83 self.eval_transitions(builder);
84 self.eval_reads(builder);
85 self.eval_last_row(builder);
86
87 self.sha256_subair.eval(builder, SHA256VM_CONTROL_WIDTH);
88 }
89}
90
91#[allow(dead_code, non_camel_case_types)]
92pub(super) enum PaddingFlags {
93 NotConsidered,
95 NotPadding,
97 FirstPadding0,
99 FirstPadding1,
100 FirstPadding2,
101 FirstPadding3,
102 FirstPadding4,
103 FirstPadding5,
104 FirstPadding6,
105 FirstPadding7,
106 FirstPadding8,
107 FirstPadding9,
108 FirstPadding10,
109 FirstPadding11,
110 FirstPadding12,
111 FirstPadding13,
112 FirstPadding14,
113 FirstPadding15,
114 FirstPadding0_LastRow,
119 FirstPadding1_LastRow,
120 FirstPadding2_LastRow,
121 FirstPadding3_LastRow,
122 FirstPadding4_LastRow,
123 FirstPadding5_LastRow,
124 FirstPadding6_LastRow,
125 FirstPadding7_LastRow,
126 EntirePaddingLastRow,
129 EntirePadding,
131}
132
133impl PaddingFlags {
134 pub const COUNT: usize = EntirePadding as usize + 1;
136}
137
138use PaddingFlags::*;
139impl Sha256VmAir {
140 fn eval_padding<AB: InteractionBuilder>(&self, builder: &mut AB) {
142 let main = builder.main();
143 let (local, next) = (
144 main.row_slice(0).expect("window should have two elements"),
145 main.row_slice(1).expect("window should have two elements"),
146 );
147 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
148 let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
149
150 self.padding_encoder
152 .eval(builder, &local_cols.control.pad_flags);
153
154 builder.assert_one(self.padding_encoder.contains_flag_range::<AB>(
155 &local_cols.control.pad_flags,
156 NotConsidered as usize..=EntirePadding as usize,
157 ));
158
159 Self::eval_padding_transitions(self, builder, local_cols, next_cols);
160 Self::eval_padding_row(self, builder, local_cols);
161 }
162
163 fn eval_padding_transitions<AB: InteractionBuilder>(
164 &self,
165 builder: &mut AB,
166 local: &Sha256VmRoundCols<AB::Var>,
167 next: &Sha256VmRoundCols<AB::Var>,
168 ) {
169 let next_is_last_row = next.inner.flags.is_digest_row * next.inner.flags.is_last_block;
170
171 builder.assert_bool(local.control.padding_occurred);
176 builder
179 .when(next_is_last_row.clone())
180 .assert_one(local.control.padding_occurred);
181
182 builder
184 .when(next_is_last_row.clone())
185 .assert_zero(next.control.padding_occurred);
186
187 builder
190 .when(local.control.padding_occurred - next_is_last_row.clone())
191 .assert_one(next.control.padding_occurred);
192
193 builder
197 .when_transition()
198 .when(not(next.inner.flags.is_first_4_rows) - next_is_last_row)
199 .assert_eq(
200 next.control.padding_occurred,
201 local.control.padding_occurred,
202 );
203
204 let next_is_first_padding_row =
206 next.control.padding_occurred - local.control.padding_occurred;
207 let next_row_idx = self.sha256_subair.row_idx_encoder.flag_with_val::<AB>(
209 &next.inner.flags.row_idx,
210 &(0..4).map(|x| (x, x)).collect::<Vec<_>>(),
211 );
212 let next_padding_offset = self.padding_encoder.flag_with_val::<AB>(
215 &next.control.pad_flags,
216 &(0..16)
217 .map(|i| (FirstPadding0 as usize + i, i))
218 .collect::<Vec<_>>(),
219 ) + self.padding_encoder.flag_with_val::<AB>(
220 &next.control.pad_flags,
221 &(0..8)
222 .map(|i| (FirstPadding0_LastRow as usize + i, i))
223 .collect::<Vec<_>>(),
224 );
225
226 let expected_len = next.inner.flags.local_block_idx
231 * next.control.padding_occurred
232 * AB::Expr::from_usize(SHA256_BLOCK_U8S)
233 + next_row_idx * AB::Expr::from_usize(SHA256_READ_SIZE)
234 + next_padding_offset;
235
236 builder.when(next_is_first_padding_row).assert_eq(
240 expected_len,
241 next.control.len * next.control.padding_occurred,
242 );
243
244 let is_next_first_padding = self.padding_encoder.contains_flag_range::<AB>(
246 &next.control.pad_flags,
247 FirstPadding0 as usize..=FirstPadding7_LastRow as usize,
248 );
249
250 let is_next_last_padding = self.padding_encoder.contains_flag_range::<AB>(
251 &next.control.pad_flags,
252 FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
253 );
254
255 let is_next_entire_padding = self.padding_encoder.contains_flag_range::<AB>(
256 &next.control.pad_flags,
257 EntirePaddingLastRow as usize..=EntirePadding as usize,
258 );
259
260 let is_next_not_considered = self
261 .padding_encoder
262 .contains_flag::<AB>(&next.control.pad_flags, &[NotConsidered as usize]);
263
264 let is_next_not_padding = self
265 .padding_encoder
266 .contains_flag::<AB>(&next.control.pad_flags, &[NotPadding as usize]);
267
268 let is_next_4th_row = self
269 .sha256_subair
270 .row_idx_encoder
271 .contains_flag::<AB>(&next.inner.flags.row_idx, &[3]);
272
273 builder.assert_eq(
275 not(next.inner.flags.is_first_4_rows),
276 is_next_not_considered,
277 );
278
279 builder.when(next.inner.flags.is_first_4_rows).assert_eq(
281 local.control.padding_occurred * next.control.padding_occurred,
282 is_next_entire_padding,
283 );
284
285 builder.when(next.inner.flags.is_first_4_rows).assert_eq(
288 not(local.control.padding_occurred) * next.control.padding_occurred,
289 is_next_first_padding,
290 );
291
292 builder
294 .when(next.inner.flags.is_first_4_rows)
295 .assert_eq(not(next.control.padding_occurred), is_next_not_padding);
296
297 builder
299 .when(next.inner.flags.is_last_block)
300 .assert_eq(is_next_4th_row, is_next_last_padding);
301 }
302
303 fn eval_padding_row<AB: InteractionBuilder>(
304 &self,
305 builder: &mut AB,
306 local: &Sha256VmRoundCols<AB::Var>,
307 ) {
308 let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
309 local.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U8S)]
310 [i % (SHA256_WORD_U8S)]
311 });
312
313 let get_ith_byte = |i: usize| {
314 let word_idx = i / SHA256_ROUNDS_PER_ROW;
315 let word = local.inner.message_schedule.w[word_idx].map(|x| x.into());
316 let byte_idx = 4 - i % 4 - 1;
318 compose::<AB::Expr>(&word[byte_idx * 8..(byte_idx + 1) * 8], 1)
319 };
320
321 let is_not_padding = self
322 .padding_encoder
323 .contains_flag::<AB>(&local.control.pad_flags, &[NotPadding as usize]);
324
325 for (i, message_byte) in message.iter().enumerate() {
327 let w = get_ith_byte(i);
328 let should_be_message = is_not_padding.clone()
329 + if i < 15 {
330 self.padding_encoder.contains_flag_range::<AB>(
331 &local.control.pad_flags,
332 FirstPadding0 as usize + i + 1..=FirstPadding15 as usize,
333 )
334 } else {
335 AB::Expr::ZERO
336 }
337 + if i < 7 {
338 self.padding_encoder.contains_flag_range::<AB>(
339 &local.control.pad_flags,
340 FirstPadding0_LastRow as usize + i + 1..=FirstPadding7_LastRow as usize,
341 )
342 } else {
343 AB::Expr::ZERO
344 };
345 builder
346 .when(should_be_message)
347 .assert_eq(w.clone(), *message_byte);
348
349 let should_be_zero = self
350 .padding_encoder
351 .contains_flag::<AB>(&local.control.pad_flags, &[EntirePadding as usize])
352 + if i < 12 {
353 self.padding_encoder.contains_flag::<AB>(
354 &local.control.pad_flags,
355 &[EntirePaddingLastRow as usize],
356 ) + if i > 0 {
357 self.padding_encoder.contains_flag_range::<AB>(
358 &local.control.pad_flags,
359 FirstPadding0_LastRow as usize
360 ..=min(
361 FirstPadding0_LastRow as usize + i - 1,
362 FirstPadding7_LastRow as usize,
363 ),
364 )
365 } else {
366 AB::Expr::ZERO
367 }
368 } else {
369 AB::Expr::ZERO
370 }
371 + if i > 0 {
372 self.padding_encoder.contains_flag_range::<AB>(
373 &local.control.pad_flags,
374 FirstPadding0 as usize..=FirstPadding0 as usize + i - 1,
375 )
376 } else {
377 AB::Expr::ZERO
378 };
379 builder.when(should_be_zero).assert_zero(w.clone());
380
381 let should_be_128 = self
384 .padding_encoder
385 .contains_flag::<AB>(&local.control.pad_flags, &[FirstPadding0 as usize + i])
386 + if i < 8 {
387 self.padding_encoder.contains_flag::<AB>(
388 &local.control.pad_flags,
389 &[FirstPadding0_LastRow as usize + i],
390 )
391 } else {
392 AB::Expr::ZERO
393 };
394
395 builder
396 .when(should_be_128)
397 .assert_eq(AB::Expr::from_u32(1 << 7), w);
398
399 }
401 let appended_len = compose::<AB::Expr>(
402 &[
403 get_ith_byte(15),
404 get_ith_byte(14),
405 get_ith_byte(13),
406 get_ith_byte(12),
407 ],
408 RV32_CELL_BITS,
409 );
410
411 let actual_len = local.control.len;
412
413 let is_last_padding_row = self.padding_encoder.contains_flag_range::<AB>(
414 &local.control.pad_flags,
415 FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
416 );
417
418 builder.when(is_last_padding_row.clone()).assert_eq(
419 appended_len * AB::F::from_usize(RV32_CELL_BITS).inverse(), actual_len,
421 );
422
423 builder.when(is_last_padding_row.clone()).assert_zero(
425 local.inner.message_schedule.w[3][0]
426 + local.inner.message_schedule.w[3][1]
427 + local.inner.message_schedule.w[3][2],
428 );
429
430 for i in 8..12 {
434 builder
435 .when(is_last_padding_row.clone())
436 .assert_zero(get_ith_byte(i));
437 }
438 }
439 fn eval_transitions<AB: InteractionBuilder>(&self, builder: &mut AB) {
441 let main = builder.main();
442 let (local, next) = (
443 main.row_slice(0).expect("window should have two elements"),
444 main.row_slice(1).expect("window should have two elements"),
445 );
446 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
447 let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
448
449 let is_last_row =
450 local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
451
452 builder
454 .when_transition()
455 .when(not::<AB::Expr>(is_last_row.clone()))
456 .assert_eq(next_cols.control.len, local_cols.control.len);
457
458 let read_ptr_delta =
461 local_cols.inner.flags.is_first_4_rows * AB::Expr::from_usize(SHA256_READ_SIZE);
462 builder
463 .when_transition()
464 .when(not::<AB::Expr>(is_last_row.clone()))
465 .assert_eq(
466 next_cols.control.read_ptr,
467 local_cols.control.read_ptr + read_ptr_delta,
468 );
469
470 let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE;
472 builder
473 .when_transition()
474 .when(not::<AB::Expr>(is_last_row.clone()))
475 .assert_eq(
476 next_cols.control.cur_timestamp,
477 local_cols.control.cur_timestamp + timestamp_delta,
478 );
479 }
480
481 fn eval_reads<AB: InteractionBuilder>(&self, builder: &mut AB) {
483 let main = builder.main();
484 let local = main.row_slice(0).expect("window should have two elements");
485 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
486
487 let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
488 local_cols.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U16S * 2)]
489 [i % (SHA256_WORD_U16S * 2)]
490 });
491
492 self.memory_bridge
493 .read(
494 MemoryAddress::new(
495 AB::Expr::from_u32(RV32_MEMORY_AS),
496 local_cols.control.read_ptr,
497 ),
498 message,
499 local_cols.control.cur_timestamp,
500 &local_cols.read_aux,
501 )
502 .eval(builder, local_cols.inner.flags.is_first_4_rows);
503 }
504 fn eval_last_row<AB: InteractionBuilder>(&self, builder: &mut AB) {
506 let main = builder.main();
507 let local = main.row_slice(0).expect("window should have two elements");
508 let local_cols: &Sha256VmDigestCols<AB::Var> = local[..SHA256VM_DIGEST_WIDTH].borrow();
509
510 let timestamp: AB::Var = local_cols.from_state.timestamp;
511 let mut timestamp_delta: usize = 0;
512 let mut timestamp_pp = || {
513 timestamp_delta += 1;
514 timestamp + AB::Expr::from_usize(timestamp_delta - 1)
515 };
516
517 let is_last_row =
518 local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
519
520 self.memory_bridge
521 .read(
522 MemoryAddress::new(AB::Expr::from_u32(RV32_REGISTER_AS), local_cols.rd_ptr),
523 local_cols.dst_ptr,
524 timestamp_pp(),
525 &local_cols.register_reads_aux[0],
526 )
527 .eval(builder, is_last_row.clone());
528
529 self.memory_bridge
530 .read(
531 MemoryAddress::new(AB::Expr::from_u32(RV32_REGISTER_AS), local_cols.rs1_ptr),
532 local_cols.src_ptr,
533 timestamp_pp(),
534 &local_cols.register_reads_aux[1],
535 )
536 .eval(builder, is_last_row.clone());
537
538 self.memory_bridge
539 .read(
540 MemoryAddress::new(AB::Expr::from_u32(RV32_REGISTER_AS), local_cols.rs2_ptr),
541 local_cols.len_data,
542 timestamp_pp(),
543 &local_cols.register_reads_aux[2],
544 )
545 .eval(builder, is_last_row.clone());
546
547 let shift = AB::Expr::from_usize(
551 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits),
552 );
553 self.bitwise_lookup_bus
555 .send_range(
556 local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
559 local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
560 )
561 .eval(builder, is_last_row.clone());
562
563 let time_delta =
565 (local_cols.inner.flags.local_block_idx + AB::Expr::ONE) * AB::Expr::from_usize(4);
566 let read_ptr_delta = time_delta.clone() * AB::Expr::from_usize(SHA256_READ_SIZE);
568
569 let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| {
570 local_cols.inner.final_hash[i / SHA256_WORD_U8S]
572 [SHA256_WORD_U8S - i % SHA256_WORD_U8S - 1]
573 });
574
575 let dst_ptr_val =
576 compose::<AB::Expr>(&local_cols.dst_ptr.map(|x| x.into()), RV32_CELL_BITS);
577
578 self.memory_bridge
582 .write(
583 MemoryAddress::new(AB::Expr::from_u32(RV32_MEMORY_AS), dst_ptr_val),
584 result,
585 timestamp_pp() + time_delta.clone(),
586 &local_cols.writes_aux,
587 )
588 .eval(builder, is_last_row.clone());
589
590 self.execution_bridge
591 .execute_and_increment_pc(
592 AB::Expr::from_usize(Rv32Sha256Opcode::SHA256.global_opcode().as_usize()),
593 [
594 local_cols.rd_ptr.into(),
595 local_cols.rs1_ptr.into(),
596 local_cols.rs2_ptr.into(),
597 AB::Expr::from_u32(RV32_REGISTER_AS),
598 AB::Expr::from_u32(RV32_MEMORY_AS),
599 ],
600 local_cols.from_state,
601 AB::Expr::from_usize(timestamp_delta) + time_delta.clone(),
602 )
603 .eval(builder, is_last_row.clone());
604
605 let len_val = compose::<AB::Expr>(&local_cols.len_data.map(|x| x.into()), RV32_CELL_BITS);
607 builder
608 .when(is_last_row.clone())
609 .assert_eq(local_cols.control.len, len_val);
610 let src_val = compose::<AB::Expr>(&local_cols.src_ptr.map(|x| x.into()), RV32_CELL_BITS);
612 builder
613 .when(is_last_row.clone())
614 .assert_eq(local_cols.control.read_ptr, src_val + read_ptr_delta);
615 builder.when(is_last_row.clone()).assert_eq(
617 local_cols.control.cur_timestamp,
618 local_cols.from_state.timestamp + AB::Expr::from_u32(3) + time_delta,
619 );
620 }
621}