1use std::{
2 array::{self, from_fn},
3 borrow::{Borrow, BorrowMut},
4 cmp::min,
5};
6
7use openvm_circuit::{
8 arch::*,
9 system::memory::{
10 offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord},
11 online::TracingMemory,
12 MemoryAuxColsFactory,
13 },
14};
15use openvm_circuit_primitives::AlignedBytesBorrow;
16use openvm_instructions::{
17 instruction::Instruction,
18 program::DEFAULT_PC_STEP,
19 riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
20 LocalOpcode,
21};
22use openvm_keccak256_transpiler::Rv32KeccakOpcode;
23use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write};
24use openvm_stark_backend::{
25 p3_field::PrimeField32,
26 p3_matrix::{dense::RowMajorMatrix, Matrix},
27 p3_maybe_rayon::prelude::*,
28};
29use p3_keccak_air::{
30 generate_trace_rows, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, NUM_ROUNDS, U64_LIMBS,
31};
32use tiny_keccak::keccakf;
33
34use super::{
35 columns::KeccakVmCols, KECCAK_ABSORB_READS, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES,
36 KECCAK_REGISTER_READS, NUM_ABSORB_ROUNDS,
37};
38use crate::{
39 columns::NUM_KECCAK_VM_COLS,
40 utils::{keccak256, keccak_f, num_keccak_f},
41 KeccakVmExecutor, KeccakVmFiller, KECCAK_DIGEST_BYTES, KECCAK_RATE_U16S, KECCAK_WORD_SIZE,
42};
43
44#[derive(Clone, Copy)]
45pub struct KeccakVmMetadata {
46 pub len: usize,
47}
48
49impl MultiRowMetadata for KeccakVmMetadata {
50 #[inline(always)]
51 fn get_num_rows(&self) -> usize {
52 num_keccak_f(self.len) * NUM_ROUNDS
53 }
54}
55
56pub(crate) type KeccakVmRecordLayout = MultiRowLayout<KeccakVmMetadata>;
57
58#[repr(C)]
59#[derive(AlignedBytesBorrow, Debug, Clone)]
60pub struct KeccakVmRecordHeader {
61 pub from_pc: u32,
62 pub timestamp: u32,
63 pub rd_ptr: u32,
64 pub rs1_ptr: u32,
65 pub rs2_ptr: u32,
66 pub dst: u32,
67 pub src: u32,
68 pub len: u32,
69
70 pub register_reads_aux: [MemoryReadAuxRecord; KECCAK_REGISTER_READS],
71 pub write_aux: [MemoryWriteBytesAuxRecord<KECCAK_WORD_SIZE>; KECCAK_DIGEST_WRITES],
72}
73
74pub struct KeccakVmRecordMut<'a> {
75 pub inner: &'a mut KeccakVmRecordHeader,
76 pub input: &'a mut [u8],
78 pub read_aux: &'a mut [MemoryReadAuxRecord],
79}
80
81impl<'a> CustomBorrow<'a, KeccakVmRecordMut<'a>, KeccakVmRecordLayout> for [u8] {
87 fn custom_borrow(&'a mut self, layout: KeccakVmRecordLayout) -> KeccakVmRecordMut<'a> {
88 let (record_buf, rest) =
92 unsafe { self.split_at_mut_unchecked(size_of::<KeccakVmRecordHeader>()) };
93
94 let num_reads = layout.metadata.len.div_ceil(KECCAK_WORD_SIZE);
95 let (input, rest) = unsafe { rest.split_at_mut_unchecked(num_reads * KECCAK_WORD_SIZE) };
101 let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::<MemoryReadAuxRecord>() };
102 KeccakVmRecordMut {
103 inner: record_buf.borrow_mut(),
104 input,
105 read_aux: &mut read_aux_buf[..num_reads],
106 }
107 }
108
109 unsafe fn extract_layout(&self) -> KeccakVmRecordLayout {
110 let header: &KeccakVmRecordHeader = self.borrow();
111 KeccakVmRecordLayout {
112 metadata: KeccakVmMetadata {
113 len: header.len as usize,
114 },
115 }
116 }
117}
118
119impl SizedRecord<KeccakVmRecordLayout> for KeccakVmRecordMut<'_> {
120 fn size(layout: &KeccakVmRecordLayout) -> usize {
121 let num_reads = layout.metadata.len.div_ceil(KECCAK_WORD_SIZE);
122 let mut total_len = size_of::<KeccakVmRecordHeader>();
123 total_len += num_reads * KECCAK_WORD_SIZE;
124 total_len = total_len.next_multiple_of(align_of::<MemoryReadAuxRecord>());
126 total_len += num_reads * size_of::<MemoryReadAuxRecord>();
127 total_len
128 }
129
130 fn alignment(_layout: &KeccakVmRecordLayout) -> usize {
131 align_of::<KeccakVmRecordHeader>()
132 }
133}
134
135impl<F, RA> PreflightExecutor<F, RA> for KeccakVmExecutor
136where
137 F: PrimeField32,
138 for<'buf> RA: RecordArena<'buf, KeccakVmRecordLayout, KeccakVmRecordMut<'buf>>,
139{
140 fn get_opcode_name(&self, _: usize) -> String {
141 format!("{:?}", Rv32KeccakOpcode::KECCAK256)
142 }
143
144 fn execute(
145 &self,
146 state: VmStateMut<F, TracingMemory, RA>,
147 instruction: &Instruction<F>,
148 ) -> Result<(), ExecutionError> {
149 let &Instruction {
150 opcode,
151 a,
152 b,
153 c,
154 d,
155 e,
156 ..
157 } = instruction;
158 debug_assert_eq!(opcode, Rv32KeccakOpcode::KECCAK256.global_opcode());
159 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
160 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
161
162 let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()) as usize;
164
165 let num_reads = len.div_ceil(KECCAK_WORD_SIZE);
166 let num_blocks = num_keccak_f(len);
167 let record = state
168 .ctx
169 .alloc(KeccakVmRecordLayout::new(KeccakVmMetadata { len }));
170
171 record.inner.from_pc = *state.pc;
172 record.inner.timestamp = state.memory.timestamp();
173 record.inner.rd_ptr = a.as_canonical_u32();
174 record.inner.rs1_ptr = b.as_canonical_u32();
175 record.inner.rs2_ptr = c.as_canonical_u32();
176
177 record.inner.dst = u32::from_le_bytes(tracing_read(
178 state.memory,
179 RV32_REGISTER_AS,
180 record.inner.rd_ptr,
181 &mut record.inner.register_reads_aux[0].prev_timestamp,
182 ));
183 record.inner.src = u32::from_le_bytes(tracing_read(
184 state.memory,
185 RV32_REGISTER_AS,
186 record.inner.rs1_ptr,
187 &mut record.inner.register_reads_aux[1].prev_timestamp,
188 ));
189 record.inner.len = u32::from_le_bytes(tracing_read(
190 state.memory,
191 RV32_REGISTER_AS,
192 record.inner.rs2_ptr,
193 &mut record.inner.register_reads_aux[2].prev_timestamp,
194 ));
195
196 debug_assert!(record.inner.src as usize + len <= (1 << self.pointer_max_bits));
197 debug_assert!(
198 record.inner.dst as usize + KECCAK_DIGEST_BYTES <= (1 << self.pointer_max_bits)
199 );
200 debug_assert!(record.inner.len < (1 << self.pointer_max_bits));
202
203 for idx in 0..num_reads {
204 if idx % KECCAK_ABSORB_READS == 0 && idx != 0 {
205 state
208 .memory
209 .increment_timestamp_by(KECCAK_REGISTER_READS as u32);
210 }
211 let read = tracing_read::<KECCAK_WORD_SIZE>(
212 state.memory,
213 RV32_MEMORY_AS,
214 record.inner.src + (idx * KECCAK_WORD_SIZE) as u32,
215 &mut record.read_aux[idx].prev_timestamp,
216 );
217 record.input[idx * KECCAK_WORD_SIZE..(idx + 1) * KECCAK_WORD_SIZE]
218 .copy_from_slice(&read);
219 }
220
221 state.memory.timestamp = record.inner.timestamp
223 + (num_blocks * (KECCAK_ABSORB_READS + KECCAK_REGISTER_READS)) as u32;
224
225 let digest = keccak256(&record.input[..len]);
226 for (i, word) in digest.chunks_exact(KECCAK_WORD_SIZE).enumerate() {
227 tracing_write::<KECCAK_WORD_SIZE>(
228 state.memory,
229 RV32_MEMORY_AS,
230 record.inner.dst + (i * KECCAK_WORD_SIZE) as u32,
231 word.try_into().unwrap(),
232 &mut record.inner.write_aux[i].prev_timestamp,
233 &mut record.inner.write_aux[i].prev_data,
234 );
235 }
236
237 state.memory.timestamp = record.inner.timestamp
239 + (len + KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES) as u32;
240 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
241 Ok(())
242 }
243}
244
245impl<F: PrimeField32> TraceFiller<F> for KeccakVmFiller {
246 fn fill_trace(
247 &self,
248 mem_helper: &MemoryAuxColsFactory<F>,
249 trace_matrix: &mut RowMajorMatrix<F>,
250 rows_used: usize,
251 ) {
252 if rows_used == 0 {
253 return;
254 }
255
256 let mut chunks = Vec::with_capacity(trace_matrix.height() / NUM_ROUNDS);
257 let mut sizes = Vec::with_capacity(trace_matrix.height() / NUM_ROUNDS);
258 let mut trace = &mut trace_matrix.values[..];
259 let mut num_blocks_so_far = 0;
260
261 loop {
264 if num_blocks_so_far * NUM_ROUNDS >= rows_used {
265 chunks.push(trace);
267 sizes.push((0, 0));
268 break;
269 } else {
270 let record: &KeccakVmRecordHeader =
275 unsafe { get_record_from_slice(&mut trace, ()) };
276 let num_blocks = num_keccak_f(record.len as usize);
277 let (chunk, rest) =
278 trace.split_at_mut(NUM_KECCAK_VM_COLS * NUM_ROUNDS * num_blocks);
279 chunks.push(chunk);
280 sizes.push((num_blocks, record.len as usize));
281 num_blocks_so_far += num_blocks;
282 trace = rest;
283 }
284 }
285
286 chunks
291 .par_iter_mut()
292 .zip(sizes.par_iter())
293 .for_each(|(slice, (num_blocks, len))| {
294 if *num_blocks == 0 {
295 let p3_trace: RowMajorMatrix<F> = generate_trace_rows(vec![[0u64; 25]; 1], 0);
300
301 slice
302 .par_chunks_exact_mut(NUM_KECCAK_VM_COLS)
303 .enumerate()
304 .for_each(|(row_idx, row)| {
305 let idx = row_idx % NUM_ROUNDS;
306 row[..NUM_KECCAK_PERM_COLS].copy_from_slice(
307 &p3_trace.values
308 [idx * NUM_KECCAK_PERM_COLS..(idx + 1) * NUM_KECCAK_PERM_COLS],
309 );
310
311 unsafe {
322 std::ptr::write_bytes(
323 row.as_mut_ptr().add(NUM_KECCAK_PERM_COLS) as *mut u8,
324 0,
325 (NUM_KECCAK_VM_COLS - NUM_KECCAK_PERM_COLS) * size_of::<F>(),
326 );
327 }
328 let cols: &mut KeccakVmCols<F> = row.borrow_mut();
329 cols.sponge.is_new_start = F::from_bool(idx == 0);
331 cols.sponge.block_bytes[0] = F::ONE;
332 cols.sponge.block_bytes[KECCAK_RATE_BYTES - 1] =
333 F::from_canonical_u32(0x80);
334 cols.sponge.is_padding_byte = [F::ONE; KECCAK_RATE_BYTES];
335 });
336 return;
337 }
338
339 let num_reads = len.div_ceil(KECCAK_WORD_SIZE);
340 let read_len = num_reads * KECCAK_WORD_SIZE;
341
342 let record: KeccakVmRecordMut = unsafe {
349 get_record_from_slice(
350 slice,
351 KeccakVmRecordLayout::new(KeccakVmMetadata { len: *len }),
352 )
353 };
354
355 let mut read_aux_records = Vec::with_capacity(num_reads);
358 read_aux_records.extend_from_slice(record.read_aux);
359 let vm_record = record.inner.clone();
360 let partial_block = if read_len != *len {
361 record.input[read_len - KECCAK_WORD_SIZE + 1..]
362 .try_into()
363 .unwrap()
364 } else {
365 [0u8; KECCAK_WORD_SIZE - 1]
366 }
367 .map(F::from_canonical_u8);
368 let mut input = Vec::with_capacity(*num_blocks * KECCAK_RATE_BYTES);
369 input.extend_from_slice(&record.input[..*len]);
370 input.push(0x01);
372 input.resize(input.capacity(), 0);
373 *input.last_mut().unwrap() += 0x80;
374
375 let mut states = Vec::with_capacity(*num_blocks);
376 let mut state = [0u64; 25];
377
378 input
379 .chunks_exact(KECCAK_RATE_BYTES)
380 .enumerate()
381 .for_each(|(idx, chunk)| {
382 for (bytes, s) in chunk.chunks_exact(8).zip(state.iter_mut()) {
384 for (i, &byte) in bytes.iter().enumerate() {
386 let s_byte = (*s >> (i * 8)) as u8;
387 if idx != 0 {
389 self.bitwise_lookup_chip
390 .request_xor(byte as u32, s_byte as u32);
391 }
392 *s ^= (byte as u64) << (i * 8);
393 }
394 }
395 states.push(state);
396 keccakf(&mut state);
397 });
398
399 slice
400 .par_chunks_exact_mut(NUM_ROUNDS * NUM_KECCAK_VM_COLS)
401 .enumerate()
402 .for_each(|(block_idx, block_slice)| {
403 let state = from_fn(|i| {
407 let x = i / 5;
408 let y = i % 5;
409 states[block_idx][x + 5 * y]
410 });
411
412 let p3_trace: RowMajorMatrix<F> = generate_trace_rows(vec![state], 0);
417 let input_offset = block_idx * KECCAK_RATE_BYTES;
418 let start_timestamp = vm_record.timestamp
419 + (block_idx * (KECCAK_REGISTER_READS + KECCAK_ABSORB_READS)) as u32;
420 let rem_len = *len - input_offset;
421
422 block_slice
423 .par_chunks_exact_mut(NUM_KECCAK_VM_COLS)
424 .enumerate()
425 .zip(p3_trace.values.par_chunks(NUM_KECCAK_PERM_COLS))
426 .for_each(|((row_idx, row), p3_row)| {
427 row[..NUM_KECCAK_PERM_COLS].copy_from_slice(p3_row);
431
432 let cols: &mut KeccakVmCols<F> = row.borrow_mut();
433 cols.sponge.is_new_start =
435 F::from_bool(block_idx == 0 && row_idx == 0);
436 if rem_len < KECCAK_RATE_BYTES {
437 cols.sponge.is_padding_byte[..rem_len].fill(F::ZERO);
438 cols.sponge.is_padding_byte[rem_len..].fill(F::ONE);
439 } else {
440 cols.sponge.is_padding_byte = [F::ZERO; KECCAK_RATE_BYTES];
441 }
442 cols.sponge.block_bytes = array::from_fn(|i| {
443 F::from_canonical_u8(input[input_offset + i])
444 });
445 if row_idx == 0 {
446 cols.sponge.state_hi = from_fn(|i| {
447 F::from_canonical_u8(
448 (states[block_idx][i / U64_LIMBS]
449 >> ((i % U64_LIMBS) * 16 + 8))
450 as u8,
451 )
452 });
453 } else if row_idx == NUM_ROUNDS - 1 {
454 let state = keccak_f(states[block_idx]);
455 cols.sponge.state_hi = from_fn(|i| {
456 F::from_canonical_u8(
457 (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8))
458 as u8,
459 )
460 });
461 if block_idx == num_blocks - 1 {
462 cols.inner.export = F::ONE;
463 for s in state.into_iter().take(NUM_ABSORB_ROUNDS) {
464 for s_byte in s.to_le_bytes() {
465 self.bitwise_lookup_chip
466 .request_xor(0, s_byte as u32);
467 }
468 }
469 }
470 } else {
471 cols.sponge.state_hi = [F::ZERO; KECCAK_RATE_U16S];
472 }
473
474 cols.instruction.pc = F::from_canonical_u32(vm_record.from_pc);
476 cols.instruction.is_enabled = F::ONE;
477 cols.instruction.is_enabled_first_round =
478 F::from_bool(row_idx == 0);
479 cols.instruction.start_timestamp =
480 F::from_canonical_u32(start_timestamp);
481 cols.instruction.dst_ptr = F::from_canonical_u32(vm_record.rd_ptr);
482 cols.instruction.src_ptr = F::from_canonical_u32(vm_record.rs1_ptr);
483 cols.instruction.len_ptr = F::from_canonical_u32(vm_record.rs2_ptr);
484 cols.instruction.dst =
485 vm_record.dst.to_le_bytes().map(F::from_canonical_u8);
486
487 let src = vm_record.src + (block_idx * KECCAK_RATE_BYTES) as u32;
488 cols.instruction.src = F::from_canonical_u32(src);
489 cols.instruction.src_limbs.copy_from_slice(
490 &src.to_le_bytes().map(F::from_canonical_u8)[1..],
491 );
492 cols.instruction.len_limbs.copy_from_slice(
493 &(rem_len as u32).to_le_bytes().map(F::from_canonical_u8)[1..],
494 );
495 cols.instruction.remaining_len =
496 F::from_canonical_u32(rem_len as u32);
497
498 if row_idx == 0 && block_idx == 0 {
500 for ((i, cols), vm_record) in cols
501 .mem_oc
502 .register_aux
503 .iter_mut()
504 .enumerate()
505 .zip(vm_record.register_reads_aux.iter())
506 {
507 mem_helper.fill(
508 vm_record.prev_timestamp,
509 start_timestamp + i as u32,
510 cols.as_mut(),
511 );
512 }
513
514 let msl_rshift = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
515 let msl_lshift = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS
516 - self.pointer_max_bits;
517 self.bitwise_lookup_chip.request_range(
519 (vm_record.dst >> msl_rshift) << msl_lshift,
520 (vm_record.src >> msl_rshift) << msl_lshift,
521 );
522 self.bitwise_lookup_chip.request_range(
523 (vm_record.len >> msl_rshift) << msl_lshift,
524 (vm_record.len >> msl_rshift) << msl_lshift,
525 );
526 } else {
527 cols.mem_oc.register_aux.par_iter_mut().for_each(|aux| {
528 mem_helper.fill_zero(aux.as_mut());
529 });
530 }
531
532 if row_idx == 0 {
534 let reads_offs = block_idx * KECCAK_ABSORB_READS;
535 let num_reads = min(
536 rem_len.div_ceil(KECCAK_WORD_SIZE),
537 KECCAK_ABSORB_READS,
538 );
539 let start_timestamp =
540 start_timestamp + KECCAK_REGISTER_READS as u32;
541 for i in 0..num_reads {
542 mem_helper.fill(
543 read_aux_records[i + reads_offs].prev_timestamp,
544 start_timestamp + i as u32,
545 cols.mem_oc.absorb_reads[i].as_mut(),
546 );
547 }
548 for i in num_reads..KECCAK_ABSORB_READS {
549 mem_helper.fill_zero(cols.mem_oc.absorb_reads[i].as_mut());
550 }
551 } else {
552 cols.mem_oc.absorb_reads.par_iter_mut().for_each(|aux| {
553 mem_helper.fill_zero(aux.as_mut());
554 });
555 }
556
557 if block_idx == num_blocks - 1 && row_idx == NUM_ROUNDS - 1 {
558 let timestamp = start_timestamp
559 + (KECCAK_ABSORB_READS + KECCAK_REGISTER_READS) as u32;
560 cols.mem_oc
561 .digest_writes
562 .par_iter_mut()
563 .enumerate()
564 .zip(vm_record.write_aux.par_iter())
565 .for_each(|((i, cols), vm_record)| {
566 cols.set_prev_data(
567 vm_record.prev_data.map(F::from_canonical_u8),
568 );
569 mem_helper.fill(
570 vm_record.prev_timestamp,
571 timestamp + i as u32,
572 cols.as_mut(),
573 );
574 });
575 } else {
576 cols.mem_oc.digest_writes.par_iter_mut().for_each(|aux| {
577 aux.set_prev_data([F::ZERO; KECCAK_WORD_SIZE]);
578 mem_helper.fill_zero(aux.as_mut());
579 });
580 }
581
582 if block_idx == num_blocks - 1 {
584 cols.mem_oc.partial_block = partial_block;
585 } else {
586 cols.mem_oc.partial_block = [F::ZERO; KECCAK_WORD_SIZE - 1];
587 }
588 });
589 });
590 });
591 }
592}