openvm_keccak256_circuit/
trace.rs

1use std::{array::from_fn, borrow::BorrowMut, sync::Arc};
2
3use openvm_circuit::system::memory::RecordId;
4use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
5use openvm_stark_backend::{
6    config::{StarkGenericConfig, Val},
7    p3_air::BaseAir,
8    p3_field::{FieldAlgebra, PrimeField32},
9    p3_matrix::{dense::RowMajorMatrix, Matrix},
10    p3_maybe_rayon::prelude::*,
11    prover::types::AirProofInput,
12    rap::get_air_name,
13    AirRef, Chip, ChipUsageGetter,
14};
15use p3_keccak_air::{
16    generate_trace_rows, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, NUM_ROUNDS, U64_LIMBS,
17};
18use tiny_keccak::keccakf;
19
20use super::{
21    columns::{KeccakInstructionCols, KeccakVmCols},
22    KeccakVmChip, KECCAK_ABSORB_READS, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES, KECCAK_RATE_U16S,
23    KECCAK_REGISTER_READS, NUM_ABSORB_ROUNDS,
24};
25
26impl<SC: StarkGenericConfig> Chip<SC> for KeccakVmChip<Val<SC>>
27where
28    Val<SC>: PrimeField32,
29{
30    fn air(&self) -> AirRef<SC> {
31        Arc::new(self.air)
32    }
33
34    fn generate_air_proof_input(self) -> AirProofInput<SC> {
35        let trace_width = self.trace_width();
36        let records = self.records;
37        let total_num_blocks: usize = records.iter().map(|r| r.input_blocks.len()).sum();
38        let mut states = Vec::with_capacity(total_num_blocks);
39        let mut instruction_blocks = Vec::with_capacity(total_num_blocks);
40        let memory = self.offline_memory.lock().unwrap();
41
42        #[derive(Clone)]
43        struct StateDiff {
44            /// hi-byte of pre-state
45            pre_hi: [u8; KECCAK_RATE_U16S],
46            /// hi-byte of post-state
47            post_hi: [u8; KECCAK_RATE_U16S],
48            /// if first block
49            register_reads: Option<[RecordId; KECCAK_REGISTER_READS]>,
50            /// if last block
51            digest_writes: Option<[RecordId; KECCAK_DIGEST_WRITES]>,
52        }
53
54        impl Default for StateDiff {
55            fn default() -> Self {
56                Self {
57                    pre_hi: [0; KECCAK_RATE_U16S],
58                    post_hi: [0; KECCAK_RATE_U16S],
59                    register_reads: None,
60                    digest_writes: None,
61                }
62            }
63        }
64
65        // prepare the states
66        let mut state: [u64; 25];
67        for record in records {
68            let dst_read = memory.record_by_id(record.dst_read);
69            let src_read = memory.record_by_id(record.src_read);
70            let len_read = memory.record_by_id(record.len_read);
71
72            state = [0u64; 25];
73            let src_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = src_read.data_slice()
74                [1..RV32_REGISTER_NUM_LIMBS]
75                .try_into()
76                .unwrap();
77            let len_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = len_read.data_slice()
78                [1..RV32_REGISTER_NUM_LIMBS]
79                .try_into()
80                .unwrap();
81            let mut instruction = KeccakInstructionCols {
82                pc: record.pc,
83                is_enabled: Val::<SC>::ONE,
84                is_enabled_first_round: Val::<SC>::ZERO,
85                start_timestamp: Val::<SC>::from_canonical_u32(dst_read.timestamp),
86                dst_ptr: dst_read.pointer,
87                src_ptr: src_read.pointer,
88                len_ptr: len_read.pointer,
89                dst: dst_read.data_slice().try_into().unwrap(),
90                src_limbs,
91                src: Val::<SC>::from_canonical_usize(record.input_blocks[0].src),
92                len_limbs,
93                remaining_len: Val::<SC>::from_canonical_usize(
94                    record.input_blocks[0].remaining_len,
95                ),
96            };
97            let num_blocks = record.input_blocks.len();
98            for (idx, block) in record.input_blocks.into_iter().enumerate() {
99                // absorb
100                for (bytes, s) in block.padded_bytes.chunks_exact(8).zip(state.iter_mut()) {
101                    // u64 <-> bytes conversion is little-endian
102                    for (i, &byte) in bytes.iter().enumerate() {
103                        let s_byte = (*s >> (i * 8)) as u8;
104                        // Update bitwise lookup (i.e. xor) chip state: order matters!
105                        if idx != 0 {
106                            self.bitwise_lookup_chip
107                                .request_xor(byte as u32, s_byte as u32);
108                        }
109                        *s ^= (byte as u64) << (i * 8);
110                    }
111                }
112                let pre_hi: [u8; KECCAK_RATE_U16S] =
113                    from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8);
114                states.push(state);
115                keccakf(&mut state);
116                let post_hi: [u8; KECCAK_RATE_U16S] =
117                    from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8);
118                // Range check the final state
119                if idx == num_blocks - 1 {
120                    for s in state.into_iter().take(NUM_ABSORB_ROUNDS) {
121                        for s_byte in s.to_le_bytes() {
122                            self.bitwise_lookup_chip.request_xor(0, s_byte as u32);
123                        }
124                    }
125                }
126                let register_reads =
127                    (idx == 0).then_some([record.dst_read, record.src_read, record.len_read]);
128                let digest_writes = (idx == num_blocks - 1).then_some(record.digest_writes);
129                let diff = StateDiff {
130                    pre_hi,
131                    post_hi,
132                    register_reads,
133                    digest_writes,
134                };
135                instruction_blocks.push((instruction, diff, block));
136                instruction.remaining_len -= Val::<SC>::from_canonical_usize(KECCAK_RATE_BYTES);
137                instruction.src += Val::<SC>::from_canonical_usize(KECCAK_RATE_BYTES);
138                instruction.start_timestamp +=
139                    Val::<SC>::from_canonical_usize(KECCAK_REGISTER_READS + KECCAK_ABSORB_READS);
140            }
141        }
142
143        // We need to transpose state matrices due to a plonky3 issue: https://github.com/Plonky3/Plonky3/issues/672
144        // Note: the fix for this issue will be a commit after the major Field crate refactor PR https://github.com/Plonky3/Plonky3/pull/640
145        //       which will require a significant refactor to switch to.
146        let p3_states = states
147            .into_iter()
148            .map(|state| {
149                // transpose of 5x5 matrix
150                from_fn(|i| {
151                    let x = i / 5;
152                    let y = i % 5;
153                    state[x + 5 * y]
154                })
155            })
156            .collect();
157        let p3_keccak_trace: RowMajorMatrix<Val<SC>> = generate_trace_rows(p3_states, 0);
158        let num_rows = p3_keccak_trace.height();
159        // Every `NUM_ROUNDS` rows corresponds to one input block
160        let num_blocks = num_rows.div_ceil(NUM_ROUNDS);
161        // Resize with dummy `is_enabled = 0`
162        instruction_blocks.resize(num_blocks, Default::default());
163
164        let aux_cols_factory = memory.aux_cols_factory();
165
166        // Use unsafe alignment so we can parallelly write to the matrix
167        let mut trace =
168            RowMajorMatrix::new(Val::<SC>::zero_vec(num_rows * trace_width), trace_width);
169        let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.ptr_max_bits;
170
171        trace
172            .values
173            .par_chunks_mut(trace_width * NUM_ROUNDS)
174            .zip(
175                p3_keccak_trace
176                    .values
177                    .par_chunks(NUM_KECCAK_PERM_COLS * NUM_ROUNDS),
178            )
179            .zip(instruction_blocks.into_par_iter())
180            .for_each(|((rows, p3_keccak_mat), (instruction, diff, block))| {
181                let height = rows.len() / trace_width;
182                for (row, p3_keccak_row) in rows
183                    .chunks_exact_mut(trace_width)
184                    .zip(p3_keccak_mat.chunks_exact(NUM_KECCAK_PERM_COLS))
185                {
186                    // Safety: `KeccakPermCols` **must** be the first field in `KeccakVmCols`
187                    row[..NUM_KECCAK_PERM_COLS].copy_from_slice(p3_keccak_row);
188                    let row_mut: &mut KeccakVmCols<Val<SC>> = row.borrow_mut();
189                    row_mut.instruction = instruction;
190
191                    row_mut.sponge.block_bytes =
192                        block.padded_bytes.map(Val::<SC>::from_canonical_u8);
193                    if let Some(partial_read_idx) = block.partial_read_idx {
194                        let partial_read = memory.record_by_id(block.reads[partial_read_idx]);
195                        row_mut
196                            .mem_oc
197                            .partial_block
198                            .copy_from_slice(&partial_read.data_slice()[1..]);
199                    }
200                    for (i, is_padding) in row_mut.sponge.is_padding_byte.iter_mut().enumerate() {
201                        *is_padding = Val::<SC>::from_bool(i >= block.remaining_len);
202                    }
203                }
204                let first_row: &mut KeccakVmCols<Val<SC>> = rows[..trace_width].borrow_mut();
205                first_row.sponge.is_new_start = Val::<SC>::from_bool(block.is_new_start);
206                first_row.sponge.state_hi = diff.pre_hi.map(Val::<SC>::from_canonical_u8);
207                first_row.instruction.is_enabled_first_round = first_row.instruction.is_enabled;
208                // Make memory access aux columns. Any aux column not explicitly defined defaults to all 0s
209                if let Some(register_reads) = diff.register_reads {
210                    let need_range_check = [
211                        &register_reads[0], // dst
212                        &register_reads[1], // src
213                        &register_reads[2], // len
214                        &register_reads[2],
215                    ]
216                    .map(|r| {
217                        memory
218                            .record_by_id(*r)
219                            .data_slice()
220                            .last()
221                            .unwrap()
222                            .as_canonical_u32()
223                    });
224                    for bytes in need_range_check.chunks(2) {
225                        self.bitwise_lookup_chip.request_range(
226                            bytes[0] << limb_shift_bits,
227                            bytes[1] << limb_shift_bits,
228                        );
229                    }
230                    for (i, id) in register_reads.into_iter().enumerate() {
231                        aux_cols_factory.generate_read_aux(
232                            memory.record_by_id(id),
233                            &mut first_row.mem_oc.register_aux[i],
234                        );
235                    }
236                }
237                for (i, id) in block.reads.into_iter().enumerate() {
238                    aux_cols_factory.generate_read_aux(
239                        memory.record_by_id(id),
240                        &mut first_row.mem_oc.absorb_reads[i],
241                    );
242                }
243
244                let last_row: &mut KeccakVmCols<Val<SC>> =
245                    rows[(height - 1) * trace_width..].borrow_mut();
246                last_row.sponge.state_hi = diff.post_hi.map(Val::<SC>::from_canonical_u8);
247                last_row.inner.export = instruction.is_enabled
248                    * Val::<SC>::from_bool(block.remaining_len < KECCAK_RATE_BYTES);
249                if let Some(digest_writes) = diff.digest_writes {
250                    for (i, record_id) in digest_writes.into_iter().enumerate() {
251                        let record = memory.record_by_id(record_id);
252                        aux_cols_factory
253                            .generate_write_aux(record, &mut last_row.mem_oc.digest_writes[i]);
254                    }
255                }
256            });
257
258        AirProofInput::simple_no_pis(trace)
259    }
260}
261
262impl<F: PrimeField32> ChipUsageGetter for KeccakVmChip<F> {
263    fn air_name(&self) -> String {
264        get_air_name(&self.air)
265    }
266    fn current_trace_height(&self) -> usize {
267        let num_blocks: usize = self.records.iter().map(|r| r.input_blocks.len()).sum();
268        num_blocks * NUM_ROUNDS
269    }
270
271    fn trace_width(&self) -> usize {
272        BaseAir::<F>::width(&self.air)
273    }
274}