openvm_ecc_circuit/weierstrass_chip/double/
cuda.rs

1use std::sync::Arc;
2
3use derive_new::new;
4use num_bigint::BigUint;
5use openvm_circuit::arch::{AdapterCoreLayout, DenseRecordArena, RecordSeeker};
6use openvm_circuit_primitives::{
7    bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
8};
9use openvm_cuda_backend::{chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F};
10use openvm_cuda_common::copy::MemCopyH2D;
11use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
12use openvm_instructions::riscv::RV32_CELL_BITS;
13use openvm_mod_circuit_builder::{
14    ExprBuilderConfig, FieldExpressionChipGPU, FieldExpressionCoreAir, FieldExpressionMetadata,
15};
16use openvm_rv32_adapters::{Rv32VecHeapAdapterCols, Rv32VecHeapAdapterExecutor};
17use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
18
19use crate::{ec_double_ne_expr, EccRecord};
20
21#[derive(new)]
22pub struct WeierstrassDoubleChipGpu<const BLOCKS: usize, const BLOCK_SIZE: usize> {
23    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
24    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
25    pub config: ExprBuilderConfig,
26    pub offset: usize,
27    pub a_biguint: BigUint,
28    pub pointer_max_bits: u32,
29    pub timestamp_max_bits: u32,
30}
31
32impl<const BLOCKS: usize, const BLOCK_SIZE: usize> Chip<DenseRecordArena, GpuBackend>
33    for WeierstrassDoubleChipGpu<BLOCKS, BLOCK_SIZE>
34{
35    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
36        let range_bus = self.range_checker.cpu_chip.as_ref().unwrap().bus();
37        let expr = ec_double_ne_expr(self.config.clone(), range_bus, self.a_biguint.clone());
38
39        let total_input_limbs = expr.builder.num_input * expr.canonical_num_limbs();
40        let layout = AdapterCoreLayout::with_metadata(FieldExpressionMetadata::<
41            F,
42            Rv32VecHeapAdapterExecutor<1, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
43        >::new(total_input_limbs));
44
45        let record_size = RecordSeeker::<
46            DenseRecordArena,
47            EccRecord<1, BLOCKS, BLOCK_SIZE>,
48            _,
49        >::get_aligned_record_size(&layout);
50
51        let records = arena.allocated();
52        if records.is_empty() {
53            return get_empty_air_proving_ctx::<GpuBackend>();
54        }
55        debug_assert_eq!(records.len() % record_size, 0);
56
57        let num_records = records.len() / record_size;
58
59        let local_opcode_idx = vec![
60            Rv32WeierstrassOpcode::EC_DOUBLE as usize,
61            Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize,
62        ];
63
64        let air = FieldExpressionCoreAir::new(expr, self.offset, local_opcode_idx, vec![]);
65
66        let adapter_width =
67            Rv32VecHeapAdapterCols::<F, 1, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>::width();
68
69        let d_records = records.to_device().unwrap();
70
71        let field_expr_chip = FieldExpressionChipGPU::new(
72            air,
73            d_records,
74            num_records,
75            record_size,
76            adapter_width,
77            BLOCKS,
78            self.range_checker.clone(),
79            self.bitwise_lookup.clone(),
80            self.pointer_max_bits,
81            self.timestamp_max_bits,
82        );
83
84        let d_trace = field_expr_chip.generate_field_trace();
85
86        AirProvingContext::simple_no_pis(d_trace)
87    }
88}