openvm_algebra_circuit/modular_chip/cuda/
is_eq.rs1use std::sync::Arc;
2
3use derive_new::new;
4use num_bigint::BigUint;
5use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
6use openvm_circuit_primitives::{
7 bigint::utils::big_uint_to_limbs, bitwise_op_lookup::BitwiseOperationLookupChipGPU,
8 var_range::VariableRangeCheckerChipGPU,
9};
10use openvm_cuda_backend::{
11 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
12};
13use openvm_cuda_common::copy::MemCopyH2D;
14use openvm_instructions::riscv::RV32_CELL_BITS;
15use openvm_rv32_adapters::{Rv32IsEqualModAdapterCols, Rv32IsEqualModAdapterRecord};
16use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
17
18use super::cuda_abi::is_eq_cuda::tracegen as modular_is_equal_tracegen;
19use crate::modular_chip::{ModularIsEqualCoreCols, ModularIsEqualRecord};
20
21#[derive(new)]
22pub struct ModularIsEqualChipGpu<
23 const NUM_LANES: usize,
24 const LANE_SIZE: usize,
25 const TOTAL_LIMBS: usize,
26> {
27 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
28 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
29 pub modulus: BigUint,
30 pub pointer_max_bits: u32,
31 pub timestamp_max_bits: u32,
32}
33
34impl<const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_LIMBS: usize>
35 Chip<DenseRecordArena, GpuBackend>
36 for ModularIsEqualChipGpu<NUM_LANES, LANE_SIZE, TOTAL_LIMBS>
37{
38 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
39 const LIMB_BITS: usize = 8;
40
41 let record_size = size_of::<(
42 Rv32IsEqualModAdapterRecord<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
43 ModularIsEqualRecord<TOTAL_LIMBS>,
44 )>();
45
46 let records = arena.allocated();
47 if records.is_empty() {
48 return get_empty_air_proving_ctx::<GpuBackend>();
49 }
50 debug_assert_eq!(records.len() % record_size, 0);
51
52 let trace_width = Rv32IsEqualModAdapterCols::<F, 2, NUM_LANES, LANE_SIZE>::width()
53 + ModularIsEqualCoreCols::<F, TOTAL_LIMBS>::width();
54 let trace_height = next_power_of_two_or_zero(records.len() / record_size);
55
56 let modulus_vec = big_uint_to_limbs(&self.modulus, LIMB_BITS);
57 assert!(modulus_vec.len() <= TOTAL_LIMBS);
58 let mut modulus_limbs = vec![0u8; TOTAL_LIMBS];
59 for (i, &limb) in modulus_vec.iter().enumerate() {
60 modulus_limbs[i] = limb as u8;
61 }
62
63 let d_records = records.to_device().unwrap();
64 let d_modulus = modulus_limbs.to_device().unwrap();
65 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
66
67 unsafe {
68 modular_is_equal_tracegen(
69 d_trace.buffer(),
70 trace_height,
71 &d_records,
72 &d_modulus,
73 TOTAL_LIMBS,
74 NUM_LANES,
75 LANE_SIZE,
76 &self.range_checker.count,
77 &self.bitwise_lookup.count,
78 self.pointer_max_bits,
79 self.timestamp_max_bits,
80 )
81 .unwrap();
82 }
83
84 AirProvingContext::simple_no_pis(d_trace)
85 }
86}