openvm_bigint_circuit/extension/
cuda.rs

1use openvm_circuit::{
2    arch::DenseRecordArena,
3    system::cuda::{
4        extensions::{
5            get_inventory_range_checker, get_or_create_bitwise_op_lookup, SystemGpuBuilder,
6        },
7        SystemChipInventoryGPU,
8    },
9};
10use openvm_circuit_primitives::range_tuple::RangeTupleCheckerChipGPU;
11use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend};
12use openvm_rv32im_circuit::Rv32ImGpuProverExt;
13use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config;
14
15use super::*;
16
17pub struct Int256GpuProverExt;
18
19// This implementation is specific to GpuBackend because the lookup chips
20// (VariableRangeCheckerChipGPU, BitwiseOperationLookupChipGPU) are specific to GpuBackend.
21impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Int256>
22    for Int256GpuProverExt
23{
24    fn extend_prover(
25        &self,
26        extension: &Int256,
27        inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
28    ) -> Result<(), ChipInventoryError> {
29        let pointer_max_bits = inventory.airs().pointer_max_bits();
30        let timestamp_max_bits = inventory.timestamp_max_bits();
31
32        let range_checker = get_inventory_range_checker(inventory);
33        let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
34
35        let range_tuple_checker = {
36            let existing_chip = inventory
37                .find_chip::<Arc<RangeTupleCheckerChipGPU<2>>>()
38                .find(|c| {
39                    c.sizes[0] >= extension.range_tuple_checker_sizes[0]
40                        && c.sizes[1] >= extension.range_tuple_checker_sizes[1]
41                });
42            if let Some(chip) = existing_chip {
43                chip.clone()
44            } else {
45                inventory.next_air::<RangeTupleCheckerAir<2>>()?;
46                let chip = Arc::new(RangeTupleCheckerChipGPU::new(
47                    extension.range_tuple_checker_sizes,
48                ));
49                inventory.add_periphery_chip(chip.clone());
50                chip
51            }
52        };
53
54        inventory.next_air::<Rv32BaseAlu256Air>()?;
55        let base_alu = BaseAlu256ChipGpu::new(
56            range_checker.clone(),
57            bitwise_lu.clone(),
58            pointer_max_bits,
59            timestamp_max_bits,
60        );
61        inventory.add_executor_chip(base_alu);
62
63        inventory.next_air::<Rv32LessThan256Air>()?;
64        let lt = LessThan256ChipGpu::new(
65            range_checker.clone(),
66            bitwise_lu.clone(),
67            pointer_max_bits,
68            timestamp_max_bits,
69        );
70        inventory.add_executor_chip(lt);
71
72        inventory.next_air::<Rv32BranchEqual256Air>()?;
73        let beq = BranchEqual256ChipGpu::new(
74            range_checker.clone(),
75            bitwise_lu.clone(),
76            pointer_max_bits,
77            timestamp_max_bits,
78        );
79        inventory.add_executor_chip(beq);
80
81        inventory.next_air::<Rv32BranchLessThan256Air>()?;
82        let blt = BranchLessThan256ChipGpu::new(
83            range_checker.clone(),
84            bitwise_lu.clone(),
85            pointer_max_bits,
86            timestamp_max_bits,
87        );
88        inventory.add_executor_chip(blt);
89
90        inventory.next_air::<Rv32Multiplication256Air>()?;
91        let mult = Multiplication256ChipGpu::new(
92            range_checker.clone(),
93            bitwise_lu.clone(),
94            range_tuple_checker.clone(),
95            pointer_max_bits,
96            timestamp_max_bits,
97        );
98        inventory.add_executor_chip(mult);
99
100        inventory.next_air::<Rv32Shift256Air>()?;
101        let shift = Shift256ChipGpu::new(
102            range_checker.clone(),
103            bitwise_lu.clone(),
104            pointer_max_bits,
105            timestamp_max_bits,
106        );
107        inventory.add_executor_chip(shift);
108
109        Ok(())
110    }
111}
112
113#[derive(Clone)]
114pub struct Int256Rv32GpuBuilder;
115
116type E = GpuBabyBearPoseidon2Engine;
117
118impl VmBuilder<E> for Int256Rv32GpuBuilder {
119    type VmConfig = Int256Rv32Config;
120    type SystemChipInventory = SystemChipInventoryGPU;
121    type RecordArena = DenseRecordArena;
122
123    fn create_chip_complex(
124        &self,
125        config: &Int256Rv32Config,
126        circuit: AirInventory<<E as StarkEngine>::SC>,
127    ) -> Result<
128        VmChipComplex<
129            <E as StarkEngine>::SC,
130            Self::RecordArena,
131            <E as StarkEngine>::PB,
132            Self::SystemChipInventory,
133        >,
134        ChipInventoryError,
135    > {
136        let mut chip_complex =
137            VmBuilder::<E>::create_chip_complex(&SystemGpuBuilder, &config.system, circuit)?;
138        let inventory = &mut chip_complex.inventory;
139        VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.rv32i, inventory)?;
140        VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.rv32m, inventory)?;
141        VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.io, inventory)?;
142        VmProverExtension::<E, _, _>::extend_prover(
143            &Int256GpuProverExt,
144            &config.bigint,
145            inventory,
146        )?;
147        Ok(chip_complex)
148    }
149}