1use std::{array, borrow::Borrow, cmp::min};
2
3use openvm_circuit::{
4 arch::ExecutionBridge,
5 system::{
6 memory::{offline_checker::MemoryBridge, MemoryAddress},
7 SystemPort,
8 },
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir,
12};
13use openvm_instructions::{
14 riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
15 LocalOpcode,
16};
17use openvm_sha256_air::{
18 compose, Sha256Air, SHA256_BLOCK_U8S, SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW,
19 SHA256_WORD_U16S, SHA256_WORD_U8S,
20};
21use openvm_sha256_transpiler::Rv32Sha256Opcode;
22use openvm_stark_backend::{
23 interaction::{BusIndex, InteractionBuilder},
24 p3_air::{Air, AirBuilder, BaseAir},
25 p3_field::{Field, FieldAlgebra},
26 p3_matrix::Matrix,
27 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
28};
29
30use super::{
31 Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH,
32 SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE,
33};
34
35#[derive(Clone, Debug)]
38pub struct Sha256VmAir {
39 pub execution_bridge: ExecutionBridge,
40 pub memory_bridge: MemoryBridge,
41 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
43 pub ptr_max_bits: usize,
46 pub(super) sha256_subair: Sha256Air,
47 pub(super) padding_encoder: Encoder,
48}
49
50impl Sha256VmAir {
51 pub fn new(
52 SystemPort {
53 execution_bus,
54 program_bus,
55 memory_bridge,
56 }: SystemPort,
57 bitwise_lookup_bus: BitwiseOperationLookupBus,
58 ptr_max_bits: usize,
59 self_bus_idx: BusIndex,
60 ) -> Self {
61 Self {
62 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
63 memory_bridge,
64 bitwise_lookup_bus,
65 ptr_max_bits,
66 sha256_subair: Sha256Air::new(bitwise_lookup_bus, self_bus_idx),
67 padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false),
68 }
69 }
70}
71
72impl<F: Field> BaseAirWithPublicValues<F> for Sha256VmAir {}
73impl<F: Field> PartitionedBaseAir<F> for Sha256VmAir {}
74impl<F: Field> BaseAir<F> for Sha256VmAir {
75 fn width(&self) -> usize {
76 SHA256VM_WIDTH
77 }
78}
79
80impl<AB: InteractionBuilder> Air<AB> for Sha256VmAir {
81 fn eval(&self, builder: &mut AB) {
82 self.eval_padding(builder);
83 self.eval_transitions(builder);
84 self.eval_reads(builder);
85 self.eval_last_row(builder);
86
87 self.sha256_subair.eval(builder, SHA256VM_CONTROL_WIDTH);
88 }
89}
90
91#[allow(dead_code, non_camel_case_types)]
92pub(super) enum PaddingFlags {
93 NotConsidered,
95 NotPadding,
97 FirstPadding0,
99 FirstPadding1,
100 FirstPadding2,
101 FirstPadding3,
102 FirstPadding4,
103 FirstPadding5,
104 FirstPadding6,
105 FirstPadding7,
106 FirstPadding8,
107 FirstPadding9,
108 FirstPadding10,
109 FirstPadding11,
110 FirstPadding12,
111 FirstPadding13,
112 FirstPadding14,
113 FirstPadding15,
114 FirstPadding0_LastRow,
119 FirstPadding1_LastRow,
120 FirstPadding2_LastRow,
121 FirstPadding3_LastRow,
122 FirstPadding4_LastRow,
123 FirstPadding5_LastRow,
124 FirstPadding6_LastRow,
125 FirstPadding7_LastRow,
126 EntirePaddingLastRow,
129 EntirePadding,
131}
132
133impl PaddingFlags {
134 pub const COUNT: usize = EntirePadding as usize + 1;
136}
137
138use PaddingFlags::*;
139impl Sha256VmAir {
140 fn eval_padding<AB: InteractionBuilder>(&self, builder: &mut AB) {
142 let main = builder.main();
143 let (local, next) = (main.row_slice(0), main.row_slice(1));
144 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
145 let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
146
147 self.padding_encoder
149 .eval(builder, &local_cols.control.pad_flags);
150
151 builder.assert_one(self.padding_encoder.contains_flag_range::<AB>(
152 &local_cols.control.pad_flags,
153 NotConsidered as usize..=EntirePadding as usize,
154 ));
155
156 Self::eval_padding_transitions(self, builder, local_cols, next_cols);
157 Self::eval_padding_row(self, builder, local_cols);
158 }
159
160 fn eval_padding_transitions<AB: InteractionBuilder>(
161 &self,
162 builder: &mut AB,
163 local: &Sha256VmRoundCols<AB::Var>,
164 next: &Sha256VmRoundCols<AB::Var>,
165 ) {
166 let next_is_last_row = next.inner.flags.is_digest_row * next.inner.flags.is_last_block;
167
168 builder.assert_bool(local.control.padding_occurred);
173 builder
176 .when(next_is_last_row.clone())
177 .assert_one(local.control.padding_occurred);
178
179 builder
181 .when(next_is_last_row.clone())
182 .assert_zero(next.control.padding_occurred);
183
184 builder
187 .when(local.control.padding_occurred - next_is_last_row.clone())
188 .assert_one(next.control.padding_occurred);
189
190 builder
194 .when_transition()
195 .when(not(next.inner.flags.is_first_4_rows) - next_is_last_row)
196 .assert_eq(
197 next.control.padding_occurred,
198 local.control.padding_occurred,
199 );
200
201 let next_is_first_padding_row =
203 next.control.padding_occurred - local.control.padding_occurred;
204 let next_row_idx = self.sha256_subair.row_idx_encoder.flag_with_val::<AB>(
206 &next.inner.flags.row_idx,
207 &(0..4).map(|x| (x, x)).collect::<Vec<_>>(),
208 );
209 let next_padding_offset = self.padding_encoder.flag_with_val::<AB>(
212 &next.control.pad_flags,
213 &(0..16)
214 .map(|i| (FirstPadding0 as usize + i, i))
215 .collect::<Vec<_>>(),
216 ) + self.padding_encoder.flag_with_val::<AB>(
217 &next.control.pad_flags,
218 &(0..8)
219 .map(|i| (FirstPadding0_LastRow as usize + i, i))
220 .collect::<Vec<_>>(),
221 );
222
223 let expected_len = next.inner.flags.local_block_idx
228 * next.control.padding_occurred
229 * AB::Expr::from_canonical_usize(SHA256_BLOCK_U8S)
230 + next_row_idx * AB::Expr::from_canonical_usize(SHA256_READ_SIZE)
231 + next_padding_offset;
232
233 builder.when(next_is_first_padding_row).assert_eq(
237 expected_len,
238 next.control.len * next.control.padding_occurred,
239 );
240
241 let is_next_first_padding = self.padding_encoder.contains_flag_range::<AB>(
243 &next.control.pad_flags,
244 FirstPadding0 as usize..=FirstPadding7_LastRow as usize,
245 );
246
247 let is_next_last_padding = self.padding_encoder.contains_flag_range::<AB>(
248 &next.control.pad_flags,
249 FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
250 );
251
252 let is_next_entire_padding = self.padding_encoder.contains_flag_range::<AB>(
253 &next.control.pad_flags,
254 EntirePaddingLastRow as usize..=EntirePadding as usize,
255 );
256
257 let is_next_not_considered = self
258 .padding_encoder
259 .contains_flag::<AB>(&next.control.pad_flags, &[NotConsidered as usize]);
260
261 let is_next_not_padding = self
262 .padding_encoder
263 .contains_flag::<AB>(&next.control.pad_flags, &[NotPadding as usize]);
264
265 let is_next_4th_row = self
266 .sha256_subair
267 .row_idx_encoder
268 .contains_flag::<AB>(&next.inner.flags.row_idx, &[3]);
269
270 builder.assert_eq(
272 not(next.inner.flags.is_first_4_rows),
273 is_next_not_considered,
274 );
275
276 builder.when(next.inner.flags.is_first_4_rows).assert_eq(
278 local.control.padding_occurred * next.control.padding_occurred,
279 is_next_entire_padding,
280 );
281
282 builder.when(next.inner.flags.is_first_4_rows).assert_eq(
285 not(local.control.padding_occurred) * next.control.padding_occurred,
286 is_next_first_padding,
287 );
288
289 builder
291 .when(next.inner.flags.is_first_4_rows)
292 .assert_eq(not(next.control.padding_occurred), is_next_not_padding);
293
294 builder
296 .when(next.inner.flags.is_last_block)
297 .assert_eq(is_next_4th_row, is_next_last_padding);
298 }
299
300 fn eval_padding_row<AB: InteractionBuilder>(
301 &self,
302 builder: &mut AB,
303 local: &Sha256VmRoundCols<AB::Var>,
304 ) {
305 let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
306 local.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U8S)]
307 [i % (SHA256_WORD_U8S)]
308 });
309
310 let get_ith_byte = |i: usize| {
311 let word_idx = i / SHA256_ROUNDS_PER_ROW;
312 let word = local.inner.message_schedule.w[word_idx].map(|x| x.into());
313 let byte_idx = 4 - i % 4 - 1;
315 compose::<AB::Expr>(&word[byte_idx * 8..(byte_idx + 1) * 8], 1)
316 };
317
318 let is_not_padding = self
319 .padding_encoder
320 .contains_flag::<AB>(&local.control.pad_flags, &[NotPadding as usize]);
321
322 for (i, message_byte) in message.iter().enumerate() {
324 let w = get_ith_byte(i);
325 let should_be_message = is_not_padding.clone()
326 + if i < 15 {
327 self.padding_encoder.contains_flag_range::<AB>(
328 &local.control.pad_flags,
329 FirstPadding0 as usize + i + 1..=FirstPadding15 as usize,
330 )
331 } else {
332 AB::Expr::ZERO
333 }
334 + if i < 7 {
335 self.padding_encoder.contains_flag_range::<AB>(
336 &local.control.pad_flags,
337 FirstPadding0_LastRow as usize + i + 1..=FirstPadding7_LastRow as usize,
338 )
339 } else {
340 AB::Expr::ZERO
341 };
342 builder
343 .when(should_be_message)
344 .assert_eq(w.clone(), *message_byte);
345
346 let should_be_zero = self
347 .padding_encoder
348 .contains_flag::<AB>(&local.control.pad_flags, &[EntirePadding as usize])
349 + if i < 12 {
350 self.padding_encoder.contains_flag::<AB>(
351 &local.control.pad_flags,
352 &[EntirePaddingLastRow as usize],
353 ) + if i > 0 {
354 self.padding_encoder.contains_flag_range::<AB>(
355 &local.control.pad_flags,
356 FirstPadding0_LastRow as usize
357 ..=min(
358 FirstPadding0_LastRow as usize + i - 1,
359 FirstPadding7_LastRow as usize,
360 ),
361 )
362 } else {
363 AB::Expr::ZERO
364 }
365 } else {
366 AB::Expr::ZERO
367 }
368 + if i > 0 {
369 self.padding_encoder.contains_flag_range::<AB>(
370 &local.control.pad_flags,
371 FirstPadding0 as usize..=FirstPadding0 as usize + i - 1,
372 )
373 } else {
374 AB::Expr::ZERO
375 };
376 builder.when(should_be_zero).assert_zero(w.clone());
377
378 let should_be_128 = self
381 .padding_encoder
382 .contains_flag::<AB>(&local.control.pad_flags, &[FirstPadding0 as usize + i])
383 + if i < 8 {
384 self.padding_encoder.contains_flag::<AB>(
385 &local.control.pad_flags,
386 &[FirstPadding0_LastRow as usize + i],
387 )
388 } else {
389 AB::Expr::ZERO
390 };
391
392 builder
393 .when(should_be_128)
394 .assert_eq(AB::Expr::from_canonical_u32(1 << 7), w);
395
396 }
398 let appended_len = compose::<AB::Expr>(
399 &[
400 get_ith_byte(15),
401 get_ith_byte(14),
402 get_ith_byte(13),
403 get_ith_byte(12),
404 ],
405 RV32_CELL_BITS,
406 );
407
408 let actual_len = local.control.len;
409
410 let is_last_padding_row = self.padding_encoder.contains_flag_range::<AB>(
411 &local.control.pad_flags,
412 FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
413 );
414
415 builder.when(is_last_padding_row.clone()).assert_eq(
416 appended_len * AB::F::from_canonical_usize(RV32_CELL_BITS).inverse(), actual_len,
418 );
419
420 builder.when(is_last_padding_row.clone()).assert_zero(
422 local.inner.message_schedule.w[3][0]
423 + local.inner.message_schedule.w[3][1]
424 + local.inner.message_schedule.w[3][2],
425 );
426
427 for i in 8..12 {
431 builder
432 .when(is_last_padding_row.clone())
433 .assert_zero(get_ith_byte(i));
434 }
435 }
436 fn eval_transitions<AB: InteractionBuilder>(&self, builder: &mut AB) {
438 let main = builder.main();
439 let (local, next) = (main.row_slice(0), main.row_slice(1));
440 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
441 let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
442
443 let is_last_row =
444 local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
445
446 builder
448 .when_transition()
449 .when(not::<AB::Expr>(is_last_row.clone()))
450 .assert_eq(next_cols.control.len, local_cols.control.len);
451
452 let read_ptr_delta = local_cols.inner.flags.is_first_4_rows
455 * AB::Expr::from_canonical_usize(SHA256_READ_SIZE);
456 builder
457 .when_transition()
458 .when(not::<AB::Expr>(is_last_row.clone()))
459 .assert_eq(
460 next_cols.control.read_ptr,
461 local_cols.control.read_ptr + read_ptr_delta,
462 );
463
464 let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE;
466 builder
467 .when_transition()
468 .when(not::<AB::Expr>(is_last_row.clone()))
469 .assert_eq(
470 next_cols.control.cur_timestamp,
471 local_cols.control.cur_timestamp + timestamp_delta,
472 );
473 }
474
475 fn eval_reads<AB: InteractionBuilder>(&self, builder: &mut AB) {
477 let main = builder.main();
478 let local = main.row_slice(0);
479 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
480
481 let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
482 local_cols.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U16S * 2)]
483 [i % (SHA256_WORD_U16S * 2)]
484 });
485
486 self.memory_bridge
487 .read(
488 MemoryAddress::new(
489 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
490 local_cols.control.read_ptr,
491 ),
492 message,
493 local_cols.control.cur_timestamp,
494 &local_cols.read_aux,
495 )
496 .eval(builder, local_cols.inner.flags.is_first_4_rows);
497 }
498 fn eval_last_row<AB: InteractionBuilder>(&self, builder: &mut AB) {
500 let main = builder.main();
501 let local = main.row_slice(0);
502 let local_cols: &Sha256VmDigestCols<AB::Var> = local[..SHA256VM_DIGEST_WIDTH].borrow();
503
504 let timestamp: AB::Var = local_cols.from_state.timestamp;
505 let mut timestamp_delta: usize = 0;
506 let mut timestamp_pp = || {
507 timestamp_delta += 1;
508 timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
509 };
510
511 let is_last_row =
512 local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
513
514 self.memory_bridge
515 .read(
516 MemoryAddress::new(
517 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
518 local_cols.rd_ptr,
519 ),
520 local_cols.dst_ptr,
521 timestamp_pp(),
522 &local_cols.register_reads_aux[0],
523 )
524 .eval(builder, is_last_row.clone());
525
526 self.memory_bridge
527 .read(
528 MemoryAddress::new(
529 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
530 local_cols.rs1_ptr,
531 ),
532 local_cols.src_ptr,
533 timestamp_pp(),
534 &local_cols.register_reads_aux[1],
535 )
536 .eval(builder, is_last_row.clone());
537
538 self.memory_bridge
539 .read(
540 MemoryAddress::new(
541 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
542 local_cols.rs2_ptr,
543 ),
544 local_cols.len_data,
545 timestamp_pp(),
546 &local_cols.register_reads_aux[2],
547 )
548 .eval(builder, is_last_row.clone());
549
550 let shift = AB::Expr::from_canonical_usize(
554 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits),
555 );
556 self.bitwise_lookup_bus
558 .send_range(
559 local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
562 local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
563 )
564 .eval(builder, is_last_row.clone());
565
566 let time_delta = (local_cols.inner.flags.local_block_idx + AB::Expr::ONE)
568 * AB::Expr::from_canonical_usize(4);
569 let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(SHA256_READ_SIZE);
571
572 let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| {
573 local_cols.inner.final_hash[i / SHA256_WORD_U8S]
575 [SHA256_WORD_U8S - i % SHA256_WORD_U8S - 1]
576 });
577
578 let dst_ptr_val =
579 compose::<AB::Expr>(&local_cols.dst_ptr.map(|x| x.into()), RV32_CELL_BITS);
580
581 self.memory_bridge
585 .write(
586 MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), dst_ptr_val),
587 result,
588 timestamp_pp() + time_delta.clone(),
589 &local_cols.writes_aux,
590 )
591 .eval(builder, is_last_row.clone());
592
593 self.execution_bridge
594 .execute_and_increment_pc(
595 AB::Expr::from_canonical_usize(Rv32Sha256Opcode::SHA256.global_opcode().as_usize()),
596 [
597 local_cols.rd_ptr.into(),
598 local_cols.rs1_ptr.into(),
599 local_cols.rs2_ptr.into(),
600 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
601 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
602 ],
603 local_cols.from_state,
604 AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(),
605 )
606 .eval(builder, is_last_row.clone());
607
608 let len_val = compose::<AB::Expr>(&local_cols.len_data.map(|x| x.into()), RV32_CELL_BITS);
610 builder
611 .when(is_last_row.clone())
612 .assert_eq(local_cols.control.len, len_val);
613 let src_val = compose::<AB::Expr>(&local_cols.src_ptr.map(|x| x.into()), RV32_CELL_BITS);
615 builder
616 .when(is_last_row.clone())
617 .assert_eq(local_cols.control.read_ptr, src_val + read_ptr_delta);
618 builder.when(is_last_row.clone()).assert_eq(
620 local_cols.control.cur_timestamp,
621 local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta,
622 );
623 }
624}