openvm_ecc_circuit/weierstrass_chip/double/
execution.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4    mem::size_of,
5};
6
7use num_bigint::BigUint;
8use openvm_circuit::{
9    arch::*,
10    system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::{
15    instruction::Instruction,
16    program::DEFAULT_PC_STEP,
17    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
18};
19use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
20use openvm_stark_backend::p3_field::PrimeField32;
21
22use super::EcDoubleExecutor;
23use crate::weierstrass_chip::curves::{ec_double, get_curve_type, CurveType};
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct EcDoublePreCompute<'a> {
28    expr: &'a FieldExpr,
29    rs_addrs: [u8; 1],
30    a: u8,
31    flag_idx: u8,
32}
33
34impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcDoubleExecutor<BLOCKS, BLOCK_SIZE> {
35    fn pre_compute_impl<F: PrimeField32>(
36        &'a self,
37        pc: u32,
38        inst: &Instruction<F>,
39        data: &mut EcDoublePreCompute<'a>,
40    ) -> Result<bool, StaticProgramError> {
41        let Instruction {
42            opcode, a, b, d, e, ..
43        } = inst;
44
45        // Validate instruction format
46        let a = a.as_canonical_u32();
47        let b = b.as_canonical_u32();
48        let d = d.as_canonical_u32();
49        let e = e.as_canonical_u32();
50        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
51            return Err(StaticProgramError::InvalidInstruction(pc));
52        }
53
54        let local_opcode = opcode.local_opcode_idx(self.offset);
55
56        // Pre-compute flag_idx
57        let needs_setup = self.expr.needs_setup();
58        let mut flag_idx = self.expr.num_flags() as u8;
59        if needs_setup {
60            // Find which opcode this is in our local_opcode_idx list
61            if let Some(opcode_position) = self
62                .local_opcode_idx
63                .iter()
64                .position(|&idx| idx == local_opcode)
65            {
66                // If this is NOT the last opcode (setup), get the corresponding flag_idx
67                if opcode_position < self.opcode_flag_idx.len() {
68                    flag_idx = self.opcode_flag_idx[opcode_position] as u8;
69                }
70            }
71        }
72
73        let rs_addrs = [b as u8];
74        *data = EcDoublePreCompute {
75            expr: &self.expr,
76            rs_addrs,
77            a: a as u8,
78            flag_idx,
79        };
80
81        let local_opcode = opcode.local_opcode_idx(self.offset);
82        let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize;
83
84        Ok(is_setup)
85    }
86}
87
88macro_rules! dispatch {
89    ($execute_impl:ident,$pre_compute:ident,$is_setup:ident) => {
90        if let Some(curve_type) = {
91            let modulus = &$pre_compute.expr.builder.prime;
92            let a_coeff = &$pre_compute.expr.setup_values[0];
93            get_curve_type(modulus, a_coeff)
94        } {
95            match ($is_setup, curve_type) {
96                (true, CurveType::K256) => {
97                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, true>)
98                }
99                (true, CurveType::P256) => {
100                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, true>)
101                }
102                (true, CurveType::BN254) => {
103                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, true>)
104                }
105                (true, CurveType::BLS12_381) => Ok($execute_impl::<
106                    _,
107                    _,
108                    BLOCKS,
109                    BLOCK_SIZE,
110                    { CurveType::BLS12_381 as u8 },
111                    true,
112                >),
113                (false, CurveType::K256) => {
114                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, false>)
115                }
116                (false, CurveType::P256) => {
117                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, false>)
118                }
119                (false, CurveType::BN254) => {
120                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, false>)
121                }
122                (false, CurveType::BLS12_381) => Ok($execute_impl::<
123                    _,
124                    _,
125                    BLOCKS,
126                    BLOCK_SIZE,
127                    { CurveType::BLS12_381 as u8 },
128                    false,
129                >),
130            }
131        } else if $is_setup {
132            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
133        } else {
134            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
135        }
136    };
137}
138
139impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> Executor<F>
140    for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
141{
142    #[inline(always)]
143    fn pre_compute_size(&self) -> usize {
144        size_of::<EcDoublePreCompute>()
145    }
146
147    #[cfg(not(feature = "tco"))]
148    fn pre_compute<Ctx>(
149        &self,
150        pc: u32,
151        inst: &Instruction<F>,
152        data: &mut [u8],
153    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
154    where
155        Ctx: ExecutionCtxTrait,
156    {
157        let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
158        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
159
160        dispatch!(execute_e1_handler, pre_compute, is_setup)
161    }
162
163    #[cfg(feature = "tco")]
164    fn handler<Ctx>(
165        &self,
166        pc: u32,
167        inst: &Instruction<F>,
168        data: &mut [u8],
169    ) -> Result<Handler<F, Ctx>, StaticProgramError>
170    where
171        Ctx: ExecutionCtxTrait,
172    {
173        let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
174        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
175
176        dispatch!(execute_e1_handler, pre_compute, is_setup)
177    }
178}
179
180impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> MeteredExecutor<F>
181    for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
182{
183    #[inline(always)]
184    fn metered_pre_compute_size(&self) -> usize {
185        size_of::<E2PreCompute<EcDoublePreCompute>>()
186    }
187
188    #[cfg(not(feature = "tco"))]
189    fn metered_pre_compute<Ctx>(
190        &self,
191        chip_idx: usize,
192        pc: u32,
193        inst: &Instruction<F>,
194        data: &mut [u8],
195    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
196    where
197        Ctx: MeteredExecutionCtxTrait,
198    {
199        let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
200        pre_compute.chip_idx = chip_idx as u32;
201        let pre_compute_pure = &mut pre_compute.data;
202        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
203
204        dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
205    }
206
207    #[cfg(feature = "tco")]
208    fn metered_handler<Ctx>(
209        &self,
210        chip_idx: usize,
211        pc: u32,
212        inst: &Instruction<F>,
213        data: &mut [u8],
214    ) -> Result<Handler<F, Ctx>, StaticProgramError>
215    where
216        Ctx: MeteredExecutionCtxTrait,
217    {
218        let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
219        pre_compute.chip_idx = chip_idx as u32;
220        let pre_compute_pure = &mut pre_compute.data;
221        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
222
223        dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
224    }
225}
226
227#[inline(always)]
228unsafe fn execute_e12_impl<
229    F: PrimeField32,
230    CTX: ExecutionCtxTrait,
231    const BLOCKS: usize,
232    const BLOCK_SIZE: usize,
233    const CURVE_TYPE: u8,
234    const IS_SETUP: bool,
235>(
236    pre_compute: &EcDoublePreCompute,
237    instret: &mut u64,
238    pc: &mut u32,
239    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
240) -> Result<(), ExecutionError> {
241    // Read register values
242    let rs_vals = pre_compute
243        .rs_addrs
244        .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
245
246    // Read memory values for the point
247    let read_data: [[u8; BLOCK_SIZE]; BLOCKS] = {
248        let address = rs_vals[0];
249        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
250        from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
251    };
252
253    if IS_SETUP {
254        let input_prime = BigUint::from_bytes_le(read_data[..BLOCKS / 2].as_flattened());
255
256        if input_prime != pre_compute.expr.builder.prime {
257            let err = ExecutionError::Fail {
258                pc: *pc,
259                msg: "EcDouble: mismatched prime",
260            };
261            return Err(err);
262        }
263
264        // Extract second field element as the a coefficient
265        let input_a = BigUint::from_bytes_le(read_data[BLOCKS / 2..].as_flattened());
266        let coeff_a = &pre_compute.expr.setup_values[0];
267        if input_a != *coeff_a {
268            let err = ExecutionError::Fail {
269                pc: *pc,
270                msg: "EcDouble: mismatched coeff_a",
271            };
272            return Err(err);
273        }
274    }
275
276    let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP {
277        let read_data: DynArray<u8> = read_data.into();
278        run_field_expression_precomputed::<true>(
279            pre_compute.expr,
280            pre_compute.flag_idx as usize,
281            &read_data.0,
282        )
283        .into()
284    } else {
285        ec_double::<CURVE_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
286    };
287
288    let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
289    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
290
291    // Write output data to memory
292    for (i, block) in output_data.into_iter().enumerate() {
293        exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
294    }
295
296    *pc = pc.wrapping_add(DEFAULT_PC_STEP);
297    *instret += 1;
298
299    Ok(())
300}
301
302#[create_handler]
303#[inline(always)]
304unsafe fn execute_e1_impl<
305    F: PrimeField32,
306    CTX: ExecutionCtxTrait,
307    const BLOCKS: usize,
308    const BLOCK_SIZE: usize,
309    const CURVE_TYPE: u8,
310    const IS_SETUP: bool,
311>(
312    pre_compute: &[u8],
313    instret: &mut u64,
314    pc: &mut u32,
315    _instret_end: u64,
316    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
317) -> Result<(), ExecutionError> {
318    let pre_compute: &EcDoublePreCompute = pre_compute.borrow();
319    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(
320        pre_compute,
321        instret,
322        pc,
323        exec_state,
324    )
325}
326
327#[create_handler]
328#[inline(always)]
329unsafe fn execute_e2_impl<
330    F: PrimeField32,
331    CTX: MeteredExecutionCtxTrait,
332    const BLOCKS: usize,
333    const BLOCK_SIZE: usize,
334    const CURVE_TYPE: u8,
335    const IS_SETUP: bool,
336>(
337    pre_compute: &[u8],
338    instret: &mut u64,
339    pc: &mut u32,
340    _arg: u64,
341    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
342) -> Result<(), ExecutionError> {
343    let e2_pre_compute: &E2PreCompute<EcDoublePreCompute> = pre_compute.borrow();
344    exec_state
345        .ctx
346        .on_height_change(e2_pre_compute.chip_idx as usize, 1);
347    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(
348        &e2_pre_compute.data,
349        instret,
350        pc,
351        exec_state,
352    )
353}