openvm_keccak256_circuit/
air.rs

1use std::{array::from_fn, borrow::Borrow, iter::zip};
2
3use itertools::{izip, Itertools};
4use openvm_circuit::{
5    arch::{ExecutionBridge, ExecutionState},
6    system::memory::{
7        offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
8        MemoryAddress,
9    },
10};
11use openvm_circuit_primitives::{
12    bitwise_op_lookup::BitwiseOperationLookupBus,
13    utils::{assert_array_eq, not, select},
14};
15use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS};
16use openvm_keccak256_transpiler::Rv32KeccakOpcode;
17use openvm_rv32im_circuit::adapters::abstract_compose;
18use openvm_stark_backend::{
19    air_builders::sub::SubAirBuilder,
20    interaction::InteractionBuilder,
21    p3_air::{Air, AirBuilder, BaseAir},
22    p3_field::PrimeCharacteristicRing,
23    p3_matrix::Matrix,
24    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
25};
26use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, U64_LIMBS};
27
28use super::{
29    columns::{KeccakVmCols, NUM_KECCAK_VM_COLS},
30    KECCAK_ABSORB_READS, KECCAK_DIGEST_BYTES, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES,
31    KECCAK_RATE_U16S, KECCAK_REGISTER_READS, KECCAK_WIDTH_U16S, KECCAK_WORD_SIZE,
32    NUM_ABSORB_ROUNDS,
33};
34
35#[derive(Clone, Copy, Debug, derive_new::new)]
36pub struct KeccakVmAir {
37    pub execution_bridge: ExecutionBridge,
38    pub memory_bridge: MemoryBridge,
39    /// Bus to send 8-bit XOR requests to.
40    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
41    /// Maximum number of bits allowed for an address pointer
42    pub ptr_max_bits: usize,
43    pub(super) offset: usize,
44}
45
46impl<F> BaseAirWithPublicValues<F> for KeccakVmAir {}
47impl<F> PartitionedBaseAir<F> for KeccakVmAir {}
48impl<F> BaseAir<F> for KeccakVmAir {
49    fn width(&self) -> usize {
50        NUM_KECCAK_VM_COLS
51    }
52}
53
54impl<AB: InteractionBuilder> Air<AB> for KeccakVmAir {
55    fn eval(&self, builder: &mut AB) {
56        let main = builder.main();
57        let (local, next) = (
58            main.row_slice(0).expect("window should have two elements"),
59            main.row_slice(1).expect("window should have two elements"),
60        );
61        let local: &KeccakVmCols<AB::Var> = (*local).borrow();
62        let next: &KeccakVmCols<AB::Var> = (*next).borrow();
63
64        builder.assert_bool(local.sponge.is_new_start);
65        builder.assert_eq(
66            local.sponge.is_new_start,
67            local.sponge.is_new_start * local.is_first_round(),
68        );
69        builder.assert_eq(
70            local.instruction.is_enabled_first_round,
71            local.instruction.is_enabled * local.is_first_round(),
72        );
73        // Not strictly necessary:
74        builder
75            .when_first_row()
76            .assert_one(local.sponge.is_new_start);
77
78        self.eval_keccak_f(builder);
79        self.constrain_padding(builder, local, next);
80        self.constrain_consistency_across_rounds(builder, local, next);
81
82        let mem = &local.mem_oc;
83        // Interactions:
84        self.constrain_absorb(builder, local, next);
85        let start_read_timestamp = self.eval_instruction(builder, local, &mem.register_aux);
86        let start_write_timestamp =
87            self.constrain_input_read(builder, local, start_read_timestamp, &mem.absorb_reads);
88        self.constrain_output_write(
89            builder,
90            local,
91            start_write_timestamp.clone(),
92            &mem.digest_writes,
93        );
94
95        self.constrain_block_transition(builder, local, next, start_write_timestamp);
96    }
97}
98
99impl KeccakVmAir {
100    /// Evaluate the keccak-f permutation constraints.
101    ///
102    /// WARNING: The keccak-f AIR columns **must** be the first columns in the main AIR.
103    #[inline]
104    pub fn eval_keccak_f<AB: AirBuilder>(&self, builder: &mut AB) {
105        let keccak_f_air = KeccakAir {};
106        let mut sub_builder =
107            SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, 0..NUM_KECCAK_PERM_COLS);
108        keccak_f_air.eval(&mut sub_builder);
109    }
110
111    /// Many columns are expected to be the same between rounds and only change per-block.
112    pub fn constrain_consistency_across_rounds<AB: AirBuilder<Var: Copy>>(
113        &self,
114        builder: &mut AB,
115        local: &KeccakVmCols<AB::Var>,
116        next: &KeccakVmCols<AB::Var>,
117    ) {
118        let mut transition_builder = builder.when_transition();
119        let mut round_builder = transition_builder.when(not(local.is_last_round()));
120        // Instruction columns
121        local
122            .instruction
123            .assert_eq(&mut round_builder, next.instruction);
124    }
125
126    pub fn constrain_block_transition<AB: AirBuilder<Var: Copy>>(
127        &self,
128        builder: &mut AB,
129        local: &KeccakVmCols<AB::Var>,
130        next: &KeccakVmCols<AB::Var>,
131        start_write_timestamp: AB::Expr,
132    ) {
133        // When we transition between blocks, if the next block isn't a new block
134        // (this means it's not receiving a new opcode or starting a dummy block)
135        // then we want _parts_ of opcode instruction to stay the same
136        // between blocks.
137        let mut block_transition = builder.when(local.is_last_round() * not(next.is_new_start()));
138        block_transition.assert_eq(local.instruction.pc, next.instruction.pc);
139        block_transition.assert_eq(local.instruction.is_enabled, next.instruction.is_enabled);
140        // dst is only going to be used for writes in the last input block
141        assert_array_eq(
142            &mut block_transition,
143            local.instruction.dst,
144            next.instruction.dst,
145        );
146        // these are not used and hence not necessary, but putting for safety until performance
147        // becomes an issue:
148        block_transition.assert_eq(local.instruction.dst_ptr, next.instruction.dst_ptr);
149        block_transition.assert_eq(local.instruction.src_ptr, next.instruction.src_ptr);
150        block_transition.assert_eq(local.instruction.len_ptr, next.instruction.len_ptr);
151        // no constraint on `instruction.len` because we use `remaining_len` instead
152
153        // Move the src pointer over based on the number of bytes read.
154        // This should always be RATE_BYTES since it's a non-final block.
155        block_transition.assert_eq(
156            next.instruction.src,
157            local.instruction.src + AB::F::from_usize(KECCAK_RATE_BYTES),
158        );
159        // Advance timestamp by the number of memory accesses from reading
160        // `dst, src, len` and block input bytes.
161        block_transition.assert_eq(next.instruction.start_timestamp, start_write_timestamp);
162        block_transition.assert_eq(
163            next.instruction.remaining_len,
164            local.instruction.remaining_len - AB::F::from_usize(KECCAK_RATE_BYTES),
165        );
166        // Padding transition is constrained in `constrain_padding`.
167    }
168
169    /// Keccak follows the 10*1 padding rule.
170    /// See Section 5.1 of <https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf>
171    /// Note this is the ONLY difference between Keccak and SHA-3
172    ///
173    /// Constrains padding constraints and length between rounds and
174    /// between blocks. Padding logic is tied to constraints on `is_new_start`.
175    pub fn constrain_padding<AB: AirBuilder>(
176        &self,
177        builder: &mut AB,
178        local: &KeccakVmCols<AB::Var>,
179        next: &KeccakVmCols<AB::Var>,
180    ) where
181        AB::Var: Copy,
182    {
183        let is_padding_byte = local.sponge.is_padding_byte;
184        let block_bytes = &local.sponge.block_bytes;
185        let remaining_len = local.remaining_len();
186
187        // is_padding_byte should all be boolean
188        for &is_padding_byte in is_padding_byte.iter() {
189            builder.assert_bool(is_padding_byte);
190        }
191        // is_padding_byte should transition from 0 to 1 only once and then stay 1
192        for i in 1..KECCAK_RATE_BYTES {
193            builder
194                .when(is_padding_byte[i - 1])
195                .assert_one(is_padding_byte[i]);
196        }
197        // is_padding_byte must stay the same on all rounds in a block
198        // we use next instead of local.step_flags.last() because the last row of the trace overall
199        // may not end on a last round
200        let is_last_round = next.inner.step_flags[0];
201        let is_not_last_round = not(is_last_round);
202        for i in 0..KECCAK_RATE_BYTES {
203            builder.when(is_not_last_round.clone()).assert_eq(
204                local.sponge.is_padding_byte[i],
205                next.sponge.is_padding_byte[i],
206            );
207        }
208
209        let num_padding_bytes = local
210            .sponge
211            .is_padding_byte
212            .iter()
213            .fold(AB::Expr::ZERO, |a, &b| a + b);
214
215        // If final rate block of input, then last byte must be padding
216        let is_final_block = is_padding_byte[KECCAK_RATE_BYTES - 1];
217
218        // is_padding_byte must be consistent with remaining_len
219        builder.when(is_final_block).assert_eq(
220            remaining_len,
221            AB::Expr::from_usize(KECCAK_RATE_BYTES) - num_padding_bytes,
222        );
223        // If this block is not final, when transitioning to next block, remaining len
224        // must decrease by `KECCAK_RATE_BYTES`.
225        builder
226            .when(is_last_round)
227            .when(not(is_final_block))
228            .assert_eq(
229                remaining_len - AB::F::from_usize(KECCAK_RATE_BYTES),
230                next.remaining_len(),
231            );
232        // To enforce that is_padding_byte must be set appropriately for an input, we require
233        // the block before a new start to have padding
234        builder
235            .when(is_last_round)
236            .when(next.is_new_start())
237            .assert_one(is_final_block);
238        // Make sure there are not repeated padding blocks
239        builder
240            .when(is_last_round)
241            .when(is_final_block)
242            .assert_one(next.is_new_start());
243        // The chain above enforces that for an input, the remaining length must decrease by RATE
244        // block-by-block until it reaches a final block with padding.
245
246        // ====== Constrain the block_bytes are padded according to is_padding_byte =====
247
248        // If the first padding byte is at the end of the block, then the block has a
249        // single padding byte
250        let has_single_padding_byte: AB::Expr =
251            is_padding_byte[KECCAK_RATE_BYTES - 1] - is_padding_byte[KECCAK_RATE_BYTES - 2];
252
253        // If the row has a single padding byte, then it must be the last byte with
254        // value 0b10000001
255        builder.when(has_single_padding_byte.clone()).assert_eq(
256            block_bytes[KECCAK_RATE_BYTES - 1],
257            AB::F::from_u8(0b10000001),
258        );
259
260        let has_multiple_padding_bytes: AB::Expr = not(has_single_padding_byte.clone());
261        for i in 0..KECCAK_RATE_BYTES - 1 {
262            let is_first_padding_byte: AB::Expr = {
263                if i > 0 {
264                    is_padding_byte[i] - is_padding_byte[i - 1]
265                } else {
266                    is_padding_byte[i].into()
267                }
268            };
269            // If the row has multiple padding bytes, the first padding byte must be 0x01
270            // because the padding 1*0 is *little-endian*
271            builder
272                .when(has_multiple_padding_bytes.clone())
273                .when(is_first_padding_byte.clone())
274                .assert_eq(block_bytes[i], AB::F::from_u8(0x01));
275            // If the row has multiple padding bytes, the other padding bytes
276            // except the last one must be 0
277            builder
278                .when(is_padding_byte[i])
279                .when(not::<AB::Expr>(is_first_padding_byte)) // hence never when single padding byte
280                .assert_zero(block_bytes[i]);
281        }
282
283        // If the row has multiple padding bytes, then the last byte must be 0x80
284        // because the padding *01 is *little-endian*
285        builder
286            .when(is_final_block)
287            .when(has_multiple_padding_bytes)
288            .assert_eq(block_bytes[KECCAK_RATE_BYTES - 1], AB::F::from_u8(0x80));
289    }
290
291    /// Constrain state transition between keccak-f permutations is valid absorb of input bytes.
292    /// The end-state in last round is given by `a_prime_prime_prime()` in `u16` limbs.
293    /// The pre-state is given by `preimage` also in `u16` limbs.
294    /// The input `block_bytes` will be given as **bytes**.
295    ///
296    /// We will XOR `block_bytes` with `a_prime_prime_prime()` and constrain to be `next.preimage`.
297    /// This will be done using 8-bit XOR lookup in a separate AIR via interactions.
298    /// This will require decomposing `u16` into bytes.
299    /// Note that the XOR lookup automatically range checks its inputs to be bytes.
300    ///
301    /// We use the following trick to keep `u16` limbs and avoid changing
302    /// the `keccak-f` AIR itself:
303    /// if we already have a 16-bit limb `x` and we also provide a 8-bit limb
304    /// `hi = x >> 8`, assuming `x` and `hi` have been range checked,
305    /// we can use the expression `lo = x - hi * 256` for the low byte.
306    /// If `lo` is range checked to `8`-bits, this constrains a valid byte
307    ///  decomposition of `x` into `hi, lo`.
308    /// This means in terms of trace cells, it is equivalent to provide
309    /// `x, hi` versus `hi, lo`.
310    pub fn constrain_absorb<AB: InteractionBuilder>(
311        &self,
312        builder: &mut AB,
313        local: &KeccakVmCols<AB::Var>,
314        next: &KeccakVmCols<AB::Var>,
315    ) {
316        let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
317            let y = i / 5;
318            let x = i % 5;
319            (0..U64_LIMBS).flat_map(move |limb| {
320                let state_limb = local.postimage(y, x, limb);
321                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
322                let lo = state_limb - hi * AB::F::from_u64(1 << 8);
323                // Conversion from bytes to u64 is little-endian
324                [lo, hi.into()]
325            })
326        });
327
328        let post_absorb_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
329            let y = i / 5;
330            let x = i % 5;
331            (0..U64_LIMBS).flat_map(move |limb| {
332                let state_limb = next.inner.preimage[y][x][limb];
333                let hi = next.sponge.state_hi[i * U64_LIMBS + limb];
334                let lo = state_limb - hi * AB::F::from_u64(1 << 8);
335                [lo, hi.into()]
336            })
337        });
338
339        // We xor on last round of each block, even if it is a final block,
340        // because we use xor to range check the output bytes (= updated_state_bytes)
341        let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
342        for (input, prev, post) in izip!(
343            next.sponge.block_bytes,
344            updated_state_bytes,
345            post_absorb_state_bytes
346        ) {
347            // Add new send interaction to lookup (x, y, x ^ y) where x, y, z
348            // will all be range checked to be 8-bits (assuming the bus is
349            // received by an 8-bit xor chip).
350
351            // When absorb, input ^ prev = post
352            // Otherwise, 0 ^ prev = prev
353            // The interaction fields are degree 2, leading to degree 3 constraint
354            self.bitwise_lookup_bus
355                .send_xor(
356                    input * not(is_final_block),
357                    prev.clone(),
358                    select(is_final_block, prev, post),
359                )
360                .eval(
361                    builder,
362                    local.is_last_round() * local.instruction.is_enabled,
363                );
364        }
365
366        // We separately constrain that when(local.is_new_start), the preimage (u16s) equals the
367        // block bytes
368        let local_preimage_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
369            let y = i / 5;
370            let x = i % 5;
371            (0..U64_LIMBS).flat_map(move |limb| {
372                let state_limb = local.inner.preimage[y][x][limb];
373                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
374                let lo = state_limb - hi * AB::F::from_u64(1 << 8);
375                [lo, hi.into()]
376            })
377        });
378        let mut when_is_new_start =
379            builder.when(local.is_new_start() * local.instruction.is_enabled);
380        for (preimage_byte, block_byte) in zip(local_preimage_bytes, local.sponge.block_bytes) {
381            when_is_new_start.assert_eq(preimage_byte, block_byte);
382        }
383
384        // constrain transition on the state outside rate
385        let mut reset_builder = builder.when(local.is_new_start());
386        for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
387            let y = i / U64_LIMBS / 5;
388            let x = (i / U64_LIMBS) % 5;
389            let limb = i % U64_LIMBS;
390            reset_builder.assert_zero(local.inner.preimage[y][x][limb]);
391        }
392        let mut absorb_builder = builder.when(local.is_last_round() * not(next.is_new_start()));
393        for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
394            let y = i / U64_LIMBS / 5;
395            let x = (i / U64_LIMBS) % 5;
396            let limb = i % U64_LIMBS;
397            absorb_builder.assert_eq(local.postimage(y, x, limb), next.inner.preimage[y][x][limb]);
398        }
399    }
400
401    /// Receive the instruction itself on program bus. Send+receive on execution bus.
402    /// Then does memory read in addr space 1 to get `dst, src, len` from memory.
403    ///
404    /// Adds range check interactions for the most significant limbs of the register values
405    /// using BitwiseOperationLookupBus.
406    ///
407    /// Returns `start_read_timestamp` which is only relevant when `local.instruction.is_enabled`.
408    /// Note that `start_read_timestamp` is a linear expression.
409    pub fn eval_instruction<AB: InteractionBuilder>(
410        &self,
411        builder: &mut AB,
412        local: &KeccakVmCols<AB::Var>,
413        register_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_REGISTER_READS],
414    ) -> AB::Expr {
415        let instruction = local.instruction;
416        // Only receive opcode if:
417        // - enabled row (not dummy row)
418        // - first round of block
419        // - is_new_start
420        // Note this is degree 3, which results in quotient degree 2 if used
421        // as `count` in interaction
422        let should_receive = local.instruction.is_enabled * local.sponge.is_new_start;
423
424        let [dst_ptr, src_ptr, len_ptr] = [
425            instruction.dst_ptr,
426            instruction.src_ptr,
427            instruction.len_ptr,
428        ];
429        let reg_addr_sp = AB::F::ONE;
430        let timestamp_change: AB::Expr = Self::timestamp_change(instruction.remaining_len);
431        self.execution_bridge
432            .execute_and_increment_pc(
433                AB::Expr::from_usize(Rv32KeccakOpcode::KECCAK256 as usize + self.offset),
434                [
435                    dst_ptr.into(),
436                    src_ptr.into(),
437                    len_ptr.into(),
438                    reg_addr_sp.into(),
439                    AB::Expr::from_u32(RV32_MEMORY_AS),
440                ],
441                ExecutionState::new(instruction.pc, instruction.start_timestamp),
442                timestamp_change,
443            )
444            .eval(builder, should_receive.clone());
445
446        let mut timestamp: AB::Expr = instruction.start_timestamp.into();
447        let recover_limbs = |limbs: [AB::Var; RV32_REGISTER_NUM_LIMBS - 1],
448                             val: AB::Var|
449         -> [AB::Expr; RV32_REGISTER_NUM_LIMBS] {
450            from_fn(|i| {
451                if i == 0 {
452                    limbs
453                        .into_iter()
454                        .enumerate()
455                        .fold(val.into(), |acc, (j, limb)| {
456                            acc - limb * AB::Expr::from_usize(1 << ((j + 1) * RV32_CELL_BITS))
457                        })
458                } else {
459                    limbs[i - 1].into()
460                }
461            })
462        };
463        // Only when it is an input do we want to do memory read for
464        // dst <- word[a]_d, src <- word[b]_d
465        let dst_data = instruction.dst.map(Into::into);
466        let src_data = recover_limbs(instruction.src_limbs, instruction.src);
467        let len_data = recover_limbs(instruction.len_limbs, instruction.remaining_len);
468        for (ptr, value, aux) in izip!(
469            [dst_ptr, src_ptr, len_ptr],
470            [dst_data, src_data, len_data],
471            register_aux,
472        ) {
473            self.memory_bridge
474                .read(
475                    MemoryAddress::new(reg_addr_sp, ptr),
476                    value,
477                    timestamp.clone(),
478                    aux,
479                )
480                .eval(builder, should_receive.clone());
481
482            timestamp += AB::Expr::ONE;
483        }
484        // See Rv32VecHeapAdapterAir
485        // repeat len for even number
486        // We range check `len` to `max_ptr_bits` to ensure `remaining_len` doesn't overflow.
487        // We could range check it to some other size, but `max_ptr_bits` is convenient.
488        let need_range_check = [
489            *instruction.dst.last().unwrap(),
490            *instruction.src_limbs.last().unwrap(),
491            *instruction.len_limbs.last().unwrap(),
492            *instruction.len_limbs.last().unwrap(),
493        ];
494        let limb_shift =
495            AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.ptr_max_bits));
496        for pair in need_range_check.chunks_exact(2) {
497            self.bitwise_lookup_bus
498                .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
499                .eval(builder, should_receive.clone());
500        }
501
502        timestamp
503    }
504
505    /// Constrain reading the input as `block_bytes` from memory.
506    /// Reads input based on `is_padding_byte`.
507    /// Constrains timestamp transitions between blocks if input crosses blocks.
508    ///
509    /// Expects `start_read_timestamp` to be a linear expression.
510    /// Returns the `start_write_timestamp` which is the timestamp to start from
511    /// for writing digest to memory.
512    pub fn constrain_input_read<AB: InteractionBuilder>(
513        &self,
514        builder: &mut AB,
515        local: &KeccakVmCols<AB::Var>,
516        start_read_timestamp: AB::Expr,
517        mem_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_ABSORB_READS],
518    ) -> AB::Expr {
519        let partial_block = &local.mem_oc.partial_block;
520        // Only read input from memory when it is an opcode-related row
521        // and only on the first round of block
522        let is_input = local.instruction.is_enabled_first_round;
523
524        let mut timestamp = start_read_timestamp;
525        // read `state` into `word[src + ...]_e`
526        // iterator of state as u16:
527        for (i, (input, is_padding, mem_aux)) in izip!(
528            local.sponge.block_bytes.chunks_exact(KECCAK_WORD_SIZE),
529            local.sponge.is_padding_byte.chunks_exact(KECCAK_WORD_SIZE),
530            mem_aux
531        )
532        .enumerate()
533        {
534            let ptr = local.instruction.src + AB::F::from_usize(i * KECCAK_WORD_SIZE);
535            // Only read block i if it is not entirely padding bytes
536            // count is degree 2
537            let count = is_input * not(is_padding[0]);
538            // The memory block read is partial if first byte is not padding but the last byte is
539            // padding. Since `count` is only 1 when first byte isn't padding, use check just if
540            // last byte is padding.
541            let is_partial_read = *is_padding.last().unwrap();
542            // word is degree 2
543            let word: [_; KECCAK_WORD_SIZE] = from_fn(|i| {
544                if i == 0 {
545                    // first byte is always ok
546                    input[0].into()
547                } else {
548                    // use `partial_block` if this is a partial read, otherwise use the normal input
549                    // block
550                    select(is_partial_read, partial_block[i - 1], input[i])
551                }
552            });
553            for i in 1..KECCAK_WORD_SIZE {
554                let not_padding: AB::Expr = not(is_padding[i]);
555                // When not a padding byte, the word byte and input byte must be equal
556                // This is constraint degree 3
557                builder.assert_eq(
558                    not_padding.clone() * word[i].clone(),
559                    not_padding.clone() * input[i],
560                );
561            }
562
563            self.memory_bridge
564                .read(
565                    MemoryAddress::new(AB::Expr::from_u32(RV32_MEMORY_AS), ptr),
566                    word, // degree 2
567                    timestamp.clone(),
568                    mem_aux,
569                )
570                .eval(builder, count);
571
572            timestamp += AB::Expr::ONE;
573        }
574        timestamp
575    }
576
577    pub fn constrain_output_write<AB: InteractionBuilder>(
578        &self,
579        builder: &mut AB,
580        local: &KeccakVmCols<AB::Var>,
581        start_write_timestamp: AB::Expr,
582        mem_aux: &[MemoryWriteAuxCols<AB::Var, KECCAK_WORD_SIZE>; KECCAK_DIGEST_WRITES],
583    ) {
584        let instruction = local.instruction;
585
586        let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
587        // since keccak-f AIR has this column, we might as well use it
588        builder.assert_eq(
589            local.inner.export,
590            instruction.is_enabled * is_final_block * local.is_last_round(),
591        );
592        // See `constrain_absorb` on how we derive the postimage bytes from u16 limbs
593        // **SAFETY:** we always XOR the final state with 0 in `constrain_absorb`,
594        // so the output bytes **are** range checked.
595        let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
596            let y = i / 5;
597            let x = i % 5;
598            (0..U64_LIMBS).flat_map(move |limb| {
599                let state_limb = local.postimage(y, x, limb);
600                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
601                let lo = state_limb - hi * AB::F::from_u64(1 << 8);
602                // Conversion from bytes to u64 is little-endian
603                [lo, hi.into()]
604            })
605        });
606        let dst = abstract_compose::<AB::Expr, _>(instruction.dst);
607        for (i, digest_bytes) in updated_state_bytes
608            .take(KECCAK_DIGEST_BYTES)
609            .chunks(KECCAK_WORD_SIZE)
610            .into_iter()
611            .enumerate()
612        {
613            let digest_bytes = digest_bytes.collect_vec();
614            let timestamp = start_write_timestamp.clone() + AB::Expr::from_usize(i);
615            self.memory_bridge
616                .write(
617                    MemoryAddress::new(
618                        AB::Expr::from_u32(RV32_MEMORY_AS),
619                        dst.clone() + AB::F::from_usize(i * KECCAK_WORD_SIZE),
620                    ),
621                    digest_bytes.try_into().unwrap(),
622                    timestamp,
623                    &mem_aux[i],
624                )
625                .eval(builder, local.inner.export)
626        }
627    }
628
629    /// Amount to advance timestamp by after execution of one opcode instruction.
630    /// This is an upper bound dependent on the length `len` operand, which is unbounded.
631    pub fn timestamp_change<T: PrimeCharacteristicRing>(len: impl Into<T>) -> T {
632        // actual number is ceil(len / 136) * (3 + 17) + KECCAK_DIGEST_WRITES
633        // digest writes only done on last row of multi-block
634        // add another KECCAK_ABSORB_READS to round up so we don't deal with padding
635        len.into()
636            + T::from_usize(KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES)
637    }
638}