openvm_keccak256_circuit/cuda/
mod.rs

1use std::{iter::repeat_n, sync::Arc};
2
3use derive_new::new;
4use openvm_circuit::arch::{DenseRecordArena, MultiRowLayout, RecordSeeker};
5use openvm_circuit_primitives::{
6    bitwise_op_lookup::BitwiseOperationLookupChipGPU, utils::next_power_of_two_or_zero,
7    var_range::VariableRangeCheckerChipGPU,
8};
9use openvm_cuda_backend::{
10    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend,
11};
12use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer};
13use openvm_instructions::riscv::RV32_CELL_BITS;
14use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
15use p3_keccak_air::NUM_ROUNDS;
16
17use crate::{
18    columns::NUM_KECCAK_VM_COLS,
19    trace::{KeccakVmMetadata, KeccakVmRecordMut},
20    utils::num_keccak_f,
21};
22
23mod cuda_abi;
24use cuda_abi::keccak256::*;
25
26#[derive(new)]
27pub struct Keccak256ChipGpu {
28    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
29    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
30    pub ptr_max_bits: u32,
31    pub timestamp_max_bits: u32,
32}
33
34impl Chip<DenseRecordArena, GpuBackend> for Keccak256ChipGpu {
35    fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
36        let records = arena.allocated_mut();
37        if records.is_empty() {
38            return get_empty_air_proving_ctx::<GpuBackend>();
39        }
40
41        let mut record_offsets = Vec::<usize>::new();
42        let mut block_to_record_idx = Vec::<u32>::new();
43        let mut block_offsets = Vec::<u32>::new();
44        let mut offset_so_far = 0;
45        let mut num_blocks_so_far = 0;
46        while offset_so_far < records.len() {
47            record_offsets.push(offset_so_far);
48            block_offsets.push(num_blocks_so_far);
49
50            let record = RecordSeeker::<
51                DenseRecordArena,
52                KeccakVmRecordMut,
53                MultiRowLayout<KeccakVmMetadata>,
54            >::get_record_at(&mut offset_so_far, records);
55
56            let num_blocks = num_keccak_f(record.inner.len as usize);
57            let record_idx = record_offsets.len() - 1;
58            block_to_record_idx.extend(repeat_n(record_idx as u32, num_blocks));
59            num_blocks_so_far += num_blocks as u32;
60        }
61        assert_eq!(num_blocks_so_far as usize, block_to_record_idx.len());
62        assert_eq!(offset_so_far, records.len());
63        assert_eq!(block_offsets.len(), record_offsets.len());
64
65        let records_num = record_offsets.len();
66        let d_records = records.to_device().unwrap();
67        let d_record_offsets = record_offsets.to_device().unwrap();
68        let d_block_offsets = block_offsets.to_device().unwrap();
69        let d_block_to_record_idx = block_to_record_idx.to_device().unwrap();
70
71        let rows_used = num_blocks_so_far as usize * NUM_ROUNDS;
72        let trace_height = next_power_of_two_or_zero(rows_used);
73        let trace = DeviceMatrix::<F>::with_capacity(trace_height, NUM_KECCAK_VM_COLS);
74
75        // We store state + keccakf(state) for each block
76        let states_num = 2 * num_blocks_so_far as usize;
77        let d_states = DeviceBuffer::<u64>::with_capacity(states_num * 25);
78
79        unsafe {
80            keccakf(
81                &d_records,
82                records_num,
83                &d_record_offsets,
84                &d_block_offsets,
85                num_blocks_so_far,
86                &d_states,
87                &self.bitwise_lookup.count,
88                RV32_CELL_BITS,
89            )
90            .unwrap();
91
92            p3_tracegen(trace.buffer(), trace_height, num_blocks_so_far, &d_states).unwrap();
93
94            tracegen(
95                trace.buffer(),
96                trace_height,
97                &d_records,
98                records_num,
99                &d_record_offsets,
100                &d_block_offsets,
101                &d_block_to_record_idx,
102                num_blocks_so_far,
103                &d_states,
104                rows_used,
105                self.ptr_max_bits,
106                &self.range_checker.count,
107                &self.bitwise_lookup.count,
108                RV32_CELL_BITS,
109                self.timestamp_max_bits,
110            )
111            .unwrap();
112        }
113
114        AirProvingContext::simple_no_pis(trace)
115    }
116}