openvm_ecc_circuit/weierstrass_chip/double/
execution.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4    mem::size_of,
5};
6
7use num_bigint::BigUint;
8use openvm_circuit::{
9    arch::*,
10    system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::{
15    instruction::Instruction,
16    program::DEFAULT_PC_STEP,
17    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
18};
19use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
20use openvm_stark_backend::p3_field::PrimeField32;
21
22use super::EcDoubleExecutor;
23use crate::weierstrass_chip::curves::{ec_double, get_curve_type, CurveType};
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct EcDoublePreCompute<'a> {
28    expr: &'a FieldExpr,
29    rs_addrs: [u8; 1],
30    a: u8,
31    flag_idx: u8,
32}
33
34impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcDoubleExecutor<BLOCKS, BLOCK_SIZE> {
35    fn pre_compute_impl<F: PrimeField32>(
36        &'a self,
37        pc: u32,
38        inst: &Instruction<F>,
39        data: &mut EcDoublePreCompute<'a>,
40    ) -> Result<bool, StaticProgramError> {
41        let Instruction {
42            opcode, a, b, d, e, ..
43        } = inst;
44
45        // Validate instruction format
46        let a = a.as_canonical_u32();
47        let b = b.as_canonical_u32();
48        let d = d.as_canonical_u32();
49        let e = e.as_canonical_u32();
50        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
51            return Err(StaticProgramError::InvalidInstruction(pc));
52        }
53
54        let local_opcode = opcode.local_opcode_idx(self.offset);
55
56        // Pre-compute flag_idx
57        let needs_setup = self.expr.needs_setup();
58        let mut flag_idx = self.expr.num_flags() as u8;
59        if needs_setup {
60            // Find which opcode this is in our local_opcode_idx list
61            if let Some(opcode_position) = self
62                .local_opcode_idx
63                .iter()
64                .position(|&idx| idx == local_opcode)
65            {
66                // If this is NOT the last opcode (setup), get the corresponding flag_idx
67                if opcode_position < self.opcode_flag_idx.len() {
68                    flag_idx = self.opcode_flag_idx[opcode_position] as u8;
69                }
70            }
71        }
72
73        let rs_addrs = [b as u8];
74        *data = EcDoublePreCompute {
75            expr: &self.expr,
76            rs_addrs,
77            a: a as u8,
78            flag_idx,
79        };
80
81        let local_opcode = opcode.local_opcode_idx(self.offset);
82        let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize;
83
84        Ok(is_setup)
85    }
86}
87
88macro_rules! dispatch {
89    ($execute_impl:ident,$pre_compute:ident,$is_setup:ident) => {
90        if let Some(curve_type) = {
91            let modulus = &$pre_compute.expr.builder.prime;
92            let a_coeff = &$pre_compute.expr.setup_values[0];
93            get_curve_type(modulus, a_coeff)
94        } {
95            match ($is_setup, curve_type) {
96                (true, CurveType::K256) => {
97                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, true>)
98                }
99                (true, CurveType::P256) => {
100                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, true>)
101                }
102                (true, CurveType::BN254) => {
103                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, true>)
104                }
105                (true, CurveType::BLS12_381) => Ok($execute_impl::<
106                    _,
107                    _,
108                    BLOCKS,
109                    BLOCK_SIZE,
110                    { CurveType::BLS12_381 as u8 },
111                    true,
112                >),
113                (false, CurveType::K256) => {
114                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, false>)
115                }
116                (false, CurveType::P256) => {
117                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, false>)
118                }
119                (false, CurveType::BN254) => {
120                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, false>)
121                }
122                (false, CurveType::BLS12_381) => Ok($execute_impl::<
123                    _,
124                    _,
125                    BLOCKS,
126                    BLOCK_SIZE,
127                    { CurveType::BLS12_381 as u8 },
128                    false,
129                >),
130            }
131        } else if $is_setup {
132            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
133        } else {
134            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
135        }
136    };
137}
138
139impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> InterpreterExecutor<F>
140    for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
141{
142    #[inline(always)]
143    fn pre_compute_size(&self) -> usize {
144        size_of::<EcDoublePreCompute>()
145    }
146
147    #[cfg(not(feature = "tco"))]
148    fn pre_compute<Ctx>(
149        &self,
150        pc: u32,
151        inst: &Instruction<F>,
152        data: &mut [u8],
153    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
154    where
155        Ctx: ExecutionCtxTrait,
156    {
157        let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
158        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
159
160        dispatch!(execute_e1_handler, pre_compute, is_setup)
161    }
162
163    #[cfg(feature = "tco")]
164    fn handler<Ctx>(
165        &self,
166        pc: u32,
167        inst: &Instruction<F>,
168        data: &mut [u8],
169    ) -> Result<Handler<F, Ctx>, StaticProgramError>
170    where
171        Ctx: ExecutionCtxTrait,
172    {
173        let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
174        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
175
176        dispatch!(execute_e1_handler, pre_compute, is_setup)
177    }
178}
179
180#[cfg(feature = "aot")]
181impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> AotExecutor<F>
182    for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
183{
184}
185
186impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> InterpreterMeteredExecutor<F>
187    for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
188{
189    #[inline(always)]
190    fn metered_pre_compute_size(&self) -> usize {
191        size_of::<E2PreCompute<EcDoublePreCompute>>()
192    }
193
194    #[cfg(not(feature = "tco"))]
195    fn metered_pre_compute<Ctx>(
196        &self,
197        chip_idx: usize,
198        pc: u32,
199        inst: &Instruction<F>,
200        data: &mut [u8],
201    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
202    where
203        Ctx: MeteredExecutionCtxTrait,
204    {
205        let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
206        pre_compute.chip_idx = chip_idx as u32;
207        let pre_compute_pure = &mut pre_compute.data;
208        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
209
210        dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
211    }
212
213    #[cfg(feature = "tco")]
214    fn metered_handler<Ctx>(
215        &self,
216        chip_idx: usize,
217        pc: u32,
218        inst: &Instruction<F>,
219        data: &mut [u8],
220    ) -> Result<Handler<F, Ctx>, StaticProgramError>
221    where
222        Ctx: MeteredExecutionCtxTrait,
223    {
224        let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
225        pre_compute.chip_idx = chip_idx as u32;
226        let pre_compute_pure = &mut pre_compute.data;
227        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
228
229        dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
230    }
231}
232
233#[cfg(feature = "aot")]
234impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> AotMeteredExecutor<F>
235    for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
236{
237}
238
239#[inline(always)]
240unsafe fn execute_e12_impl<
241    F: PrimeField32,
242    CTX: ExecutionCtxTrait,
243    const BLOCKS: usize,
244    const BLOCK_SIZE: usize,
245    const CURVE_TYPE: u8,
246    const IS_SETUP: bool,
247>(
248    pre_compute: &EcDoublePreCompute,
249    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
250) -> Result<(), ExecutionError> {
251    let pc = exec_state.pc();
252    // Read register values
253    let rs_vals = pre_compute
254        .rs_addrs
255        .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
256
257    // Read memory values for the point
258    let read_data: [[u8; BLOCK_SIZE]; BLOCKS] = {
259        let address = rs_vals[0];
260        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
261        from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
262    };
263
264    if IS_SETUP {
265        let input_prime = BigUint::from_bytes_le(read_data[..BLOCKS / 2].as_flattened());
266
267        if input_prime != pre_compute.expr.builder.prime {
268            let err = ExecutionError::Fail {
269                pc,
270                msg: "EcDouble: mismatched prime",
271            };
272            return Err(err);
273        }
274
275        // Extract second field element as the a coefficient
276        let input_a = BigUint::from_bytes_le(read_data[BLOCKS / 2..].as_flattened());
277        let coeff_a = &pre_compute.expr.setup_values[0];
278        if input_a != *coeff_a {
279            let err = ExecutionError::Fail {
280                pc,
281                msg: "EcDouble: mismatched coeff_a",
282            };
283            return Err(err);
284        }
285    }
286
287    let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP {
288        let read_data: DynArray<u8> = read_data.into();
289        run_field_expression_precomputed::<true>(
290            pre_compute.expr,
291            pre_compute.flag_idx as usize,
292            &read_data.0,
293        )
294        .into()
295    } else {
296        ec_double::<CURVE_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
297    };
298
299    let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
300    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
301
302    // Write output data to memory
303    for (i, block) in output_data.into_iter().enumerate() {
304        exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
305    }
306
307    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
308
309    Ok(())
310}
311
312#[create_handler]
313#[inline(always)]
314unsafe fn execute_e1_impl<
315    F: PrimeField32,
316    CTX: ExecutionCtxTrait,
317    const BLOCKS: usize,
318    const BLOCK_SIZE: usize,
319    const CURVE_TYPE: u8,
320    const IS_SETUP: bool,
321>(
322    pre_compute: *const u8,
323    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
324) -> Result<(), ExecutionError> {
325    let pre_compute: &EcDoublePreCompute =
326        std::slice::from_raw_parts(pre_compute, size_of::<EcDoublePreCompute>()).borrow();
327    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(pre_compute, exec_state)
328}
329
330#[create_handler]
331#[inline(always)]
332unsafe fn execute_e2_impl<
333    F: PrimeField32,
334    CTX: MeteredExecutionCtxTrait,
335    const BLOCKS: usize,
336    const BLOCK_SIZE: usize,
337    const CURVE_TYPE: u8,
338    const IS_SETUP: bool,
339>(
340    pre_compute: *const u8,
341    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
342) -> Result<(), ExecutionError> {
343    let e2_pre_compute: &E2PreCompute<EcDoublePreCompute> =
344        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<EcDoublePreCompute>>())
345            .borrow();
346    exec_state
347        .ctx
348        .on_height_change(e2_pre_compute.chip_idx as usize, 1);
349    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(
350        &e2_pre_compute.data,
351        exec_state,
352    )
353}