1use std::{array::from_fn, borrow::Borrow, iter::zip};
2
3use itertools::{izip, Itertools};
4use openvm_circuit::{
5 arch::{ExecutionBridge, ExecutionState},
6 system::memory::{
7 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
8 MemoryAddress,
9 },
10};
11use openvm_circuit_primitives::{
12 bitwise_op_lookup::BitwiseOperationLookupBus,
13 utils::{assert_array_eq, not, select},
14};
15use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS};
16use openvm_keccak256_transpiler::Rv32KeccakOpcode;
17use openvm_rv32im_circuit::adapters::abstract_compose;
18use openvm_stark_backend::{
19 air_builders::sub::SubAirBuilder,
20 interaction::InteractionBuilder,
21 p3_air::{Air, AirBuilder, BaseAir},
22 p3_field::PrimeCharacteristicRing,
23 p3_matrix::Matrix,
24 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
25};
26use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, U64_LIMBS};
27
28use super::{
29 columns::{KeccakVmCols, NUM_KECCAK_VM_COLS},
30 KECCAK_ABSORB_READS, KECCAK_DIGEST_BYTES, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES,
31 KECCAK_RATE_U16S, KECCAK_REGISTER_READS, KECCAK_WIDTH_U16S, KECCAK_WORD_SIZE,
32 NUM_ABSORB_ROUNDS,
33};
34
35#[derive(Clone, Copy, Debug, derive_new::new)]
36pub struct KeccakVmAir {
37 pub execution_bridge: ExecutionBridge,
38 pub memory_bridge: MemoryBridge,
39 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
41 pub ptr_max_bits: usize,
43 pub(super) offset: usize,
44}
45
46impl<F> BaseAirWithPublicValues<F> for KeccakVmAir {}
47impl<F> PartitionedBaseAir<F> for KeccakVmAir {}
48impl<F> BaseAir<F> for KeccakVmAir {
49 fn width(&self) -> usize {
50 NUM_KECCAK_VM_COLS
51 }
52}
53
54impl<AB: InteractionBuilder> Air<AB> for KeccakVmAir {
55 fn eval(&self, builder: &mut AB) {
56 let main = builder.main();
57 let (local, next) = (
58 main.row_slice(0).expect("window should have two elements"),
59 main.row_slice(1).expect("window should have two elements"),
60 );
61 let local: &KeccakVmCols<AB::Var> = (*local).borrow();
62 let next: &KeccakVmCols<AB::Var> = (*next).borrow();
63
64 builder.assert_bool(local.sponge.is_new_start);
65 builder.assert_eq(
66 local.sponge.is_new_start,
67 local.sponge.is_new_start * local.is_first_round(),
68 );
69 builder.assert_eq(
70 local.instruction.is_enabled_first_round,
71 local.instruction.is_enabled * local.is_first_round(),
72 );
73 builder
75 .when_first_row()
76 .assert_one(local.sponge.is_new_start);
77
78 self.eval_keccak_f(builder);
79 self.constrain_padding(builder, local, next);
80 self.constrain_consistency_across_rounds(builder, local, next);
81 builder
84 .when_last_row()
85 .assert_zero(local.instruction.is_enabled);
86 let mem = &local.mem_oc;
87 self.constrain_absorb(builder, local, next);
89 let start_read_timestamp = self.eval_instruction(builder, local, &mem.register_aux);
90 let start_write_timestamp =
91 self.constrain_input_read(builder, local, start_read_timestamp, &mem.absorb_reads);
92 self.constrain_output_write(
93 builder,
94 local,
95 start_write_timestamp.clone(),
96 &mem.digest_writes,
97 );
98
99 self.constrain_block_transition(builder, local, next, start_write_timestamp);
100 }
101}
102
103impl KeccakVmAir {
104 #[inline]
108 pub fn eval_keccak_f<AB: AirBuilder>(&self, builder: &mut AB) {
109 let keccak_f_air = KeccakAir {};
110 let mut sub_builder =
111 SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, 0..NUM_KECCAK_PERM_COLS);
112 keccak_f_air.eval(&mut sub_builder);
113 }
114
115 pub fn constrain_consistency_across_rounds<AB: AirBuilder<Var: Copy>>(
117 &self,
118 builder: &mut AB,
119 local: &KeccakVmCols<AB::Var>,
120 next: &KeccakVmCols<AB::Var>,
121 ) {
122 let mut transition_builder = builder.when_transition();
123 let mut round_builder = transition_builder.when(not(local.is_last_round()));
124 local
126 .instruction
127 .assert_eq(&mut round_builder, next.instruction);
128 }
129
130 pub fn constrain_block_transition<AB: AirBuilder<Var: Copy>>(
131 &self,
132 builder: &mut AB,
133 local: &KeccakVmCols<AB::Var>,
134 next: &KeccakVmCols<AB::Var>,
135 start_write_timestamp: AB::Expr,
136 ) {
137 let mut block_transition = builder.when(local.is_last_round() * not(next.is_new_start()));
142 block_transition.assert_eq(local.instruction.pc, next.instruction.pc);
143 block_transition.assert_eq(local.instruction.is_enabled, next.instruction.is_enabled);
144 assert_array_eq(
146 &mut block_transition,
147 local.instruction.dst,
148 next.instruction.dst,
149 );
150 block_transition.assert_eq(local.instruction.dst_ptr, next.instruction.dst_ptr);
153 block_transition.assert_eq(local.instruction.src_ptr, next.instruction.src_ptr);
154 block_transition.assert_eq(local.instruction.len_ptr, next.instruction.len_ptr);
155 block_transition.assert_eq(
160 next.instruction.src,
161 local.instruction.src + AB::F::from_usize(KECCAK_RATE_BYTES),
162 );
163 block_transition.assert_eq(next.instruction.start_timestamp, start_write_timestamp);
166 block_transition.assert_eq(
167 next.instruction.remaining_len,
168 local.instruction.remaining_len - AB::F::from_usize(KECCAK_RATE_BYTES),
169 );
170 }
172
173 pub fn constrain_padding<AB: AirBuilder>(
180 &self,
181 builder: &mut AB,
182 local: &KeccakVmCols<AB::Var>,
183 next: &KeccakVmCols<AB::Var>,
184 ) where
185 AB::Var: Copy,
186 {
187 let is_padding_byte = local.sponge.is_padding_byte;
188 let block_bytes = &local.sponge.block_bytes;
189 let remaining_len = local.remaining_len();
190
191 for &is_padding_byte in is_padding_byte.iter() {
193 builder.assert_bool(is_padding_byte);
194 }
195 for i in 1..KECCAK_RATE_BYTES {
197 builder
198 .when(is_padding_byte[i - 1])
199 .assert_one(is_padding_byte[i]);
200 }
201 let is_last_round = next.inner.step_flags[0];
205 let is_not_last_round = not(is_last_round);
206 for i in 0..KECCAK_RATE_BYTES {
207 builder.when(is_not_last_round.clone()).assert_eq(
208 local.sponge.is_padding_byte[i],
209 next.sponge.is_padding_byte[i],
210 );
211 }
212
213 let num_padding_bytes = local
214 .sponge
215 .is_padding_byte
216 .iter()
217 .fold(AB::Expr::ZERO, |a, &b| a + b);
218
219 let is_final_block = is_padding_byte[KECCAK_RATE_BYTES - 1];
221
222 builder.when(is_final_block).assert_eq(
224 remaining_len,
225 AB::Expr::from_usize(KECCAK_RATE_BYTES) - num_padding_bytes,
226 );
227 builder
230 .when(is_last_round)
231 .when(not(is_final_block))
232 .assert_eq(
233 remaining_len - AB::F::from_usize(KECCAK_RATE_BYTES),
234 next.remaining_len(),
235 );
236 builder
239 .when(is_last_round)
240 .when(next.is_new_start())
241 .assert_one(is_final_block);
242 builder
244 .when(is_last_round)
245 .when(is_final_block)
246 .assert_one(next.is_new_start());
247 let has_single_padding_byte: AB::Expr =
255 is_padding_byte[KECCAK_RATE_BYTES - 1] - is_padding_byte[KECCAK_RATE_BYTES - 2];
256
257 builder.when(has_single_padding_byte.clone()).assert_eq(
260 block_bytes[KECCAK_RATE_BYTES - 1],
261 AB::F::from_u8(0b10000001),
262 );
263
264 let has_multiple_padding_bytes: AB::Expr = not(has_single_padding_byte.clone());
265 for i in 0..KECCAK_RATE_BYTES - 1 {
266 let is_first_padding_byte: AB::Expr = {
267 if i > 0 {
268 is_padding_byte[i] - is_padding_byte[i - 1]
269 } else {
270 is_padding_byte[i].into()
271 }
272 };
273 builder
276 .when(has_multiple_padding_bytes.clone())
277 .when(is_first_padding_byte.clone())
278 .assert_eq(block_bytes[i], AB::F::from_u8(0x01));
279 builder
282 .when(is_padding_byte[i])
283 .when(not::<AB::Expr>(is_first_padding_byte)) .assert_zero(block_bytes[i]);
285 }
286
287 builder
290 .when(is_final_block)
291 .when(has_multiple_padding_bytes)
292 .assert_eq(block_bytes[KECCAK_RATE_BYTES - 1], AB::F::from_u8(0x80));
293 }
294
295 pub fn constrain_absorb<AB: InteractionBuilder>(
315 &self,
316 builder: &mut AB,
317 local: &KeccakVmCols<AB::Var>,
318 next: &KeccakVmCols<AB::Var>,
319 ) {
320 let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
321 let y = i / 5;
322 let x = i % 5;
323 (0..U64_LIMBS).flat_map(move |limb| {
324 let state_limb = local.postimage(y, x, limb);
325 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
326 let lo = state_limb - hi * AB::F::from_u64(1 << 8);
327 [lo, hi.into()]
329 })
330 });
331
332 let post_absorb_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
333 let y = i / 5;
334 let x = i % 5;
335 (0..U64_LIMBS).flat_map(move |limb| {
336 let state_limb = next.inner.preimage[y][x][limb];
337 let hi = next.sponge.state_hi[i * U64_LIMBS + limb];
338 let lo = state_limb - hi * AB::F::from_u64(1 << 8);
339 [lo, hi.into()]
340 })
341 });
342
343 let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
346 for (input, prev, post) in izip!(
347 next.sponge.block_bytes,
348 updated_state_bytes,
349 post_absorb_state_bytes
350 ) {
351 self.bitwise_lookup_bus
359 .send_xor(
360 input * not(is_final_block),
361 prev.clone(),
362 select(is_final_block, prev, post),
363 )
364 .eval(
365 builder,
366 local.is_last_round() * local.instruction.is_enabled,
367 );
368 }
369
370 let local_preimage_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
373 let y = i / 5;
374 let x = i % 5;
375 (0..U64_LIMBS).flat_map(move |limb| {
376 let state_limb = local.inner.preimage[y][x][limb];
377 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
378 let lo = state_limb - hi * AB::F::from_u64(1 << 8);
379 [lo, hi.into()]
380 })
381 });
382 let mut when_is_new_start =
383 builder.when(local.is_new_start() * local.instruction.is_enabled);
384 for (preimage_byte, block_byte) in zip(local_preimage_bytes, local.sponge.block_bytes) {
385 when_is_new_start.assert_eq(preimage_byte, block_byte);
386 }
387
388 let mut reset_builder = builder.when(local.is_new_start());
390 for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
391 let y = i / U64_LIMBS / 5;
392 let x = (i / U64_LIMBS) % 5;
393 let limb = i % U64_LIMBS;
394 reset_builder.assert_zero(local.inner.preimage[y][x][limb]);
395 }
396 let mut absorb_builder = builder.when(local.is_last_round() * not(next.is_new_start()));
397 for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
398 let y = i / U64_LIMBS / 5;
399 let x = (i / U64_LIMBS) % 5;
400 let limb = i % U64_LIMBS;
401 absorb_builder.assert_eq(local.postimage(y, x, limb), next.inner.preimage[y][x][limb]);
402 }
403 }
404
405 pub fn eval_instruction<AB: InteractionBuilder>(
414 &self,
415 builder: &mut AB,
416 local: &KeccakVmCols<AB::Var>,
417 register_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_REGISTER_READS],
418 ) -> AB::Expr {
419 let instruction = local.instruction;
420 builder.assert_bool(instruction.is_enabled);
421 let should_receive = local.instruction.is_enabled * local.sponge.is_new_start;
428
429 let [dst_ptr, src_ptr, len_ptr] = [
430 instruction.dst_ptr,
431 instruction.src_ptr,
432 instruction.len_ptr,
433 ];
434 let reg_addr_sp = AB::F::ONE;
435 let timestamp_change: AB::Expr = Self::timestamp_change(instruction.remaining_len);
436 self.execution_bridge
437 .execute_and_increment_pc(
438 AB::Expr::from_usize(Rv32KeccakOpcode::KECCAK256 as usize + self.offset),
439 [
440 dst_ptr.into(),
441 src_ptr.into(),
442 len_ptr.into(),
443 reg_addr_sp.into(),
444 AB::Expr::from_u32(RV32_MEMORY_AS),
445 ],
446 ExecutionState::new(instruction.pc, instruction.start_timestamp),
447 timestamp_change,
448 )
449 .eval(builder, should_receive.clone());
450
451 let mut timestamp: AB::Expr = instruction.start_timestamp.into();
452 let recover_limbs = |limbs: [AB::Var; RV32_REGISTER_NUM_LIMBS - 1],
453 val: AB::Var|
454 -> [AB::Expr; RV32_REGISTER_NUM_LIMBS] {
455 from_fn(|i| {
456 if i == 0 {
457 limbs
458 .into_iter()
459 .enumerate()
460 .fold(val.into(), |acc, (j, limb)| {
461 acc - limb * AB::Expr::from_usize(1 << ((j + 1) * RV32_CELL_BITS))
462 })
463 } else {
464 limbs[i - 1].into()
465 }
466 })
467 };
468 let dst_data = instruction.dst.map(Into::into);
471 let src_data = recover_limbs(instruction.src_limbs, instruction.src);
472 let len_data = recover_limbs(instruction.len_limbs, instruction.remaining_len);
473 for (ptr, value, aux) in izip!(
474 [dst_ptr, src_ptr, len_ptr],
475 [dst_data, src_data, len_data],
476 register_aux,
477 ) {
478 self.memory_bridge
479 .read(
480 MemoryAddress::new(reg_addr_sp, ptr),
481 value,
482 timestamp.clone(),
483 aux,
484 )
485 .eval(builder, should_receive.clone());
486
487 timestamp += AB::Expr::ONE;
488 }
489 let need_range_check = [
494 *instruction.dst.last().unwrap(),
495 *instruction.src_limbs.last().unwrap(),
496 *instruction.len_limbs.last().unwrap(),
497 *instruction.len_limbs.last().unwrap(),
498 ];
499 let limb_shift =
500 AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.ptr_max_bits));
501 for pair in need_range_check.chunks_exact(2) {
502 self.bitwise_lookup_bus
503 .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
504 .eval(builder, should_receive.clone());
505 }
506
507 timestamp
508 }
509
510 pub fn constrain_input_read<AB: InteractionBuilder>(
518 &self,
519 builder: &mut AB,
520 local: &KeccakVmCols<AB::Var>,
521 start_read_timestamp: AB::Expr,
522 mem_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_ABSORB_READS],
523 ) -> AB::Expr {
524 let partial_block = &local.mem_oc.partial_block;
525 let is_input = local.instruction.is_enabled_first_round;
528
529 let mut timestamp = start_read_timestamp;
530 for (i, (input, is_padding, mem_aux)) in izip!(
533 local.sponge.block_bytes.chunks_exact(KECCAK_WORD_SIZE),
534 local.sponge.is_padding_byte.chunks_exact(KECCAK_WORD_SIZE),
535 mem_aux
536 )
537 .enumerate()
538 {
539 let ptr = local.instruction.src + AB::F::from_usize(i * KECCAK_WORD_SIZE);
540 let count = is_input * not(is_padding[0]);
543 let is_partial_read = *is_padding.last().unwrap();
547 let word: [_; KECCAK_WORD_SIZE] = from_fn(|i| {
549 if i == 0 {
550 input[0].into()
552 } else {
553 select(is_partial_read, partial_block[i - 1], input[i])
556 }
557 });
558 for i in 1..KECCAK_WORD_SIZE {
559 let not_padding: AB::Expr = not(is_padding[i]);
560 builder.assert_eq(
563 not_padding.clone() * word[i].clone(),
564 not_padding.clone() * input[i],
565 );
566 }
567
568 self.memory_bridge
569 .read(
570 MemoryAddress::new(AB::Expr::from_u32(RV32_MEMORY_AS), ptr),
571 word, timestamp.clone(),
573 mem_aux,
574 )
575 .eval(builder, count);
576
577 timestamp += AB::Expr::ONE;
578 }
579 timestamp
580 }
581
582 pub fn constrain_output_write<AB: InteractionBuilder>(
583 &self,
584 builder: &mut AB,
585 local: &KeccakVmCols<AB::Var>,
586 start_write_timestamp: AB::Expr,
587 mem_aux: &[MemoryWriteAuxCols<AB::Var, KECCAK_WORD_SIZE>; KECCAK_DIGEST_WRITES],
588 ) {
589 let instruction = local.instruction;
590
591 let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
592 builder.assert_eq(
594 local.inner.export,
595 instruction.is_enabled * is_final_block * local.is_last_round(),
596 );
597 let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
601 let y = i / 5;
602 let x = i % 5;
603 (0..U64_LIMBS).flat_map(move |limb| {
604 let state_limb = local.postimage(y, x, limb);
605 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
606 let lo = state_limb - hi * AB::F::from_u64(1 << 8);
607 [lo, hi.into()]
609 })
610 });
611 let dst = abstract_compose::<AB::Expr, _>(instruction.dst);
612 for (i, digest_bytes) in updated_state_bytes
613 .take(KECCAK_DIGEST_BYTES)
614 .chunks(KECCAK_WORD_SIZE)
615 .into_iter()
616 .enumerate()
617 {
618 let digest_bytes = digest_bytes.collect_vec();
619 let timestamp = start_write_timestamp.clone() + AB::Expr::from_usize(i);
620 self.memory_bridge
621 .write(
622 MemoryAddress::new(
623 AB::Expr::from_u32(RV32_MEMORY_AS),
624 dst.clone() + AB::F::from_usize(i * KECCAK_WORD_SIZE),
625 ),
626 digest_bytes.try_into().unwrap(),
627 timestamp,
628 &mem_aux[i],
629 )
630 .eval(builder, local.inner.export)
631 }
632 }
633
634 pub fn timestamp_change<T: PrimeCharacteristicRing>(len: impl Into<T>) -> T {
637 len.into()
641 + T::from_usize(KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES)
642 }
643}