openvm_ecc_circuit/extension/
cuda.rs

1use openvm_algebra_circuit::Rv32ModularGpuBuilder;
2use openvm_circuit::{
3    arch::{
4        AirInventory, ChipInventory, ChipInventoryError, DenseRecordArena, VmBuilder,
5        VmChipComplex, VmProverExtension,
6    },
7    system::cuda::{
8        extensions::{get_inventory_range_checker, get_or_create_bitwise_op_lookup},
9        SystemChipInventoryGPU,
10    },
11};
12use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend};
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::LocalOpcode;
15use openvm_mod_circuit_builder::ExprBuilderConfig;
16use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config;
17use strum::EnumCount;
18
19use crate::{
20    Rv32WeierstrassConfig, WeierstrassAddNeChipGpu, WeierstrassAir, WeierstrassDoubleChipGpu,
21    WeierstrassExtension,
22};
23
24#[derive(Clone)]
25pub struct EccGpuProverExt;
26
27// This implementation is specific to GpuBackend because the lookup chips
28// (VariableRangeCheckerChipGPU, BitwiseOperationLookupChipGPU) are specific to GpuBackend.
29impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, WeierstrassExtension>
30    for EccGpuProverExt
31{
32    fn extend_prover(
33        &self,
34        extension: &WeierstrassExtension,
35        inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
36    ) -> Result<(), ChipInventoryError> {
37        let pointer_max_bits = inventory.airs().pointer_max_bits();
38        let timestamp_max_bits = inventory.timestamp_max_bits();
39
40        // Range checker should always exist in inventory
41        let range_checker = get_inventory_range_checker(inventory);
42
43        let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
44
45        for (i, curve) in extension.supported_curves.iter().enumerate() {
46            let start_offset =
47                Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
48            let bytes = curve.modulus.bits().div_ceil(8);
49
50            if bytes <= 32 {
51                let config = ExprBuilderConfig {
52                    modulus: curve.modulus.clone(),
53                    num_limbs: 32,
54                    limb_bits: 8,
55                };
56
57                inventory.next_air::<WeierstrassAir<2, 2, 32>>()?;
58                let addne = WeierstrassAddNeChipGpu::<2, 32>::new(
59                    range_checker.clone(),
60                    bitwise_lu.clone(),
61                    config.clone(),
62                    start_offset,
63                    pointer_max_bits as u32,
64                    timestamp_max_bits as u32,
65                );
66                inventory.add_executor_chip(addne);
67
68                inventory.next_air::<WeierstrassAir<1, 2, 32>>()?;
69                let double = WeierstrassDoubleChipGpu::<2, 32>::new(
70                    range_checker.clone(),
71                    bitwise_lu.clone(),
72                    config,
73                    start_offset,
74                    curve.a.clone(),
75                    pointer_max_bits as u32,
76                    timestamp_max_bits as u32,
77                );
78                inventory.add_executor_chip(double);
79            } else if bytes <= 48 {
80                let config = ExprBuilderConfig {
81                    modulus: curve.modulus.clone(),
82                    num_limbs: 48,
83                    limb_bits: 8,
84                };
85
86                inventory.next_air::<WeierstrassAir<2, 6, 16>>()?;
87                let addne = WeierstrassAddNeChipGpu::<6, 16>::new(
88                    range_checker.clone(),
89                    bitwise_lu.clone(),
90                    config.clone(),
91                    start_offset,
92                    pointer_max_bits as u32,
93                    timestamp_max_bits as u32,
94                );
95                inventory.add_executor_chip(addne);
96
97                inventory.next_air::<WeierstrassAir<1, 6, 16>>()?;
98                let double = WeierstrassDoubleChipGpu::<6, 16>::new(
99                    range_checker.clone(),
100                    bitwise_lu.clone(),
101                    config,
102                    start_offset,
103                    curve.a.clone(),
104                    pointer_max_bits as u32,
105                    timestamp_max_bits as u32,
106                );
107                inventory.add_executor_chip(double);
108            } else {
109                panic!("Modulus too large");
110            }
111        }
112
113        Ok(())
114    }
115}
116
117#[derive(Clone)]
118pub struct Rv32WeierstrassGpuBuilder;
119
120type E = GpuBabyBearPoseidon2Engine;
121
122impl VmBuilder<E> for Rv32WeierstrassGpuBuilder {
123    type VmConfig = Rv32WeierstrassConfig;
124    type SystemChipInventory = SystemChipInventoryGPU;
125    type RecordArena = DenseRecordArena;
126
127    fn create_chip_complex(
128        &self,
129        config: &Rv32WeierstrassConfig,
130        circuit: AirInventory<BabyBearPoseidon2Config>,
131    ) -> Result<
132        VmChipComplex<
133            BabyBearPoseidon2Config,
134            Self::RecordArena,
135            GpuBackend,
136            Self::SystemChipInventory,
137        >,
138        ChipInventoryError,
139    > {
140        let mut chip_complex =
141            VmBuilder::<E>::create_chip_complex(&Rv32ModularGpuBuilder, &config.modular, circuit)?;
142        let inventory = &mut chip_complex.inventory;
143        VmProverExtension::<E, _, _>::extend_prover(
144            &EccGpuProverExt,
145            &config.weierstrass,
146            inventory,
147        )?;
148
149        Ok(chip_complex)
150    }
151}