openvm_ecc_circuit/weierstrass_chip/double/
mod.rs

1use std::{cell::RefCell, rc::Rc};
2
3use derive_more::derive::{Deref, DerefMut};
4use num_bigint::BigUint;
5use num_traits::One;
6use openvm_circuit::{
7    arch::*,
8    system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper},
9};
10use openvm_circuit_derive::PreflightExecutor;
11use openvm_circuit_primitives::{
12    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
13    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
14};
15use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
16use openvm_instructions::riscv::RV32_CELL_BITS;
17use openvm_mod_circuit_builder::{
18    ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldExpressionExecutor,
19    FieldExpressionFiller, FieldVariable,
20};
21use openvm_rv32_adapters::{
22    Rv32VecHeapAdapterAir, Rv32VecHeapAdapterExecutor, Rv32VecHeapAdapterFiller,
23};
24
25use super::{WeierstrassAir, WeierstrassChip};
26
27#[cfg(feature = "cuda")]
28mod cuda;
29mod execution;
30
31#[cfg(feature = "cuda")]
32pub use cuda::*;
33
34pub fn ec_double_ne_expr(
35    config: ExprBuilderConfig, // The coordinate field.
36    range_bus: VariableRangeCheckerBus,
37    a_biguint: BigUint,
38) -> FieldExpr {
39    config.check_valid();
40    let builder = ExprBuilder::new(config, range_bus.range_max_bits);
41    let builder = Rc::new(RefCell::new(builder));
42
43    let mut x1 = ExprBuilder::new_input(builder.clone());
44    let mut y1 = ExprBuilder::new_input(builder.clone());
45    let a = ExprBuilder::new_const(builder.clone(), a_biguint.clone());
46    let is_double_flag = (*builder).borrow_mut().new_flag();
47    // We need to prevent divide by zero when not double flag
48    // (equivalently, when it is the setup opcode)
49    let lambda_denom = FieldVariable::select(
50        is_double_flag,
51        &y1.int_mul(2),
52        &ExprBuilder::new_const(builder.clone(), BigUint::one()),
53    );
54    let mut lambda = (x1.square().int_mul(3) + a) / lambda_denom;
55    let mut x3 = lambda.square() - x1.int_mul(2);
56    x3.save_output();
57    let mut y3 = lambda * (x1 - x3.clone()) - y1;
58    y3.save_output();
59
60    let builder = (*builder).borrow().clone();
61    FieldExpr::new_with_setup_values(builder, range_bus, true, vec![a_biguint])
62}
63
64/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2.
65/// BLOCKS: how many blocks do we need to represent one input or output
66/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per
67/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2.
68#[derive(Clone, PreflightExecutor, Deref, DerefMut)]
69pub struct EcDoubleExecutor<const BLOCKS: usize, const BLOCK_SIZE: usize>(
70    FieldExpressionExecutor<Rv32VecHeapAdapterExecutor<1, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>>,
71);
72
73fn gen_base_expr(
74    config: ExprBuilderConfig,
75    range_checker_bus: VariableRangeCheckerBus,
76    a_biguint: BigUint,
77) -> (FieldExpr, Vec<usize>) {
78    let expr = ec_double_ne_expr(config, range_checker_bus, a_biguint);
79
80    let local_opcode_idx = vec![
81        Rv32WeierstrassOpcode::EC_DOUBLE as usize,
82        Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize,
83    ];
84
85    (expr, local_opcode_idx)
86}
87
88#[allow(clippy::too_many_arguments)]
89pub fn get_ec_double_air<const BLOCKS: usize, const BLOCK_SIZE: usize>(
90    exec_bridge: ExecutionBridge,
91    mem_bridge: MemoryBridge,
92    config: ExprBuilderConfig,
93    range_checker_bus: VariableRangeCheckerBus,
94    bitwise_lookup_bus: BitwiseOperationLookupBus,
95    pointer_max_bits: usize,
96    offset: usize,
97    a_biguint: BigUint,
98) -> WeierstrassAir<1, BLOCKS, BLOCK_SIZE> {
99    let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint);
100    WeierstrassAir::new(
101        Rv32VecHeapAdapterAir::new(
102            exec_bridge,
103            mem_bridge,
104            bitwise_lookup_bus,
105            pointer_max_bits,
106        ),
107        FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]),
108    )
109}
110
111pub fn get_ec_double_step<const BLOCKS: usize, const BLOCK_SIZE: usize>(
112    config: ExprBuilderConfig,
113    range_checker_bus: VariableRangeCheckerBus,
114    pointer_max_bits: usize,
115    offset: usize,
116    a_biguint: BigUint,
117) -> EcDoubleExecutor<BLOCKS, BLOCK_SIZE> {
118    let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint);
119    EcDoubleExecutor(FieldExpressionExecutor::new(
120        Rv32VecHeapAdapterExecutor::new(pointer_max_bits),
121        expr,
122        offset,
123        local_opcode_idx,
124        vec![],
125        "EcDouble",
126    ))
127}
128
129pub fn get_ec_double_chip<F, const BLOCKS: usize, const BLOCK_SIZE: usize>(
130    config: ExprBuilderConfig,
131    mem_helper: SharedMemoryHelper<F>,
132    range_checker: SharedVariableRangeCheckerChip,
133    bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
134    pointer_max_bits: usize,
135    a_biguint: BigUint,
136) -> WeierstrassChip<F, 1, BLOCKS, BLOCK_SIZE> {
137    let (expr, local_opcode_idx) = gen_base_expr(config, range_checker.bus(), a_biguint);
138    WeierstrassChip::new(
139        FieldExpressionFiller::new(
140            Rv32VecHeapAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip),
141            expr,
142            local_opcode_idx,
143            vec![],
144            range_checker,
145            true,
146        ),
147        mem_helper,
148    )
149}