openvm_ecc_circuit/weierstrass_chip/add_ne/
cuda.rs

1use std::sync::Arc;
2
3use derive_new::new;
4use openvm_circuit::arch::{AdapterCoreLayout, DenseRecordArena, RecordSeeker};
5use openvm_circuit_primitives::{
6    bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
7};
8use openvm_cuda_backend::{chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F};
9use openvm_cuda_common::copy::MemCopyH2D;
10use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
11use openvm_instructions::riscv::RV32_CELL_BITS;
12use openvm_mod_circuit_builder::{
13    ExprBuilderConfig, FieldExpressionChipGPU, FieldExpressionCoreAir, FieldExpressionMetadata,
14};
15use openvm_rv32_adapters::{Rv32VecHeapAdapterCols, Rv32VecHeapAdapterExecutor};
16use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
17
18use crate::{ec_add_ne_expr, EccRecord};
19
20#[derive(new)]
21pub struct WeierstrassAddNeChipGpu<const BLOCKS: usize, const BLOCK_SIZE: usize> {
22    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
23    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
24    pub config: ExprBuilderConfig,
25    pub offset: usize,
26    pub pointer_max_bits: u32,
27    pub timestamp_max_bits: u32,
28}
29
30impl<const BLOCKS: usize, const BLOCK_SIZE: usize> Chip<DenseRecordArena, GpuBackend>
31    for WeierstrassAddNeChipGpu<BLOCKS, BLOCK_SIZE>
32{
33    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
34        let range_bus = self.range_checker.cpu_chip.as_ref().unwrap().bus();
35        let expr = ec_add_ne_expr(self.config.clone(), range_bus);
36
37        let total_input_limbs = expr.builder.num_input * expr.canonical_num_limbs();
38        let layout = AdapterCoreLayout::with_metadata(FieldExpressionMetadata::<
39            F,
40            Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
41        >::new(total_input_limbs));
42
43        let record_size = RecordSeeker::<
44            DenseRecordArena,
45            EccRecord<2, BLOCKS, BLOCK_SIZE>,
46            _,
47        >::get_aligned_record_size(&layout);
48
49        let records = arena.allocated();
50        if records.is_empty() {
51            return get_empty_air_proving_ctx::<GpuBackend>();
52        }
53        debug_assert_eq!(records.len() % record_size, 0);
54
55        let num_records = records.len() / record_size;
56
57        let local_opcode_idx = vec![
58            Rv32WeierstrassOpcode::EC_ADD_NE as usize,
59            Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize,
60        ];
61
62        let air = FieldExpressionCoreAir::new(expr, self.offset, local_opcode_idx, vec![]);
63
64        let adapter_width =
65            Rv32VecHeapAdapterCols::<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>::width();
66
67        let d_records = records.to_device().unwrap();
68
69        let field_expr_chip = FieldExpressionChipGPU::new(
70            air,
71            d_records,
72            num_records,
73            record_size,
74            adapter_width,
75            BLOCKS,
76            self.range_checker.clone(),
77            self.bitwise_lookup.clone(),
78            self.pointer_max_bits,
79            self.timestamp_max_bits,
80        );
81
82        let d_trace = field_expr_chip.generate_field_trace();
83
84        AirProvingContext::simple_no_pis(d_trace)
85    }
86}