openvm_algebra_circuit/fp2_chip/cuda/
addsub.rs1use std::sync::Arc;
2
3use derive_new::new;
4use openvm_algebra_transpiler::Fp2Opcode;
5use openvm_circuit::arch::{AdapterCoreLayout, DenseRecordArena, RecordSeeker};
6use openvm_circuit_primitives::{
7 bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
8};
9use openvm_cuda_backend::{chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F};
10use openvm_cuda_common::copy::MemCopyH2D;
11use openvm_instructions::riscv::RV32_CELL_BITS;
12use openvm_mod_circuit_builder::{
13 ExprBuilderConfig, FieldExpressionChipGPU, FieldExpressionCoreAir, FieldExpressionMetadata,
14};
15use openvm_rv32_adapters::{Rv32VecHeapAdapterCols, Rv32VecHeapAdapterExecutor};
16use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
17
18use crate::{fp2_chip::fp2_addsub_expr, AlgebraRecord};
19
20#[derive(new)]
21pub struct Fp2AddSubChipGpu<const BLOCKS: usize, const BLOCK_SIZE: usize> {
22 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
23 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
24 pub config: ExprBuilderConfig,
25 pub offset: usize,
26 pub pointer_max_bits: u32,
27 pub timestamp_max_bits: u32,
28}
29
30impl<const BLOCKS: usize, const BLOCK_SIZE: usize> Chip<DenseRecordArena, GpuBackend>
31 for Fp2AddSubChipGpu<BLOCKS, BLOCK_SIZE>
32{
33 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
34 let range_bus = self.range_checker.cpu_chip.as_ref().unwrap().bus();
35 let (expr, is_add_flag, is_sub_flag) = fp2_addsub_expr(self.config.clone(), range_bus);
36
37 let total_input_limbs = expr.builder.num_input * expr.canonical_num_limbs();
38 let layout = AdapterCoreLayout::with_metadata(FieldExpressionMetadata::<
39 F,
40 Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
41 >::new(total_input_limbs));
42
43 let record_size = RecordSeeker::<
44 DenseRecordArena,
45 AlgebraRecord<2, BLOCKS, BLOCK_SIZE>,
46 _,
47 >::get_aligned_record_size(&layout);
48
49 let records = arena.allocated();
50 if records.is_empty() {
51 return get_empty_air_proving_ctx::<GpuBackend>();
52 }
53 debug_assert_eq!(records.len() % record_size, 0);
54
55 let num_records = records.len() / record_size;
56
57 let local_opcode_idx = vec![
58 Fp2Opcode::ADD as usize,
59 Fp2Opcode::SUB as usize,
60 Fp2Opcode::SETUP_ADDSUB as usize,
61 ];
62 let opcode_flag_idx = vec![is_add_flag, is_sub_flag];
63
64 let air = FieldExpressionCoreAir::new(expr, self.offset, local_opcode_idx, opcode_flag_idx);
65
66 let adapter_width =
67 Rv32VecHeapAdapterCols::<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>::width();
68
69 let d_records = records.to_device().unwrap();
70
71 let field_expr_chip = FieldExpressionChipGPU::new(
72 air,
73 d_records,
74 num_records,
75 record_size,
76 adapter_width,
77 BLOCKS,
78 self.range_checker.clone(),
79 self.bitwise_lookup.clone(),
80 self.pointer_max_bits,
81 self.timestamp_max_bits,
82 );
83
84 let d_trace = field_expr_chip.generate_field_trace();
85
86 AirProvingContext::simple_no_pis(d_trace)
87 }
88}