openvm_circuit/arch/testing/
cuda.rs

1use std::sync::Arc;
2
3use openvm_circuit_primitives::{
4    bitwise_op_lookup::{
5        BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
6        BitwiseOperationLookupChipGPU, SharedBitwiseOperationLookupChip,
7    },
8    range_tuple::{
9        RangeTupleCheckerAir, RangeTupleCheckerBus, RangeTupleCheckerChip,
10        RangeTupleCheckerChipGPU, SharedRangeTupleCheckerChip,
11    },
12    var_range::{
13        SharedVariableRangeCheckerChip, VariableRangeCheckerAir, VariableRangeCheckerBus,
14        VariableRangeCheckerChip, VariableRangeCheckerChipGPU,
15    },
16};
17use openvm_cuda_backend::{
18    data_transporter::assert_eq_host_and_device_matrix,
19    engine::GpuBabyBearPoseidon2Engine,
20    prover_backend::GpuBackend,
21    types::{F, SC},
22};
23use openvm_instructions::{program::PC_BITS, riscv::RV32_REGISTER_AS};
24use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir};
25use openvm_stark_backend::{
26    config::Val,
27    interaction::{LookupBus, PermutationCheckBus},
28    p3_air::BaseAir,
29    p3_field::{PrimeCharacteristicRing, PrimeField32},
30    prover::{cpu::CpuBackend, types::AirProvingContext},
31    rap::AnyRap,
32    utils::disable_debug_builder,
33    verifier::VerificationError,
34    AirRef, Chip,
35};
36use openvm_stark_sdk::{
37    config::{setup_tracing_with_log_level, FriParameters},
38    engine::{StarkFriEngine, VerificationDataWithFriParams},
39};
40use rand::{rngs::StdRng, Rng, SeedableRng};
41use tracing::Level;
42
43#[cfg(feature = "metrics")]
44use crate::metrics::VmMetrics;
45use crate::{
46    arch::{
47        instructions::instruction::Instruction,
48        testing::{
49            default_tracing_memory, default_var_range_checker_bus, dummy_memory_helper,
50            execution::{air::ExecutionDummyAir, DeviceExecutionTester},
51            memory::DeviceMemoryTester,
52            program::{air::ProgramDummyAir, DeviceProgramTester},
53            TestBuilder, TestChipHarness, EXECUTION_BUS, MEMORY_BUS, MEMORY_MERKLE_BUS,
54            POSEIDON2_DIRECT_BUS, READ_INSTRUCTION_BUS,
55        },
56        Arena, DenseRecordArena, ExecutionBridge, ExecutionBus, ExecutionState, MatrixRecordArena,
57        MemoryConfig, PreflightExecutor, Streams, VmStateMut,
58    },
59    system::{
60        cuda::{poseidon2::Poseidon2PeripheryChipGPU, DIGEST_WIDTH},
61        memory::{
62            offline_checker::{MemoryBridge, MemoryBus},
63            MemoryAirInventory, SharedMemoryHelper,
64        },
65        poseidon2::air::Poseidon2PeripheryAir,
66        program::ProgramBus,
67        SystemPort,
68    },
69    utils::next_power_of_two_or_zero,
70};
71
72pub struct GpuTestChipHarness<F, Executor, AIR, GpuChip, CpuChip> {
73    pub executor: Executor,
74    pub air: AIR,
75    pub gpu_chip: GpuChip,
76    pub cpu_chip: CpuChip,
77    pub dense_arena: DenseRecordArena,
78    pub matrix_arena: MatrixRecordArena<F>,
79}
80
81impl<F, Executor, AIR, GpuChip, CpuChip> GpuTestChipHarness<F, Executor, AIR, GpuChip, CpuChip>
82where
83    F: PrimeField32,
84    AIR: BaseAir<F>,
85{
86    pub fn with_capacity(
87        executor: Executor,
88        air: AIR,
89        gpu_chip: GpuChip,
90        cpu_chip: CpuChip,
91        height: usize,
92    ) -> Self {
93        let width = air.width();
94        let height = next_power_of_two_or_zero(height);
95        let dense_arena = DenseRecordArena::with_capacity(height, width);
96        let matrix_arena = MatrixRecordArena::with_capacity(height, width);
97        Self {
98            executor,
99            air,
100            gpu_chip,
101            cpu_chip,
102            dense_arena,
103            matrix_arena,
104        }
105    }
106}
107
108impl TestBuilder<F> for GpuChipTestBuilder {
109    fn execute<E, RA>(&mut self, executor: &mut E, arena: &mut RA, instruction: &Instruction<F>)
110    where
111        E: PreflightExecutor<F, RA>,
112        RA: Arena,
113    {
114        let initial_pc = self.rng.random_range(0..(1 << PC_BITS));
115        self.execute_with_pc(executor, arena, instruction, initial_pc);
116    }
117
118    fn execute_with_pc<E, RA>(
119        &mut self,
120        executor: &mut E,
121        arena: &mut RA,
122        instruction: &Instruction<F>,
123        initial_pc: u32,
124    ) where
125        E: PreflightExecutor<F, RA>,
126        RA: Arena,
127    {
128        let initial_state = ExecutionState {
129            pc: initial_pc,
130            timestamp: self.memory.memory.timestamp(),
131        };
132        tracing::debug!("initial_timestamp={}", initial_state.timestamp);
133
134        let mut pc = initial_pc;
135        let state_mut = VmStateMut::new(
136            &mut pc,
137            &mut self.memory.memory,
138            &mut self.streams,
139            &mut self.rng,
140            &mut self.custom_pvs,
141            arena,
142            #[cfg(feature = "metrics")]
143            &mut self.metrics,
144        );
145
146        executor
147            .execute(state_mut, instruction)
148            .expect("Expected the execution not to fail");
149        let final_state = ExecutionState {
150            pc,
151            timestamp: self.memory.memory.timestamp(),
152        };
153
154        self.program.execute(instruction, &initial_state);
155        self.execution.execute(initial_state, final_state);
156    }
157
158    fn read_cell(&mut self, address_space: usize, pointer: usize) -> F {
159        self.read::<1>(address_space, pointer)[0]
160    }
161
162    fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) {
163        self.write(address_space, pointer, [value]);
164    }
165
166    fn read<const N: usize>(&mut self, address_space: usize, pointer: usize) -> [F; N] {
167        self.memory.read(address_space, pointer)
168    }
169
170    fn write<const N: usize>(&mut self, address_space: usize, pointer: usize, value: [F; N]) {
171        self.memory.write(address_space, pointer, value);
172    }
173
174    fn write_usize<const N: usize>(
175        &mut self,
176        address_space: usize,
177        pointer: usize,
178        value: [usize; N],
179    ) {
180        self.write(address_space, pointer, value.map(F::from_usize));
181    }
182
183    fn address_bits(&self) -> usize {
184        self.memory.config.pointer_max_bits
185    }
186
187    fn last_to_pc(&self) -> F {
188        self.execution.0.last_to_pc()
189    }
190
191    fn last_from_pc(&self) -> F {
192        self.execution.0.last_from_pc()
193    }
194
195    fn execution_final_state(&self) -> ExecutionState<F> {
196        self.execution.0.records.last().unwrap().final_state
197    }
198
199    fn streams_mut(&mut self) -> &mut Streams<F> {
200        &mut self.streams
201    }
202
203    fn get_default_register(&mut self, increment: usize) -> usize {
204        self.default_register += increment;
205        self.default_register - increment
206    }
207
208    fn get_default_pointer(&mut self, increment: usize) -> usize {
209        self.default_pointer += increment;
210        self.default_pointer - increment
211    }
212
213    fn write_heap_pointer_default(
214        &mut self,
215        reg_increment: usize,
216        pointer_increment: usize,
217    ) -> (usize, usize) {
218        let register = self.get_default_register(reg_increment);
219        let pointer = self.get_default_pointer(pointer_increment);
220        self.write(1, register, pointer.to_le_bytes().map(F::from_u8));
221        (register, pointer)
222    }
223
224    fn write_heap_default<const NUM_LIMBS: usize>(
225        &mut self,
226        reg_increment: usize,
227        pointer_increment: usize,
228        writes: Vec<[F; NUM_LIMBS]>,
229    ) -> (usize, usize) {
230        let register = self.get_default_register(reg_increment);
231        let pointer = self.get_default_pointer(pointer_increment);
232        self.write_heap(register, pointer, writes);
233        (register, pointer)
234    }
235}
236
237pub struct GpuChipTestBuilder {
238    pub memory: DeviceMemoryTester,
239    pub execution: DeviceExecutionTester,
240    pub program: DeviceProgramTester,
241    pub streams: Streams<F>,
242
243    var_range_checker: Arc<VariableRangeCheckerChipGPU>,
244    bitwise_op_lookup: Option<Arc<BitwiseOperationLookupChipGPU<8>>>,
245    range_tuple_checker: Option<Arc<RangeTupleCheckerChipGPU<2>>>,
246
247    rng: StdRng,
248    pub custom_pvs: Vec<Option<F>>,
249    default_register: usize,
250    default_pointer: usize,
251    #[cfg(feature = "metrics")]
252    metrics: VmMetrics,
253}
254
255impl Default for GpuChipTestBuilder {
256    fn default() -> Self {
257        let mut mem_config = MemoryConfig::default();
258        // Currently tests still use gen_pointer for the full 1<<29 range of address space 1.
259        mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
260        Self::volatile(mem_config, default_var_range_checker_bus())
261    }
262}
263
264impl GpuChipTestBuilder {
265    pub fn new() -> Self {
266        Self::default()
267    }
268
269    pub fn new_persistent() -> Self {
270        let mut mem_config = MemoryConfig::default();
271        // Currently tests still use gen_pointer for the full 1<<29 range of address space 1.
272        mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
273        Self::persistent(mem_config, default_var_range_checker_bus())
274    }
275
276    pub fn volatile(mem_config: MemoryConfig, bus: VariableRangeCheckerBus) -> Self {
277        setup_tracing_with_log_level(Level::INFO);
278        let mem_bus = MemoryBus::new(MEMORY_BUS);
279        let range_checker = Arc::new(VariableRangeCheckerChipGPU::hybrid(Arc::new(
280            VariableRangeCheckerChip::new(bus),
281        )));
282        Self {
283            memory: DeviceMemoryTester::volatile(
284                default_tracing_memory(&mem_config, 1),
285                mem_bus,
286                mem_config,
287                range_checker.clone(),
288            ),
289            execution: DeviceExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
290            program: DeviceProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
291            streams: Default::default(),
292            var_range_checker: range_checker,
293            bitwise_op_lookup: None,
294            range_tuple_checker: None,
295            rng: StdRng::seed_from_u64(0),
296            custom_pvs: Vec::new(),
297            default_register: 0,
298            default_pointer: 0,
299            #[cfg(feature = "metrics")]
300            metrics: VmMetrics::default(),
301        }
302    }
303
304    pub fn persistent(mem_config: MemoryConfig, bus: VariableRangeCheckerBus) -> Self {
305        setup_tracing_with_log_level(Level::INFO);
306        let mem_bus = MemoryBus::new(MEMORY_BUS);
307        let range_checker = Arc::new(VariableRangeCheckerChipGPU::hybrid(Arc::new(
308            VariableRangeCheckerChip::new(bus),
309        )));
310        Self {
311            memory: DeviceMemoryTester::persistent(
312                default_tracing_memory(&mem_config, DIGEST_WIDTH),
313                mem_bus,
314                mem_config,
315                range_checker.clone(),
316            ),
317            execution: DeviceExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
318            program: DeviceProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
319            streams: Default::default(),
320            var_range_checker: range_checker,
321            bitwise_op_lookup: None,
322            range_tuple_checker: None,
323            rng: StdRng::seed_from_u64(0),
324            custom_pvs: Vec::new(),
325            default_register: 0,
326            default_pointer: 0,
327            #[cfg(feature = "metrics")]
328            metrics: VmMetrics::default(),
329        }
330    }
331
332    pub fn with_bitwise_op_lookup(mut self, bus: BitwiseOperationLookupBus) -> Self {
333        self.bitwise_op_lookup = Some(Arc::new(BitwiseOperationLookupChipGPU::hybrid(Arc::new(
334            BitwiseOperationLookupChip::new(bus),
335        ))));
336        self
337    }
338
339    pub fn with_range_tuple_checker(mut self, bus: RangeTupleCheckerBus<2>) -> Self {
340        self.range_tuple_checker = Some(Arc::new(RangeTupleCheckerChipGPU::hybrid(Arc::new(
341            RangeTupleCheckerChip::new(bus),
342        ))));
343        self
344    }
345
346    pub fn execute_harness<E, A, C, RA: Arena>(
347        &mut self,
348        harness: &mut TestChipHarness<F, E, A, C, RA>,
349        instruction: &Instruction<F>,
350    ) where
351        E: PreflightExecutor<F, RA>,
352    {
353        self.execute(&mut harness.executor, &mut harness.arena, instruction);
354    }
355
356    pub fn execute_with_pc_harness<E, A, C, RA: Arena>(
357        &mut self,
358        harness: &mut TestChipHarness<F, E, A, C, RA>,
359        instruction: &Instruction<F>,
360        initial_pc: u32,
361    ) where
362        E: PreflightExecutor<F, RA>,
363    {
364        self.execute_with_pc(
365            &mut harness.executor,
366            &mut harness.arena,
367            instruction,
368            initial_pc,
369        );
370    }
371
372    pub fn write_heap<const NUM_LIMBS: usize>(
373        &mut self,
374        register: usize,
375        pointer: usize,
376        writes: Vec<[F; NUM_LIMBS]>,
377    ) {
378        self.write(1usize, register, pointer.to_le_bytes().map(F::from_u8));
379        if NUM_LIMBS.is_power_of_two() {
380            for (i, &write) in writes.iter().enumerate() {
381                self.write(2usize, pointer + i * NUM_LIMBS, write);
382            }
383        } else {
384            for (i, &write) in writes.iter().enumerate() {
385                let ptr = pointer + i * NUM_LIMBS;
386                for j in (0..NUM_LIMBS).step_by(4) {
387                    self.write::<4>(2usize, ptr + j, write[j..j + 4].try_into().unwrap());
388                }
389            }
390        }
391    }
392
393    pub fn system_port(&self) -> SystemPort {
394        SystemPort {
395            execution_bus: self.execution_bus(),
396            program_bus: self.program_bus(),
397            memory_bridge: self.memory_bridge(),
398        }
399    }
400    pub fn execution_bridge(&self) -> ExecutionBridge {
401        ExecutionBridge::new(self.execution.bus(), self.program.bus())
402    }
403
404    pub fn memory_bridge(&self) -> MemoryBridge {
405        self.memory.memory_bridge()
406    }
407
408    pub fn execution_bus(&self) -> ExecutionBus {
409        self.execution.bus()
410    }
411
412    pub fn program_bus(&self) -> ProgramBus {
413        self.program.bus()
414    }
415
416    pub fn memory_bus(&self) -> MemoryBus {
417        self.memory.mem_bus
418    }
419
420    pub fn rng(&mut self) -> &mut StdRng {
421        &mut self.rng
422    }
423
424    pub fn range_checker(&self) -> Arc<VariableRangeCheckerChipGPU> {
425        self.var_range_checker.clone()
426    }
427
428    pub fn bitwise_op_lookup(&self) -> Arc<BitwiseOperationLookupChipGPU<8>> {
429        self.bitwise_op_lookup
430            .clone()
431            .expect("Initialize GpuChipTestBuilder with .with_bitwise_op_lookup()")
432    }
433
434    pub fn range_tuple_checker(&self) -> Arc<RangeTupleCheckerChipGPU<2>> {
435        self.range_tuple_checker
436            .clone()
437            .expect("Initialize GpuChipTestBuilder with .with_range_tuple_checker()")
438    }
439
440    // WARNING: This CPU chip is meant for hybrid chip use, its usage WILL
441    // result in altered tracegen. For a dummy primitive chip for trace
442    // comparison, see utils::dummy_range_checker.
443    pub fn cpu_range_checker(&self) -> SharedVariableRangeCheckerChip {
444        self.var_range_checker.cpu_chip.clone().unwrap()
445    }
446
447    // WARNING: This CPU chip is meant for hybrid chip use, its usage WILL
448    // result in altered tracegen. For a dummy primitive chip for trace
449    // comparison, see utils::dummy_bitwise_op_lookup.
450    pub fn cpu_bitwise_op_lookup(&self) -> SharedBitwiseOperationLookupChip<8> {
451        self.bitwise_op_lookup
452            .as_ref()
453            .expect("Initialize GpuChipTestBuilder with .with_bitwise_op_lookup()")
454            .cpu_chip
455            .clone()
456            .unwrap()
457    }
458
459    // WARNING: This CPU chip is meant for hybrid chip use, its usage WILL
460    // result in altered tracegen. For a dummy primitive chip for trace
461    // comparison, see utils::dummy_range_tuple_checker.
462    pub fn cpu_range_tuple_checker(&self) -> SharedRangeTupleCheckerChip<2> {
463        self.range_tuple_checker
464            .as_ref()
465            .expect("Initialize GpuChipTestBuilder with .with_range_tuple_checker()")
466            .cpu_chip
467            .clone()
468            .unwrap()
469    }
470
471    // WARNING: This utility is meant for hybrid chip use, its usage WILL
472    // result in altered tracegen. For use during trace comparison, see
473    // utils::dummy_memory_helper.
474    pub fn cpu_memory_helper(&self) -> SharedMemoryHelper<F> {
475        SharedMemoryHelper::new(
476            self.cpu_range_checker(),
477            self.memory.config.timestamp_max_bits,
478        )
479    }
480
481    // See [cpu_memory_helper]. Use this utility for creation of CPU chips that
482    // are meant for tracegen comparison purposes which should not update other
483    // periphery chips (e.g., range checker).
484    pub fn dummy_memory_helper(&self) -> SharedMemoryHelper<F> {
485        dummy_memory_helper(self.cpu_range_checker().bus(), self.timestamp_max_bits())
486    }
487
488    pub fn timestamp_max_bits(&self) -> usize {
489        self.memory.config.timestamp_max_bits
490    }
491
492    pub fn build(self) -> GpuChipTester {
493        GpuChipTester {
494            var_range_checker: Some(self.var_range_checker),
495            bitwise_op_lookup: self.bitwise_op_lookup,
496            range_tuple_checker: self.range_tuple_checker,
497            memory: Some(self.memory),
498            ..Default::default()
499        }
500        .load(
501            ExecutionDummyAir::new(self.execution.bus()),
502            self.execution,
503            (),
504        )
505        .load(ProgramDummyAir::new(self.program.bus()), self.program, ())
506    }
507}
508
509#[derive(Default)]
510pub struct GpuChipTester {
511    pub airs: Vec<AirRef<SC>>,
512    pub ctxs: Vec<AirProvingContext<GpuBackend>>,
513    pub memory: Option<DeviceMemoryTester>,
514    pub var_range_checker: Option<Arc<VariableRangeCheckerChipGPU>>,
515    pub bitwise_op_lookup: Option<Arc<BitwiseOperationLookupChipGPU<8>>>,
516    pub range_tuple_checker: Option<Arc<RangeTupleCheckerChipGPU<2>>>,
517}
518
519impl GpuChipTester {
520    pub fn load<A, G, RA>(mut self, air: A, gpu_chip: G, gpu_arena: RA) -> Self
521    where
522        A: AnyRap<SC> + 'static,
523        G: Chip<RA, GpuBackend>,
524    {
525        let proving_ctx = gpu_chip.generate_proving_ctx(gpu_arena);
526        if proving_ctx.common_main.is_some() {
527            self = self.load_air_proving_ctx(Arc::new(air) as AirRef<SC>, proving_ctx);
528        }
529        self
530    }
531
532    pub fn load_harness<E, A, G, RA>(self, harness: TestChipHarness<F, E, A, G, RA>) -> Self
533    where
534        A: AnyRap<SC> + 'static,
535        G: Chip<RA, GpuBackend>,
536    {
537        self.load(harness.air, harness.chip, harness.arena)
538    }
539
540    pub fn load_periphery<A, G>(self, air: A, gpu_chip: G) -> Self
541    where
542        A: AnyRap<SC> + 'static,
543        G: Chip<(), GpuBackend>,
544    {
545        self.load(air, gpu_chip, ())
546    }
547
548    pub fn load_air_proving_ctx(
549        mut self,
550        air: AirRef<SC>,
551        proving_ctx: AirProvingContext<GpuBackend>,
552    ) -> Self {
553        #[cfg(feature = "touchemall")]
554        {
555            use openvm_cuda_backend::engine::check_trace_validity;
556
557            check_trace_validity(&proving_ctx, &air.name());
558        }
559        self.airs.push(air);
560        self.ctxs.push(proving_ctx);
561        self
562    }
563
564    pub fn load_and_compare<A, G, RA, C, CRA>(
565        mut self,
566        air: A,
567        gpu_chip: G,
568        gpu_arena: RA,
569        cpu_chip: C,
570        cpu_arena: CRA,
571    ) -> Self
572    where
573        A: AnyRap<SC> + 'static,
574        C: Chip<CRA, CpuBackend<SC>>,
575        G: Chip<RA, GpuBackend>,
576    {
577        let proving_ctx = gpu_chip.generate_proving_ctx(gpu_arena);
578        let expected_trace = cpu_chip.generate_proving_ctx(cpu_arena).common_main;
579        if proving_ctx.common_main.is_none() {
580            assert!(expected_trace.is_none());
581            return self;
582        }
583        #[cfg(feature = "touchemall")]
584        {
585            use openvm_cuda_backend::engine::check_trace_validity;
586
587            check_trace_validity(&proving_ctx, &air.name());
588        }
589        assert_eq_host_and_device_matrix(
590            expected_trace.unwrap(),
591            proving_ctx.common_main.as_ref().unwrap(),
592        );
593        self.airs.push(Arc::new(air) as AirRef<SC>);
594        self.ctxs.push(proving_ctx);
595        self
596    }
597
598    pub fn load_gpu_harness<E, A, GpuChip, CpuChip>(
599        self,
600        harness: GpuTestChipHarness<Val<SC>, E, A, GpuChip, CpuChip>,
601    ) -> Self
602    where
603        A: AnyRap<SC> + 'static,
604        CpuChip: Chip<MatrixRecordArena<Val<SC>>, CpuBackend<SC>>,
605        GpuChip: Chip<DenseRecordArena, GpuBackend>,
606    {
607        self.load_and_compare(
608            harness.air,
609            harness.gpu_chip,
610            harness.dense_arena,
611            harness.cpu_chip,
612            harness.matrix_arena,
613        )
614    }
615
616    pub fn finalize(mut self) -> Self {
617        if let Some(mut memory_tester) = self.memory.take() {
618            let is_persistent = memory_tester.inventory.continuation_enabled();
619            let touched_memory = memory_tester.memory.finalize::<F>(is_persistent);
620            let memory_bridge = memory_tester.memory_bridge();
621
622            for chip in memory_tester.chip_for_block.into_values() {
623                self = self.load_periphery(chip.0.air, chip);
624            }
625
626            let airs = MemoryAirInventory::<SC>::new(
627                memory_bridge,
628                &memory_tester.config,
629                memory_tester.range_bus,
630                is_persistent.then_some((
631                    PermutationCheckBus::new(MEMORY_MERKLE_BUS),
632                    PermutationCheckBus::new(POSEIDON2_DIRECT_BUS),
633                )),
634            )
635            .into_airs();
636            let ctxs = memory_tester
637                .inventory
638                .generate_proving_ctxs(memory_tester.memory.access_adapter_records, touched_memory);
639            for (air, ctx) in airs
640                .into_iter()
641                .zip(ctxs)
642                .filter(|(_, ctx)| ctx.common_main.is_some())
643            {
644                self = self.load_air_proving_ctx(air, ctx);
645            }
646
647            if let Some(hasher_chip) = memory_tester.hasher_chip {
648                let air: AirRef<SC> = match hasher_chip.as_ref() {
649                    Poseidon2PeripheryChipGPU::Register0(_) => {
650                        let config = Poseidon2Config::default();
651                        Arc::new(Poseidon2PeripheryAir::new(
652                            Arc::new(Poseidon2SubAir::<F, 0>::new(config.constants.into())),
653                            LookupBus::new(POSEIDON2_DIRECT_BUS),
654                        ))
655                    }
656                    Poseidon2PeripheryChipGPU::Register1(_) => {
657                        let config = Poseidon2Config::default();
658                        Arc::new(Poseidon2PeripheryAir::new(
659                            Arc::new(Poseidon2SubAir::<F, 1>::new(config.constants.into())),
660                            LookupBus::new(POSEIDON2_DIRECT_BUS),
661                        ))
662                    }
663                };
664                let ctx = hasher_chip.generate_proving_ctx(());
665                self = self.load_air_proving_ctx(air, ctx);
666            }
667        }
668        if let Some(var_range_checker) = self.var_range_checker.take() {
669            self = self.load_periphery(
670                VariableRangeCheckerAir::new(var_range_checker.cpu_chip.as_ref().unwrap().bus()),
671                var_range_checker,
672            );
673        }
674        if let Some(bitwise_op_lookup) = self.bitwise_op_lookup.take() {
675            self = self.load_periphery(
676                BitwiseOperationLookupAir::<8>::new(
677                    bitwise_op_lookup.cpu_chip.as_ref().unwrap().bus(),
678                ),
679                bitwise_op_lookup,
680            );
681        }
682        if let Some(range_tuple_checker) = self.range_tuple_checker.take() {
683            self = self.load_periphery(
684                RangeTupleCheckerAir {
685                    bus: *range_tuple_checker.cpu_chip.as_ref().unwrap().bus(),
686                },
687                range_tuple_checker,
688            );
689        }
690        self
691    }
692
693    pub fn test<P: Fn() -> GpuBabyBearPoseidon2Engine>(
694        self,
695        engine_provider: P,
696    ) -> Result<VerificationDataWithFriParams<SC>, VerificationError> {
697        engine_provider().run_test(self.airs, self.ctxs)
698    }
699
700    pub fn simple_test(self) -> Result<VerificationDataWithFriParams<SC>, VerificationError> {
701        self.test(|| GpuBabyBearPoseidon2Engine::new(FriParameters::new_for_testing(1)))
702    }
703
704    pub fn simple_test_with_expected_error(self, expected_error: VerificationError) {
705        disable_debug_builder();
706        let msg = format!(
707            "Expected verification to fail with {:?}, but it didn't",
708            &expected_error
709        );
710        let result = self.simple_test();
711        assert_eq!(result.err(), Some(expected_error), "{msg}");
712    }
713}