1use std::{array::from_fn, borrow::Borrow, iter::zip};
2
3use itertools::{izip, Itertools};
4use openvm_circuit::{
5 arch::{ExecutionBridge, ExecutionState},
6 system::memory::{
7 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
8 MemoryAddress,
9 },
10};
11use openvm_circuit_primitives::{
12 bitwise_op_lookup::BitwiseOperationLookupBus,
13 utils::{assert_array_eq, not, select},
14};
15use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS};
16use openvm_keccak256_transpiler::Rv32KeccakOpcode;
17use openvm_rv32im_circuit::adapters::abstract_compose;
18use openvm_stark_backend::{
19 air_builders::sub::SubAirBuilder,
20 interaction::InteractionBuilder,
21 p3_air::{Air, AirBuilder, BaseAir},
22 p3_field::FieldAlgebra,
23 p3_matrix::Matrix,
24 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
25};
26use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, U64_LIMBS};
27
28use super::{
29 columns::{KeccakVmCols, NUM_KECCAK_VM_COLS},
30 KECCAK_ABSORB_READS, KECCAK_DIGEST_BYTES, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES,
31 KECCAK_RATE_U16S, KECCAK_REGISTER_READS, KECCAK_WIDTH_U16S, KECCAK_WORD_SIZE,
32 NUM_ABSORB_ROUNDS,
33};
34
35#[derive(Clone, Copy, Debug, derive_new::new)]
36pub struct KeccakVmAir {
37 pub execution_bridge: ExecutionBridge,
38 pub memory_bridge: MemoryBridge,
39 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
41 pub ptr_max_bits: usize,
43 pub(super) offset: usize,
44}
45
46impl<F> BaseAirWithPublicValues<F> for KeccakVmAir {}
47impl<F> PartitionedBaseAir<F> for KeccakVmAir {}
48impl<F> BaseAir<F> for KeccakVmAir {
49 fn width(&self) -> usize {
50 NUM_KECCAK_VM_COLS
51 }
52}
53
54impl<AB: InteractionBuilder> Air<AB> for KeccakVmAir {
55 fn eval(&self, builder: &mut AB) {
56 let main = builder.main();
57 let (local, next) = (main.row_slice(0), main.row_slice(1));
58 let local: &KeccakVmCols<AB::Var> = (*local).borrow();
59 let next: &KeccakVmCols<AB::Var> = (*next).borrow();
60
61 builder.assert_bool(local.sponge.is_new_start);
62 builder.assert_eq(
63 local.sponge.is_new_start,
64 local.sponge.is_new_start * local.is_first_round(),
65 );
66 builder.assert_eq(
67 local.instruction.is_enabled_first_round,
68 local.instruction.is_enabled * local.is_first_round(),
69 );
70 builder
72 .when_first_row()
73 .assert_one(local.sponge.is_new_start);
74
75 self.eval_keccak_f(builder);
76 self.constrain_padding(builder, local, next);
77 self.constrain_consistency_across_rounds(builder, local, next);
78
79 let mem = &local.mem_oc;
80 self.constrain_absorb(builder, local, next);
82 let start_read_timestamp = self.eval_instruction(builder, local, &mem.register_aux);
83 let start_write_timestamp =
84 self.constrain_input_read(builder, local, start_read_timestamp, &mem.absorb_reads);
85 self.constrain_output_write(
86 builder,
87 local,
88 start_write_timestamp.clone(),
89 &mem.digest_writes,
90 );
91
92 self.constrain_block_transition(builder, local, next, start_write_timestamp);
93 }
94}
95
96impl KeccakVmAir {
97 #[inline]
101 pub fn eval_keccak_f<AB: AirBuilder>(&self, builder: &mut AB) {
102 let keccak_f_air = KeccakAir {};
103 let mut sub_builder =
104 SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, 0..NUM_KECCAK_PERM_COLS);
105 keccak_f_air.eval(&mut sub_builder);
106 }
107
108 pub fn constrain_consistency_across_rounds<AB: AirBuilder>(
110 &self,
111 builder: &mut AB,
112 local: &KeccakVmCols<AB::Var>,
113 next: &KeccakVmCols<AB::Var>,
114 ) {
115 let mut transition_builder = builder.when_transition();
116 let mut round_builder = transition_builder.when(not(local.is_last_round()));
117 local
119 .instruction
120 .assert_eq(&mut round_builder, next.instruction);
121 }
122
123 pub fn constrain_block_transition<AB: AirBuilder>(
124 &self,
125 builder: &mut AB,
126 local: &KeccakVmCols<AB::Var>,
127 next: &KeccakVmCols<AB::Var>,
128 start_write_timestamp: AB::Expr,
129 ) {
130 let mut block_transition = builder.when(local.is_last_round() * not(next.is_new_start()));
135 block_transition.assert_eq(local.instruction.pc, next.instruction.pc);
136 block_transition.assert_eq(local.instruction.is_enabled, next.instruction.is_enabled);
137 assert_array_eq(
139 &mut block_transition,
140 local.instruction.dst,
141 next.instruction.dst,
142 );
143 block_transition.assert_eq(local.instruction.dst_ptr, next.instruction.dst_ptr);
145 block_transition.assert_eq(local.instruction.src_ptr, next.instruction.src_ptr);
146 block_transition.assert_eq(local.instruction.len_ptr, next.instruction.len_ptr);
147 block_transition.assert_eq(
152 next.instruction.src,
153 local.instruction.src + AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
154 );
155 block_transition.assert_eq(next.instruction.start_timestamp, start_write_timestamp);
158 block_transition.assert_eq(
159 next.instruction.remaining_len,
160 local.instruction.remaining_len - AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
161 );
162 }
164
165 pub fn constrain_padding<AB: AirBuilder>(
172 &self,
173 builder: &mut AB,
174 local: &KeccakVmCols<AB::Var>,
175 next: &KeccakVmCols<AB::Var>,
176 ) {
177 let is_padding_byte = local.sponge.is_padding_byte;
178 let block_bytes = &local.sponge.block_bytes;
179 let remaining_len = local.remaining_len();
180
181 for &is_padding_byte in is_padding_byte.iter() {
183 builder.assert_bool(is_padding_byte);
184 }
185 for i in 1..KECCAK_RATE_BYTES {
187 builder
188 .when(is_padding_byte[i - 1])
189 .assert_one(is_padding_byte[i]);
190 }
191 let is_last_round = next.inner.step_flags[0];
195 let is_not_last_round = not(is_last_round);
196 for i in 0..KECCAK_RATE_BYTES {
197 builder.when(is_not_last_round.clone()).assert_eq(
198 local.sponge.is_padding_byte[i],
199 next.sponge.is_padding_byte[i],
200 );
201 }
202
203 let num_padding_bytes = local
204 .sponge
205 .is_padding_byte
206 .iter()
207 .fold(AB::Expr::ZERO, |a, &b| a + b);
208
209 let is_final_block = is_padding_byte[KECCAK_RATE_BYTES - 1];
211
212 builder.when(is_final_block).assert_eq(
214 remaining_len,
215 AB::Expr::from_canonical_usize(KECCAK_RATE_BYTES) - num_padding_bytes,
216 );
217 builder
220 .when(is_last_round)
221 .when(not(is_final_block))
222 .assert_eq(
223 remaining_len - AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
224 next.remaining_len(),
225 );
226 builder
229 .when(is_last_round)
230 .when(next.is_new_start())
231 .assert_one(is_final_block);
232 builder
234 .when(is_last_round)
235 .when(is_final_block)
236 .assert_one(next.is_new_start());
237 let has_single_padding_byte: AB::Expr =
245 is_padding_byte[KECCAK_RATE_BYTES - 1] - is_padding_byte[KECCAK_RATE_BYTES - 2];
246
247 builder.when(has_single_padding_byte.clone()).assert_eq(
250 block_bytes[KECCAK_RATE_BYTES - 1],
251 AB::F::from_canonical_u8(0b10000001),
252 );
253
254 let has_multiple_padding_bytes: AB::Expr = not(has_single_padding_byte.clone());
255 for i in 0..KECCAK_RATE_BYTES - 1 {
256 let is_first_padding_byte: AB::Expr = {
257 if i > 0 {
258 is_padding_byte[i] - is_padding_byte[i - 1]
259 } else {
260 is_padding_byte[i].into()
261 }
262 };
263 builder
266 .when(has_multiple_padding_bytes.clone())
267 .when(is_first_padding_byte.clone())
268 .assert_eq(block_bytes[i], AB::F::from_canonical_u8(0x01));
269 builder
272 .when(is_padding_byte[i])
273 .when(not::<AB::Expr>(is_first_padding_byte)) .assert_zero(block_bytes[i]);
275 }
276
277 builder
280 .when(is_final_block)
281 .when(has_multiple_padding_bytes)
282 .assert_eq(
283 block_bytes[KECCAK_RATE_BYTES - 1],
284 AB::F::from_canonical_u8(0x80),
285 );
286 }
287
288 pub fn constrain_absorb<AB: InteractionBuilder>(
308 &self,
309 builder: &mut AB,
310 local: &KeccakVmCols<AB::Var>,
311 next: &KeccakVmCols<AB::Var>,
312 ) {
313 let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
314 let y = i / 5;
315 let x = i % 5;
316 (0..U64_LIMBS).flat_map(move |limb| {
317 let state_limb = local.postimage(y, x, limb);
318 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
319 let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
320 [lo, hi.into()]
322 })
323 });
324
325 let post_absorb_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
326 let y = i / 5;
327 let x = i % 5;
328 (0..U64_LIMBS).flat_map(move |limb| {
329 let state_limb = next.inner.preimage[y][x][limb];
330 let hi = next.sponge.state_hi[i * U64_LIMBS + limb];
331 let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
332 [lo, hi.into()]
333 })
334 });
335
336 let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
339 for (input, prev, post) in izip!(
340 next.sponge.block_bytes,
341 updated_state_bytes,
342 post_absorb_state_bytes
343 ) {
344 self.bitwise_lookup_bus
352 .send_xor(
353 input * not(is_final_block),
354 prev.clone(),
355 select(is_final_block, prev, post),
356 )
357 .eval(
358 builder,
359 local.is_last_round() * local.instruction.is_enabled,
360 );
361 }
362
363 let local_preimage_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
365 let y = i / 5;
366 let x = i % 5;
367 (0..U64_LIMBS).flat_map(move |limb| {
368 let state_limb = local.inner.preimage[y][x][limb];
369 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
370 let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
371 [lo, hi.into()]
372 })
373 });
374 let mut when_is_new_start =
375 builder.when(local.is_new_start() * local.instruction.is_enabled);
376 for (preimage_byte, block_byte) in zip(local_preimage_bytes, local.sponge.block_bytes) {
377 when_is_new_start.assert_eq(preimage_byte, block_byte);
378 }
379
380 let mut reset_builder = builder.when(local.is_new_start());
382 for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
383 let y = i / U64_LIMBS / 5;
384 let x = (i / U64_LIMBS) % 5;
385 let limb = i % U64_LIMBS;
386 reset_builder.assert_zero(local.inner.preimage[y][x][limb]);
387 }
388 let mut absorb_builder = builder.when(local.is_last_round() * not(next.is_new_start()));
389 for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
390 let y = i / U64_LIMBS / 5;
391 let x = (i / U64_LIMBS) % 5;
392 let limb = i % U64_LIMBS;
393 absorb_builder.assert_eq(local.postimage(y, x, limb), next.inner.preimage[y][x][limb]);
394 }
395 }
396
397 pub fn eval_instruction<AB: InteractionBuilder>(
406 &self,
407 builder: &mut AB,
408 local: &KeccakVmCols<AB::Var>,
409 register_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_REGISTER_READS],
410 ) -> AB::Expr {
411 let instruction = local.instruction;
412 let should_receive = local.instruction.is_enabled * local.sponge.is_new_start;
419
420 let [dst_ptr, src_ptr, len_ptr] = [
421 instruction.dst_ptr,
422 instruction.src_ptr,
423 instruction.len_ptr,
424 ];
425 let reg_addr_sp = AB::F::ONE;
426 let timestamp_change: AB::Expr = Self::timestamp_change(instruction.remaining_len);
427 self.execution_bridge
428 .execute_and_increment_pc(
429 AB::Expr::from_canonical_usize(Rv32KeccakOpcode::KECCAK256 as usize + self.offset),
430 [
431 dst_ptr.into(),
432 src_ptr.into(),
433 len_ptr.into(),
434 reg_addr_sp.into(),
435 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
436 ],
437 ExecutionState::new(instruction.pc, instruction.start_timestamp),
438 timestamp_change,
439 )
440 .eval(builder, should_receive.clone());
441
442 let mut timestamp: AB::Expr = instruction.start_timestamp.into();
443 let recover_limbs = |limbs: [AB::Var; RV32_REGISTER_NUM_LIMBS - 1],
444 val: AB::Var|
445 -> [AB::Expr; RV32_REGISTER_NUM_LIMBS] {
446 from_fn(|i| {
447 if i == 0 {
448 limbs
449 .into_iter()
450 .enumerate()
451 .fold(val.into(), |acc, (j, limb)| {
452 acc - limb
453 * AB::Expr::from_canonical_usize(1 << ((j + 1) * RV32_CELL_BITS))
454 })
455 } else {
456 limbs[i - 1].into()
457 }
458 })
459 };
460 let dst_data = instruction.dst.map(Into::into);
463 let src_data = recover_limbs(instruction.src_limbs, instruction.src);
464 let len_data = recover_limbs(instruction.len_limbs, instruction.remaining_len);
465 for (ptr, value, aux) in izip!(
466 [dst_ptr, src_ptr, len_ptr],
467 [dst_data, src_data, len_data],
468 register_aux,
469 ) {
470 self.memory_bridge
471 .read(
472 MemoryAddress::new(reg_addr_sp, ptr),
473 value,
474 timestamp.clone(),
475 aux,
476 )
477 .eval(builder, should_receive.clone());
478
479 timestamp += AB::Expr::ONE;
480 }
481 let need_range_check = [
486 *instruction.dst.last().unwrap(),
487 *instruction.src_limbs.last().unwrap(),
488 *instruction.len_limbs.last().unwrap(),
489 *instruction.len_limbs.last().unwrap(),
490 ];
491 let limb_shift = AB::F::from_canonical_usize(
492 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.ptr_max_bits),
493 );
494 for pair in need_range_check.chunks_exact(2) {
495 self.bitwise_lookup_bus
496 .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
497 .eval(builder, should_receive.clone());
498 }
499
500 timestamp
501 }
502
503 pub fn constrain_input_read<AB: InteractionBuilder>(
511 &self,
512 builder: &mut AB,
513 local: &KeccakVmCols<AB::Var>,
514 start_read_timestamp: AB::Expr,
515 mem_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_ABSORB_READS],
516 ) -> AB::Expr {
517 let partial_block = &local.mem_oc.partial_block;
518 let is_input = local.instruction.is_enabled_first_round;
521
522 let mut timestamp = start_read_timestamp;
523 for (i, (input, is_padding, mem_aux)) in izip!(
526 local.sponge.block_bytes.chunks_exact(KECCAK_WORD_SIZE),
527 local.sponge.is_padding_byte.chunks_exact(KECCAK_WORD_SIZE),
528 mem_aux
529 )
530 .enumerate()
531 {
532 let ptr = local.instruction.src + AB::F::from_canonical_usize(i * KECCAK_WORD_SIZE);
533 let count = is_input * not(is_padding[0]);
536 let is_partial_read = *is_padding.last().unwrap();
538 let word: [_; KECCAK_WORD_SIZE] = from_fn(|i| {
540 if i == 0 {
541 input[0].into()
543 } else {
544 select(is_partial_read, partial_block[i - 1], input[i])
546 }
547 });
548 for i in 1..KECCAK_WORD_SIZE {
549 let not_padding: AB::Expr = not(is_padding[i]);
550 builder.assert_eq(
553 not_padding.clone() * word[i].clone(),
554 not_padding.clone() * input[i],
555 );
556 }
557
558 self.memory_bridge
559 .read(
560 MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), ptr),
561 word, timestamp.clone(),
563 mem_aux,
564 )
565 .eval(builder, count);
566
567 timestamp += AB::Expr::ONE;
568 }
569 timestamp
570 }
571
572 pub fn constrain_output_write<AB: InteractionBuilder>(
573 &self,
574 builder: &mut AB,
575 local: &KeccakVmCols<AB::Var>,
576 start_write_timestamp: AB::Expr,
577 mem_aux: &[MemoryWriteAuxCols<AB::Var, KECCAK_WORD_SIZE>; KECCAK_DIGEST_WRITES],
578 ) {
579 let instruction = local.instruction;
580
581 let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
582 builder.assert_eq(
584 local.inner.export,
585 instruction.is_enabled * is_final_block * local.is_last_round(),
586 );
587 let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
591 let y = i / 5;
592 let x = i % 5;
593 (0..U64_LIMBS).flat_map(move |limb| {
594 let state_limb = local.postimage(y, x, limb);
595 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
596 let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
597 [lo, hi.into()]
599 })
600 });
601 let dst = abstract_compose::<AB::Expr, _>(instruction.dst);
602 for (i, digest_bytes) in updated_state_bytes
603 .take(KECCAK_DIGEST_BYTES)
604 .chunks(KECCAK_WORD_SIZE)
605 .into_iter()
606 .enumerate()
607 {
608 let digest_bytes = digest_bytes.collect_vec();
609 let timestamp = start_write_timestamp.clone() + AB::Expr::from_canonical_usize(i);
610 self.memory_bridge
611 .write(
612 MemoryAddress::new(
613 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
614 dst.clone() + AB::F::from_canonical_usize(i * KECCAK_WORD_SIZE),
615 ),
616 digest_bytes.try_into().unwrap(),
617 timestamp,
618 &mem_aux[i],
619 )
620 .eval(builder, local.inner.export)
621 }
622 }
623
624 pub fn timestamp_change<T: FieldAlgebra>(len: impl Into<T>) -> T {
627 len.into()
631 + T::from_canonical_usize(
632 KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES,
633 )
634 }
635}