openvm_native_circuit/extension/
cuda.rs

1use openvm_circuit::{
2    arch::{ChipInventory, ChipInventoryError, DenseRecordArena, VmProverExtension},
3    system::cuda::extensions::get_inventory_range_checker,
4};
5use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine;
6use openvm_native_compiler::BLOCK_LOAD_STORE_SIZE;
7use openvm_stark_sdk::{
8    config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear,
9};
10
11use crate::{
12    branch_eq::{NativeBranchEqAir, NativeBranchEqChipGpu},
13    castf::{CastFAir, CastFChipGpu},
14    field_arithmetic::{FieldArithmeticAir, FieldArithmeticChipGpu},
15    field_extension::{FieldExtensionAir, FieldExtensionChipGpu},
16    fri::{FriReducedOpeningAir, FriReducedOpeningChipGpu},
17    jal_rangecheck::{JalRangeCheckAir, JalRangeCheckGpu},
18    loadstore::{NativeLoadStoreAir, NativeLoadStoreChipGpu},
19    poseidon2::{air::NativePoseidon2Air, NativePoseidon2ChipGpu},
20    CastFExtension, GpuBackend, Native,
21};
22
23pub struct NativeGpuProverExt;
24// This implementation is specific to GpuBackend because the lookup chips
25// (VariableRangeCheckerChipGPU, BitwiseOperationLookupChipGPU) are specific to GpuBackend.
26impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Native>
27    for NativeGpuProverExt
28{
29    fn extend_prover(
30        &self,
31        _: &Native,
32        inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
33    ) -> Result<(), ChipInventoryError> {
34        let timestamp_max_bits = inventory.timestamp_max_bits();
35        let range_checker = get_inventory_range_checker(inventory);
36
37        // These calls to next_air are not strictly necessary to construct the chips, but provide a
38        // safeguard to ensure that chip construction matches the circuit definition
39        inventory.next_air::<NativeLoadStoreAir<1>>()?;
40        let load_store =
41            NativeLoadStoreChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits);
42        inventory.add_executor_chip(load_store);
43
44        inventory.next_air::<NativeLoadStoreAir<BLOCK_LOAD_STORE_SIZE>>()?;
45        let block_load_store = NativeLoadStoreChipGpu::<BLOCK_LOAD_STORE_SIZE>::new(
46            range_checker.clone(),
47            timestamp_max_bits,
48        );
49        inventory.add_executor_chip(block_load_store);
50
51        inventory.next_air::<NativeBranchEqAir>()?;
52        let branch_eq = NativeBranchEqChipGpu::new(range_checker.clone(), timestamp_max_bits);
53
54        inventory.add_executor_chip(branch_eq);
55
56        inventory.next_air::<JalRangeCheckAir>()?;
57        let jal_rangecheck = JalRangeCheckGpu::new(range_checker.clone(), timestamp_max_bits);
58        inventory.add_executor_chip(jal_rangecheck);
59
60        inventory.next_air::<FieldArithmeticAir>()?;
61        let field_arithmetic =
62            FieldArithmeticChipGpu::new(range_checker.clone(), timestamp_max_bits);
63        inventory.add_executor_chip(field_arithmetic);
64
65        inventory.next_air::<FieldExtensionAir>()?;
66        let field_extension = FieldExtensionChipGpu::new(range_checker.clone(), timestamp_max_bits);
67        inventory.add_executor_chip(field_extension);
68
69        inventory.next_air::<FriReducedOpeningAir>()?;
70        let fri_reduced_opening =
71            FriReducedOpeningChipGpu::new(range_checker.clone(), timestamp_max_bits);
72        inventory.add_executor_chip(fri_reduced_opening);
73
74        inventory.next_air::<NativePoseidon2Air<BabyBear, 1>>()?;
75        let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits);
76        inventory.add_executor_chip(poseidon2);
77
78        Ok(())
79    }
80}
81
82impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, CastFExtension>
83    for NativeGpuProverExt
84{
85    fn extend_prover(
86        &self,
87        _: &CastFExtension,
88        inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
89    ) -> Result<(), ChipInventoryError> {
90        let timestamp_max_bits = inventory.timestamp_max_bits();
91        let range_checker = get_inventory_range_checker(inventory);
92
93        inventory.next_air::<CastFAir>()?;
94        let castf = CastFChipGpu::new(range_checker, timestamp_max_bits);
95        inventory.add_executor_chip(castf);
96
97        Ok(())
98    }
99}