openvm_keccak256_circuit/
air.rs

1use std::{array::from_fn, borrow::Borrow, iter::zip};
2
3use itertools::{izip, Itertools};
4use openvm_circuit::{
5    arch::{ExecutionBridge, ExecutionState},
6    system::memory::{
7        offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
8        MemoryAddress,
9    },
10};
11use openvm_circuit_primitives::{
12    bitwise_op_lookup::BitwiseOperationLookupBus,
13    utils::{assert_array_eq, not, select},
14};
15use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS};
16use openvm_keccak256_transpiler::Rv32KeccakOpcode;
17use openvm_rv32im_circuit::adapters::abstract_compose;
18use openvm_stark_backend::{
19    air_builders::sub::SubAirBuilder,
20    interaction::InteractionBuilder,
21    p3_air::{Air, AirBuilder, BaseAir},
22    p3_field::FieldAlgebra,
23    p3_matrix::Matrix,
24    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
25};
26use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, U64_LIMBS};
27
28use super::{
29    columns::{KeccakVmCols, NUM_KECCAK_VM_COLS},
30    KECCAK_ABSORB_READS, KECCAK_DIGEST_BYTES, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES,
31    KECCAK_RATE_U16S, KECCAK_REGISTER_READS, KECCAK_WIDTH_U16S, KECCAK_WORD_SIZE,
32    NUM_ABSORB_ROUNDS,
33};
34
35#[derive(Clone, Copy, Debug, derive_new::new)]
36pub struct KeccakVmAir {
37    pub execution_bridge: ExecutionBridge,
38    pub memory_bridge: MemoryBridge,
39    /// Bus to send 8-bit XOR requests to.
40    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
41    /// Maximum number of bits allowed for an address pointer
42    pub ptr_max_bits: usize,
43    pub(super) offset: usize,
44}
45
46impl<F> BaseAirWithPublicValues<F> for KeccakVmAir {}
47impl<F> PartitionedBaseAir<F> for KeccakVmAir {}
48impl<F> BaseAir<F> for KeccakVmAir {
49    fn width(&self) -> usize {
50        NUM_KECCAK_VM_COLS
51    }
52}
53
54impl<AB: InteractionBuilder> Air<AB> for KeccakVmAir {
55    fn eval(&self, builder: &mut AB) {
56        let main = builder.main();
57        let (local, next) = (main.row_slice(0), main.row_slice(1));
58        let local: &KeccakVmCols<AB::Var> = (*local).borrow();
59        let next: &KeccakVmCols<AB::Var> = (*next).borrow();
60
61        builder.assert_bool(local.sponge.is_new_start);
62        builder.assert_eq(
63            local.sponge.is_new_start,
64            local.sponge.is_new_start * local.is_first_round(),
65        );
66        builder.assert_eq(
67            local.instruction.is_enabled_first_round,
68            local.instruction.is_enabled * local.is_first_round(),
69        );
70        // Not strictly necessary:
71        builder
72            .when_first_row()
73            .assert_one(local.sponge.is_new_start);
74
75        self.eval_keccak_f(builder);
76        self.constrain_padding(builder, local, next);
77        self.constrain_consistency_across_rounds(builder, local, next);
78
79        let mem = &local.mem_oc;
80        // Interactions:
81        self.constrain_absorb(builder, local, next);
82        let start_read_timestamp = self.eval_instruction(builder, local, &mem.register_aux);
83        let start_write_timestamp =
84            self.constrain_input_read(builder, local, start_read_timestamp, &mem.absorb_reads);
85        self.constrain_output_write(
86            builder,
87            local,
88            start_write_timestamp.clone(),
89            &mem.digest_writes,
90        );
91
92        self.constrain_block_transition(builder, local, next, start_write_timestamp);
93    }
94}
95
96impl KeccakVmAir {
97    /// Evaluate the keccak-f permutation constraints.
98    ///
99    /// WARNING: The keccak-f AIR columns **must** be the first columns in the main AIR.
100    #[inline]
101    pub fn eval_keccak_f<AB: AirBuilder>(&self, builder: &mut AB) {
102        let keccak_f_air = KeccakAir {};
103        let mut sub_builder =
104            SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, 0..NUM_KECCAK_PERM_COLS);
105        keccak_f_air.eval(&mut sub_builder);
106    }
107
108    /// Many columns are expected to be the same between rounds and only change per-block.
109    pub fn constrain_consistency_across_rounds<AB: AirBuilder>(
110        &self,
111        builder: &mut AB,
112        local: &KeccakVmCols<AB::Var>,
113        next: &KeccakVmCols<AB::Var>,
114    ) {
115        let mut transition_builder = builder.when_transition();
116        let mut round_builder = transition_builder.when(not(local.is_last_round()));
117        // Instruction columns
118        local
119            .instruction
120            .assert_eq(&mut round_builder, next.instruction);
121    }
122
123    pub fn constrain_block_transition<AB: AirBuilder>(
124        &self,
125        builder: &mut AB,
126        local: &KeccakVmCols<AB::Var>,
127        next: &KeccakVmCols<AB::Var>,
128        start_write_timestamp: AB::Expr,
129    ) {
130        // When we transition between blocks, if the next block isn't a new block
131        // (this means it's not receiving a new opcode or starting a dummy block)
132        // then we want _parts_ of opcode instruction to stay the same
133        // between blocks.
134        let mut block_transition = builder.when(local.is_last_round() * not(next.is_new_start()));
135        block_transition.assert_eq(local.instruction.pc, next.instruction.pc);
136        block_transition.assert_eq(local.instruction.is_enabled, next.instruction.is_enabled);
137        // dst is only going to be used for writes in the last input block
138        assert_array_eq(
139            &mut block_transition,
140            local.instruction.dst,
141            next.instruction.dst,
142        );
143        // these are not used and hence not necessary, but putting for safety until performance
144        // becomes an issue:
145        block_transition.assert_eq(local.instruction.dst_ptr, next.instruction.dst_ptr);
146        block_transition.assert_eq(local.instruction.src_ptr, next.instruction.src_ptr);
147        block_transition.assert_eq(local.instruction.len_ptr, next.instruction.len_ptr);
148        // no constraint on `instruction.len` because we use `remaining_len` instead
149
150        // Move the src pointer over based on the number of bytes read.
151        // This should always be RATE_BYTES since it's a non-final block.
152        block_transition.assert_eq(
153            next.instruction.src,
154            local.instruction.src + AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
155        );
156        // Advance timestamp by the number of memory accesses from reading
157        // `dst, src, len` and block input bytes.
158        block_transition.assert_eq(next.instruction.start_timestamp, start_write_timestamp);
159        block_transition.assert_eq(
160            next.instruction.remaining_len,
161            local.instruction.remaining_len - AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
162        );
163        // Padding transition is constrained in `constrain_padding`.
164    }
165
166    /// Keccak follows the 10*1 padding rule.
167    /// See Section 5.1 of <https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf>
168    /// Note this is the ONLY difference between Keccak and SHA-3
169    ///
170    /// Constrains padding constraints and length between rounds and
171    /// between blocks. Padding logic is tied to constraints on `is_new_start`.
172    pub fn constrain_padding<AB: AirBuilder>(
173        &self,
174        builder: &mut AB,
175        local: &KeccakVmCols<AB::Var>,
176        next: &KeccakVmCols<AB::Var>,
177    ) {
178        let is_padding_byte = local.sponge.is_padding_byte;
179        let block_bytes = &local.sponge.block_bytes;
180        let remaining_len = local.remaining_len();
181
182        // is_padding_byte should all be boolean
183        for &is_padding_byte in is_padding_byte.iter() {
184            builder.assert_bool(is_padding_byte);
185        }
186        // is_padding_byte should transition from 0 to 1 only once and then stay 1
187        for i in 1..KECCAK_RATE_BYTES {
188            builder
189                .when(is_padding_byte[i - 1])
190                .assert_one(is_padding_byte[i]);
191        }
192        // is_padding_byte must stay the same on all rounds in a block
193        // we use next instead of local.step_flags.last() because the last row of the trace overall
194        // may not end on a last round
195        let is_last_round = next.inner.step_flags[0];
196        let is_not_last_round = not(is_last_round);
197        for i in 0..KECCAK_RATE_BYTES {
198            builder.when(is_not_last_round.clone()).assert_eq(
199                local.sponge.is_padding_byte[i],
200                next.sponge.is_padding_byte[i],
201            );
202        }
203
204        let num_padding_bytes = local
205            .sponge
206            .is_padding_byte
207            .iter()
208            .fold(AB::Expr::ZERO, |a, &b| a + b);
209
210        // If final rate block of input, then last byte must be padding
211        let is_final_block = is_padding_byte[KECCAK_RATE_BYTES - 1];
212
213        // is_padding_byte must be consistent with remaining_len
214        builder.when(is_final_block).assert_eq(
215            remaining_len,
216            AB::Expr::from_canonical_usize(KECCAK_RATE_BYTES) - num_padding_bytes,
217        );
218        // If this block is not final, when transitioning to next block, remaining len
219        // must decrease by `KECCAK_RATE_BYTES`.
220        builder
221            .when(is_last_round)
222            .when(not(is_final_block))
223            .assert_eq(
224                remaining_len - AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
225                next.remaining_len(),
226            );
227        // To enforce that is_padding_byte must be set appropriately for an input, we require
228        // the block before a new start to have padding
229        builder
230            .when(is_last_round)
231            .when(next.is_new_start())
232            .assert_one(is_final_block);
233        // Make sure there are not repeated padding blocks
234        builder
235            .when(is_last_round)
236            .when(is_final_block)
237            .assert_one(next.is_new_start());
238        // The chain above enforces that for an input, the remaining length must decrease by RATE
239        // block-by-block until it reaches a final block with padding.
240
241        // ====== Constrain the block_bytes are padded according to is_padding_byte =====
242
243        // If the first padding byte is at the end of the block, then the block has a
244        // single padding byte
245        let has_single_padding_byte: AB::Expr =
246            is_padding_byte[KECCAK_RATE_BYTES - 1] - is_padding_byte[KECCAK_RATE_BYTES - 2];
247
248        // If the row has a single padding byte, then it must be the last byte with
249        // value 0b10000001
250        builder.when(has_single_padding_byte.clone()).assert_eq(
251            block_bytes[KECCAK_RATE_BYTES - 1],
252            AB::F::from_canonical_u8(0b10000001),
253        );
254
255        let has_multiple_padding_bytes: AB::Expr = not(has_single_padding_byte.clone());
256        for i in 0..KECCAK_RATE_BYTES - 1 {
257            let is_first_padding_byte: AB::Expr = {
258                if i > 0 {
259                    is_padding_byte[i] - is_padding_byte[i - 1]
260                } else {
261                    is_padding_byte[i].into()
262                }
263            };
264            // If the row has multiple padding bytes, the first padding byte must be 0x01
265            // because the padding 1*0 is *little-endian*
266            builder
267                .when(has_multiple_padding_bytes.clone())
268                .when(is_first_padding_byte.clone())
269                .assert_eq(block_bytes[i], AB::F::from_canonical_u8(0x01));
270            // If the row has multiple padding bytes, the other padding bytes
271            // except the last one must be 0
272            builder
273                .when(is_padding_byte[i])
274                .when(not::<AB::Expr>(is_first_padding_byte)) // hence never when single padding byte
275                .assert_zero(block_bytes[i]);
276        }
277
278        // If the row has multiple padding bytes, then the last byte must be 0x80
279        // because the padding *01 is *little-endian*
280        builder
281            .when(is_final_block)
282            .when(has_multiple_padding_bytes)
283            .assert_eq(
284                block_bytes[KECCAK_RATE_BYTES - 1],
285                AB::F::from_canonical_u8(0x80),
286            );
287    }
288
289    /// Constrain state transition between keccak-f permutations is valid absorb of input bytes.
290    /// The end-state in last round is given by `a_prime_prime_prime()` in `u16` limbs.
291    /// The pre-state is given by `preimage` also in `u16` limbs.
292    /// The input `block_bytes` will be given as **bytes**.
293    ///
294    /// We will XOR `block_bytes` with `a_prime_prime_prime()` and constrain to be `next.preimage`.
295    /// This will be done using 8-bit XOR lookup in a separate AIR via interactions.
296    /// This will require decomposing `u16` into bytes.
297    /// Note that the XOR lookup automatically range checks its inputs to be bytes.
298    ///
299    /// We use the following trick to keep `u16` limbs and avoid changing
300    /// the `keccak-f` AIR itself:
301    /// if we already have a 16-bit limb `x` and we also provide a 8-bit limb
302    /// `hi = x >> 8`, assuming `x` and `hi` have been range checked,
303    /// we can use the expression `lo = x - hi * 256` for the low byte.
304    /// If `lo` is range checked to `8`-bits, this constrains a valid byte
305    ///  decomposition of `x` into `hi, lo`.
306    /// This means in terms of trace cells, it is equivalent to provide
307    /// `x, hi` versus `hi, lo`.
308    pub fn constrain_absorb<AB: InteractionBuilder>(
309        &self,
310        builder: &mut AB,
311        local: &KeccakVmCols<AB::Var>,
312        next: &KeccakVmCols<AB::Var>,
313    ) {
314        let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
315            let y = i / 5;
316            let x = i % 5;
317            (0..U64_LIMBS).flat_map(move |limb| {
318                let state_limb = local.postimage(y, x, limb);
319                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
320                let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
321                // Conversion from bytes to u64 is little-endian
322                [lo, hi.into()]
323            })
324        });
325
326        let post_absorb_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
327            let y = i / 5;
328            let x = i % 5;
329            (0..U64_LIMBS).flat_map(move |limb| {
330                let state_limb = next.inner.preimage[y][x][limb];
331                let hi = next.sponge.state_hi[i * U64_LIMBS + limb];
332                let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
333                [lo, hi.into()]
334            })
335        });
336
337        // We xor on last round of each block, even if it is a final block,
338        // because we use xor to range check the output bytes (= updated_state_bytes)
339        let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
340        for (input, prev, post) in izip!(
341            next.sponge.block_bytes,
342            updated_state_bytes,
343            post_absorb_state_bytes
344        ) {
345            // Add new send interaction to lookup (x, y, x ^ y) where x, y, z
346            // will all be range checked to be 8-bits (assuming the bus is
347            // received by an 8-bit xor chip).
348
349            // When absorb, input ^ prev = post
350            // Otherwise, 0 ^ prev = prev
351            // The interaction fields are degree 2, leading to degree 3 constraint
352            self.bitwise_lookup_bus
353                .send_xor(
354                    input * not(is_final_block),
355                    prev.clone(),
356                    select(is_final_block, prev, post),
357                )
358                .eval(
359                    builder,
360                    local.is_last_round() * local.instruction.is_enabled,
361                );
362        }
363
364        // We separately constrain that when(local.is_new_start), the preimage (u16s) equals the
365        // block bytes
366        let local_preimage_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
367            let y = i / 5;
368            let x = i % 5;
369            (0..U64_LIMBS).flat_map(move |limb| {
370                let state_limb = local.inner.preimage[y][x][limb];
371                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
372                let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
373                [lo, hi.into()]
374            })
375        });
376        let mut when_is_new_start =
377            builder.when(local.is_new_start() * local.instruction.is_enabled);
378        for (preimage_byte, block_byte) in zip(local_preimage_bytes, local.sponge.block_bytes) {
379            when_is_new_start.assert_eq(preimage_byte, block_byte);
380        }
381
382        // constrain transition on the state outside rate
383        let mut reset_builder = builder.when(local.is_new_start());
384        for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
385            let y = i / U64_LIMBS / 5;
386            let x = (i / U64_LIMBS) % 5;
387            let limb = i % U64_LIMBS;
388            reset_builder.assert_zero(local.inner.preimage[y][x][limb]);
389        }
390        let mut absorb_builder = builder.when(local.is_last_round() * not(next.is_new_start()));
391        for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
392            let y = i / U64_LIMBS / 5;
393            let x = (i / U64_LIMBS) % 5;
394            let limb = i % U64_LIMBS;
395            absorb_builder.assert_eq(local.postimage(y, x, limb), next.inner.preimage[y][x][limb]);
396        }
397    }
398
399    /// Receive the instruction itself on program bus. Send+receive on execution bus.
400    /// Then does memory read in addr space 1 to get `dst, src, len` from memory.
401    ///
402    /// Adds range check interactions for the most significant limbs of the register values
403    /// using BitwiseOperationLookupBus.
404    ///
405    /// Returns `start_read_timestamp` which is only relevant when `local.instruction.is_enabled`.
406    /// Note that `start_read_timestamp` is a linear expression.
407    pub fn eval_instruction<AB: InteractionBuilder>(
408        &self,
409        builder: &mut AB,
410        local: &KeccakVmCols<AB::Var>,
411        register_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_REGISTER_READS],
412    ) -> AB::Expr {
413        let instruction = local.instruction;
414        // Only receive opcode if:
415        // - enabled row (not dummy row)
416        // - first round of block
417        // - is_new_start
418        // Note this is degree 3, which results in quotient degree 2 if used
419        // as `count` in interaction
420        let should_receive = local.instruction.is_enabled * local.sponge.is_new_start;
421
422        let [dst_ptr, src_ptr, len_ptr] = [
423            instruction.dst_ptr,
424            instruction.src_ptr,
425            instruction.len_ptr,
426        ];
427        let reg_addr_sp = AB::F::ONE;
428        let timestamp_change: AB::Expr = Self::timestamp_change(instruction.remaining_len);
429        self.execution_bridge
430            .execute_and_increment_pc(
431                AB::Expr::from_canonical_usize(Rv32KeccakOpcode::KECCAK256 as usize + self.offset),
432                [
433                    dst_ptr.into(),
434                    src_ptr.into(),
435                    len_ptr.into(),
436                    reg_addr_sp.into(),
437                    AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
438                ],
439                ExecutionState::new(instruction.pc, instruction.start_timestamp),
440                timestamp_change,
441            )
442            .eval(builder, should_receive.clone());
443
444        let mut timestamp: AB::Expr = instruction.start_timestamp.into();
445        let recover_limbs = |limbs: [AB::Var; RV32_REGISTER_NUM_LIMBS - 1],
446                             val: AB::Var|
447         -> [AB::Expr; RV32_REGISTER_NUM_LIMBS] {
448            from_fn(|i| {
449                if i == 0 {
450                    limbs
451                        .into_iter()
452                        .enumerate()
453                        .fold(val.into(), |acc, (j, limb)| {
454                            acc - limb
455                                * AB::Expr::from_canonical_usize(1 << ((j + 1) * RV32_CELL_BITS))
456                        })
457                } else {
458                    limbs[i - 1].into()
459                }
460            })
461        };
462        // Only when it is an input do we want to do memory read for
463        // dst <- word[a]_d, src <- word[b]_d
464        let dst_data = instruction.dst.map(Into::into);
465        let src_data = recover_limbs(instruction.src_limbs, instruction.src);
466        let len_data = recover_limbs(instruction.len_limbs, instruction.remaining_len);
467        for (ptr, value, aux) in izip!(
468            [dst_ptr, src_ptr, len_ptr],
469            [dst_data, src_data, len_data],
470            register_aux,
471        ) {
472            self.memory_bridge
473                .read(
474                    MemoryAddress::new(reg_addr_sp, ptr),
475                    value,
476                    timestamp.clone(),
477                    aux,
478                )
479                .eval(builder, should_receive.clone());
480
481            timestamp += AB::Expr::ONE;
482        }
483        // See Rv32VecHeapAdapterAir
484        // repeat len for even number
485        // We range check `len` to `max_ptr_bits` to ensure `remaining_len` doesn't overflow.
486        // We could range check it to some other size, but `max_ptr_bits` is convenient.
487        let need_range_check = [
488            *instruction.dst.last().unwrap(),
489            *instruction.src_limbs.last().unwrap(),
490            *instruction.len_limbs.last().unwrap(),
491            *instruction.len_limbs.last().unwrap(),
492        ];
493        let limb_shift = AB::F::from_canonical_usize(
494            1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.ptr_max_bits),
495        );
496        for pair in need_range_check.chunks_exact(2) {
497            self.bitwise_lookup_bus
498                .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
499                .eval(builder, should_receive.clone());
500        }
501
502        timestamp
503    }
504
505    /// Constrain reading the input as `block_bytes` from memory.
506    /// Reads input based on `is_padding_byte`.
507    /// Constrains timestamp transitions between blocks if input crosses blocks.
508    ///
509    /// Expects `start_read_timestamp` to be a linear expression.
510    /// Returns the `start_write_timestamp` which is the timestamp to start from
511    /// for writing digest to memory.
512    pub fn constrain_input_read<AB: InteractionBuilder>(
513        &self,
514        builder: &mut AB,
515        local: &KeccakVmCols<AB::Var>,
516        start_read_timestamp: AB::Expr,
517        mem_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_ABSORB_READS],
518    ) -> AB::Expr {
519        let partial_block = &local.mem_oc.partial_block;
520        // Only read input from memory when it is an opcode-related row
521        // and only on the first round of block
522        let is_input = local.instruction.is_enabled_first_round;
523
524        let mut timestamp = start_read_timestamp;
525        // read `state` into `word[src + ...]_e`
526        // iterator of state as u16:
527        for (i, (input, is_padding, mem_aux)) in izip!(
528            local.sponge.block_bytes.chunks_exact(KECCAK_WORD_SIZE),
529            local.sponge.is_padding_byte.chunks_exact(KECCAK_WORD_SIZE),
530            mem_aux
531        )
532        .enumerate()
533        {
534            let ptr = local.instruction.src + AB::F::from_canonical_usize(i * KECCAK_WORD_SIZE);
535            // Only read block i if it is not entirely padding bytes
536            // count is degree 2
537            let count = is_input * not(is_padding[0]);
538            // The memory block read is partial if first byte is not padding but the last byte is
539            // padding. Since `count` is only 1 when first byte isn't padding, use check just if
540            // last byte is padding.
541            let is_partial_read = *is_padding.last().unwrap();
542            // word is degree 2
543            let word: [_; KECCAK_WORD_SIZE] = from_fn(|i| {
544                if i == 0 {
545                    // first byte is always ok
546                    input[0].into()
547                } else {
548                    // use `partial_block` if this is a partial read, otherwise use the normal input
549                    // block
550                    select(is_partial_read, partial_block[i - 1], input[i])
551                }
552            });
553            for i in 1..KECCAK_WORD_SIZE {
554                let not_padding: AB::Expr = not(is_padding[i]);
555                // When not a padding byte, the word byte and input byte must be equal
556                // This is constraint degree 3
557                builder.assert_eq(
558                    not_padding.clone() * word[i].clone(),
559                    not_padding.clone() * input[i],
560                );
561            }
562
563            self.memory_bridge
564                .read(
565                    MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), ptr),
566                    word, // degree 2
567                    timestamp.clone(),
568                    mem_aux,
569                )
570                .eval(builder, count);
571
572            timestamp += AB::Expr::ONE;
573        }
574        timestamp
575    }
576
577    pub fn constrain_output_write<AB: InteractionBuilder>(
578        &self,
579        builder: &mut AB,
580        local: &KeccakVmCols<AB::Var>,
581        start_write_timestamp: AB::Expr,
582        mem_aux: &[MemoryWriteAuxCols<AB::Var, KECCAK_WORD_SIZE>; KECCAK_DIGEST_WRITES],
583    ) {
584        let instruction = local.instruction;
585
586        let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
587        // since keccak-f AIR has this column, we might as well use it
588        builder.assert_eq(
589            local.inner.export,
590            instruction.is_enabled * is_final_block * local.is_last_round(),
591        );
592        // See `constrain_absorb` on how we derive the postimage bytes from u16 limbs
593        // **SAFETY:** we always XOR the final state with 0 in `constrain_absorb`,
594        // so the output bytes **are** range checked.
595        let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
596            let y = i / 5;
597            let x = i % 5;
598            (0..U64_LIMBS).flat_map(move |limb| {
599                let state_limb = local.postimage(y, x, limb);
600                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
601                let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
602                // Conversion from bytes to u64 is little-endian
603                [lo, hi.into()]
604            })
605        });
606        let dst = abstract_compose::<AB::Expr, _>(instruction.dst);
607        for (i, digest_bytes) in updated_state_bytes
608            .take(KECCAK_DIGEST_BYTES)
609            .chunks(KECCAK_WORD_SIZE)
610            .into_iter()
611            .enumerate()
612        {
613            let digest_bytes = digest_bytes.collect_vec();
614            let timestamp = start_write_timestamp.clone() + AB::Expr::from_canonical_usize(i);
615            self.memory_bridge
616                .write(
617                    MemoryAddress::new(
618                        AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
619                        dst.clone() + AB::F::from_canonical_usize(i * KECCAK_WORD_SIZE),
620                    ),
621                    digest_bytes.try_into().unwrap(),
622                    timestamp,
623                    &mem_aux[i],
624                )
625                .eval(builder, local.inner.export)
626        }
627    }
628
629    /// Amount to advance timestamp by after execution of one opcode instruction.
630    /// This is an upper bound dependent on the length `len` operand, which is unbounded.
631    pub fn timestamp_change<T: FieldAlgebra>(len: impl Into<T>) -> T {
632        // actual number is ceil(len / 136) * (3 + 17) + KECCAK_DIGEST_WRITES
633        // digest writes only done on last row of multi-block
634        // add another KECCAK_ABSORB_READS to round up so we don't deal with padding
635        len.into()
636            + T::from_canonical_usize(
637                KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES,
638            )
639    }
640}