1use std::{array::from_fn, borrow::Borrow, iter::zip};
2
3use itertools::{izip, Itertools};
4use openvm_circuit::{
5 arch::{ExecutionBridge, ExecutionState},
6 system::memory::{
7 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
8 MemoryAddress,
9 },
10};
11use openvm_circuit_primitives::{
12 bitwise_op_lookup::BitwiseOperationLookupBus,
13 utils::{assert_array_eq, not, select},
14};
15use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS};
16use openvm_keccak256_transpiler::Rv32KeccakOpcode;
17use openvm_rv32im_circuit::adapters::abstract_compose;
18use openvm_stark_backend::{
19 air_builders::sub::SubAirBuilder,
20 interaction::InteractionBuilder,
21 p3_air::{Air, AirBuilder, BaseAir},
22 p3_field::PrimeCharacteristicRing,
23 p3_matrix::Matrix,
24 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
25};
26use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, U64_LIMBS};
27
28use super::{
29 columns::{KeccakVmCols, NUM_KECCAK_VM_COLS},
30 KECCAK_ABSORB_READS, KECCAK_DIGEST_BYTES, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES,
31 KECCAK_RATE_U16S, KECCAK_REGISTER_READS, KECCAK_WIDTH_U16S, KECCAK_WORD_SIZE,
32 NUM_ABSORB_ROUNDS,
33};
34
35#[derive(Clone, Copy, Debug, derive_new::new)]
36pub struct KeccakVmAir {
37 pub execution_bridge: ExecutionBridge,
38 pub memory_bridge: MemoryBridge,
39 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
41 pub ptr_max_bits: usize,
43 pub(super) offset: usize,
44}
45
46impl<F> BaseAirWithPublicValues<F> for KeccakVmAir {}
47impl<F> PartitionedBaseAir<F> for KeccakVmAir {}
48impl<F> BaseAir<F> for KeccakVmAir {
49 fn width(&self) -> usize {
50 NUM_KECCAK_VM_COLS
51 }
52}
53
54impl<AB: InteractionBuilder> Air<AB> for KeccakVmAir {
55 fn eval(&self, builder: &mut AB) {
56 let main = builder.main();
57 let (local, next) = (
58 main.row_slice(0).expect("window should have two elements"),
59 main.row_slice(1).expect("window should have two elements"),
60 );
61 let local: &KeccakVmCols<AB::Var> = (*local).borrow();
62 let next: &KeccakVmCols<AB::Var> = (*next).borrow();
63
64 builder.assert_bool(local.sponge.is_new_start);
65 builder.assert_eq(
66 local.sponge.is_new_start,
67 local.sponge.is_new_start * local.is_first_round(),
68 );
69 builder.assert_eq(
70 local.instruction.is_enabled_first_round,
71 local.instruction.is_enabled * local.is_first_round(),
72 );
73 builder
75 .when_first_row()
76 .assert_one(local.sponge.is_new_start);
77
78 self.eval_keccak_f(builder);
79 self.constrain_padding(builder, local, next);
80 self.constrain_consistency_across_rounds(builder, local, next);
81
82 let mem = &local.mem_oc;
83 self.constrain_absorb(builder, local, next);
85 let start_read_timestamp = self.eval_instruction(builder, local, &mem.register_aux);
86 let start_write_timestamp =
87 self.constrain_input_read(builder, local, start_read_timestamp, &mem.absorb_reads);
88 self.constrain_output_write(
89 builder,
90 local,
91 start_write_timestamp.clone(),
92 &mem.digest_writes,
93 );
94
95 self.constrain_block_transition(builder, local, next, start_write_timestamp);
96 }
97}
98
99impl KeccakVmAir {
100 #[inline]
104 pub fn eval_keccak_f<AB: AirBuilder>(&self, builder: &mut AB) {
105 let keccak_f_air = KeccakAir {};
106 let mut sub_builder =
107 SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, 0..NUM_KECCAK_PERM_COLS);
108 keccak_f_air.eval(&mut sub_builder);
109 }
110
111 pub fn constrain_consistency_across_rounds<AB: AirBuilder<Var: Copy>>(
113 &self,
114 builder: &mut AB,
115 local: &KeccakVmCols<AB::Var>,
116 next: &KeccakVmCols<AB::Var>,
117 ) {
118 let mut transition_builder = builder.when_transition();
119 let mut round_builder = transition_builder.when(not(local.is_last_round()));
120 local
122 .instruction
123 .assert_eq(&mut round_builder, next.instruction);
124 }
125
126 pub fn constrain_block_transition<AB: AirBuilder<Var: Copy>>(
127 &self,
128 builder: &mut AB,
129 local: &KeccakVmCols<AB::Var>,
130 next: &KeccakVmCols<AB::Var>,
131 start_write_timestamp: AB::Expr,
132 ) {
133 let mut block_transition = builder.when(local.is_last_round() * not(next.is_new_start()));
138 block_transition.assert_eq(local.instruction.pc, next.instruction.pc);
139 block_transition.assert_eq(local.instruction.is_enabled, next.instruction.is_enabled);
140 assert_array_eq(
142 &mut block_transition,
143 local.instruction.dst,
144 next.instruction.dst,
145 );
146 block_transition.assert_eq(local.instruction.dst_ptr, next.instruction.dst_ptr);
149 block_transition.assert_eq(local.instruction.src_ptr, next.instruction.src_ptr);
150 block_transition.assert_eq(local.instruction.len_ptr, next.instruction.len_ptr);
151 block_transition.assert_eq(
156 next.instruction.src,
157 local.instruction.src + AB::F::from_usize(KECCAK_RATE_BYTES),
158 );
159 block_transition.assert_eq(next.instruction.start_timestamp, start_write_timestamp);
162 block_transition.assert_eq(
163 next.instruction.remaining_len,
164 local.instruction.remaining_len - AB::F::from_usize(KECCAK_RATE_BYTES),
165 );
166 }
168
169 pub fn constrain_padding<AB: AirBuilder>(
176 &self,
177 builder: &mut AB,
178 local: &KeccakVmCols<AB::Var>,
179 next: &KeccakVmCols<AB::Var>,
180 ) where
181 AB::Var: Copy,
182 {
183 let is_padding_byte = local.sponge.is_padding_byte;
184 let block_bytes = &local.sponge.block_bytes;
185 let remaining_len = local.remaining_len();
186
187 for &is_padding_byte in is_padding_byte.iter() {
189 builder.assert_bool(is_padding_byte);
190 }
191 for i in 1..KECCAK_RATE_BYTES {
193 builder
194 .when(is_padding_byte[i - 1])
195 .assert_one(is_padding_byte[i]);
196 }
197 let is_last_round = next.inner.step_flags[0];
201 let is_not_last_round = not(is_last_round);
202 for i in 0..KECCAK_RATE_BYTES {
203 builder.when(is_not_last_round.clone()).assert_eq(
204 local.sponge.is_padding_byte[i],
205 next.sponge.is_padding_byte[i],
206 );
207 }
208
209 let num_padding_bytes = local
210 .sponge
211 .is_padding_byte
212 .iter()
213 .fold(AB::Expr::ZERO, |a, &b| a + b);
214
215 let is_final_block = is_padding_byte[KECCAK_RATE_BYTES - 1];
217
218 builder.when(is_final_block).assert_eq(
220 remaining_len,
221 AB::Expr::from_usize(KECCAK_RATE_BYTES) - num_padding_bytes,
222 );
223 builder
226 .when(is_last_round)
227 .when(not(is_final_block))
228 .assert_eq(
229 remaining_len - AB::F::from_usize(KECCAK_RATE_BYTES),
230 next.remaining_len(),
231 );
232 builder
235 .when(is_last_round)
236 .when(next.is_new_start())
237 .assert_one(is_final_block);
238 builder
240 .when(is_last_round)
241 .when(is_final_block)
242 .assert_one(next.is_new_start());
243 let has_single_padding_byte: AB::Expr =
251 is_padding_byte[KECCAK_RATE_BYTES - 1] - is_padding_byte[KECCAK_RATE_BYTES - 2];
252
253 builder.when(has_single_padding_byte.clone()).assert_eq(
256 block_bytes[KECCAK_RATE_BYTES - 1],
257 AB::F::from_u8(0b10000001),
258 );
259
260 let has_multiple_padding_bytes: AB::Expr = not(has_single_padding_byte.clone());
261 for i in 0..KECCAK_RATE_BYTES - 1 {
262 let is_first_padding_byte: AB::Expr = {
263 if i > 0 {
264 is_padding_byte[i] - is_padding_byte[i - 1]
265 } else {
266 is_padding_byte[i].into()
267 }
268 };
269 builder
272 .when(has_multiple_padding_bytes.clone())
273 .when(is_first_padding_byte.clone())
274 .assert_eq(block_bytes[i], AB::F::from_u8(0x01));
275 builder
278 .when(is_padding_byte[i])
279 .when(not::<AB::Expr>(is_first_padding_byte)) .assert_zero(block_bytes[i]);
281 }
282
283 builder
286 .when(is_final_block)
287 .when(has_multiple_padding_bytes)
288 .assert_eq(block_bytes[KECCAK_RATE_BYTES - 1], AB::F::from_u8(0x80));
289 }
290
291 pub fn constrain_absorb<AB: InteractionBuilder>(
311 &self,
312 builder: &mut AB,
313 local: &KeccakVmCols<AB::Var>,
314 next: &KeccakVmCols<AB::Var>,
315 ) {
316 let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
317 let y = i / 5;
318 let x = i % 5;
319 (0..U64_LIMBS).flat_map(move |limb| {
320 let state_limb = local.postimage(y, x, limb);
321 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
322 let lo = state_limb - hi * AB::F::from_u64(1 << 8);
323 [lo, hi.into()]
325 })
326 });
327
328 let post_absorb_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
329 let y = i / 5;
330 let x = i % 5;
331 (0..U64_LIMBS).flat_map(move |limb| {
332 let state_limb = next.inner.preimage[y][x][limb];
333 let hi = next.sponge.state_hi[i * U64_LIMBS + limb];
334 let lo = state_limb - hi * AB::F::from_u64(1 << 8);
335 [lo, hi.into()]
336 })
337 });
338
339 let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
342 for (input, prev, post) in izip!(
343 next.sponge.block_bytes,
344 updated_state_bytes,
345 post_absorb_state_bytes
346 ) {
347 self.bitwise_lookup_bus
355 .send_xor(
356 input * not(is_final_block),
357 prev.clone(),
358 select(is_final_block, prev, post),
359 )
360 .eval(
361 builder,
362 local.is_last_round() * local.instruction.is_enabled,
363 );
364 }
365
366 let local_preimage_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
369 let y = i / 5;
370 let x = i % 5;
371 (0..U64_LIMBS).flat_map(move |limb| {
372 let state_limb = local.inner.preimage[y][x][limb];
373 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
374 let lo = state_limb - hi * AB::F::from_u64(1 << 8);
375 [lo, hi.into()]
376 })
377 });
378 let mut when_is_new_start =
379 builder.when(local.is_new_start() * local.instruction.is_enabled);
380 for (preimage_byte, block_byte) in zip(local_preimage_bytes, local.sponge.block_bytes) {
381 when_is_new_start.assert_eq(preimage_byte, block_byte);
382 }
383
384 let mut reset_builder = builder.when(local.is_new_start());
386 for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
387 let y = i / U64_LIMBS / 5;
388 let x = (i / U64_LIMBS) % 5;
389 let limb = i % U64_LIMBS;
390 reset_builder.assert_zero(local.inner.preimage[y][x][limb]);
391 }
392 let mut absorb_builder = builder.when(local.is_last_round() * not(next.is_new_start()));
393 for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
394 let y = i / U64_LIMBS / 5;
395 let x = (i / U64_LIMBS) % 5;
396 let limb = i % U64_LIMBS;
397 absorb_builder.assert_eq(local.postimage(y, x, limb), next.inner.preimage[y][x][limb]);
398 }
399 }
400
401 pub fn eval_instruction<AB: InteractionBuilder>(
410 &self,
411 builder: &mut AB,
412 local: &KeccakVmCols<AB::Var>,
413 register_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_REGISTER_READS],
414 ) -> AB::Expr {
415 let instruction = local.instruction;
416 let should_receive = local.instruction.is_enabled * local.sponge.is_new_start;
423
424 let [dst_ptr, src_ptr, len_ptr] = [
425 instruction.dst_ptr,
426 instruction.src_ptr,
427 instruction.len_ptr,
428 ];
429 let reg_addr_sp = AB::F::ONE;
430 let timestamp_change: AB::Expr = Self::timestamp_change(instruction.remaining_len);
431 self.execution_bridge
432 .execute_and_increment_pc(
433 AB::Expr::from_usize(Rv32KeccakOpcode::KECCAK256 as usize + self.offset),
434 [
435 dst_ptr.into(),
436 src_ptr.into(),
437 len_ptr.into(),
438 reg_addr_sp.into(),
439 AB::Expr::from_u32(RV32_MEMORY_AS),
440 ],
441 ExecutionState::new(instruction.pc, instruction.start_timestamp),
442 timestamp_change,
443 )
444 .eval(builder, should_receive.clone());
445
446 let mut timestamp: AB::Expr = instruction.start_timestamp.into();
447 let recover_limbs = |limbs: [AB::Var; RV32_REGISTER_NUM_LIMBS - 1],
448 val: AB::Var|
449 -> [AB::Expr; RV32_REGISTER_NUM_LIMBS] {
450 from_fn(|i| {
451 if i == 0 {
452 limbs
453 .into_iter()
454 .enumerate()
455 .fold(val.into(), |acc, (j, limb)| {
456 acc - limb * AB::Expr::from_usize(1 << ((j + 1) * RV32_CELL_BITS))
457 })
458 } else {
459 limbs[i - 1].into()
460 }
461 })
462 };
463 let dst_data = instruction.dst.map(Into::into);
466 let src_data = recover_limbs(instruction.src_limbs, instruction.src);
467 let len_data = recover_limbs(instruction.len_limbs, instruction.remaining_len);
468 for (ptr, value, aux) in izip!(
469 [dst_ptr, src_ptr, len_ptr],
470 [dst_data, src_data, len_data],
471 register_aux,
472 ) {
473 self.memory_bridge
474 .read(
475 MemoryAddress::new(reg_addr_sp, ptr),
476 value,
477 timestamp.clone(),
478 aux,
479 )
480 .eval(builder, should_receive.clone());
481
482 timestamp += AB::Expr::ONE;
483 }
484 let need_range_check = [
489 *instruction.dst.last().unwrap(),
490 *instruction.src_limbs.last().unwrap(),
491 *instruction.len_limbs.last().unwrap(),
492 *instruction.len_limbs.last().unwrap(),
493 ];
494 let limb_shift =
495 AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.ptr_max_bits));
496 for pair in need_range_check.chunks_exact(2) {
497 self.bitwise_lookup_bus
498 .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
499 .eval(builder, should_receive.clone());
500 }
501
502 timestamp
503 }
504
505 pub fn constrain_input_read<AB: InteractionBuilder>(
513 &self,
514 builder: &mut AB,
515 local: &KeccakVmCols<AB::Var>,
516 start_read_timestamp: AB::Expr,
517 mem_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_ABSORB_READS],
518 ) -> AB::Expr {
519 let partial_block = &local.mem_oc.partial_block;
520 let is_input = local.instruction.is_enabled_first_round;
523
524 let mut timestamp = start_read_timestamp;
525 for (i, (input, is_padding, mem_aux)) in izip!(
528 local.sponge.block_bytes.chunks_exact(KECCAK_WORD_SIZE),
529 local.sponge.is_padding_byte.chunks_exact(KECCAK_WORD_SIZE),
530 mem_aux
531 )
532 .enumerate()
533 {
534 let ptr = local.instruction.src + AB::F::from_usize(i * KECCAK_WORD_SIZE);
535 let count = is_input * not(is_padding[0]);
538 let is_partial_read = *is_padding.last().unwrap();
542 let word: [_; KECCAK_WORD_SIZE] = from_fn(|i| {
544 if i == 0 {
545 input[0].into()
547 } else {
548 select(is_partial_read, partial_block[i - 1], input[i])
551 }
552 });
553 for i in 1..KECCAK_WORD_SIZE {
554 let not_padding: AB::Expr = not(is_padding[i]);
555 builder.assert_eq(
558 not_padding.clone() * word[i].clone(),
559 not_padding.clone() * input[i],
560 );
561 }
562
563 self.memory_bridge
564 .read(
565 MemoryAddress::new(AB::Expr::from_u32(RV32_MEMORY_AS), ptr),
566 word, timestamp.clone(),
568 mem_aux,
569 )
570 .eval(builder, count);
571
572 timestamp += AB::Expr::ONE;
573 }
574 timestamp
575 }
576
577 pub fn constrain_output_write<AB: InteractionBuilder>(
578 &self,
579 builder: &mut AB,
580 local: &KeccakVmCols<AB::Var>,
581 start_write_timestamp: AB::Expr,
582 mem_aux: &[MemoryWriteAuxCols<AB::Var, KECCAK_WORD_SIZE>; KECCAK_DIGEST_WRITES],
583 ) {
584 let instruction = local.instruction;
585
586 let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
587 builder.assert_eq(
589 local.inner.export,
590 instruction.is_enabled * is_final_block * local.is_last_round(),
591 );
592 let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
596 let y = i / 5;
597 let x = i % 5;
598 (0..U64_LIMBS).flat_map(move |limb| {
599 let state_limb = local.postimage(y, x, limb);
600 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
601 let lo = state_limb - hi * AB::F::from_u64(1 << 8);
602 [lo, hi.into()]
604 })
605 });
606 let dst = abstract_compose::<AB::Expr, _>(instruction.dst);
607 for (i, digest_bytes) in updated_state_bytes
608 .take(KECCAK_DIGEST_BYTES)
609 .chunks(KECCAK_WORD_SIZE)
610 .into_iter()
611 .enumerate()
612 {
613 let digest_bytes = digest_bytes.collect_vec();
614 let timestamp = start_write_timestamp.clone() + AB::Expr::from_usize(i);
615 self.memory_bridge
616 .write(
617 MemoryAddress::new(
618 AB::Expr::from_u32(RV32_MEMORY_AS),
619 dst.clone() + AB::F::from_usize(i * KECCAK_WORD_SIZE),
620 ),
621 digest_bytes.try_into().unwrap(),
622 timestamp,
623 &mem_aux[i],
624 )
625 .eval(builder, local.inner.export)
626 }
627 }
628
629 pub fn timestamp_change<T: PrimeCharacteristicRing>(len: impl Into<T>) -> T {
632 len.into()
636 + T::from_usize(KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES)
637 }
638}