openvm_sha256_circuit/sha256_chip/
cuda.rs1use std::{iter::repeat_n, sync::Arc};
4
5use derive_new::new;
6use openvm_circuit::{
7 arch::{DenseRecordArena, MultiRowLayout, RecordSeeker},
8 utils::next_power_of_two_or_zero,
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
12};
13use openvm_cuda_backend::{
14 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend,
15};
16use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer};
17use openvm_instructions::riscv::RV32_CELL_BITS;
18use openvm_sha256_air::{get_sha256_num_blocks, SHA256_HASH_WORDS, SHA256_ROWS_PER_BLOCK};
19use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
20
21use crate::{
22 cuda_abi::sha256::{
23 sha256_fill_invalid_rows, sha256_first_pass_tracegen, sha256_hash_computation,
24 sha256_second_pass_dependencies,
25 },
26 Sha256VmMetadata, Sha256VmRecordMut, SHA256VM_WIDTH,
27};
28
29#[derive(new)]
31pub struct Sha256VmChipGpu {
32 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
33 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
34 pub ptr_max_bits: u32,
35 pub timestamp_max_bits: u32,
36}
37
38impl Chip<DenseRecordArena, GpuBackend> for Sha256VmChipGpu {
39 fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
40 let records = arena.allocated_mut();
41 if records.is_empty() {
42 return get_empty_air_proving_ctx::<GpuBackend>();
43 }
44
45 let mut record_offsets = Vec::<usize>::new();
46 let mut block_to_record_idx = Vec::<u32>::new();
47 let mut block_offsets = Vec::<u32>::new();
48 let mut offset_so_far = 0;
49 let mut num_blocks_so_far: u32 = 0;
50
51 while offset_so_far < records.len() {
52 record_offsets.push(offset_so_far);
53 block_offsets.push(num_blocks_so_far);
54
55 let record = RecordSeeker::<
56 DenseRecordArena,
57 Sha256VmRecordMut,
58 MultiRowLayout<Sha256VmMetadata>,
59 >::get_record_at(&mut offset_so_far, records);
60
61 let num_blocks = get_sha256_num_blocks(record.inner.len);
62 let record_idx = record_offsets.len() - 1;
63
64 block_to_record_idx.extend(repeat_n(record_idx as u32, num_blocks as usize));
65 num_blocks_so_far += num_blocks;
66 }
67
68 assert_eq!(num_blocks_so_far as usize, block_to_record_idx.len());
69 assert_eq!(offset_so_far, records.len());
70 assert_eq!(block_offsets.len(), record_offsets.len());
71
72 let d_records = records.to_device().unwrap();
73 let d_record_offsets = record_offsets.to_device().unwrap();
74 let d_block_offsets = block_offsets.to_device().unwrap();
75 let d_block_to_record_idx = block_to_record_idx.to_device().unwrap();
76
77 let d_prev_hashes = DeviceBuffer::<u32>::with_capacity(
78 num_blocks_so_far as usize * SHA256_HASH_WORDS, );
80
81 unsafe {
82 sha256_hash_computation(
83 &d_records,
84 record_offsets.len(),
85 &d_record_offsets,
86 &d_block_offsets,
87 &d_prev_hashes,
88 num_blocks_so_far,
89 )
90 .expect("Hash computation kernel failed");
91 }
92
93 let rows_used = num_blocks_so_far as usize * SHA256_ROWS_PER_BLOCK;
94 let trace_height = next_power_of_two_or_zero(rows_used);
95 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, SHA256VM_WIDTH);
96
97 unsafe {
98 sha256_first_pass_tracegen(
99 d_trace.buffer(),
100 trace_height,
101 &d_records,
102 record_offsets.len(),
103 &d_record_offsets,
104 &d_block_offsets,
105 &d_block_to_record_idx,
106 num_blocks_so_far,
107 &d_prev_hashes,
108 self.ptr_max_bits,
109 &self.range_checker.count,
110 &self.bitwise_lookup.count,
111 RV32_CELL_BITS as u32,
112 self.timestamp_max_bits,
113 )
114 .expect("First pass trace generation failed");
115 }
116
117 unsafe {
118 sha256_fill_invalid_rows(d_trace.buffer(), trace_height, rows_used)
119 .expect("Invalid rows filling failed");
120 }
121
122 unsafe {
123 sha256_second_pass_dependencies(d_trace.buffer(), trace_height, rows_used)
124 .expect("Second pass trace generation failed");
125 }
126
127 AirProvingContext::simple_no_pis(d_trace)
128 }
129}