1use std::{array::from_fn, borrow::Borrow, iter::zip};
2
3use itertools::{izip, Itertools};
4use openvm_circuit::{
5 arch::{ExecutionBridge, ExecutionState},
6 system::memory::{
7 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
8 MemoryAddress,
9 },
10};
11use openvm_circuit_primitives::{
12 bitwise_op_lookup::BitwiseOperationLookupBus,
13 utils::{assert_array_eq, not, select},
14};
15use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS};
16use openvm_keccak256_transpiler::Rv32KeccakOpcode;
17use openvm_rv32im_circuit::adapters::abstract_compose;
18use openvm_stark_backend::{
19 air_builders::sub::SubAirBuilder,
20 interaction::InteractionBuilder,
21 p3_air::{Air, AirBuilder, BaseAir},
22 p3_field::FieldAlgebra,
23 p3_matrix::Matrix,
24 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
25};
26use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, U64_LIMBS};
27
28use super::{
29 columns::{KeccakVmCols, NUM_KECCAK_VM_COLS},
30 KECCAK_ABSORB_READS, KECCAK_DIGEST_BYTES, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES,
31 KECCAK_RATE_U16S, KECCAK_REGISTER_READS, KECCAK_WIDTH_U16S, KECCAK_WORD_SIZE,
32 NUM_ABSORB_ROUNDS,
33};
34
35#[derive(Clone, Copy, Debug, derive_new::new)]
36pub struct KeccakVmAir {
37 pub execution_bridge: ExecutionBridge,
38 pub memory_bridge: MemoryBridge,
39 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
41 pub ptr_max_bits: usize,
43 pub(super) offset: usize,
44}
45
46impl<F> BaseAirWithPublicValues<F> for KeccakVmAir {}
47impl<F> PartitionedBaseAir<F> for KeccakVmAir {}
48impl<F> BaseAir<F> for KeccakVmAir {
49 fn width(&self) -> usize {
50 NUM_KECCAK_VM_COLS
51 }
52}
53
54impl<AB: InteractionBuilder> Air<AB> for KeccakVmAir {
55 fn eval(&self, builder: &mut AB) {
56 let main = builder.main();
57 let (local, next) = (main.row_slice(0), main.row_slice(1));
58 let local: &KeccakVmCols<AB::Var> = (*local).borrow();
59 let next: &KeccakVmCols<AB::Var> = (*next).borrow();
60
61 builder.assert_bool(local.sponge.is_new_start);
62 builder.assert_eq(
63 local.sponge.is_new_start,
64 local.sponge.is_new_start * local.is_first_round(),
65 );
66 builder.assert_eq(
67 local.instruction.is_enabled_first_round,
68 local.instruction.is_enabled * local.is_first_round(),
69 );
70 builder
72 .when_first_row()
73 .assert_one(local.sponge.is_new_start);
74
75 self.eval_keccak_f(builder);
76 self.constrain_padding(builder, local, next);
77 self.constrain_consistency_across_rounds(builder, local, next);
78
79 let mem = &local.mem_oc;
80 self.constrain_absorb(builder, local, next);
82 let start_read_timestamp = self.eval_instruction(builder, local, &mem.register_aux);
83 let start_write_timestamp =
84 self.constrain_input_read(builder, local, start_read_timestamp, &mem.absorb_reads);
85 self.constrain_output_write(
86 builder,
87 local,
88 start_write_timestamp.clone(),
89 &mem.digest_writes,
90 );
91
92 self.constrain_block_transition(builder, local, next, start_write_timestamp);
93 }
94}
95
96impl KeccakVmAir {
97 #[inline]
101 pub fn eval_keccak_f<AB: AirBuilder>(&self, builder: &mut AB) {
102 let keccak_f_air = KeccakAir {};
103 let mut sub_builder =
104 SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, 0..NUM_KECCAK_PERM_COLS);
105 keccak_f_air.eval(&mut sub_builder);
106 }
107
108 pub fn constrain_consistency_across_rounds<AB: AirBuilder>(
110 &self,
111 builder: &mut AB,
112 local: &KeccakVmCols<AB::Var>,
113 next: &KeccakVmCols<AB::Var>,
114 ) {
115 let mut transition_builder = builder.when_transition();
116 let mut round_builder = transition_builder.when(not(local.is_last_round()));
117 local
119 .instruction
120 .assert_eq(&mut round_builder, next.instruction);
121 }
122
123 pub fn constrain_block_transition<AB: AirBuilder>(
124 &self,
125 builder: &mut AB,
126 local: &KeccakVmCols<AB::Var>,
127 next: &KeccakVmCols<AB::Var>,
128 start_write_timestamp: AB::Expr,
129 ) {
130 let mut block_transition = builder.when(local.is_last_round() * not(next.is_new_start()));
135 block_transition.assert_eq(local.instruction.pc, next.instruction.pc);
136 block_transition.assert_eq(local.instruction.is_enabled, next.instruction.is_enabled);
137 assert_array_eq(
139 &mut block_transition,
140 local.instruction.dst,
141 next.instruction.dst,
142 );
143 block_transition.assert_eq(local.instruction.dst_ptr, next.instruction.dst_ptr);
146 block_transition.assert_eq(local.instruction.src_ptr, next.instruction.src_ptr);
147 block_transition.assert_eq(local.instruction.len_ptr, next.instruction.len_ptr);
148 block_transition.assert_eq(
153 next.instruction.src,
154 local.instruction.src + AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
155 );
156 block_transition.assert_eq(next.instruction.start_timestamp, start_write_timestamp);
159 block_transition.assert_eq(
160 next.instruction.remaining_len,
161 local.instruction.remaining_len - AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
162 );
163 }
165
166 pub fn constrain_padding<AB: AirBuilder>(
173 &self,
174 builder: &mut AB,
175 local: &KeccakVmCols<AB::Var>,
176 next: &KeccakVmCols<AB::Var>,
177 ) {
178 let is_padding_byte = local.sponge.is_padding_byte;
179 let block_bytes = &local.sponge.block_bytes;
180 let remaining_len = local.remaining_len();
181
182 for &is_padding_byte in is_padding_byte.iter() {
184 builder.assert_bool(is_padding_byte);
185 }
186 for i in 1..KECCAK_RATE_BYTES {
188 builder
189 .when(is_padding_byte[i - 1])
190 .assert_one(is_padding_byte[i]);
191 }
192 let is_last_round = next.inner.step_flags[0];
196 let is_not_last_round = not(is_last_round);
197 for i in 0..KECCAK_RATE_BYTES {
198 builder.when(is_not_last_round.clone()).assert_eq(
199 local.sponge.is_padding_byte[i],
200 next.sponge.is_padding_byte[i],
201 );
202 }
203
204 let num_padding_bytes = local
205 .sponge
206 .is_padding_byte
207 .iter()
208 .fold(AB::Expr::ZERO, |a, &b| a + b);
209
210 let is_final_block = is_padding_byte[KECCAK_RATE_BYTES - 1];
212
213 builder.when(is_final_block).assert_eq(
215 remaining_len,
216 AB::Expr::from_canonical_usize(KECCAK_RATE_BYTES) - num_padding_bytes,
217 );
218 builder
221 .when(is_last_round)
222 .when(not(is_final_block))
223 .assert_eq(
224 remaining_len - AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
225 next.remaining_len(),
226 );
227 builder
230 .when(is_last_round)
231 .when(next.is_new_start())
232 .assert_one(is_final_block);
233 builder
235 .when(is_last_round)
236 .when(is_final_block)
237 .assert_one(next.is_new_start());
238 let has_single_padding_byte: AB::Expr =
246 is_padding_byte[KECCAK_RATE_BYTES - 1] - is_padding_byte[KECCAK_RATE_BYTES - 2];
247
248 builder.when(has_single_padding_byte.clone()).assert_eq(
251 block_bytes[KECCAK_RATE_BYTES - 1],
252 AB::F::from_canonical_u8(0b10000001),
253 );
254
255 let has_multiple_padding_bytes: AB::Expr = not(has_single_padding_byte.clone());
256 for i in 0..KECCAK_RATE_BYTES - 1 {
257 let is_first_padding_byte: AB::Expr = {
258 if i > 0 {
259 is_padding_byte[i] - is_padding_byte[i - 1]
260 } else {
261 is_padding_byte[i].into()
262 }
263 };
264 builder
267 .when(has_multiple_padding_bytes.clone())
268 .when(is_first_padding_byte.clone())
269 .assert_eq(block_bytes[i], AB::F::from_canonical_u8(0x01));
270 builder
273 .when(is_padding_byte[i])
274 .when(not::<AB::Expr>(is_first_padding_byte)) .assert_zero(block_bytes[i]);
276 }
277
278 builder
281 .when(is_final_block)
282 .when(has_multiple_padding_bytes)
283 .assert_eq(
284 block_bytes[KECCAK_RATE_BYTES - 1],
285 AB::F::from_canonical_u8(0x80),
286 );
287 }
288
289 pub fn constrain_absorb<AB: InteractionBuilder>(
309 &self,
310 builder: &mut AB,
311 local: &KeccakVmCols<AB::Var>,
312 next: &KeccakVmCols<AB::Var>,
313 ) {
314 let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
315 let y = i / 5;
316 let x = i % 5;
317 (0..U64_LIMBS).flat_map(move |limb| {
318 let state_limb = local.postimage(y, x, limb);
319 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
320 let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
321 [lo, hi.into()]
323 })
324 });
325
326 let post_absorb_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
327 let y = i / 5;
328 let x = i % 5;
329 (0..U64_LIMBS).flat_map(move |limb| {
330 let state_limb = next.inner.preimage[y][x][limb];
331 let hi = next.sponge.state_hi[i * U64_LIMBS + limb];
332 let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
333 [lo, hi.into()]
334 })
335 });
336
337 let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
340 for (input, prev, post) in izip!(
341 next.sponge.block_bytes,
342 updated_state_bytes,
343 post_absorb_state_bytes
344 ) {
345 self.bitwise_lookup_bus
353 .send_xor(
354 input * not(is_final_block),
355 prev.clone(),
356 select(is_final_block, prev, post),
357 )
358 .eval(
359 builder,
360 local.is_last_round() * local.instruction.is_enabled,
361 );
362 }
363
364 let local_preimage_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
367 let y = i / 5;
368 let x = i % 5;
369 (0..U64_LIMBS).flat_map(move |limb| {
370 let state_limb = local.inner.preimage[y][x][limb];
371 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
372 let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
373 [lo, hi.into()]
374 })
375 });
376 let mut when_is_new_start =
377 builder.when(local.is_new_start() * local.instruction.is_enabled);
378 for (preimage_byte, block_byte) in zip(local_preimage_bytes, local.sponge.block_bytes) {
379 when_is_new_start.assert_eq(preimage_byte, block_byte);
380 }
381
382 let mut reset_builder = builder.when(local.is_new_start());
384 for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
385 let y = i / U64_LIMBS / 5;
386 let x = (i / U64_LIMBS) % 5;
387 let limb = i % U64_LIMBS;
388 reset_builder.assert_zero(local.inner.preimage[y][x][limb]);
389 }
390 let mut absorb_builder = builder.when(local.is_last_round() * not(next.is_new_start()));
391 for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
392 let y = i / U64_LIMBS / 5;
393 let x = (i / U64_LIMBS) % 5;
394 let limb = i % U64_LIMBS;
395 absorb_builder.assert_eq(local.postimage(y, x, limb), next.inner.preimage[y][x][limb]);
396 }
397 }
398
399 pub fn eval_instruction<AB: InteractionBuilder>(
408 &self,
409 builder: &mut AB,
410 local: &KeccakVmCols<AB::Var>,
411 register_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_REGISTER_READS],
412 ) -> AB::Expr {
413 let instruction = local.instruction;
414 let should_receive = local.instruction.is_enabled * local.sponge.is_new_start;
421
422 let [dst_ptr, src_ptr, len_ptr] = [
423 instruction.dst_ptr,
424 instruction.src_ptr,
425 instruction.len_ptr,
426 ];
427 let reg_addr_sp = AB::F::ONE;
428 let timestamp_change: AB::Expr = Self::timestamp_change(instruction.remaining_len);
429 self.execution_bridge
430 .execute_and_increment_pc(
431 AB::Expr::from_canonical_usize(Rv32KeccakOpcode::KECCAK256 as usize + self.offset),
432 [
433 dst_ptr.into(),
434 src_ptr.into(),
435 len_ptr.into(),
436 reg_addr_sp.into(),
437 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
438 ],
439 ExecutionState::new(instruction.pc, instruction.start_timestamp),
440 timestamp_change,
441 )
442 .eval(builder, should_receive.clone());
443
444 let mut timestamp: AB::Expr = instruction.start_timestamp.into();
445 let recover_limbs = |limbs: [AB::Var; RV32_REGISTER_NUM_LIMBS - 1],
446 val: AB::Var|
447 -> [AB::Expr; RV32_REGISTER_NUM_LIMBS] {
448 from_fn(|i| {
449 if i == 0 {
450 limbs
451 .into_iter()
452 .enumerate()
453 .fold(val.into(), |acc, (j, limb)| {
454 acc - limb
455 * AB::Expr::from_canonical_usize(1 << ((j + 1) * RV32_CELL_BITS))
456 })
457 } else {
458 limbs[i - 1].into()
459 }
460 })
461 };
462 let dst_data = instruction.dst.map(Into::into);
465 let src_data = recover_limbs(instruction.src_limbs, instruction.src);
466 let len_data = recover_limbs(instruction.len_limbs, instruction.remaining_len);
467 for (ptr, value, aux) in izip!(
468 [dst_ptr, src_ptr, len_ptr],
469 [dst_data, src_data, len_data],
470 register_aux,
471 ) {
472 self.memory_bridge
473 .read(
474 MemoryAddress::new(reg_addr_sp, ptr),
475 value,
476 timestamp.clone(),
477 aux,
478 )
479 .eval(builder, should_receive.clone());
480
481 timestamp += AB::Expr::ONE;
482 }
483 let need_range_check = [
488 *instruction.dst.last().unwrap(),
489 *instruction.src_limbs.last().unwrap(),
490 *instruction.len_limbs.last().unwrap(),
491 *instruction.len_limbs.last().unwrap(),
492 ];
493 let limb_shift = AB::F::from_canonical_usize(
494 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.ptr_max_bits),
495 );
496 for pair in need_range_check.chunks_exact(2) {
497 self.bitwise_lookup_bus
498 .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
499 .eval(builder, should_receive.clone());
500 }
501
502 timestamp
503 }
504
505 pub fn constrain_input_read<AB: InteractionBuilder>(
513 &self,
514 builder: &mut AB,
515 local: &KeccakVmCols<AB::Var>,
516 start_read_timestamp: AB::Expr,
517 mem_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_ABSORB_READS],
518 ) -> AB::Expr {
519 let partial_block = &local.mem_oc.partial_block;
520 let is_input = local.instruction.is_enabled_first_round;
523
524 let mut timestamp = start_read_timestamp;
525 for (i, (input, is_padding, mem_aux)) in izip!(
528 local.sponge.block_bytes.chunks_exact(KECCAK_WORD_SIZE),
529 local.sponge.is_padding_byte.chunks_exact(KECCAK_WORD_SIZE),
530 mem_aux
531 )
532 .enumerate()
533 {
534 let ptr = local.instruction.src + AB::F::from_canonical_usize(i * KECCAK_WORD_SIZE);
535 let count = is_input * not(is_padding[0]);
538 let is_partial_read = *is_padding.last().unwrap();
542 let word: [_; KECCAK_WORD_SIZE] = from_fn(|i| {
544 if i == 0 {
545 input[0].into()
547 } else {
548 select(is_partial_read, partial_block[i - 1], input[i])
551 }
552 });
553 for i in 1..KECCAK_WORD_SIZE {
554 let not_padding: AB::Expr = not(is_padding[i]);
555 builder.assert_eq(
558 not_padding.clone() * word[i].clone(),
559 not_padding.clone() * input[i],
560 );
561 }
562
563 self.memory_bridge
564 .read(
565 MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), ptr),
566 word, timestamp.clone(),
568 mem_aux,
569 )
570 .eval(builder, count);
571
572 timestamp += AB::Expr::ONE;
573 }
574 timestamp
575 }
576
577 pub fn constrain_output_write<AB: InteractionBuilder>(
578 &self,
579 builder: &mut AB,
580 local: &KeccakVmCols<AB::Var>,
581 start_write_timestamp: AB::Expr,
582 mem_aux: &[MemoryWriteAuxCols<AB::Var, KECCAK_WORD_SIZE>; KECCAK_DIGEST_WRITES],
583 ) {
584 let instruction = local.instruction;
585
586 let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
587 builder.assert_eq(
589 local.inner.export,
590 instruction.is_enabled * is_final_block * local.is_last_round(),
591 );
592 let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
596 let y = i / 5;
597 let x = i % 5;
598 (0..U64_LIMBS).flat_map(move |limb| {
599 let state_limb = local.postimage(y, x, limb);
600 let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
601 let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
602 [lo, hi.into()]
604 })
605 });
606 let dst = abstract_compose::<AB::Expr, _>(instruction.dst);
607 for (i, digest_bytes) in updated_state_bytes
608 .take(KECCAK_DIGEST_BYTES)
609 .chunks(KECCAK_WORD_SIZE)
610 .into_iter()
611 .enumerate()
612 {
613 let digest_bytes = digest_bytes.collect_vec();
614 let timestamp = start_write_timestamp.clone() + AB::Expr::from_canonical_usize(i);
615 self.memory_bridge
616 .write(
617 MemoryAddress::new(
618 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
619 dst.clone() + AB::F::from_canonical_usize(i * KECCAK_WORD_SIZE),
620 ),
621 digest_bytes.try_into().unwrap(),
622 timestamp,
623 &mem_aux[i],
624 )
625 .eval(builder, local.inner.export)
626 }
627 }
628
629 pub fn timestamp_change<T: FieldAlgebra>(len: impl Into<T>) -> T {
632 len.into()
636 + T::from_canonical_usize(
637 KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES,
638 )
639 }
640}