1use std::{array, borrow::Borrow, cmp::min};
2
3use openvm_circuit::{
4 arch::ExecutionBridge,
5 system::memory::{offline_checker::MemoryBridge, MemoryAddress},
6};
7use openvm_circuit_primitives::{
8 bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir,
9};
10use openvm_instructions::{
11 riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12 LocalOpcode,
13};
14use openvm_sha256_air::{
15 compose, Sha256Air, SHA256_BLOCK_U8S, SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW,
16 SHA256_WORD_U16S, SHA256_WORD_U8S,
17};
18use openvm_sha256_transpiler::Rv32Sha256Opcode;
19use openvm_stark_backend::{
20 interaction::InteractionBuilder,
21 p3_air::{Air, AirBuilder, BaseAir},
22 p3_field::{Field, FieldAlgebra},
23 p3_matrix::Matrix,
24 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
25};
26
27use super::{
28 Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH,
29 SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE,
30};
31
32#[derive(Clone, Debug, derive_new::new)]
35pub struct Sha256VmAir {
36 pub execution_bridge: ExecutionBridge,
37 pub memory_bridge: MemoryBridge,
38 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
40 pub ptr_max_bits: usize,
43 pub(super) sha256_subair: Sha256Air,
44 pub(super) padding_encoder: Encoder,
45}
46
47impl<F: Field> BaseAirWithPublicValues<F> for Sha256VmAir {}
48impl<F: Field> PartitionedBaseAir<F> for Sha256VmAir {}
49impl<F: Field> BaseAir<F> for Sha256VmAir {
50 fn width(&self) -> usize {
51 SHA256VM_WIDTH
52 }
53}
54
55impl<AB: InteractionBuilder> Air<AB> for Sha256VmAir {
56 fn eval(&self, builder: &mut AB) {
57 self.eval_padding(builder);
58 self.eval_transitions(builder);
59 self.eval_reads(builder);
60 self.eval_last_row(builder);
61
62 self.sha256_subair.eval(builder, SHA256VM_CONTROL_WIDTH);
63 }
64}
65
66#[allow(dead_code, non_camel_case_types)]
67pub(super) enum PaddingFlags {
68 NotConsidered,
70 NotPadding,
72 FirstPadding0,
74 FirstPadding1,
75 FirstPadding2,
76 FirstPadding3,
77 FirstPadding4,
78 FirstPadding5,
79 FirstPadding6,
80 FirstPadding7,
81 FirstPadding8,
82 FirstPadding9,
83 FirstPadding10,
84 FirstPadding11,
85 FirstPadding12,
86 FirstPadding13,
87 FirstPadding14,
88 FirstPadding15,
89 FirstPadding0_LastRow,
94 FirstPadding1_LastRow,
95 FirstPadding2_LastRow,
96 FirstPadding3_LastRow,
97 FirstPadding4_LastRow,
98 FirstPadding5_LastRow,
99 FirstPadding6_LastRow,
100 FirstPadding7_LastRow,
101 EntirePaddingLastRow,
104 EntirePadding,
106}
107
108impl PaddingFlags {
109 pub const COUNT: usize = EntirePadding as usize + 1;
111}
112
113use PaddingFlags::*;
114impl Sha256VmAir {
115 fn eval_padding<AB: InteractionBuilder>(&self, builder: &mut AB) {
117 let main = builder.main();
118 let (local, next) = (main.row_slice(0), main.row_slice(1));
119 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
120 let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
121
122 self.padding_encoder
124 .eval(builder, &local_cols.control.pad_flags);
125
126 builder.assert_one(self.padding_encoder.contains_flag_range::<AB>(
127 &local_cols.control.pad_flags,
128 NotConsidered as usize..=EntirePadding as usize,
129 ));
130
131 Self::eval_padding_transitions(self, builder, local_cols, next_cols);
132 Self::eval_padding_row(self, builder, local_cols);
133 }
134
135 fn eval_padding_transitions<AB: InteractionBuilder>(
136 &self,
137 builder: &mut AB,
138 local: &Sha256VmRoundCols<AB::Var>,
139 next: &Sha256VmRoundCols<AB::Var>,
140 ) {
141 let next_is_last_row = next.inner.flags.is_digest_row * next.inner.flags.is_last_block;
142
143 builder.assert_bool(local.control.padding_occurred);
148 builder
151 .when(next_is_last_row.clone())
152 .assert_one(local.control.padding_occurred);
153
154 builder
156 .when(next_is_last_row.clone())
157 .assert_zero(next.control.padding_occurred);
158
159 builder
162 .when(local.control.padding_occurred - next_is_last_row.clone())
163 .assert_one(next.control.padding_occurred);
164
165 builder
169 .when_transition()
170 .when(not(next.inner.flags.is_first_4_rows) - next_is_last_row)
171 .assert_eq(
172 next.control.padding_occurred,
173 local.control.padding_occurred,
174 );
175
176 let next_is_first_padding_row =
178 next.control.padding_occurred - local.control.padding_occurred;
179 let next_row_idx = self.sha256_subair.row_idx_encoder.flag_with_val::<AB>(
181 &next.inner.flags.row_idx,
182 &(0..4).map(|x| (x, x)).collect::<Vec<_>>(),
183 );
184 let next_padding_offset = self.padding_encoder.flag_with_val::<AB>(
187 &next.control.pad_flags,
188 &(0..16)
189 .map(|i| (FirstPadding0 as usize + i, i))
190 .collect::<Vec<_>>(),
191 ) + self.padding_encoder.flag_with_val::<AB>(
192 &next.control.pad_flags,
193 &(0..8)
194 .map(|i| (FirstPadding0_LastRow as usize + i, i))
195 .collect::<Vec<_>>(),
196 );
197
198 let expected_len = next.inner.flags.local_block_idx
203 * next.control.padding_occurred
204 * AB::Expr::from_canonical_usize(SHA256_BLOCK_U8S)
205 + next_row_idx * AB::Expr::from_canonical_usize(SHA256_READ_SIZE)
206 + next_padding_offset;
207
208 builder.when(next_is_first_padding_row).assert_eq(
212 expected_len,
213 next.control.len * next.control.padding_occurred,
214 );
215
216 let is_next_first_padding = self.padding_encoder.contains_flag_range::<AB>(
218 &next.control.pad_flags,
219 FirstPadding0 as usize..=FirstPadding7_LastRow as usize,
220 );
221
222 let is_next_last_padding = self.padding_encoder.contains_flag_range::<AB>(
223 &next.control.pad_flags,
224 FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
225 );
226
227 let is_next_entire_padding = self.padding_encoder.contains_flag_range::<AB>(
228 &next.control.pad_flags,
229 EntirePaddingLastRow as usize..=EntirePadding as usize,
230 );
231
232 let is_next_not_considered = self
233 .padding_encoder
234 .contains_flag::<AB>(&next.control.pad_flags, &[NotConsidered as usize]);
235
236 let is_next_not_padding = self
237 .padding_encoder
238 .contains_flag::<AB>(&next.control.pad_flags, &[NotPadding as usize]);
239
240 let is_next_4th_row = self
241 .sha256_subair
242 .row_idx_encoder
243 .contains_flag::<AB>(&next.inner.flags.row_idx, &[3]);
244
245 builder.assert_eq(
247 not(next.inner.flags.is_first_4_rows),
248 is_next_not_considered,
249 );
250
251 builder.when(next.inner.flags.is_first_4_rows).assert_eq(
253 local.control.padding_occurred * next.control.padding_occurred,
254 is_next_entire_padding,
255 );
256
257 builder.when(next.inner.flags.is_first_4_rows).assert_eq(
260 not(local.control.padding_occurred) * next.control.padding_occurred,
261 is_next_first_padding,
262 );
263
264 builder
266 .when(next.inner.flags.is_first_4_rows)
267 .assert_eq(not(next.control.padding_occurred), is_next_not_padding);
268
269 builder
271 .when(next.inner.flags.is_last_block)
272 .assert_eq(is_next_4th_row, is_next_last_padding);
273 }
274
275 fn eval_padding_row<AB: InteractionBuilder>(
276 &self,
277 builder: &mut AB,
278 local: &Sha256VmRoundCols<AB::Var>,
279 ) {
280 let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
281 local.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U8S)]
282 [i % (SHA256_WORD_U8S)]
283 });
284
285 let get_ith_byte = |i: usize| {
286 let word_idx = i / SHA256_ROUNDS_PER_ROW;
287 let word = local.inner.message_schedule.w[word_idx].map(|x| x.into());
288 let byte_idx = 4 - i % 4 - 1;
290 compose::<AB::Expr>(&word[byte_idx * 8..(byte_idx + 1) * 8], 1)
291 };
292
293 let is_not_padding = self
294 .padding_encoder
295 .contains_flag::<AB>(&local.control.pad_flags, &[NotPadding as usize]);
296
297 for (i, message_byte) in message.iter().enumerate() {
299 let w = get_ith_byte(i);
300 let should_be_message = is_not_padding.clone()
301 + if i < 15 {
302 self.padding_encoder.contains_flag_range::<AB>(
303 &local.control.pad_flags,
304 FirstPadding0 as usize + i + 1..=FirstPadding15 as usize,
305 )
306 } else {
307 AB::Expr::ZERO
308 }
309 + if i < 7 {
310 self.padding_encoder.contains_flag_range::<AB>(
311 &local.control.pad_flags,
312 FirstPadding0_LastRow as usize + i + 1..=FirstPadding7_LastRow as usize,
313 )
314 } else {
315 AB::Expr::ZERO
316 };
317 builder
318 .when(should_be_message)
319 .assert_eq(w.clone(), *message_byte);
320
321 let should_be_zero = self
322 .padding_encoder
323 .contains_flag::<AB>(&local.control.pad_flags, &[EntirePadding as usize])
324 + if i < 12 {
325 self.padding_encoder.contains_flag::<AB>(
326 &local.control.pad_flags,
327 &[EntirePaddingLastRow as usize],
328 ) + if i > 0 {
329 self.padding_encoder.contains_flag_range::<AB>(
330 &local.control.pad_flags,
331 FirstPadding0_LastRow as usize
332 ..=min(
333 FirstPadding0_LastRow as usize + i - 1,
334 FirstPadding7_LastRow as usize,
335 ),
336 )
337 } else {
338 AB::Expr::ZERO
339 }
340 } else {
341 AB::Expr::ZERO
342 }
343 + if i > 0 {
344 self.padding_encoder.contains_flag_range::<AB>(
345 &local.control.pad_flags,
346 FirstPadding0 as usize..=FirstPadding0 as usize + i - 1,
347 )
348 } else {
349 AB::Expr::ZERO
350 };
351 builder.when(should_be_zero).assert_zero(w.clone());
352
353 let should_be_128 = self
356 .padding_encoder
357 .contains_flag::<AB>(&local.control.pad_flags, &[FirstPadding0 as usize + i])
358 + if i < 8 {
359 self.padding_encoder.contains_flag::<AB>(
360 &local.control.pad_flags,
361 &[FirstPadding0_LastRow as usize + i],
362 )
363 } else {
364 AB::Expr::ZERO
365 };
366
367 builder
368 .when(should_be_128)
369 .assert_eq(AB::Expr::from_canonical_u32(1 << 7), w);
370
371 }
373 let appended_len = compose::<AB::Expr>(
374 &[
375 get_ith_byte(15),
376 get_ith_byte(14),
377 get_ith_byte(13),
378 get_ith_byte(12),
379 ],
380 RV32_CELL_BITS,
381 );
382
383 let actual_len = local.control.len;
384
385 let is_last_padding_row = self.padding_encoder.contains_flag_range::<AB>(
386 &local.control.pad_flags,
387 FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
388 );
389
390 builder.when(is_last_padding_row.clone()).assert_eq(
391 appended_len * AB::F::from_canonical_usize(RV32_CELL_BITS).inverse(), actual_len,
393 );
394
395 builder.when(is_last_padding_row.clone()).assert_zero(
397 local.inner.message_schedule.w[3][0]
398 + local.inner.message_schedule.w[3][1]
399 + local.inner.message_schedule.w[3][2],
400 );
401
402 for i in 8..12 {
406 builder
407 .when(is_last_padding_row.clone())
408 .assert_zero(get_ith_byte(i));
409 }
410 }
411 fn eval_transitions<AB: InteractionBuilder>(&self, builder: &mut AB) {
413 let main = builder.main();
414 let (local, next) = (main.row_slice(0), main.row_slice(1));
415 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
416 let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
417
418 let is_last_row =
419 local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
420
421 builder
423 .when_transition()
424 .when(not::<AB::Expr>(is_last_row.clone()))
425 .assert_eq(next_cols.control.len, local_cols.control.len);
426
427 let read_ptr_delta = local_cols.inner.flags.is_first_4_rows
430 * AB::Expr::from_canonical_usize(SHA256_READ_SIZE);
431 builder
432 .when_transition()
433 .when(not::<AB::Expr>(is_last_row.clone()))
434 .assert_eq(
435 next_cols.control.read_ptr,
436 local_cols.control.read_ptr + read_ptr_delta,
437 );
438
439 let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE;
441 builder
442 .when_transition()
443 .when(not::<AB::Expr>(is_last_row.clone()))
444 .assert_eq(
445 next_cols.control.cur_timestamp,
446 local_cols.control.cur_timestamp + timestamp_delta,
447 );
448 }
449
450 fn eval_reads<AB: InteractionBuilder>(&self, builder: &mut AB) {
452 let main = builder.main();
453 let local = main.row_slice(0);
454 let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
455
456 let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
457 local_cols.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U16S * 2)]
458 [i % (SHA256_WORD_U16S * 2)]
459 });
460
461 self.memory_bridge
462 .read(
463 MemoryAddress::new(
464 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
465 local_cols.control.read_ptr,
466 ),
467 message,
468 local_cols.control.cur_timestamp,
469 &local_cols.read_aux,
470 )
471 .eval(builder, local_cols.inner.flags.is_first_4_rows);
472 }
473 fn eval_last_row<AB: InteractionBuilder>(&self, builder: &mut AB) {
475 let main = builder.main();
476 let local = main.row_slice(0);
477 let local_cols: &Sha256VmDigestCols<AB::Var> = local[..SHA256VM_DIGEST_WIDTH].borrow();
478
479 let timestamp: AB::Var = local_cols.from_state.timestamp;
480 let mut timestamp_delta: usize = 0;
481 let mut timestamp_pp = || {
482 timestamp_delta += 1;
483 timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
484 };
485
486 let is_last_row =
487 local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
488
489 self.memory_bridge
490 .read(
491 MemoryAddress::new(
492 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
493 local_cols.rd_ptr,
494 ),
495 local_cols.dst_ptr,
496 timestamp_pp(),
497 &local_cols.register_reads_aux[0],
498 )
499 .eval(builder, is_last_row.clone());
500
501 self.memory_bridge
502 .read(
503 MemoryAddress::new(
504 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
505 local_cols.rs1_ptr,
506 ),
507 local_cols.src_ptr,
508 timestamp_pp(),
509 &local_cols.register_reads_aux[1],
510 )
511 .eval(builder, is_last_row.clone());
512
513 self.memory_bridge
514 .read(
515 MemoryAddress::new(
516 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
517 local_cols.rs2_ptr,
518 ),
519 local_cols.len_data,
520 timestamp_pp(),
521 &local_cols.register_reads_aux[2],
522 )
523 .eval(builder, is_last_row.clone());
524
525 let shift = AB::Expr::from_canonical_usize(
529 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits),
530 );
531 self.bitwise_lookup_bus
533 .send_range(
534 local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
537 local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
538 )
539 .eval(builder, is_last_row.clone());
540
541 let time_delta = (local_cols.inner.flags.local_block_idx + AB::Expr::ONE)
543 * AB::Expr::from_canonical_usize(4);
544 let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(SHA256_READ_SIZE);
546
547 let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| {
548 local_cols.inner.final_hash[i / SHA256_WORD_U8S]
550 [SHA256_WORD_U8S - i % SHA256_WORD_U8S - 1]
551 });
552
553 let dst_ptr_val =
554 compose::<AB::Expr>(&local_cols.dst_ptr.map(|x| x.into()), RV32_CELL_BITS);
555
556 self.memory_bridge
560 .write(
561 MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), dst_ptr_val),
562 result,
563 timestamp_pp() + time_delta.clone(),
564 &local_cols.writes_aux,
565 )
566 .eval(builder, is_last_row.clone());
567
568 self.execution_bridge
569 .execute_and_increment_pc(
570 AB::Expr::from_canonical_usize(Rv32Sha256Opcode::SHA256.global_opcode().as_usize()),
571 [
572 local_cols.rd_ptr.into(),
573 local_cols.rs1_ptr.into(),
574 local_cols.rs2_ptr.into(),
575 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
576 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
577 ],
578 local_cols.from_state,
579 AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(),
580 )
581 .eval(builder, is_last_row.clone());
582
583 let len_val = compose::<AB::Expr>(&local_cols.len_data.map(|x| x.into()), RV32_CELL_BITS);
585 builder
586 .when(is_last_row.clone())
587 .assert_eq(local_cols.control.len, len_val);
588 let src_val = compose::<AB::Expr>(&local_cols.src_ptr.map(|x| x.into()), RV32_CELL_BITS);
590 builder
591 .when(is_last_row.clone())
592 .assert_eq(local_cols.control.read_ptr, src_val + read_ptr_delta);
593 builder.when(is_last_row.clone()).assert_eq(
595 local_cols.control.cur_timestamp,
596 local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta,
597 );
598 }
599}