openvm_keccak256_circuit/cuda/
mod.rs1use std::{iter::repeat_n, sync::Arc};
2
3use derive_new::new;
4use openvm_circuit::arch::{DenseRecordArena, MultiRowLayout, RecordSeeker};
5use openvm_circuit_primitives::{
6 bitwise_op_lookup::BitwiseOperationLookupChipGPU, utils::next_power_of_two_or_zero,
7 var_range::VariableRangeCheckerChipGPU,
8};
9use openvm_cuda_backend::{
10 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend,
11};
12use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer};
13use openvm_instructions::riscv::RV32_CELL_BITS;
14use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
15use p3_keccak_air::NUM_ROUNDS;
16
17use crate::{
18 columns::NUM_KECCAK_VM_COLS,
19 trace::{KeccakVmMetadata, KeccakVmRecordMut},
20 utils::num_keccak_f,
21};
22
23mod cuda_abi;
24use cuda_abi::keccak256::*;
25
26#[derive(new)]
27pub struct Keccak256ChipGpu {
28 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
29 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
30 pub ptr_max_bits: u32,
31 pub timestamp_max_bits: u32,
32}
33
34impl Chip<DenseRecordArena, GpuBackend> for Keccak256ChipGpu {
35 fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
36 let records = arena.allocated_mut();
37 if records.is_empty() {
38 return get_empty_air_proving_ctx::<GpuBackend>();
39 }
40
41 let mut record_offsets = Vec::<usize>::new();
42 let mut block_to_record_idx = Vec::<u32>::new();
43 let mut block_offsets = Vec::<u32>::new();
44 let mut offset_so_far = 0;
45 let mut num_blocks_so_far = 0;
46 while offset_so_far < records.len() {
47 record_offsets.push(offset_so_far);
48 block_offsets.push(num_blocks_so_far);
49
50 let record = RecordSeeker::<
51 DenseRecordArena,
52 KeccakVmRecordMut,
53 MultiRowLayout<KeccakVmMetadata>,
54 >::get_record_at(&mut offset_so_far, records);
55
56 let num_blocks = num_keccak_f(record.inner.len as usize);
57 let record_idx = record_offsets.len() - 1;
58 block_to_record_idx.extend(repeat_n(record_idx as u32, num_blocks));
59 num_blocks_so_far += num_blocks as u32;
60 }
61 assert_eq!(num_blocks_so_far as usize, block_to_record_idx.len());
62 assert_eq!(offset_so_far, records.len());
63 assert_eq!(block_offsets.len(), record_offsets.len());
64
65 let records_num = record_offsets.len();
66 let d_records = records.to_device().unwrap();
67 let d_record_offsets = record_offsets.to_device().unwrap();
68 let d_block_offsets = block_offsets.to_device().unwrap();
69 let d_block_to_record_idx = block_to_record_idx.to_device().unwrap();
70
71 let rows_used = num_blocks_so_far as usize * NUM_ROUNDS;
72 let trace_height = next_power_of_two_or_zero(rows_used);
73 let trace = DeviceMatrix::<F>::with_capacity(trace_height, NUM_KECCAK_VM_COLS);
74
75 let states_num = 2 * num_blocks_so_far as usize;
77 let d_states = DeviceBuffer::<u64>::with_capacity(states_num * 25);
78
79 unsafe {
80 keccakf(
81 &d_records,
82 records_num,
83 &d_record_offsets,
84 &d_block_offsets,
85 num_blocks_so_far,
86 &d_states,
87 &self.bitwise_lookup.count,
88 RV32_CELL_BITS,
89 )
90 .unwrap();
91
92 p3_tracegen(trace.buffer(), trace_height, num_blocks_so_far, &d_states).unwrap();
93
94 tracegen(
95 trace.buffer(),
96 trace_height,
97 &d_records,
98 records_num,
99 &d_record_offsets,
100 &d_block_offsets,
101 &d_block_to_record_idx,
102 num_blocks_so_far,
103 &d_states,
104 rows_used,
105 self.ptr_max_bits,
106 &self.range_checker.count,
107 &self.bitwise_lookup.count,
108 RV32_CELL_BITS,
109 self.timestamp_max_bits,
110 )
111 .unwrap();
112 }
113
114 AirProvingContext::simple_no_pis(trace)
115 }
116}