openvm_ecc_circuit/weierstrass_chip/add_ne/
execution.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_circuit::fields::{get_field_type, FieldType};
8use openvm_circuit::{
9    arch::*,
10    system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::{
15    instruction::Instruction,
16    program::DEFAULT_PC_STEP,
17    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
18};
19use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
20use openvm_stark_backend::p3_field::PrimeField32;
21
22use super::EcAddNeExecutor;
23use crate::weierstrass_chip::curves::ec_add_ne;
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct EcAddNePreCompute<'a> {
28    expr: &'a FieldExpr,
29    rs_addrs: [u8; 2],
30    a: u8,
31    flag_idx: u8,
32}
33
34impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcAddNeExecutor<BLOCKS, BLOCK_SIZE> {
35    fn pre_compute_impl<F: PrimeField32>(
36        &'a self,
37        pc: u32,
38        inst: &Instruction<F>,
39        data: &mut EcAddNePreCompute<'a>,
40    ) -> Result<bool, StaticProgramError> {
41        let Instruction {
42            opcode,
43            a,
44            b,
45            c,
46            d,
47            e,
48            ..
49        } = inst;
50
51        // Validate instruction format
52        let a = a.as_canonical_u32();
53        let b = b.as_canonical_u32();
54        let c = c.as_canonical_u32();
55        let d = d.as_canonical_u32();
56        let e = e.as_canonical_u32();
57        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
58            return Err(StaticProgramError::InvalidInstruction(pc));
59        }
60
61        let local_opcode = opcode.local_opcode_idx(self.offset);
62
63        // Pre-compute flag_idx
64        let needs_setup = self.expr.needs_setup();
65        let mut flag_idx = self.expr.num_flags() as u8;
66        if needs_setup {
67            // Find which opcode this is in our local_opcode_idx list
68            if let Some(opcode_position) = self
69                .local_opcode_idx
70                .iter()
71                .position(|&idx| idx == local_opcode)
72            {
73                // If this is NOT the last opcode (setup), get the corresponding flag_idx
74                if opcode_position < self.opcode_flag_idx.len() {
75                    flag_idx = self.opcode_flag_idx[opcode_position] as u8;
76                }
77            }
78        }
79
80        let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
81        *data = EcAddNePreCompute {
82            expr: &self.expr,
83            rs_addrs,
84            a: a as u8,
85            flag_idx,
86        };
87
88        let local_opcode = opcode.local_opcode_idx(self.offset);
89        let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize;
90
91        Ok(is_setup)
92    }
93}
94
95macro_rules! dispatch {
96    ($execute_impl:ident, $pre_compute:ident, $is_setup:ident) => {
97        if let Some(field_type) = {
98            let modulus = &$pre_compute.expr.builder.prime;
99            get_field_type(modulus)
100        } {
101            match ($is_setup, field_type) {
102                (true, FieldType::K256Coordinate) => Ok($execute_impl::<
103                    _,
104                    _,
105                    BLOCKS,
106                    BLOCK_SIZE,
107                    { FieldType::K256Coordinate as u8 },
108                    true,
109                >),
110                (true, FieldType::P256Coordinate) => Ok($execute_impl::<
111                    _,
112                    _,
113                    BLOCKS,
114                    BLOCK_SIZE,
115                    { FieldType::P256Coordinate as u8 },
116                    true,
117                >),
118                (true, FieldType::BN254Coordinate) => Ok($execute_impl::<
119                    _,
120                    _,
121                    BLOCKS,
122                    BLOCK_SIZE,
123                    { FieldType::BN254Coordinate as u8 },
124                    true,
125                >),
126                (true, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
127                    _,
128                    _,
129                    BLOCKS,
130                    BLOCK_SIZE,
131                    { FieldType::BLS12_381Coordinate as u8 },
132                    true,
133                >),
134                (false, FieldType::K256Coordinate) => Ok($execute_impl::<
135                    _,
136                    _,
137                    BLOCKS,
138                    BLOCK_SIZE,
139                    { FieldType::K256Coordinate as u8 },
140                    false,
141                >),
142                (false, FieldType::P256Coordinate) => Ok($execute_impl::<
143                    _,
144                    _,
145                    BLOCKS,
146                    BLOCK_SIZE,
147                    { FieldType::P256Coordinate as u8 },
148                    false,
149                >),
150                (false, FieldType::BN254Coordinate) => Ok($execute_impl::<
151                    _,
152                    _,
153                    BLOCKS,
154                    BLOCK_SIZE,
155                    { FieldType::BN254Coordinate as u8 },
156                    false,
157                >),
158                (false, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
159                    _,
160                    _,
161                    BLOCKS,
162                    BLOCK_SIZE,
163                    { FieldType::BLS12_381Coordinate as u8 },
164                    false,
165                >),
166                _ => panic!("Unsupported field type"),
167            }
168        } else if $is_setup {
169            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
170        } else {
171            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
172        }
173    };
174}
175impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> Executor<F>
176    for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
177{
178    #[inline(always)]
179    fn pre_compute_size(&self) -> usize {
180        std::mem::size_of::<EcAddNePreCompute>()
181    }
182
183    #[cfg(not(feature = "tco"))]
184    fn pre_compute<Ctx>(
185        &self,
186        pc: u32,
187        inst: &Instruction<F>,
188        data: &mut [u8],
189    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
190    where
191        Ctx: ExecutionCtxTrait,
192    {
193        let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
194        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
195
196        dispatch!(execute_e1_handler, pre_compute, is_setup)
197    }
198
199    #[cfg(feature = "tco")]
200    fn handler<Ctx>(
201        &self,
202        pc: u32,
203        inst: &Instruction<F>,
204        data: &mut [u8],
205    ) -> Result<Handler<F, Ctx>, StaticProgramError>
206    where
207        Ctx: ExecutionCtxTrait,
208    {
209        let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
210        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
211
212        dispatch!(execute_e1_handler, pre_compute, is_setup)
213    }
214}
215
216impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> MeteredExecutor<F>
217    for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
218{
219    #[inline(always)]
220    fn metered_pre_compute_size(&self) -> usize {
221        std::mem::size_of::<E2PreCompute<EcAddNePreCompute>>()
222    }
223
224    #[cfg(not(feature = "tco"))]
225    fn metered_pre_compute<Ctx>(
226        &self,
227        chip_idx: usize,
228        pc: u32,
229        inst: &Instruction<F>,
230        data: &mut [u8],
231    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
232    where
233        Ctx: MeteredExecutionCtxTrait,
234    {
235        let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
236        pre_compute.chip_idx = chip_idx as u32;
237
238        let pre_compute_pure = &mut pre_compute.data;
239        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
240        dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
241    }
242
243    #[cfg(feature = "tco")]
244    fn metered_handler<Ctx>(
245        &self,
246        chip_idx: usize,
247        pc: u32,
248        inst: &Instruction<F>,
249        data: &mut [u8],
250    ) -> Result<Handler<F, Ctx>, StaticProgramError>
251    where
252        Ctx: MeteredExecutionCtxTrait,
253    {
254        let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
255        pre_compute.chip_idx = chip_idx as u32;
256
257        let pre_compute_pure = &mut pre_compute.data;
258        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
259        dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
260    }
261}
262
263#[inline(always)]
264unsafe fn execute_e12_impl<
265    F: PrimeField32,
266    CTX: ExecutionCtxTrait,
267    const BLOCKS: usize,
268    const BLOCK_SIZE: usize,
269    const FIELD_TYPE: u8,
270    const IS_SETUP: bool,
271>(
272    pre_compute: &EcAddNePreCompute,
273    instret: &mut u64,
274    pc: &mut u32,
275    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
276) -> Result<(), ExecutionError> {
277    // Read register values
278    let rs_vals = pre_compute
279        .rs_addrs
280        .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
281
282    // Read memory values for both points
283    let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
284        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
285        from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
286    });
287
288    if IS_SETUP {
289        let input_prime = BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened());
290        if input_prime != pre_compute.expr.prime {
291            let err = ExecutionError::Fail {
292                pc: *pc,
293                msg: "EcAddNe: mismatched prime",
294            };
295            return Err(err);
296        }
297    }
298
299    let output_data = if FIELD_TYPE == u8::MAX || IS_SETUP {
300        let read_data: DynArray<u8> = read_data.into();
301        run_field_expression_precomputed::<true>(
302            pre_compute.expr,
303            pre_compute.flag_idx as usize,
304            &read_data.0,
305        )
306        .into()
307    } else {
308        ec_add_ne::<FIELD_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
309    };
310
311    let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
312    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
313
314    // Write output data to memory
315    for (i, block) in output_data.into_iter().enumerate() {
316        exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
317    }
318
319    *pc = pc.wrapping_add(DEFAULT_PC_STEP);
320    *instret += 1;
321
322    Ok(())
323}
324
325#[create_handler]
326#[inline(always)]
327unsafe fn execute_e1_impl<
328    F: PrimeField32,
329    CTX: ExecutionCtxTrait,
330    const BLOCKS: usize,
331    const BLOCK_SIZE: usize,
332    const FIELD_TYPE: u8,
333    const IS_SETUP: bool,
334>(
335    pre_compute: &[u8],
336    instret: &mut u64,
337    pc: &mut u32,
338    _instret_end: u64,
339    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
340) -> Result<(), ExecutionError> {
341    let pre_compute: &EcAddNePreCompute = pre_compute.borrow();
342    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(
343        pre_compute,
344        instret,
345        pc,
346        exec_state,
347    )
348}
349
350#[create_handler]
351#[inline(always)]
352unsafe fn execute_e2_impl<
353    F: PrimeField32,
354    CTX: MeteredExecutionCtxTrait,
355    const BLOCKS: usize,
356    const BLOCK_SIZE: usize,
357    const FIELD_TYPE: u8,
358    const IS_SETUP: bool,
359>(
360    pre_compute: &[u8],
361    instret: &mut u64,
362    pc: &mut u32,
363    _arg: u64,
364    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
365) -> Result<(), ExecutionError> {
366    let e2_pre_compute: &E2PreCompute<EcAddNePreCompute> = pre_compute.borrow();
367    exec_state
368        .ctx
369        .on_height_change(e2_pre_compute.chip_idx as usize, 1);
370    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(
371        &e2_pre_compute.data,
372        instret,
373        pc,
374        exec_state,
375    )
376}