openvm_ecc_circuit/weierstrass_chip/double/
execution.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_circuit::{
8    arch::*,
9    system::memory::{online::GuestMemory, POINTER_MAX_BITS},
10};
11use openvm_circuit_primitives::AlignedBytesBorrow;
12use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
13use openvm_instructions::{
14    instruction::Instruction,
15    program::DEFAULT_PC_STEP,
16    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
17};
18use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
19use openvm_stark_backend::p3_field::PrimeField32;
20
21use super::EcDoubleExecutor;
22use crate::weierstrass_chip::curves::{ec_double, get_curve_type, CurveType};
23
24#[derive(AlignedBytesBorrow, Clone)]
25#[repr(C)]
26struct EcDoublePreCompute<'a> {
27    expr: &'a FieldExpr,
28    rs_addrs: [u8; 1],
29    a: u8,
30    flag_idx: u8,
31}
32
33impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcDoubleExecutor<BLOCKS, BLOCK_SIZE> {
34    fn pre_compute_impl<F: PrimeField32>(
35        &'a self,
36        pc: u32,
37        inst: &Instruction<F>,
38        data: &mut EcDoublePreCompute<'a>,
39    ) -> Result<bool, StaticProgramError> {
40        let Instruction {
41            opcode, a, b, d, e, ..
42        } = inst;
43
44        // Validate instruction format
45        let a = a.as_canonical_u32();
46        let b = b.as_canonical_u32();
47        let d = d.as_canonical_u32();
48        let e = e.as_canonical_u32();
49        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
50            return Err(StaticProgramError::InvalidInstruction(pc));
51        }
52
53        let local_opcode = opcode.local_opcode_idx(self.offset);
54
55        // Pre-compute flag_idx
56        let needs_setup = self.expr.needs_setup();
57        let mut flag_idx = self.expr.num_flags() as u8;
58        if needs_setup {
59            // Find which opcode this is in our local_opcode_idx list
60            if let Some(opcode_position) = self
61                .local_opcode_idx
62                .iter()
63                .position(|&idx| idx == local_opcode)
64            {
65                // If this is NOT the last opcode (setup), get the corresponding flag_idx
66                if opcode_position < self.opcode_flag_idx.len() {
67                    flag_idx = self.opcode_flag_idx[opcode_position] as u8;
68                }
69            }
70        }
71
72        let rs_addrs = [b as u8];
73        *data = EcDoublePreCompute {
74            expr: &self.expr,
75            rs_addrs,
76            a: a as u8,
77            flag_idx,
78        };
79
80        let local_opcode = opcode.local_opcode_idx(self.offset);
81        let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize;
82
83        Ok(is_setup)
84    }
85}
86
87macro_rules! dispatch {
88    ($execute_impl:ident,$pre_compute:ident,$is_setup:ident) => {
89        if let Some(curve_type) = {
90            let modulus = &$pre_compute.expr.builder.prime;
91            let a_coeff = &$pre_compute.expr.setup_values[0];
92            get_curve_type(modulus, a_coeff)
93        } {
94            match ($is_setup, curve_type) {
95                (true, CurveType::K256) => {
96                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, true>)
97                }
98                (true, CurveType::P256) => {
99                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, true>)
100                }
101                (true, CurveType::BN254) => {
102                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, true>)
103                }
104                (true, CurveType::BLS12_381) => Ok($execute_impl::<
105                    _,
106                    _,
107                    BLOCKS,
108                    BLOCK_SIZE,
109                    { CurveType::BLS12_381 as u8 },
110                    true,
111                >),
112                (false, CurveType::K256) => {
113                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, false>)
114                }
115                (false, CurveType::P256) => {
116                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, false>)
117                }
118                (false, CurveType::BN254) => {
119                    Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, false>)
120                }
121                (false, CurveType::BLS12_381) => Ok($execute_impl::<
122                    _,
123                    _,
124                    BLOCKS,
125                    BLOCK_SIZE,
126                    { CurveType::BLS12_381 as u8 },
127                    false,
128                >),
129            }
130        } else if $is_setup {
131            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
132        } else {
133            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
134        }
135    };
136}
137
138impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> Executor<F>
139    for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
140{
141    #[inline(always)]
142    fn pre_compute_size(&self) -> usize {
143        std::mem::size_of::<EcDoublePreCompute>()
144    }
145
146    fn pre_compute<Ctx>(
147        &self,
148        pc: u32,
149        inst: &Instruction<F>,
150        data: &mut [u8],
151    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
152    where
153        Ctx: ExecutionCtxTrait,
154    {
155        let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
156        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
157
158        dispatch!(execute_e1_impl, pre_compute, is_setup)
159    }
160
161    #[cfg(feature = "tco")]
162    fn handler<Ctx>(
163        &self,
164        pc: u32,
165        inst: &Instruction<F>,
166        data: &mut [u8],
167    ) -> Result<Handler<F, Ctx>, StaticProgramError>
168    where
169        Ctx: ExecutionCtxTrait,
170    {
171        let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
172        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
173
174        dispatch!(execute_e1_tco_handler, pre_compute, is_setup)
175    }
176}
177
178impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> MeteredExecutor<F>
179    for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
180{
181    #[inline(always)]
182    fn metered_pre_compute_size(&self) -> usize {
183        std::mem::size_of::<E2PreCompute<EcDoublePreCompute>>()
184    }
185
186    fn metered_pre_compute<Ctx>(
187        &self,
188        chip_idx: usize,
189        pc: u32,
190        inst: &Instruction<F>,
191        data: &mut [u8],
192    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
193    where
194        Ctx: MeteredExecutionCtxTrait,
195    {
196        let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
197        pre_compute.chip_idx = chip_idx as u32;
198        let pre_compute_pure = &mut pre_compute.data;
199        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
200
201        dispatch!(execute_e2_impl, pre_compute_pure, is_setup)
202    }
203
204    #[cfg(feature = "tco")]
205    fn metered_handler<Ctx>(
206        &self,
207        chip_idx: usize,
208        pc: u32,
209        inst: &Instruction<F>,
210        data: &mut [u8],
211    ) -> Result<Handler<F, Ctx>, StaticProgramError>
212    where
213        Ctx: MeteredExecutionCtxTrait,
214    {
215        let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
216        pre_compute.chip_idx = chip_idx as u32;
217        let pre_compute_pure = &mut pre_compute.data;
218        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
219
220        dispatch!(execute_e2_tco_handler, pre_compute_pure, is_setup)
221    }
222}
223
224unsafe fn execute_e12_impl<
225    F: PrimeField32,
226    CTX: ExecutionCtxTrait,
227    const BLOCKS: usize,
228    const BLOCK_SIZE: usize,
229    const CURVE_TYPE: u8,
230    const IS_SETUP: bool,
231>(
232    pre_compute: &EcDoublePreCompute,
233    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
234) {
235    // Read register values
236    let rs_vals = pre_compute
237        .rs_addrs
238        .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
239
240    // Read memory values for the point
241    let read_data: [[u8; BLOCK_SIZE]; BLOCKS] = {
242        let address = rs_vals[0];
243        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
244        from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
245    };
246
247    if IS_SETUP {
248        let input_prime = BigUint::from_bytes_le(read_data[..BLOCKS / 2].as_flattened());
249
250        if input_prime != pre_compute.expr.builder.prime {
251            vm_state.exit_code = Err(ExecutionError::Fail {
252                pc: vm_state.pc,
253                msg: "EcDouble: mismatched prime",
254            });
255            return;
256        }
257
258        // Extract second field element as the a coefficient
259        let input_a = BigUint::from_bytes_le(read_data[BLOCKS / 2..].as_flattened());
260        let coeff_a = &pre_compute.expr.setup_values[0];
261        if input_a != *coeff_a {
262            vm_state.exit_code = Err(ExecutionError::Fail {
263                pc: vm_state.pc,
264                msg: "EcDouble: mismatched coeff_a",
265            });
266            return;
267        }
268    }
269
270    let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP {
271        let read_data: DynArray<u8> = read_data.into();
272        run_field_expression_precomputed::<true>(
273            pre_compute.expr,
274            pre_compute.flag_idx as usize,
275            &read_data.0,
276        )
277        .into()
278    } else {
279        ec_double::<CURVE_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
280    };
281
282    let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
283    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
284
285    // Write output data to memory
286    for (i, block) in output_data.into_iter().enumerate() {
287        vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
288    }
289
290    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
291    vm_state.instret += 1;
292}
293
294#[create_tco_handler]
295unsafe fn execute_e1_impl<
296    F: PrimeField32,
297    CTX: ExecutionCtxTrait,
298    const BLOCKS: usize,
299    const BLOCK_SIZE: usize,
300    const CURVE_TYPE: u8,
301    const IS_SETUP: bool,
302>(
303    pre_compute: &[u8],
304    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
305) {
306    let pre_compute: &EcDoublePreCompute = pre_compute.borrow();
307    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(pre_compute, vm_state);
308}
309
310#[create_tco_handler]
311unsafe fn execute_e2_impl<
312    F: PrimeField32,
313    CTX: MeteredExecutionCtxTrait,
314    const BLOCKS: usize,
315    const BLOCK_SIZE: usize,
316    const CURVE_TYPE: u8,
317    const IS_SETUP: bool,
318>(
319    pre_compute: &[u8],
320    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
321) {
322    let e2_pre_compute: &E2PreCompute<EcDoublePreCompute> = pre_compute.borrow();
323    vm_state
324        .ctx
325        .on_height_change(e2_pre_compute.chip_idx as usize, 1);
326    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(
327        &e2_pre_compute.data,
328        vm_state,
329    );
330}