openvm_keccak256_circuit/
air.rs

1use std::{array::from_fn, borrow::Borrow, iter::zip};
2
3use itertools::{izip, Itertools};
4use openvm_circuit::{
5    arch::{ExecutionBridge, ExecutionState},
6    system::memory::{
7        offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
8        MemoryAddress,
9    },
10};
11use openvm_circuit_primitives::{
12    bitwise_op_lookup::BitwiseOperationLookupBus,
13    utils::{assert_array_eq, not, select},
14};
15use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS};
16use openvm_keccak256_transpiler::Rv32KeccakOpcode;
17use openvm_rv32im_circuit::adapters::abstract_compose;
18use openvm_stark_backend::{
19    air_builders::sub::SubAirBuilder,
20    interaction::InteractionBuilder,
21    p3_air::{Air, AirBuilder, BaseAir},
22    p3_field::PrimeCharacteristicRing,
23    p3_matrix::Matrix,
24    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
25};
26use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, U64_LIMBS};
27
28use super::{
29    columns::{KeccakVmCols, NUM_KECCAK_VM_COLS},
30    KECCAK_ABSORB_READS, KECCAK_DIGEST_BYTES, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES,
31    KECCAK_RATE_U16S, KECCAK_REGISTER_READS, KECCAK_WIDTH_U16S, KECCAK_WORD_SIZE,
32    NUM_ABSORB_ROUNDS,
33};
34
35#[derive(Clone, Copy, Debug, derive_new::new)]
36pub struct KeccakVmAir {
37    pub execution_bridge: ExecutionBridge,
38    pub memory_bridge: MemoryBridge,
39    /// Bus to send 8-bit XOR requests to.
40    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
41    /// Maximum number of bits allowed for an address pointer
42    pub ptr_max_bits: usize,
43    pub(super) offset: usize,
44}
45
46impl<F> BaseAirWithPublicValues<F> for KeccakVmAir {}
47impl<F> PartitionedBaseAir<F> for KeccakVmAir {}
48impl<F> BaseAir<F> for KeccakVmAir {
49    fn width(&self) -> usize {
50        NUM_KECCAK_VM_COLS
51    }
52}
53
54impl<AB: InteractionBuilder> Air<AB> for KeccakVmAir {
55    fn eval(&self, builder: &mut AB) {
56        let main = builder.main();
57        let (local, next) = (
58            main.row_slice(0).expect("window should have two elements"),
59            main.row_slice(1).expect("window should have two elements"),
60        );
61        let local: &KeccakVmCols<AB::Var> = (*local).borrow();
62        let next: &KeccakVmCols<AB::Var> = (*next).borrow();
63
64        builder.assert_bool(local.sponge.is_new_start);
65        builder.assert_eq(
66            local.sponge.is_new_start,
67            local.sponge.is_new_start * local.is_first_round(),
68        );
69        builder.assert_eq(
70            local.instruction.is_enabled_first_round,
71            local.instruction.is_enabled * local.is_first_round(),
72        );
73        // Not strictly necessary:
74        builder
75            .when_first_row()
76            .assert_one(local.sponge.is_new_start);
77
78        self.eval_keccak_f(builder);
79        self.constrain_padding(builder, local, next);
80        self.constrain_consistency_across_rounds(builder, local, next);
81        // Anchor the trace at the end so the wrapped evaluation window cannot leave an
82        // enabled hash in-flight without ever reaching a last-round finalization step.
83        builder
84            .when_last_row()
85            .assert_zero(local.instruction.is_enabled);
86        let mem = &local.mem_oc;
87        // Interactions:
88        self.constrain_absorb(builder, local, next);
89        let start_read_timestamp = self.eval_instruction(builder, local, &mem.register_aux);
90        let start_write_timestamp =
91            self.constrain_input_read(builder, local, start_read_timestamp, &mem.absorb_reads);
92        self.constrain_output_write(
93            builder,
94            local,
95            start_write_timestamp.clone(),
96            &mem.digest_writes,
97        );
98
99        self.constrain_block_transition(builder, local, next, start_write_timestamp);
100    }
101}
102
103impl KeccakVmAir {
104    /// Evaluate the keccak-f permutation constraints.
105    ///
106    /// WARNING: The keccak-f AIR columns **must** be the first columns in the main AIR.
107    #[inline]
108    pub fn eval_keccak_f<AB: AirBuilder>(&self, builder: &mut AB) {
109        let keccak_f_air = KeccakAir {};
110        let mut sub_builder =
111            SubAirBuilder::<AB, KeccakAir, AB::Var>::new(builder, 0..NUM_KECCAK_PERM_COLS);
112        keccak_f_air.eval(&mut sub_builder);
113    }
114
115    /// Many columns are expected to be the same between rounds and only change per-block.
116    pub fn constrain_consistency_across_rounds<AB: AirBuilder<Var: Copy>>(
117        &self,
118        builder: &mut AB,
119        local: &KeccakVmCols<AB::Var>,
120        next: &KeccakVmCols<AB::Var>,
121    ) {
122        let mut transition_builder = builder.when_transition();
123        let mut round_builder = transition_builder.when(not(local.is_last_round()));
124        // Instruction columns
125        local
126            .instruction
127            .assert_eq(&mut round_builder, next.instruction);
128    }
129
130    pub fn constrain_block_transition<AB: AirBuilder<Var: Copy>>(
131        &self,
132        builder: &mut AB,
133        local: &KeccakVmCols<AB::Var>,
134        next: &KeccakVmCols<AB::Var>,
135        start_write_timestamp: AB::Expr,
136    ) {
137        // When we transition between blocks, if the next block isn't a new block
138        // (this means it's not receiving a new opcode or starting a dummy block)
139        // then we want _parts_ of opcode instruction to stay the same
140        // between blocks.
141        let mut block_transition = builder.when(local.is_last_round() * not(next.is_new_start()));
142        block_transition.assert_eq(local.instruction.pc, next.instruction.pc);
143        block_transition.assert_eq(local.instruction.is_enabled, next.instruction.is_enabled);
144        // dst is only going to be used for writes in the last input block
145        assert_array_eq(
146            &mut block_transition,
147            local.instruction.dst,
148            next.instruction.dst,
149        );
150        // these are not used and hence not necessary, but putting for safety until performance
151        // becomes an issue:
152        block_transition.assert_eq(local.instruction.dst_ptr, next.instruction.dst_ptr);
153        block_transition.assert_eq(local.instruction.src_ptr, next.instruction.src_ptr);
154        block_transition.assert_eq(local.instruction.len_ptr, next.instruction.len_ptr);
155        // no constraint on `instruction.len` because we use `remaining_len` instead
156
157        // Move the src pointer over based on the number of bytes read.
158        // This should always be RATE_BYTES since it's a non-final block.
159        block_transition.assert_eq(
160            next.instruction.src,
161            local.instruction.src + AB::F::from_usize(KECCAK_RATE_BYTES),
162        );
163        // Advance timestamp by the number of memory accesses from reading
164        // `dst, src, len` and block input bytes.
165        block_transition.assert_eq(next.instruction.start_timestamp, start_write_timestamp);
166        block_transition.assert_eq(
167            next.instruction.remaining_len,
168            local.instruction.remaining_len - AB::F::from_usize(KECCAK_RATE_BYTES),
169        );
170        // Padding transition is constrained in `constrain_padding`.
171    }
172
173    /// Keccak follows the 10*1 padding rule.
174    /// See Section 5.1 of <https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf>
175    /// Note this is the ONLY difference between Keccak and SHA-3
176    ///
177    /// Constrains padding constraints and length between rounds and
178    /// between blocks. Padding logic is tied to constraints on `is_new_start`.
179    pub fn constrain_padding<AB: AirBuilder>(
180        &self,
181        builder: &mut AB,
182        local: &KeccakVmCols<AB::Var>,
183        next: &KeccakVmCols<AB::Var>,
184    ) where
185        AB::Var: Copy,
186    {
187        let is_padding_byte = local.sponge.is_padding_byte;
188        let block_bytes = &local.sponge.block_bytes;
189        let remaining_len = local.remaining_len();
190
191        // is_padding_byte should all be boolean
192        for &is_padding_byte in is_padding_byte.iter() {
193            builder.assert_bool(is_padding_byte);
194        }
195        // is_padding_byte should transition from 0 to 1 only once and then stay 1
196        for i in 1..KECCAK_RATE_BYTES {
197            builder
198                .when(is_padding_byte[i - 1])
199                .assert_one(is_padding_byte[i]);
200        }
201        // is_padding_byte must stay the same on all rounds in a block
202        // we use next instead of local.step_flags.last() because the last row of the trace overall
203        // may not end on a last round
204        let is_last_round = next.inner.step_flags[0];
205        let is_not_last_round = not(is_last_round);
206        for i in 0..KECCAK_RATE_BYTES {
207            builder.when(is_not_last_round.clone()).assert_eq(
208                local.sponge.is_padding_byte[i],
209                next.sponge.is_padding_byte[i],
210            );
211        }
212
213        let num_padding_bytes = local
214            .sponge
215            .is_padding_byte
216            .iter()
217            .fold(AB::Expr::ZERO, |a, &b| a + b);
218
219        // If final rate block of input, then last byte must be padding
220        let is_final_block = is_padding_byte[KECCAK_RATE_BYTES - 1];
221
222        // is_padding_byte must be consistent with remaining_len
223        builder.when(is_final_block).assert_eq(
224            remaining_len,
225            AB::Expr::from_usize(KECCAK_RATE_BYTES) - num_padding_bytes,
226        );
227        // If this block is not final, when transitioning to next block, remaining len
228        // must decrease by `KECCAK_RATE_BYTES`.
229        builder
230            .when(is_last_round)
231            .when(not(is_final_block))
232            .assert_eq(
233                remaining_len - AB::F::from_usize(KECCAK_RATE_BYTES),
234                next.remaining_len(),
235            );
236        // To enforce that is_padding_byte must be set appropriately for an input, we require
237        // the block before a new start to have padding
238        builder
239            .when(is_last_round)
240            .when(next.is_new_start())
241            .assert_one(is_final_block);
242        // Make sure there are not repeated padding blocks
243        builder
244            .when(is_last_round)
245            .when(is_final_block)
246            .assert_one(next.is_new_start());
247        // The chain above enforces that for an input, the remaining length must decrease by RATE
248        // block-by-block until it reaches a final block with padding.
249
250        // ====== Constrain the block_bytes are padded according to is_padding_byte =====
251
252        // If the first padding byte is at the end of the block, then the block has a
253        // single padding byte
254        let has_single_padding_byte: AB::Expr =
255            is_padding_byte[KECCAK_RATE_BYTES - 1] - is_padding_byte[KECCAK_RATE_BYTES - 2];
256
257        // If the row has a single padding byte, then it must be the last byte with
258        // value 0b10000001
259        builder.when(has_single_padding_byte.clone()).assert_eq(
260            block_bytes[KECCAK_RATE_BYTES - 1],
261            AB::F::from_u8(0b10000001),
262        );
263
264        let has_multiple_padding_bytes: AB::Expr = not(has_single_padding_byte.clone());
265        for i in 0..KECCAK_RATE_BYTES - 1 {
266            let is_first_padding_byte: AB::Expr = {
267                if i > 0 {
268                    is_padding_byte[i] - is_padding_byte[i - 1]
269                } else {
270                    is_padding_byte[i].into()
271                }
272            };
273            // If the row has multiple padding bytes, the first padding byte must be 0x01
274            // because the padding 1*0 is *little-endian*
275            builder
276                .when(has_multiple_padding_bytes.clone())
277                .when(is_first_padding_byte.clone())
278                .assert_eq(block_bytes[i], AB::F::from_u8(0x01));
279            // If the row has multiple padding bytes, the other padding bytes
280            // except the last one must be 0
281            builder
282                .when(is_padding_byte[i])
283                .when(not::<AB::Expr>(is_first_padding_byte)) // hence never when single padding byte
284                .assert_zero(block_bytes[i]);
285        }
286
287        // If the row has multiple padding bytes, then the last byte must be 0x80
288        // because the padding *01 is *little-endian*
289        builder
290            .when(is_final_block)
291            .when(has_multiple_padding_bytes)
292            .assert_eq(block_bytes[KECCAK_RATE_BYTES - 1], AB::F::from_u8(0x80));
293    }
294
295    /// Constrain state transition between keccak-f permutations is valid absorb of input bytes.
296    /// The end-state in last round is given by `a_prime_prime_prime()` in `u16` limbs.
297    /// The pre-state is given by `preimage` also in `u16` limbs.
298    /// The input `block_bytes` will be given as **bytes**.
299    ///
300    /// We will XOR `block_bytes` with `a_prime_prime_prime()` and constrain to be `next.preimage`.
301    /// This will be done using 8-bit XOR lookup in a separate AIR via interactions.
302    /// This will require decomposing `u16` into bytes.
303    /// Note that the XOR lookup automatically range checks its inputs to be bytes.
304    ///
305    /// We use the following trick to keep `u16` limbs and avoid changing
306    /// the `keccak-f` AIR itself:
307    /// if we already have a 16-bit limb `x` and we also provide a 8-bit limb
308    /// `hi = x >> 8`, assuming `x` and `hi` have been range checked,
309    /// we can use the expression `lo = x - hi * 256` for the low byte.
310    /// If `lo` is range checked to `8`-bits, this constrains a valid byte
311    ///  decomposition of `x` into `hi, lo`.
312    /// This means in terms of trace cells, it is equivalent to provide
313    /// `x, hi` versus `hi, lo`.
314    pub fn constrain_absorb<AB: InteractionBuilder>(
315        &self,
316        builder: &mut AB,
317        local: &KeccakVmCols<AB::Var>,
318        next: &KeccakVmCols<AB::Var>,
319    ) {
320        let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
321            let y = i / 5;
322            let x = i % 5;
323            (0..U64_LIMBS).flat_map(move |limb| {
324                let state_limb = local.postimage(y, x, limb);
325                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
326                let lo = state_limb - hi * AB::F::from_u64(1 << 8);
327                // Conversion from bytes to u64 is little-endian
328                [lo, hi.into()]
329            })
330        });
331
332        let post_absorb_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
333            let y = i / 5;
334            let x = i % 5;
335            (0..U64_LIMBS).flat_map(move |limb| {
336                let state_limb = next.inner.preimage[y][x][limb];
337                let hi = next.sponge.state_hi[i * U64_LIMBS + limb];
338                let lo = state_limb - hi * AB::F::from_u64(1 << 8);
339                [lo, hi.into()]
340            })
341        });
342
343        // We xor on last round of each block, even if it is a final block,
344        // because we use xor to range check the output bytes (= updated_state_bytes)
345        let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
346        for (input, prev, post) in izip!(
347            next.sponge.block_bytes,
348            updated_state_bytes,
349            post_absorb_state_bytes
350        ) {
351            // Add new send interaction to lookup (x, y, x ^ y) where x, y, z
352            // will all be range checked to be 8-bits (assuming the bus is
353            // received by an 8-bit xor chip).
354
355            // When absorb, input ^ prev = post
356            // Otherwise, 0 ^ prev = prev
357            // The interaction fields are degree 2, leading to degree 3 constraint
358            self.bitwise_lookup_bus
359                .send_xor(
360                    input * not(is_final_block),
361                    prev.clone(),
362                    select(is_final_block, prev, post),
363                )
364                .eval(
365                    builder,
366                    local.is_last_round() * local.instruction.is_enabled,
367                );
368        }
369
370        // We separately constrain that when(local.is_new_start), the preimage (u16s) equals the
371        // block bytes
372        let local_preimage_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
373            let y = i / 5;
374            let x = i % 5;
375            (0..U64_LIMBS).flat_map(move |limb| {
376                let state_limb = local.inner.preimage[y][x][limb];
377                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
378                let lo = state_limb - hi * AB::F::from_u64(1 << 8);
379                [lo, hi.into()]
380            })
381        });
382        let mut when_is_new_start =
383            builder.when(local.is_new_start() * local.instruction.is_enabled);
384        for (preimage_byte, block_byte) in zip(local_preimage_bytes, local.sponge.block_bytes) {
385            when_is_new_start.assert_eq(preimage_byte, block_byte);
386        }
387
388        // constrain transition on the state outside rate
389        let mut reset_builder = builder.when(local.is_new_start());
390        for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
391            let y = i / U64_LIMBS / 5;
392            let x = (i / U64_LIMBS) % 5;
393            let limb = i % U64_LIMBS;
394            reset_builder.assert_zero(local.inner.preimage[y][x][limb]);
395        }
396        let mut absorb_builder = builder.when(local.is_last_round() * not(next.is_new_start()));
397        for i in KECCAK_RATE_U16S..KECCAK_WIDTH_U16S {
398            let y = i / U64_LIMBS / 5;
399            let x = (i / U64_LIMBS) % 5;
400            let limb = i % U64_LIMBS;
401            absorb_builder.assert_eq(local.postimage(y, x, limb), next.inner.preimage[y][x][limb]);
402        }
403    }
404
405    /// Receive the instruction itself on program bus. Send+receive on execution bus.
406    /// Then does memory read in addr space 1 to get `dst, src, len` from memory.
407    ///
408    /// Adds range check interactions for the most significant limbs of the register values
409    /// using BitwiseOperationLookupBus.
410    ///
411    /// Returns `start_read_timestamp` which is only relevant when `local.instruction.is_enabled`.
412    /// Note that `start_read_timestamp` is a linear expression.
413    pub fn eval_instruction<AB: InteractionBuilder>(
414        &self,
415        builder: &mut AB,
416        local: &KeccakVmCols<AB::Var>,
417        register_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_REGISTER_READS],
418    ) -> AB::Expr {
419        let instruction = local.instruction;
420        builder.assert_bool(instruction.is_enabled);
421        // Only receive opcode if:
422        // - enabled row (not dummy row)
423        // - first round of block
424        // - is_new_start
425        // Note this is degree 3, which results in quotient degree 2 if used
426        // as `count` in interaction
427        let should_receive = local.instruction.is_enabled * local.sponge.is_new_start;
428
429        let [dst_ptr, src_ptr, len_ptr] = [
430            instruction.dst_ptr,
431            instruction.src_ptr,
432            instruction.len_ptr,
433        ];
434        let reg_addr_sp = AB::F::ONE;
435        let timestamp_change: AB::Expr = Self::timestamp_change(instruction.remaining_len);
436        self.execution_bridge
437            .execute_and_increment_pc(
438                AB::Expr::from_usize(Rv32KeccakOpcode::KECCAK256 as usize + self.offset),
439                [
440                    dst_ptr.into(),
441                    src_ptr.into(),
442                    len_ptr.into(),
443                    reg_addr_sp.into(),
444                    AB::Expr::from_u32(RV32_MEMORY_AS),
445                ],
446                ExecutionState::new(instruction.pc, instruction.start_timestamp),
447                timestamp_change,
448            )
449            .eval(builder, should_receive.clone());
450
451        let mut timestamp: AB::Expr = instruction.start_timestamp.into();
452        let recover_limbs = |limbs: [AB::Var; RV32_REGISTER_NUM_LIMBS - 1],
453                             val: AB::Var|
454         -> [AB::Expr; RV32_REGISTER_NUM_LIMBS] {
455            from_fn(|i| {
456                if i == 0 {
457                    limbs
458                        .into_iter()
459                        .enumerate()
460                        .fold(val.into(), |acc, (j, limb)| {
461                            acc - limb * AB::Expr::from_usize(1 << ((j + 1) * RV32_CELL_BITS))
462                        })
463                } else {
464                    limbs[i - 1].into()
465                }
466            })
467        };
468        // Only when it is an input do we want to do memory read for
469        // dst <- word[a]_d, src <- word[b]_d
470        let dst_data = instruction.dst.map(Into::into);
471        let src_data = recover_limbs(instruction.src_limbs, instruction.src);
472        let len_data = recover_limbs(instruction.len_limbs, instruction.remaining_len);
473        for (ptr, value, aux) in izip!(
474            [dst_ptr, src_ptr, len_ptr],
475            [dst_data, src_data, len_data],
476            register_aux,
477        ) {
478            self.memory_bridge
479                .read(
480                    MemoryAddress::new(reg_addr_sp, ptr),
481                    value,
482                    timestamp.clone(),
483                    aux,
484                )
485                .eval(builder, should_receive.clone());
486
487            timestamp += AB::Expr::ONE;
488        }
489        // See Rv32VecHeapAdapterAir
490        // repeat len for even number
491        // We range check `len` to `max_ptr_bits` to ensure `remaining_len` doesn't overflow.
492        // We could range check it to some other size, but `max_ptr_bits` is convenient.
493        let need_range_check = [
494            *instruction.dst.last().unwrap(),
495            *instruction.src_limbs.last().unwrap(),
496            *instruction.len_limbs.last().unwrap(),
497            *instruction.len_limbs.last().unwrap(),
498        ];
499        let limb_shift =
500            AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.ptr_max_bits));
501        for pair in need_range_check.chunks_exact(2) {
502            self.bitwise_lookup_bus
503                .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
504                .eval(builder, should_receive.clone());
505        }
506
507        timestamp
508    }
509
510    /// Constrain reading the input as `block_bytes` from memory.
511    /// Reads input based on `is_padding_byte`.
512    /// Constrains timestamp transitions between blocks if input crosses blocks.
513    ///
514    /// Expects `start_read_timestamp` to be a linear expression.
515    /// Returns the `start_write_timestamp` which is the timestamp to start from
516    /// for writing digest to memory.
517    pub fn constrain_input_read<AB: InteractionBuilder>(
518        &self,
519        builder: &mut AB,
520        local: &KeccakVmCols<AB::Var>,
521        start_read_timestamp: AB::Expr,
522        mem_aux: &[MemoryReadAuxCols<AB::Var>; KECCAK_ABSORB_READS],
523    ) -> AB::Expr {
524        let partial_block = &local.mem_oc.partial_block;
525        // Only read input from memory when it is an opcode-related row
526        // and only on the first round of block
527        let is_input = local.instruction.is_enabled_first_round;
528
529        let mut timestamp = start_read_timestamp;
530        // read `state` into `word[src + ...]_e`
531        // iterator of state as u16:
532        for (i, (input, is_padding, mem_aux)) in izip!(
533            local.sponge.block_bytes.chunks_exact(KECCAK_WORD_SIZE),
534            local.sponge.is_padding_byte.chunks_exact(KECCAK_WORD_SIZE),
535            mem_aux
536        )
537        .enumerate()
538        {
539            let ptr = local.instruction.src + AB::F::from_usize(i * KECCAK_WORD_SIZE);
540            // Only read block i if it is not entirely padding bytes
541            // count is degree 2
542            let count = is_input * not(is_padding[0]);
543            // The memory block read is partial if first byte is not padding but the last byte is
544            // padding. Since `count` is only 1 when first byte isn't padding, use check just if
545            // last byte is padding.
546            let is_partial_read = *is_padding.last().unwrap();
547            // word is degree 2
548            let word: [_; KECCAK_WORD_SIZE] = from_fn(|i| {
549                if i == 0 {
550                    // first byte is always ok
551                    input[0].into()
552                } else {
553                    // use `partial_block` if this is a partial read, otherwise use the normal input
554                    // block
555                    select(is_partial_read, partial_block[i - 1], input[i])
556                }
557            });
558            for i in 1..KECCAK_WORD_SIZE {
559                let not_padding: AB::Expr = not(is_padding[i]);
560                // When not a padding byte, the word byte and input byte must be equal
561                // This is constraint degree 3
562                builder.assert_eq(
563                    not_padding.clone() * word[i].clone(),
564                    not_padding.clone() * input[i],
565                );
566            }
567
568            self.memory_bridge
569                .read(
570                    MemoryAddress::new(AB::Expr::from_u32(RV32_MEMORY_AS), ptr),
571                    word, // degree 2
572                    timestamp.clone(),
573                    mem_aux,
574                )
575                .eval(builder, count);
576
577            timestamp += AB::Expr::ONE;
578        }
579        timestamp
580    }
581
582    pub fn constrain_output_write<AB: InteractionBuilder>(
583        &self,
584        builder: &mut AB,
585        local: &KeccakVmCols<AB::Var>,
586        start_write_timestamp: AB::Expr,
587        mem_aux: &[MemoryWriteAuxCols<AB::Var, KECCAK_WORD_SIZE>; KECCAK_DIGEST_WRITES],
588    ) {
589        let instruction = local.instruction;
590
591        let is_final_block = *local.sponge.is_padding_byte.last().unwrap();
592        // since keccak-f AIR has this column, we might as well use it
593        builder.assert_eq(
594            local.inner.export,
595            instruction.is_enabled * is_final_block * local.is_last_round(),
596        );
597        // See `constrain_absorb` on how we derive the postimage bytes from u16 limbs
598        // **SAFETY:** we always XOR the final state with 0 in `constrain_absorb`,
599        // so the output bytes **are** range checked.
600        let updated_state_bytes = (0..NUM_ABSORB_ROUNDS).flat_map(|i| {
601            let y = i / 5;
602            let x = i % 5;
603            (0..U64_LIMBS).flat_map(move |limb| {
604                let state_limb = local.postimage(y, x, limb);
605                let hi = local.sponge.state_hi[i * U64_LIMBS + limb];
606                let lo = state_limb - hi * AB::F::from_u64(1 << 8);
607                // Conversion from bytes to u64 is little-endian
608                [lo, hi.into()]
609            })
610        });
611        let dst = abstract_compose::<AB::Expr, _>(instruction.dst);
612        for (i, digest_bytes) in updated_state_bytes
613            .take(KECCAK_DIGEST_BYTES)
614            .chunks(KECCAK_WORD_SIZE)
615            .into_iter()
616            .enumerate()
617        {
618            let digest_bytes = digest_bytes.collect_vec();
619            let timestamp = start_write_timestamp.clone() + AB::Expr::from_usize(i);
620            self.memory_bridge
621                .write(
622                    MemoryAddress::new(
623                        AB::Expr::from_u32(RV32_MEMORY_AS),
624                        dst.clone() + AB::F::from_usize(i * KECCAK_WORD_SIZE),
625                    ),
626                    digest_bytes.try_into().unwrap(),
627                    timestamp,
628                    &mem_aux[i],
629                )
630                .eval(builder, local.inner.export)
631        }
632    }
633
634    /// Amount to advance timestamp by after execution of one opcode instruction.
635    /// This is an upper bound dependent on the length `len` operand, which is unbounded.
636    pub fn timestamp_change<T: PrimeCharacteristicRing>(len: impl Into<T>) -> T {
637        // actual number is ceil(len / 136) * (3 + 17) + KECCAK_DIGEST_WRITES
638        // digest writes only done on last row of multi-block
639        // add another KECCAK_ABSORB_READS to round up so we don't deal with padding
640        len.into()
641            + T::from_usize(KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES)
642    }
643}