openvm_algebra_circuit/modular_chip/cuda/
is_eq.rs

1use std::sync::Arc;
2
3use derive_new::new;
4use num_bigint::BigUint;
5use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
6use openvm_circuit_primitives::{
7    bigint::utils::big_uint_to_limbs, bitwise_op_lookup::BitwiseOperationLookupChipGPU,
8    var_range::VariableRangeCheckerChipGPU,
9};
10use openvm_cuda_backend::{
11    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
12};
13use openvm_cuda_common::copy::MemCopyH2D;
14use openvm_instructions::riscv::RV32_CELL_BITS;
15use openvm_rv32_adapters::{Rv32IsEqualModAdapterCols, Rv32IsEqualModAdapterRecord};
16use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
17
18use super::cuda_abi::is_eq_cuda::tracegen as modular_is_equal_tracegen;
19use crate::modular_chip::{ModularIsEqualCoreCols, ModularIsEqualRecord};
20
21#[derive(new)]
22pub struct ModularIsEqualChipGpu<
23    const NUM_LANES: usize,
24    const LANE_SIZE: usize,
25    const TOTAL_LIMBS: usize,
26> {
27    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
28    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
29    pub modulus: BigUint,
30    pub pointer_max_bits: u32,
31    pub timestamp_max_bits: u32,
32}
33
34impl<const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_LIMBS: usize>
35    Chip<DenseRecordArena, GpuBackend>
36    for ModularIsEqualChipGpu<NUM_LANES, LANE_SIZE, TOTAL_LIMBS>
37{
38    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
39        const LIMB_BITS: usize = 8;
40
41        let record_size = size_of::<(
42            Rv32IsEqualModAdapterRecord<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
43            ModularIsEqualRecord<TOTAL_LIMBS>,
44        )>();
45
46        let records = arena.allocated();
47        if records.is_empty() {
48            return get_empty_air_proving_ctx::<GpuBackend>();
49        }
50        debug_assert_eq!(records.len() % record_size, 0);
51
52        let trace_width = Rv32IsEqualModAdapterCols::<F, 2, NUM_LANES, LANE_SIZE>::width()
53            + ModularIsEqualCoreCols::<F, TOTAL_LIMBS>::width();
54        let trace_height = next_power_of_two_or_zero(records.len() / record_size);
55
56        let modulus_vec = big_uint_to_limbs(&self.modulus, LIMB_BITS);
57        assert!(modulus_vec.len() <= TOTAL_LIMBS);
58        let mut modulus_limbs = vec![0u8; TOTAL_LIMBS];
59        for (i, &limb) in modulus_vec.iter().enumerate() {
60            modulus_limbs[i] = limb as u8;
61        }
62
63        let d_records = records.to_device().unwrap();
64        let d_modulus = modulus_limbs.to_device().unwrap();
65        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
66
67        unsafe {
68            modular_is_equal_tracegen(
69                d_trace.buffer(),
70                trace_height,
71                &d_records,
72                &d_modulus,
73                TOTAL_LIMBS,
74                NUM_LANES,
75                LANE_SIZE,
76                &self.range_checker.count,
77                &self.bitwise_lookup.count,
78                self.pointer_max_bits,
79                self.timestamp_max_bits,
80            )
81            .unwrap();
82        }
83
84        AirProvingContext::simple_no_pis(d_trace)
85    }
86}