openvm_keccak256_circuit/
air.rs

1use std::{array::from_fn, borrow::Borrow, iter::zip};
2
3use itertools::{izip, Itertools};
4use openvm_circuit::{
5    arch::{ExecutionBridge, ExecutionState},
6    system::memory::{
7        offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
8        MemoryAddress,
9    },
10};
11use openvm_circuit_primitives::{
12    bitwise_op_lookup::BitwiseOperationLookupBus,
13    utils::{assert_array_eq, not, select},
14};
15use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS};
16use openvm_keccak256_transpiler::Rv32KeccakOpcode;
17use openvm_rv32im_circuit::adapters::abstract_compose;
18use openvm_stark_backend::{
19    air_builders::sub::SubAirBuilder,
20    interaction::InteractionBuilder,
21    p3_air::{Air, AirBuilder, BaseAir},
22    p3_field::FieldAlgebra,
23    p3_matrix::Matrix,
24    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
25};
26use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, U64_LIMBS};
27
28use super::{
29    columns::{KeccakVmCols, NUM_KECCAK_VM_COLS},
30    KECCAK_ABSORB_READS, KECCAK_DIGEST_BYTES, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES,
31    KECCAK_RATE_U16S, KECCAK_REGISTER_READS, KECCAK_WIDTH_U16S, KECCAK_WORD_SIZE,
32    NUM_ABSORB_ROUNDS,
33};
34
35#[derive(Clone, Copy, Debug, derive_new::new)]
36pub struct KeccakVmAir {
37    pub execution_bridge: ExecutionBridge,
38    pub memory_bridge: MemoryBridge,
39    /// Bus to send 8-bit XOR requests to.
40    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
41    /// Maximum number of bits allowed for an address pointer
42    pub ptr_max_bits: usize,
43    pub(super) offset: usize,
44}
45
46impl<F> BaseAirWithPublicValues<F> for KeccakVmAir {}
47impl<F> PartitionedBaseAir<F> for KeccakVmAir {}
48impl<F> BaseAir<F> for KeccakVmAir {
49    fn width(&self) -> usize {
50        NUM_KECCAK_VM_COLS
51    }
52}
53
54impl<AB: InteractionBuilder> Air<AB> for KeccakVmAir {
55    fn eval(&self, builder: &mut AB) {
56        let main = builder.main();
57        let (local, next) = (main.row_slice(0), main.row_slice(1));
58        let local: &KeccakVmCols<AB::Var> = (*local).borrow();
59        let next: &KeccakVmCols<AB::Var> = (*next).borrow();
60
61        builder.assert_bool(local.sponge.is_new_start);
62        builder.assert_eq(
63            local.sponge.is_new_start,
64            local.sponge.is_new_start * local.is_first_round(),
65        );
66        builder.assert_eq(
67            local.instruction.is_enabled_first_round,
68            local.instruction.is_enabled * local.is_first_round(),
69        );
70        // Not strictly necessary:
71        builder
72            .when_first_row()
73            .assert_one(local.sponge.is_new_start);
74
75        self.eval_keccak_f(builder);
76        self.constrain_padding(builder, local, next);
77        self.constrain_consistency_across_rounds(builder, local, next);
78
79        let mem = &local.mem_oc;
80        // Interactions:
81        self.constrain_absorb(builder, local, next);
82        let start_read_timestamp = self.eval_instruction(builder, local, &mem.register_aux);
83        let start_write_timestamp =
84            self.constrain_input_read(builder, local, start_read_timestamp, &mem.absorb_reads);
85        self.constrain_output_write(
86            builder,
87            local,
88            start_write_timestamp.clone(),
89            &mem.digest_writes,
90        );
91
92        self.constrain_block_transition(builder, local, next, start_write_timestamp);
93    }
94}
95
96impl KeccakVmAir {
97    /// Evaluate the keccak-f permutation constraints.
98    ///
99    /// WARNING: The keccak-f AIR columns **must** be the first columns in the main AIR.
100    #[inline]
101    pub fn eval_keccak_f<AB: AirBuilder>(&self, builder: &mut AB) {
102        let keccak_f_air = KeccakAir {};
103        let mut sub_builder =
104            SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, 0..NUM_KECCAK_PERM_COLS);
105        keccak_f_air.eval(&mut sub_builder);
106    }
107
108    /// Many columns are expected to be the same between rounds and only change per-block.
109    pub fn constrain_consistency_across_rounds<AB: AirBuilder>(
110        &self,
111        builder: &mut AB,
112        local: &KeccakVmCols<AB::Var>,
113        next: &KeccakVmCols<AB::Var>,
114    ) {
115        let mut transition_builder = builder.when_transition();
116        let mut round_builder = transition_builder.when(not(local.is_last_round()));
117        // Instruction columns
118        local
119            .instruction
120            .assert_eq(&mut round_builder, next.instruction);
121    }
122
123    pub fn constrain_block_transition<AB: AirBuilder>(
124        &self,
125        builder: &mut AB,
126        local: &KeccakVmCols<AB::Var>,
127        next: &KeccakVmCols<AB::Var>,
128        start_write_timestamp: AB::Expr,
129    ) {
130        // When we transition between blocks, if the next block isn't a new block
131        // (this means it's not receiving a new opcode or starting a dummy block)
132        // then we want _parts_ of opcode instruction to stay the same
133        // between blocks.
134        let mut block_transition = builder.when(local.is_last_round() * not(next.is_new_start()));
135        block_transition.assert_eq(local.instruction.pc, next.instruction.pc);
136        block_transition.assert_eq(local.instruction.is_enabled, next.instruction.is_enabled);
137        // dst is only going to be used for writes in the last input block
138        assert_array_eq(
139            &mut block_transition,
140            local.instruction.dst,
141            next.instruction.dst,
142        );
143        // these are not used and hence not necessary, but putting for safety until performance becomes an issue:
144        block_transition.assert_eq(local.instruction.dst_ptr, next.instruction.dst_ptr);
145        block_transition.assert_eq(local.instruction.src_ptr, next.instruction.src_ptr);
146        block_transition.assert_eq(local.instruction.len_ptr, next.instruction.len_ptr);
147        // no constraint on `instruction.len` because we use `remaining_len` instead
148
149        // Move the src pointer over based on the number of bytes read.
150        // This should always be RATE_BYTES since it's a non-final block.
151        block_transition.assert_eq(
152            next.instruction.src,
153            local.instruction.src + AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
154        );
155        // Advance timestamp by the number of memory accesses from reading
156        // `dst, src, len` and block input bytes.
157        block_transition.assert_eq(next.instruction.start_timestamp, start_write_timestamp);
158        block_transition.assert_eq(
159            next.instruction.remaining_len,
160            local.instruction.remaining_len - AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
161        );
162        // Padding transition is constrained in `constrain_padding`.
163    }
164
165    /// Keccak follows the 10*1 padding rule.
166    /// See Section 5.1 of <https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf>
167    /// Note this is the ONLY difference between Keccak and SHA-3
168    ///
169    /// Constrains padding constraints and length between rounds and
170    /// between blocks. Padding logic is tied to constraints on `is_new_start`.
171    pub fn constrain_padding<AB: AirBuilder>(
172        &self,
173        builder: &mut AB,
174        local: &KeccakVmCols<AB::Var>,
175        next: &KeccakVmCols<AB::Var>,
176    ) {
177        let is_padding_byte = local.sponge.is_padding_byte;
178        let block_bytes = &local.sponge.block_bytes;
179        let remaining_len = local.remaining_len();
180
181        // is_padding_byte should all be boolean
182        for &is_padding_byte in is_padding_byte.iter() {
183            builder.assert_bool(is_padding_byte);
184        }
185        // is_padding_byte should transition from 0 to 1 only once and then stay 1
186        for i in 1..KECCAK_RATE_BYTES {
187            builder
188                .when(is_padding_byte[i - 1])
189                .assert_one(is_padding_byte[i]);
190        }
191        // is_padding_byte must stay the same on all rounds in a block
192        // we use next instead of local.step_flags.last() because the last row of the trace overall may not
193        // end on a last round
194        let is_last_round = next.inner.step_flags[0];
195        let is_not_last_round = not(is_last_round);
196        for i in 0..KECCAK_RATE_BYTES {
197            builder.when(is_not_last_round.clone()).assert_eq(
198                local.sponge.is_padding_byte[i],
199                next.sponge.is_padding_byte[i],
200            );
201        }
202
203        let num_padding_bytes = local
204            .sponge
205            .is_padding_byte
206            .iter()
207            .fold(AB::Expr::ZERO, |a, &b| a + b);
208
209        // If final rate block of input, then last byte must be padding
210        let is_final_block = is_padding_byte[KECCAK_RATE_BYTES - 1];
211
212        // is_padding_byte must be consistent with remaining_len
213        builder.when(is_final_block).assert_eq(
214            remaining_len,
215            AB::Expr::from_canonical_usize(KECCAK_RATE_BYTES) - num_padding_bytes,
216        );
217        // If this block is not final, when transitioning to next block, remaining len
218        // must decrease by `KECCAK_RATE_BYTES`.
219        builder
220            .when(is_last_round)
221            .when(not(is_final_block))
222            .assert_eq(
223                remaining_len - AB::F::from_canonical_usize(KECCAK_RATE_BYTES),
224                next.remaining_len(),
225            );
226        // To enforce that is_padding_byte must be set appropriately for an input, we require
227        // the block before a new start to have padding
228        builder
229            .when(is_last_round)
230            .when(next.is_new_start())
231            .assert_one(is_final_block);
232        // Make sure there are not repeated padding blocks
233        builder
234            .when(is_last_round)
235            .when(is_final_block)
236            .assert_one(next.is_new_start());
237        // The chain above enforces that for an input, the remaining length must decrease by RATE
238        // block-by-block until it reaches a final block with padding.
239
240        // ====== Constrain the block_bytes are padded according to is_padding_byte =====
241
242        // If the first padding byte is at the end of the block, then the block has a
243        // single padding byte
244        let has_single_padding_byte: AB::Expr =
245            is_padding_byte[KECCAK_RATE_BYTES - 1] - is_padding_byte[KECCAK_RATE_BYTES - 2];
246
247        // If the row has a single padding byte, then it must be the last byte with
248        // value 0b10000001
249        builder.when(has_single_padding_byte.clone()).assert_eq(
250            block_bytes[KECCAK_RATE_BYTES - 1],
251            AB::F::from_canonical_u8(0b10000001),
252        );
253
254        let has_multiple_padding_bytes: AB::Expr = not(has_single_padding_byte.clone());
255        for i in 0..KECCAK_RATE_BYTES - 1 {
256            let is_first_padding_byte: AB::Expr = {
257                if i > 0 {
258                    is_padding_byte[i] - is_padding_byte[i - 1]
259                } else {
260                    is_padding_byte[i].into()
261                }
262            };
263            // If the row has multiple padding bytes, the first padding byte must be 0x01
264            // because the padding 1*0 is *little-endian*
265            builder
266                .when(has_multiple_padding_bytes.clone())
267                .when(is_first_padding_byte.clone())
268                .assert_eq(block_bytes[i], AB::F::from_canonical_u8(0x01));
269            // If the row has multiple padding bytes, the other padding bytes
270            // except the last one must be 0
271            builder
272                .when(is_padding_byte[i])
273                .when(not::<AB::Expr>(is_first_padding_byte)) // hence never when single padding byte
274                .assert_zero(block_bytes[i]);
275        }
276
277        // If the row has multiple padding bytes, then the last byte must be 0x80
278        // because the padding *01 is *little-endian*
279        builder
280            .when(is_final_block)
281            .when(has_multiple_padding_bytes)
282            .assert_eq(
283                block_bytes[KECCAK_RATE_BYTES - 1],
284                AB::F::from_canonical_u8(0x80),
285            );
286    }
287
288    /// Constrain state transition between keccak-f permutations is valid absorb of input bytes.
289    /// The end-state in last round is given by `a_prime_prime_prime()` in `u16` limbs.
290    /// The pre-state is given by `preimage` also in `u16` limbs.
291    /// The input `block_bytes` will be given as **bytes**.
292    ///
293    /// We will XOR `block_bytes` with `a_prime_prime_prime()` and constrain to be `next.preimage`.
294    /// This will be done using 8-bit XOR lookup in a separate AIR via interactions.
295    /// This will require decomposing `u16` into bytes.
296    /// Note that the XOR lookup automatically range checks its inputs to be bytes.
297    ///
298    /// We use the following trick to keep `u16` limbs and avoid changing
299    /// the `keccak-f` AIR itself:
300    /// if we already have a 16-bit limb `x` and we also provide a 8-bit limb
301    /// `hi = x >> 8`, assuming `x` and `hi` have been range checked,
302    /// we can use the expression `lo = x - hi * 256` for the low byte.
303    /// If `lo` is range checked to `8`-bits, this constrains a valid byte
304    ///  decomposition of `x` into `hi, lo`.
305    /// This means in terms of trace cells, it is equivalent to provide
306    /// `x, hi` versus `hi, lo`.
307    pub fn constrain_absorb<AB: InteractionBuilder>(
308        &self,
309        builder: &mut AB,
310        local: &KeccakVmCols<AB::Var>,
311        next: &KeccakVmCols<AB::Var>,
312    ) {
313        let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
314            let y = i / 5;
315            let x = i % 5;
316            (0..U64_LIMBS).flat_map(move |limb| {
317                let state_limb = local.postimage(y, x, limb);
318                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
319                let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
320                // Conversion from bytes to u64 is little-endian
321                [lo, hi.into()]
322            })
323        });
324
325        let post_absorb_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
326            let y = i / 5;
327            let x = i % 5;
328            (0..U64_LIMBS).flat_map(move |limb| {
329                let state_limb = next.inner.preimage[y][x][limb];
330                let hi = next.sponge.state_hi[i * U64_LIMBS + limb];
331                let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
332                [lo, hi.into()]
333            })
334        });
335
336        // We xor on last round of each block, even if it is a final block,
337        // because we use xor to range check the output bytes (= updated_state_bytes)
338        let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
339        for (input, prev, post) in izip!(
340            next.sponge.block_bytes,
341            updated_state_bytes,
342            post_absorb_state_bytes
343        ) {
344            // Add new send interaction to lookup (x, y, x ^ y) where x, y, z
345            // will all be range checked to be 8-bits (assuming the bus is
346            // received by an 8-bit xor chip).
347
348            // When absorb, input ^ prev = post
349            // Otherwise, 0 ^ prev = prev
350            // The interaction fields are degree 2, leading to degree 3 constraint
351            self.bitwise_lookup_bus
352                .send_xor(
353                    input * not(is_final_block),
354                    prev.clone(),
355                    select(is_final_block, prev, post),
356                )
357                .eval(
358                    builder,
359                    local.is_last_round() * local.instruction.is_enabled,
360                );
361        }
362
363        // We separately constrain that when(local.is_new_start), the preimage (u16s) equals the block bytes
364        let local_preimage_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
365            let y = i / 5;
366            let x = i % 5;
367            (0..U64_LIMBS).flat_map(move |limb| {
368                let state_limb = local.inner.preimage[y][x][limb];
369                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
370                let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
371                [lo, hi.into()]
372            })
373        });
374        let mut when_is_new_start =
375            builder.when(local.is_new_start() * local.instruction.is_enabled);
376        for (preimage_byte, block_byte) in zip(local_preimage_bytes, local.sponge.block_bytes) {
377            when_is_new_start.assert_eq(preimage_byte, block_byte);
378        }
379
380        // constrain transition on the state outside rate
381        let mut reset_builder = builder.when(local.is_new_start());
382        for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
383            let y = i / U64_LIMBS / 5;
384            let x = (i / U64_LIMBS) % 5;
385            let limb = i % U64_LIMBS;
386            reset_builder.assert_zero(local.inner.preimage[y][x][limb]);
387        }
388        let mut absorb_builder = builder.when(local.is_last_round() * not(next.is_new_start()));
389        for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
390            let y = i / U64_LIMBS / 5;
391            let x = (i / U64_LIMBS) % 5;
392            let limb = i % U64_LIMBS;
393            absorb_builder.assert_eq(local.postimage(y, x, limb), next.inner.preimage[y][x][limb]);
394        }
395    }
396
397    /// Receive the instruction itself on program bus. Send+receive on execution bus.
398    /// Then does memory read in addr space 1 to get `dst, src, len` from memory.
399    ///
400    /// Adds range check interactions for the most significant limbs of the register values
401    /// using BitwiseOperationLookupBus.
402    ///
403    /// Returns `start_read_timestamp` which is only relevant when `local.instruction.is_enabled`.
404    /// Note that `start_read_timestamp` is a linear expression.
405    pub fn eval_instruction<AB: InteractionBuilder>(
406        &self,
407        builder: &mut AB,
408        local: &KeccakVmCols<AB::Var>,
409        register_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_REGISTER_READS],
410    ) -> AB::Expr {
411        let instruction = local.instruction;
412        // Only receive opcode if:
413        // - enabled row (not dummy row)
414        // - first round of block
415        // - is_new_start
416        // Note this is degree 3, which results in quotient degree 2 if used
417        // as `count` in interaction
418        let should_receive = local.instruction.is_enabled * local.sponge.is_new_start;
419
420        let [dst_ptr, src_ptr, len_ptr] = [
421            instruction.dst_ptr,
422            instruction.src_ptr,
423            instruction.len_ptr,
424        ];
425        let reg_addr_sp = AB::F::ONE;
426        let timestamp_change: AB::Expr = Self::timestamp_change(instruction.remaining_len);
427        self.execution_bridge
428            .execute_and_increment_pc(
429                AB::Expr::from_canonical_usize(Rv32KeccakOpcode::KECCAK256 as usize + self.offset),
430                [
431                    dst_ptr.into(),
432                    src_ptr.into(),
433                    len_ptr.into(),
434                    reg_addr_sp.into(),
435                    AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
436                ],
437                ExecutionState::new(instruction.pc, instruction.start_timestamp),
438                timestamp_change,
439            )
440            .eval(builder, should_receive.clone());
441
442        let mut timestamp: AB::Expr = instruction.start_timestamp.into();
443        let recover_limbs = |limbs: [AB::Var; RV32_REGISTER_NUM_LIMBS - 1],
444                             val: AB::Var|
445         -> [AB::Expr; RV32_REGISTER_NUM_LIMBS] {
446            from_fn(|i| {
447                if i == 0 {
448                    limbs
449                        .into_iter()
450                        .enumerate()
451                        .fold(val.into(), |acc, (j, limb)| {
452                            acc - limb
453                                * AB::Expr::from_canonical_usize(1 << ((j + 1) * RV32_CELL_BITS))
454                        })
455                } else {
456                    limbs[i - 1].into()
457                }
458            })
459        };
460        // Only when it is an input do we want to do memory read for
461        // dst <- word[a]_d, src <- word[b]_d
462        let dst_data = instruction.dst.map(Into::into);
463        let src_data = recover_limbs(instruction.src_limbs, instruction.src);
464        let len_data = recover_limbs(instruction.len_limbs, instruction.remaining_len);
465        for (ptr, value, aux) in izip!(
466            [dst_ptr, src_ptr, len_ptr],
467            [dst_data, src_data, len_data],
468            register_aux,
469        ) {
470            self.memory_bridge
471                .read(
472                    MemoryAddress::new(reg_addr_sp, ptr),
473                    value,
474                    timestamp.clone(),
475                    aux,
476                )
477                .eval(builder, should_receive.clone());
478
479            timestamp += AB::Expr::ONE;
480        }
481        // See Rv32VecHeapAdapterAir
482        // repeat len for even number
483        // We range check `len` to `max_ptr_bits` to ensure `remaining_len` doesn't overflow.
484        // We could range check it to some other size, but `max_ptr_bits` is convenient.
485        let need_range_check = [
486            *instruction.dst.last().unwrap(),
487            *instruction.src_limbs.last().unwrap(),
488            *instruction.len_limbs.last().unwrap(),
489            *instruction.len_limbs.last().unwrap(),
490        ];
491        let limb_shift = AB::F::from_canonical_usize(
492            1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.ptr_max_bits),
493        );
494        for pair in need_range_check.chunks_exact(2) {
495            self.bitwise_lookup_bus
496                .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
497                .eval(builder, should_receive.clone());
498        }
499
500        timestamp
501    }
502
503    /// Constrain reading the input as `block_bytes` from memory.
504    /// Reads input based on `is_padding_byte`.
505    /// Constrains timestamp transitions between blocks if input crosses blocks.
506    ///
507    /// Expects `start_read_timestamp` to be a linear expression.
508    /// Returns the `start_write_timestamp` which is the timestamp to start from
509    /// for writing digest to memory.
510    pub fn constrain_input_read<AB: InteractionBuilder>(
511        &self,
512        builder: &mut AB,
513        local: &KeccakVmCols<AB::Var>,
514        start_read_timestamp: AB::Expr,
515        mem_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_ABSORB_READS],
516    ) -> AB::Expr {
517        let partial_block = &local.mem_oc.partial_block;
518        // Only read input from memory when it is an opcode-related row
519        // and only on the first round of block
520        let is_input = local.instruction.is_enabled_first_round;
521
522        let mut timestamp = start_read_timestamp;
523        // read `state` into `word[src + ...]_e`
524        // iterator of state as u16:
525        for (i, (input, is_padding, mem_aux)) in izip!(
526            local.sponge.block_bytes.chunks_exact(KECCAK_WORD_SIZE),
527            local.sponge.is_padding_byte.chunks_exact(KECCAK_WORD_SIZE),
528            mem_aux
529        )
530        .enumerate()
531        {
532            let ptr = local.instruction.src + AB::F::from_canonical_usize(i * KECCAK_WORD_SIZE);
533            // Only read block i if it is not entirely padding bytes
534            // count is degree 2
535            let count = is_input * not(is_padding[0]);
536            // The memory block read is partial if first byte is not padding but the last byte is padding. Since `count` is only 1 when first byte isn't padding, use check just if last byte is padding.
537            let is_partial_read = *is_padding.last().unwrap();
538            // word is degree 2
539            let word: [_; KECCAK_WORD_SIZE] = from_fn(|i| {
540                if i == 0 {
541                    // first byte is always ok
542                    input[0].into()
543                } else {
544                    // use `partial_block` if this is a partial read, otherwise use the normal input block
545                    select(is_partial_read, partial_block[i - 1], input[i])
546                }
547            });
548            for i in 1..KECCAK_WORD_SIZE {
549                let not_padding: AB::Expr = not(is_padding[i]);
550                // When not a padding byte, the word byte and input byte must be equal
551                // This is constraint degree 3
552                builder.assert_eq(
553                    not_padding.clone() * word[i].clone(),
554                    not_padding.clone() * input[i],
555                );
556            }
557
558            self.memory_bridge
559                .read(
560                    MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), ptr),
561                    word, // degree 2
562                    timestamp.clone(),
563                    mem_aux,
564                )
565                .eval(builder, count);
566
567            timestamp += AB::Expr::ONE;
568        }
569        timestamp
570    }
571
572    pub fn constrain_output_write<AB: InteractionBuilder>(
573        &self,
574        builder: &mut AB,
575        local: &KeccakVmCols<AB::Var>,
576        start_write_timestamp: AB::Expr,
577        mem_aux: &[MemoryWriteAuxCols<AB::Var, KECCAK_WORD_SIZE>; KECCAK_DIGEST_WRITES],
578    ) {
579        let instruction = local.instruction;
580
581        let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
582        // since keccak-f AIR has this column, we might as well use it
583        builder.assert_eq(
584            local.inner.export,
585            instruction.is_enabled * is_final_block * local.is_last_round(),
586        );
587        // See `constrain_absorb` on how we derive the postimage bytes from u16 limbs
588        // **SAFETY:** we always XOR the final state with 0 in `constrain_absorb`,
589        // so the output bytes **are** range checked.
590        let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
591            let y = i / 5;
592            let x = i % 5;
593            (0..U64_LIMBS).flat_map(move |limb| {
594                let state_limb = local.postimage(y, x, limb);
595                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
596                let lo = state_limb - hi * AB::F::from_canonical_u64(1 << 8);
597                // Conversion from bytes to u64 is little-endian
598                [lo, hi.into()]
599            })
600        });
601        let dst = abstract_compose::<AB::Expr, _>(instruction.dst);
602        for (i, digest_bytes) in updated_state_bytes
603            .take(KECCAK_DIGEST_BYTES)
604            .chunks(KECCAK_WORD_SIZE)
605            .into_iter()
606            .enumerate()
607        {
608            let digest_bytes = digest_bytes.collect_vec();
609            let timestamp = start_write_timestamp.clone() + AB::Expr::from_canonical_usize(i);
610            self.memory_bridge
611                .write(
612                    MemoryAddress::new(
613                        AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
614                        dst.clone() + AB::F::from_canonical_usize(i * KECCAK_WORD_SIZE),
615                    ),
616                    digest_bytes.try_into().unwrap(),
617                    timestamp,
618                    &mem_aux[i],
619                )
620                .eval(builder, local.inner.export)
621        }
622    }
623
624    /// Amount to advance timestamp by after execution of one opcode instruction.
625    /// This is an upper bound dependent on the length `len` operand, which is unbounded.
626    pub fn timestamp_change<T: FieldAlgebra>(len: impl Into<T>) -> T {
627        // actual number is ceil(len / 136) * (3 + 17) + KECCAK_DIGEST_WRITES
628        // digest writes only done on last row of multi-block
629        // add another KECCAK_ABSORB_READS to round up so we don't deal with padding
630        len.into()
631            + T::from_canonical_usize(
632                KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES,
633            )
634    }
635}