openvm_keccak256_circuit/
lib.rs

1//! Stateful keccak256 hasher. Handles full keccak sponge (padding, absorb, keccak-f) on
2//! variable length inputs read from VM memory.
3use std::{
4    array::from_fn,
5    cmp::min,
6    sync::{Arc, Mutex},
7};
8
9use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip;
10use openvm_stark_backend::p3_field::PrimeField32;
11use serde::{Deserialize, Serialize};
12use serde_big_array::BigArray;
13use tiny_keccak::{Hasher, Keccak};
14use utils::num_keccak_f;
15
16pub mod air;
17pub mod columns;
18pub mod trace;
19pub mod utils;
20
21mod extension;
22pub use extension::*;
23
24#[cfg(test)]
25mod tests;
26
27pub use air::KeccakVmAir;
28use openvm_circuit::{
29    arch::{ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor},
30    system::{
31        memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory, RecordId},
32        program::ProgramBus,
33    },
34};
35use openvm_instructions::{
36    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode,
37};
38use openvm_keccak256_transpiler::Rv32KeccakOpcode;
39use openvm_rv32im_circuit::adapters::read_rv32_register;
40
41// ==== Constants for register/memory adapter ====
42/// Register reads to get dst, src, len
43const KECCAK_REGISTER_READS: usize = 3;
44/// Number of cells to read/write in a single memory access
45const KECCAK_WORD_SIZE: usize = 4;
46/// Memory reads for absorb per row
47const KECCAK_ABSORB_READS: usize = KECCAK_RATE_BYTES / KECCAK_WORD_SIZE;
48/// Memory writes for digest per row
49const KECCAK_DIGEST_WRITES: usize = KECCAK_DIGEST_BYTES / KECCAK_WORD_SIZE;
50
51// ==== Do not change these constants! ====
52/// Total number of sponge bytes: number of rate bytes + number of capacity
53/// bytes.
54pub const KECCAK_WIDTH_BYTES: usize = 200;
55/// Total number of 16-bit limbs in the sponge.
56pub const KECCAK_WIDTH_U16S: usize = KECCAK_WIDTH_BYTES / 2;
57/// Number of rate bytes.
58pub const KECCAK_RATE_BYTES: usize = 136;
59/// Number of 16-bit rate limbs.
60pub const KECCAK_RATE_U16S: usize = KECCAK_RATE_BYTES / 2;
61/// Number of absorb rounds, equal to rate in u64s.
62pub const NUM_ABSORB_ROUNDS: usize = KECCAK_RATE_BYTES / 8;
63/// Number of capacity bytes.
64pub const KECCAK_CAPACITY_BYTES: usize = 64;
65/// Number of 16-bit capacity limbs.
66pub const KECCAK_CAPACITY_U16S: usize = KECCAK_CAPACITY_BYTES / 2;
67/// Number of output digest bytes used during the squeezing phase.
68pub const KECCAK_DIGEST_BYTES: usize = 32;
69/// Number of 64-bit digest limbs.
70pub const KECCAK_DIGEST_U64S: usize = KECCAK_DIGEST_BYTES / 8;
71
72pub struct KeccakVmChip<F: PrimeField32> {
73    pub air: KeccakVmAir,
74    /// IO and memory data necessary for each opcode call
75    pub records: Vec<KeccakRecord<F>>,
76    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>,
77
78    offset: usize,
79
80    offline_memory: Arc<Mutex<OfflineMemory<F>>>,
81}
82
83#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
84pub struct KeccakRecord<F> {
85    pub pc: F,
86    pub dst_read: RecordId,
87    pub src_read: RecordId,
88    pub len_read: RecordId,
89    pub input_blocks: Vec<KeccakInputBlock>,
90    pub digest_writes: [RecordId; KECCAK_DIGEST_WRITES],
91}
92
93#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
94pub struct KeccakInputBlock {
95    /// Memory reads for non-padding bytes in this block. Length is at most [KECCAK_RATE_BYTES / KECCAK_WORD_SIZE].
96    pub reads: Vec<RecordId>,
97    /// Index in `reads` of the memory read for < KECCAK_WORD_SIZE bytes, if any.
98    pub partial_read_idx: Option<usize>,
99    /// Bytes with padding. Can be derived from `bytes_read` but we store for convenience.
100    #[serde(with = "BigArray")]
101    pub padded_bytes: [u8; KECCAK_RATE_BYTES],
102    pub remaining_len: usize,
103    pub src: usize,
104    pub is_new_start: bool,
105}
106
107impl<F: PrimeField32> KeccakVmChip<F> {
108    pub fn new(
109        execution_bus: ExecutionBus,
110        program_bus: ProgramBus,
111        memory_bridge: MemoryBridge,
112        address_bits: usize,
113        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>,
114        offset: usize,
115        offline_memory: Arc<Mutex<OfflineMemory<F>>>,
116    ) -> Self {
117        Self {
118            air: KeccakVmAir::new(
119                ExecutionBridge::new(execution_bus, program_bus),
120                memory_bridge,
121                bitwise_lookup_chip.bus(),
122                address_bits,
123                offset,
124            ),
125            bitwise_lookup_chip,
126            records: Vec::new(),
127            offset,
128            offline_memory,
129        }
130    }
131}
132
133impl<F: PrimeField32> InstructionExecutor<F> for KeccakVmChip<F> {
134    fn execute(
135        &mut self,
136        memory: &mut MemoryController<F>,
137        instruction: &Instruction<F>,
138        from_state: ExecutionState<u32>,
139    ) -> Result<ExecutionState<u32>, ExecutionError> {
140        let &Instruction {
141            opcode,
142            a,
143            b,
144            c,
145            d,
146            e,
147            ..
148        } = instruction;
149        let local_opcode = Rv32KeccakOpcode::from_usize(opcode.local_opcode_idx(self.offset));
150        debug_assert_eq!(local_opcode, Rv32KeccakOpcode::KECCAK256);
151
152        let mut timestamp_delta = 3;
153        let (dst_read, dst) = read_rv32_register(memory, d, a);
154        let (src_read, src) = read_rv32_register(memory, d, b);
155        let (len_read, len) = read_rv32_register(memory, d, c);
156        #[cfg(debug_assertions)]
157        {
158            assert!(dst < (1 << self.air.ptr_max_bits));
159            assert!(src < (1 << self.air.ptr_max_bits));
160            assert!(len < (1 << self.air.ptr_max_bits));
161        }
162
163        let mut remaining_len = len as usize;
164        let num_blocks = num_keccak_f(remaining_len);
165        let mut input_blocks = Vec::with_capacity(num_blocks);
166        let mut hasher = Keccak::v256();
167        let mut src = src as usize;
168
169        for block_idx in 0..num_blocks {
170            if block_idx != 0 {
171                memory.increment_timestamp_by(KECCAK_REGISTER_READS as u32);
172                timestamp_delta += KECCAK_REGISTER_READS as u32;
173            }
174            let mut reads = Vec::with_capacity(KECCAK_RATE_BYTES);
175
176            let mut partial_read_idx = None;
177            let mut bytes = [0u8; KECCAK_RATE_BYTES];
178            for i in (0..KECCAK_RATE_BYTES).step_by(KECCAK_WORD_SIZE) {
179                if i < remaining_len {
180                    let read =
181                        memory.read::<RV32_REGISTER_NUM_LIMBS>(e, F::from_canonical_usize(src + i));
182
183                    let chunk = read.1.map(|x| {
184                        x.as_canonical_u32()
185                            .try_into()
186                            .expect("Memory cell not a byte")
187                    });
188                    let copy_len = min(KECCAK_WORD_SIZE, remaining_len - i);
189                    if copy_len != KECCAK_WORD_SIZE {
190                        partial_read_idx = Some(reads.len());
191                    }
192                    bytes[i..i + copy_len].copy_from_slice(&chunk[..copy_len]);
193                    reads.push(read.0);
194                } else {
195                    memory.increment_timestamp();
196                }
197                timestamp_delta += 1;
198            }
199
200            let mut block = KeccakInputBlock {
201                reads,
202                partial_read_idx,
203                padded_bytes: bytes,
204                remaining_len,
205                src,
206                is_new_start: block_idx == 0,
207            };
208            if block_idx != num_blocks - 1 {
209                src += KECCAK_RATE_BYTES;
210                remaining_len -= KECCAK_RATE_BYTES;
211                hasher.update(&block.padded_bytes);
212            } else {
213                // handle padding here since it is convenient
214                debug_assert!(remaining_len < KECCAK_RATE_BYTES);
215                hasher.update(&block.padded_bytes[..remaining_len]);
216
217                if remaining_len == KECCAK_RATE_BYTES - 1 {
218                    block.padded_bytes[remaining_len] = 0b1000_0001;
219                } else {
220                    block.padded_bytes[remaining_len] = 0x01;
221                    block.padded_bytes[KECCAK_RATE_BYTES - 1] = 0x80;
222                }
223            }
224            input_blocks.push(block);
225        }
226        let mut output = [0u8; 32];
227        hasher.finalize(&mut output);
228        let dst = dst as usize;
229        let digest_writes: [_; KECCAK_DIGEST_WRITES] = from_fn(|i| {
230            timestamp_delta += 1;
231            memory
232                .write::<KECCAK_WORD_SIZE>(
233                    e,
234                    F::from_canonical_usize(dst + i * KECCAK_WORD_SIZE),
235                    from_fn(|j| F::from_canonical_u8(output[i * KECCAK_WORD_SIZE + j])),
236                )
237                .0
238        });
239        tracing::trace!("[runtime] keccak256 output: {:?}", output);
240
241        let record = KeccakRecord {
242            pc: F::from_canonical_u32(from_state.pc),
243            dst_read,
244            src_read,
245            len_read,
246            input_blocks,
247            digest_writes,
248        };
249
250        // Add the events to chip state for later trace generation usage
251        self.records.push(record);
252
253        // NOTE: Check this is consistent with KeccakVmAir::timestamp_change (we don't use it to avoid
254        // unnecessary conversions here)
255        let total_timestamp_delta =
256            len + (KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES) as u32;
257        memory.increment_timestamp_by(total_timestamp_delta - timestamp_delta);
258
259        Ok(ExecutionState {
260            pc: from_state.pc + DEFAULT_PC_STEP,
261            timestamp: from_state.timestamp + total_timestamp_delta,
262        })
263    }
264
265    fn get_opcode_name(&self, _: usize) -> String {
266        "KECCAK256".to_string()
267    }
268}
269
270impl Default for KeccakInputBlock {
271    fn default() -> Self {
272        // Padding for empty byte array so padding constraints still hold
273        let mut padded_bytes = [0u8; KECCAK_RATE_BYTES];
274        padded_bytes[0] = 0x01;
275        *padded_bytes.last_mut().unwrap() = 0x80;
276        Self {
277            padded_bytes,
278            partial_read_idx: None,
279            remaining_len: 0,
280            is_new_start: true,
281            reads: Vec::new(),
282            src: 0,
283        }
284    }
285}