openvm_ecc_circuit/extension/
cuda.rs1use openvm_algebra_circuit::Rv32ModularGpuBuilder;
2use openvm_circuit::{
3 arch::{
4 AirInventory, ChipInventory, ChipInventoryError, DenseRecordArena, VmBuilder,
5 VmChipComplex, VmProverExtension,
6 },
7 system::cuda::{
8 extensions::{get_inventory_range_checker, get_or_create_bitwise_op_lookup},
9 SystemChipInventoryGPU,
10 },
11};
12use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend};
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::LocalOpcode;
15use openvm_mod_circuit_builder::ExprBuilderConfig;
16use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config;
17use strum::EnumCount;
18
19use crate::{
20 Rv32WeierstrassConfig, WeierstrassAddNeChipGpu, WeierstrassAir, WeierstrassDoubleChipGpu,
21 WeierstrassExtension,
22};
23
24#[derive(Clone)]
25pub struct EccGpuProverExt;
26
27impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, WeierstrassExtension>
30 for EccGpuProverExt
31{
32 fn extend_prover(
33 &self,
34 extension: &WeierstrassExtension,
35 inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
36 ) -> Result<(), ChipInventoryError> {
37 let pointer_max_bits = inventory.airs().pointer_max_bits();
38 let timestamp_max_bits = inventory.timestamp_max_bits();
39
40 let range_checker = get_inventory_range_checker(inventory);
42
43 let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
44
45 for (i, curve) in extension.supported_curves.iter().enumerate() {
46 let start_offset =
47 Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
48 let bytes = curve.modulus.bits().div_ceil(8);
49
50 if bytes <= 32 {
51 let config = ExprBuilderConfig {
52 modulus: curve.modulus.clone(),
53 num_limbs: 32,
54 limb_bits: 8,
55 };
56
57 inventory.next_air::<WeierstrassAir<2, 2, 32>>()?;
58 let addne = WeierstrassAddNeChipGpu::<2, 32>::new(
59 range_checker.clone(),
60 bitwise_lu.clone(),
61 config.clone(),
62 start_offset,
63 pointer_max_bits as u32,
64 timestamp_max_bits as u32,
65 );
66 inventory.add_executor_chip(addne);
67
68 inventory.next_air::<WeierstrassAir<1, 2, 32>>()?;
69 let double = WeierstrassDoubleChipGpu::<2, 32>::new(
70 range_checker.clone(),
71 bitwise_lu.clone(),
72 config,
73 start_offset,
74 curve.a.clone(),
75 pointer_max_bits as u32,
76 timestamp_max_bits as u32,
77 );
78 inventory.add_executor_chip(double);
79 } else if bytes <= 48 {
80 let config = ExprBuilderConfig {
81 modulus: curve.modulus.clone(),
82 num_limbs: 48,
83 limb_bits: 8,
84 };
85
86 inventory.next_air::<WeierstrassAir<2, 6, 16>>()?;
87 let addne = WeierstrassAddNeChipGpu::<6, 16>::new(
88 range_checker.clone(),
89 bitwise_lu.clone(),
90 config.clone(),
91 start_offset,
92 pointer_max_bits as u32,
93 timestamp_max_bits as u32,
94 );
95 inventory.add_executor_chip(addne);
96
97 inventory.next_air::<WeierstrassAir<1, 6, 16>>()?;
98 let double = WeierstrassDoubleChipGpu::<6, 16>::new(
99 range_checker.clone(),
100 bitwise_lu.clone(),
101 config,
102 start_offset,
103 curve.a.clone(),
104 pointer_max_bits as u32,
105 timestamp_max_bits as u32,
106 );
107 inventory.add_executor_chip(double);
108 } else {
109 panic!("Modulus too large");
110 }
111 }
112
113 Ok(())
114 }
115}
116
117#[derive(Clone)]
118pub struct Rv32WeierstrassGpuBuilder;
119
120type E = GpuBabyBearPoseidon2Engine;
121
122impl VmBuilder<E> for Rv32WeierstrassGpuBuilder {
123 type VmConfig = Rv32WeierstrassConfig;
124 type SystemChipInventory = SystemChipInventoryGPU;
125 type RecordArena = DenseRecordArena;
126
127 fn create_chip_complex(
128 &self,
129 config: &Rv32WeierstrassConfig,
130 circuit: AirInventory<BabyBearPoseidon2Config>,
131 ) -> Result<
132 VmChipComplex<
133 BabyBearPoseidon2Config,
134 Self::RecordArena,
135 GpuBackend,
136 Self::SystemChipInventory,
137 >,
138 ChipInventoryError,
139 > {
140 let mut chip_complex =
141 VmBuilder::<E>::create_chip_complex(&Rv32ModularGpuBuilder, &config.modular, circuit)?;
142 let inventory = &mut chip_complex.inventory;
143 VmProverExtension::<E, _, _>::extend_prover(
144 &EccGpuProverExt,
145 &config.weierstrass,
146 inventory,
147 )?;
148
149 Ok(chip_complex)
150 }
151}