openvm_keccak256_circuit/
trace.rs
1use std::{array::from_fn, borrow::BorrowMut, sync::Arc};
2
3use openvm_circuit::system::memory::RecordId;
4use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
5use openvm_stark_backend::{
6 config::{StarkGenericConfig, Val},
7 p3_air::BaseAir,
8 p3_field::{FieldAlgebra, PrimeField32},
9 p3_matrix::{dense::RowMajorMatrix, Matrix},
10 p3_maybe_rayon::prelude::*,
11 prover::types::AirProofInput,
12 rap::get_air_name,
13 AirRef, Chip, ChipUsageGetter,
14};
15use p3_keccak_air::{
16 generate_trace_rows, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, NUM_ROUNDS, U64_LIMBS,
17};
18use tiny_keccak::keccakf;
19
20use super::{
21 columns::{KeccakInstructionCols, KeccakVmCols},
22 KeccakVmChip, KECCAK_ABSORB_READS, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES, KECCAK_RATE_U16S,
23 KECCAK_REGISTER_READS, NUM_ABSORB_ROUNDS,
24};
25
26impl<SC: StarkGenericConfig> Chip<SC> for KeccakVmChip<Val<SC>>
27where
28 Val<SC>: PrimeField32,
29{
30 fn air(&self) -> AirRef<SC> {
31 Arc::new(self.air)
32 }
33
34 fn generate_air_proof_input(self) -> AirProofInput<SC> {
35 let trace_width = self.trace_width();
36 let records = self.records;
37 let total_num_blocks: usize = records.iter().map(|r| r.input_blocks.len()).sum();
38 let mut states = Vec::with_capacity(total_num_blocks);
39 let mut instruction_blocks = Vec::with_capacity(total_num_blocks);
40 let memory = self.offline_memory.lock().unwrap();
41
42 #[derive(Clone)]
43 struct StateDiff {
44 pre_hi: [u8; KECCAK_RATE_U16S],
46 post_hi: [u8; KECCAK_RATE_U16S],
48 register_reads: Option<[RecordId; KECCAK_REGISTER_READS]>,
50 digest_writes: Option<[RecordId; KECCAK_DIGEST_WRITES]>,
52 }
53
54 impl Default for StateDiff {
55 fn default() -> Self {
56 Self {
57 pre_hi: [0; KECCAK_RATE_U16S],
58 post_hi: [0; KECCAK_RATE_U16S],
59 register_reads: None,
60 digest_writes: None,
61 }
62 }
63 }
64
65 let mut state: [u64; 25];
67 for record in records {
68 let dst_read = memory.record_by_id(record.dst_read);
69 let src_read = memory.record_by_id(record.src_read);
70 let len_read = memory.record_by_id(record.len_read);
71
72 state = [0u64; 25];
73 let src_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = src_read.data_slice()
74 [1..RV32_REGISTER_NUM_LIMBS]
75 .try_into()
76 .unwrap();
77 let len_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = len_read.data_slice()
78 [1..RV32_REGISTER_NUM_LIMBS]
79 .try_into()
80 .unwrap();
81 let mut instruction = KeccakInstructionCols {
82 pc: record.pc,
83 is_enabled: Val::<SC>::ONE,
84 is_enabled_first_round: Val::<SC>::ZERO,
85 start_timestamp: Val::<SC>::from_canonical_u32(dst_read.timestamp),
86 dst_ptr: dst_read.pointer,
87 src_ptr: src_read.pointer,
88 len_ptr: len_read.pointer,
89 dst: dst_read.data_slice().try_into().unwrap(),
90 src_limbs,
91 src: Val::<SC>::from_canonical_usize(record.input_blocks[0].src),
92 len_limbs,
93 remaining_len: Val::<SC>::from_canonical_usize(
94 record.input_blocks[0].remaining_len,
95 ),
96 };
97 let num_blocks = record.input_blocks.len();
98 for (idx, block) in record.input_blocks.into_iter().enumerate() {
99 for (bytes, s) in block.padded_bytes.chunks_exact(8).zip(state.iter_mut()) {
101 for (i, &byte) in bytes.iter().enumerate() {
103 let s_byte = (*s >> (i * 8)) as u8;
104 if idx != 0 {
106 self.bitwise_lookup_chip
107 .request_xor(byte as u32, s_byte as u32);
108 }
109 *s ^= (byte as u64) << (i * 8);
110 }
111 }
112 let pre_hi: [u8; KECCAK_RATE_U16S] =
113 from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8);
114 states.push(state);
115 keccakf(&mut state);
116 let post_hi: [u8; KECCAK_RATE_U16S] =
117 from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8);
118 if idx == num_blocks - 1 {
120 for s in state.into_iter().take(NUM_ABSORB_ROUNDS) {
121 for s_byte in s.to_le_bytes() {
122 self.bitwise_lookup_chip.request_xor(0, s_byte as u32);
123 }
124 }
125 }
126 let register_reads =
127 (idx == 0).then_some([record.dst_read, record.src_read, record.len_read]);
128 let digest_writes = (idx == num_blocks - 1).then_some(record.digest_writes);
129 let diff = StateDiff {
130 pre_hi,
131 post_hi,
132 register_reads,
133 digest_writes,
134 };
135 instruction_blocks.push((instruction, diff, block));
136 instruction.remaining_len -= Val::<SC>::from_canonical_usize(KECCAK_RATE_BYTES);
137 instruction.src += Val::<SC>::from_canonical_usize(KECCAK_RATE_BYTES);
138 instruction.start_timestamp +=
139 Val::<SC>::from_canonical_usize(KECCAK_REGISTER_READS + KECCAK_ABSORB_READS);
140 }
141 }
142
143 let p3_states = states
147 .into_iter()
148 .map(|state| {
149 from_fn(|i| {
151 let x = i / 5;
152 let y = i % 5;
153 state[x + 5 * y]
154 })
155 })
156 .collect();
157 let p3_keccak_trace: RowMajorMatrix<Val<SC>> = generate_trace_rows(p3_states, 0);
158 let num_rows = p3_keccak_trace.height();
159 let num_blocks = num_rows.div_ceil(NUM_ROUNDS);
161 instruction_blocks.resize(num_blocks, Default::default());
163
164 let aux_cols_factory = memory.aux_cols_factory();
165
166 let mut trace =
168 RowMajorMatrix::new(Val::<SC>::zero_vec(num_rows * trace_width), trace_width);
169 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.ptr_max_bits;
170
171 trace
172 .values
173 .par_chunks_mut(trace_width * NUM_ROUNDS)
174 .zip(
175 p3_keccak_trace
176 .values
177 .par_chunks(NUM_KECCAK_PERM_COLS * NUM_ROUNDS),
178 )
179 .zip(instruction_blocks.into_par_iter())
180 .for_each(|((rows, p3_keccak_mat), (instruction, diff, block))| {
181 let height = rows.len() / trace_width;
182 for (row, p3_keccak_row) in rows
183 .chunks_exact_mut(trace_width)
184 .zip(p3_keccak_mat.chunks_exact(NUM_KECCAK_PERM_COLS))
185 {
186 row[..NUM_KECCAK_PERM_COLS].copy_from_slice(p3_keccak_row);
188 let row_mut: &mut KeccakVmCols<Val<SC>> = row.borrow_mut();
189 row_mut.instruction = instruction;
190
191 row_mut.sponge.block_bytes =
192 block.padded_bytes.map(Val::<SC>::from_canonical_u8);
193 if let Some(partial_read_idx) = block.partial_read_idx {
194 let partial_read = memory.record_by_id(block.reads[partial_read_idx]);
195 row_mut
196 .mem_oc
197 .partial_block
198 .copy_from_slice(&partial_read.data_slice()[1..]);
199 }
200 for (i, is_padding) in row_mut.sponge.is_padding_byte.iter_mut().enumerate() {
201 *is_padding = Val::<SC>::from_bool(i >= block.remaining_len);
202 }
203 }
204 let first_row: &mut KeccakVmCols<Val<SC>> = rows[..trace_width].borrow_mut();
205 first_row.sponge.is_new_start = Val::<SC>::from_bool(block.is_new_start);
206 first_row.sponge.state_hi = diff.pre_hi.map(Val::<SC>::from_canonical_u8);
207 first_row.instruction.is_enabled_first_round = first_row.instruction.is_enabled;
208 if let Some(register_reads) = diff.register_reads {
210 let need_range_check = [
211 ®ister_reads[0], ®ister_reads[1], ®ister_reads[2], ®ister_reads[2],
215 ]
216 .map(|r| {
217 memory
218 .record_by_id(*r)
219 .data_slice()
220 .last()
221 .unwrap()
222 .as_canonical_u32()
223 });
224 for bytes in need_range_check.chunks(2) {
225 self.bitwise_lookup_chip.request_range(
226 bytes[0] << limb_shift_bits,
227 bytes[1] << limb_shift_bits,
228 );
229 }
230 for (i, id) in register_reads.into_iter().enumerate() {
231 aux_cols_factory.generate_read_aux(
232 memory.record_by_id(id),
233 &mut first_row.mem_oc.register_aux[i],
234 );
235 }
236 }
237 for (i, id) in block.reads.into_iter().enumerate() {
238 aux_cols_factory.generate_read_aux(
239 memory.record_by_id(id),
240 &mut first_row.mem_oc.absorb_reads[i],
241 );
242 }
243
244 let last_row: &mut KeccakVmCols<Val<SC>> =
245 rows[(height - 1) * trace_width..].borrow_mut();
246 last_row.sponge.state_hi = diff.post_hi.map(Val::<SC>::from_canonical_u8);
247 last_row.inner.export = instruction.is_enabled
248 * Val::<SC>::from_bool(block.remaining_len < KECCAK_RATE_BYTES);
249 if let Some(digest_writes) = diff.digest_writes {
250 for (i, record_id) in digest_writes.into_iter().enumerate() {
251 let record = memory.record_by_id(record_id);
252 aux_cols_factory
253 .generate_write_aux(record, &mut last_row.mem_oc.digest_writes[i]);
254 }
255 }
256 });
257
258 AirProofInput::simple_no_pis(trace)
259 }
260}
261
262impl<F: PrimeField32> ChipUsageGetter for KeccakVmChip<F> {
263 fn air_name(&self) -> String {
264 get_air_name(&self.air)
265 }
266 fn current_trace_height(&self) -> usize {
267 let num_blocks: usize = self.records.iter().map(|r| r.input_blocks.len()).sum();
268 num_blocks * NUM_ROUNDS
269 }
270
271 fn trace_width(&self) -> usize {
272 BaseAir::<F>::width(&self.air)
273 }
274}