openvm_native_circuit/extension/
cuda.rs1use openvm_circuit::{
2 arch::{ChipInventory, ChipInventoryError, DenseRecordArena, VmProverExtension},
3 system::cuda::extensions::get_inventory_range_checker,
4};
5use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine;
6use openvm_native_compiler::BLOCK_LOAD_STORE_SIZE;
7use openvm_stark_sdk::{
8 config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear,
9};
10
11use crate::{
12 branch_eq::{NativeBranchEqAir, NativeBranchEqChipGpu},
13 castf::{CastFAir, CastFChipGpu},
14 field_arithmetic::{FieldArithmeticAir, FieldArithmeticChipGpu},
15 field_extension::{FieldExtensionAir, FieldExtensionChipGpu},
16 fri::{FriReducedOpeningAir, FriReducedOpeningChipGpu},
17 jal_rangecheck::{JalRangeCheckAir, JalRangeCheckGpu},
18 loadstore::{NativeLoadStoreAir, NativeLoadStoreChipGpu},
19 poseidon2::{air::NativePoseidon2Air, NativePoseidon2ChipGpu},
20 CastFExtension, GpuBackend, Native,
21};
22
23pub struct NativeGpuProverExt;
24impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Native>
27 for NativeGpuProverExt
28{
29 fn extend_prover(
30 &self,
31 _: &Native,
32 inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
33 ) -> Result<(), ChipInventoryError> {
34 let timestamp_max_bits = inventory.timestamp_max_bits();
35 let range_checker = get_inventory_range_checker(inventory);
36
37 inventory.next_air::<NativeLoadStoreAir<1>>()?;
40 let load_store =
41 NativeLoadStoreChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits);
42 inventory.add_executor_chip(load_store);
43
44 inventory.next_air::<NativeLoadStoreAir<BLOCK_LOAD_STORE_SIZE>>()?;
45 let block_load_store = NativeLoadStoreChipGpu::<BLOCK_LOAD_STORE_SIZE>::new(
46 range_checker.clone(),
47 timestamp_max_bits,
48 );
49 inventory.add_executor_chip(block_load_store);
50
51 inventory.next_air::<NativeBranchEqAir>()?;
52 let branch_eq = NativeBranchEqChipGpu::new(range_checker.clone(), timestamp_max_bits);
53
54 inventory.add_executor_chip(branch_eq);
55
56 inventory.next_air::<JalRangeCheckAir>()?;
57 let jal_rangecheck = JalRangeCheckGpu::new(range_checker.clone(), timestamp_max_bits);
58 inventory.add_executor_chip(jal_rangecheck);
59
60 inventory.next_air::<FieldArithmeticAir>()?;
61 let field_arithmetic =
62 FieldArithmeticChipGpu::new(range_checker.clone(), timestamp_max_bits);
63 inventory.add_executor_chip(field_arithmetic);
64
65 inventory.next_air::<FieldExtensionAir>()?;
66 let field_extension = FieldExtensionChipGpu::new(range_checker.clone(), timestamp_max_bits);
67 inventory.add_executor_chip(field_extension);
68
69 inventory.next_air::<FriReducedOpeningAir>()?;
70 let fri_reduced_opening =
71 FriReducedOpeningChipGpu::new(range_checker.clone(), timestamp_max_bits);
72 inventory.add_executor_chip(fri_reduced_opening);
73
74 inventory.next_air::<NativePoseidon2Air<BabyBear, 1>>()?;
75 let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits);
76 inventory.add_executor_chip(poseidon2);
77
78 Ok(())
79 }
80}
81
82impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, CastFExtension>
83 for NativeGpuProverExt
84{
85 fn extend_prover(
86 &self,
87 _: &CastFExtension,
88 inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
89 ) -> Result<(), ChipInventoryError> {
90 let timestamp_max_bits = inventory.timestamp_max_bits();
91 let range_checker = get_inventory_range_checker(inventory);
92
93 inventory.next_air::<CastFAir>()?;
94 let castf = CastFChipGpu::new(range_checker, timestamp_max_bits);
95 inventory.add_executor_chip(castf);
96
97 Ok(())
98 }
99}