openvm_sha256_circuit/sha256_chip/
cuda.rs

1// crates/tracegen/src/extensions/sha256/mod.rs
2
3use std::{iter::repeat_n, sync::Arc};
4
5use derive_new::new;
6use openvm_circuit::{
7    arch::{DenseRecordArena, MultiRowLayout, RecordSeeker},
8    utils::next_power_of_two_or_zero,
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
12};
13use openvm_cuda_backend::{
14    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend,
15};
16use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer};
17use openvm_instructions::riscv::RV32_CELL_BITS;
18use openvm_sha256_air::{get_sha256_num_blocks, SHA256_HASH_WORDS, SHA256_ROWS_PER_BLOCK};
19use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
20
21use crate::{
22    cuda_abi::sha256::{
23        sha256_fill_invalid_rows, sha256_first_pass_tracegen, sha256_hash_computation,
24        sha256_second_pass_dependencies,
25    },
26    Sha256VmMetadata, Sha256VmRecordMut, SHA256VM_WIDTH,
27};
28
29// ===== SHA256 GPU CHIP IMPLEMENTATION =====
30#[derive(new)]
31pub struct Sha256VmChipGpu {
32    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
33    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
34    pub ptr_max_bits: u32,
35    pub timestamp_max_bits: u32,
36}
37
38impl Chip<DenseRecordArena, GpuBackend> for Sha256VmChipGpu {
39    fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
40        let records = arena.allocated_mut();
41        if records.is_empty() {
42            return get_empty_air_proving_ctx::<GpuBackend>();
43        }
44
45        let mut record_offsets = Vec::<usize>::new();
46        let mut block_to_record_idx = Vec::<u32>::new();
47        let mut block_offsets = Vec::<u32>::new();
48        let mut offset_so_far = 0;
49        let mut num_blocks_so_far: u32 = 0;
50
51        while offset_so_far < records.len() {
52            record_offsets.push(offset_so_far);
53            block_offsets.push(num_blocks_so_far);
54
55            let record = RecordSeeker::<
56                DenseRecordArena,
57                Sha256VmRecordMut,
58                MultiRowLayout<Sha256VmMetadata>,
59            >::get_record_at(&mut offset_so_far, records);
60
61            let num_blocks = get_sha256_num_blocks(record.inner.len);
62            let record_idx = record_offsets.len() - 1;
63
64            block_to_record_idx.extend(repeat_n(record_idx as u32, num_blocks as usize));
65            num_blocks_so_far += num_blocks;
66        }
67
68        assert_eq!(num_blocks_so_far as usize, block_to_record_idx.len());
69        assert_eq!(offset_so_far, records.len());
70        assert_eq!(block_offsets.len(), record_offsets.len());
71
72        let d_records = records.to_device().unwrap();
73        let d_record_offsets = record_offsets.to_device().unwrap();
74        let d_block_offsets = block_offsets.to_device().unwrap();
75        let d_block_to_record_idx = block_to_record_idx.to_device().unwrap();
76
77        let d_prev_hashes = DeviceBuffer::<u32>::with_capacity(
78            num_blocks_so_far as usize * SHA256_HASH_WORDS, // 8 words per SHA256 hash block
79        );
80
81        unsafe {
82            sha256_hash_computation(
83                &d_records,
84                record_offsets.len(),
85                &d_record_offsets,
86                &d_block_offsets,
87                &d_prev_hashes,
88                num_blocks_so_far,
89            )
90            .expect("Hash computation kernel failed");
91        }
92
93        let rows_used = num_blocks_so_far as usize * SHA256_ROWS_PER_BLOCK;
94        let trace_height = next_power_of_two_or_zero(rows_used);
95        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, SHA256VM_WIDTH);
96
97        unsafe {
98            sha256_first_pass_tracegen(
99                d_trace.buffer(),
100                trace_height,
101                &d_records,
102                record_offsets.len(),
103                &d_record_offsets,
104                &d_block_offsets,
105                &d_block_to_record_idx,
106                num_blocks_so_far,
107                &d_prev_hashes,
108                self.ptr_max_bits,
109                &self.range_checker.count,
110                &self.bitwise_lookup.count,
111                RV32_CELL_BITS as u32,
112                self.timestamp_max_bits,
113            )
114            .expect("First pass trace generation failed");
115        }
116
117        unsafe {
118            sha256_fill_invalid_rows(d_trace.buffer(), trace_height, rows_used)
119                .expect("Invalid rows filling failed");
120        }
121
122        unsafe {
123            sha256_second_pass_dependencies(d_trace.buffer(), trace_height, rows_used)
124                .expect("Second pass trace generation failed");
125        }
126
127        AirProvingContext::simple_no_pis(d_trace)
128    }
129}