openvm_ecc_circuit/weierstrass_chip/add_ne/
execution.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_circuit::fields::{get_field_type, FieldType};
8use openvm_circuit::{
9    arch::*,
10    system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::{
15    instruction::Instruction,
16    program::DEFAULT_PC_STEP,
17    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
18};
19use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
20use openvm_stark_backend::p3_field::PrimeField32;
21
22use super::EcAddNeExecutor;
23use crate::weierstrass_chip::curves::ec_add_ne;
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct EcAddNePreCompute<'a> {
28    expr: &'a FieldExpr,
29    rs_addrs: [u8; 2],
30    a: u8,
31    flag_idx: u8,
32}
33
34impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcAddNeExecutor<BLOCKS, BLOCK_SIZE> {
35    fn pre_compute_impl<F: PrimeField32>(
36        &'a self,
37        pc: u32,
38        inst: &Instruction<F>,
39        data: &mut EcAddNePreCompute<'a>,
40    ) -> Result<bool, StaticProgramError> {
41        let Instruction {
42            opcode,
43            a,
44            b,
45            c,
46            d,
47            e,
48            ..
49        } = inst;
50
51        // Validate instruction format
52        let a = a.as_canonical_u32();
53        let b = b.as_canonical_u32();
54        let c = c.as_canonical_u32();
55        let d = d.as_canonical_u32();
56        let e = e.as_canonical_u32();
57        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
58            return Err(StaticProgramError::InvalidInstruction(pc));
59        }
60
61        let local_opcode = opcode.local_opcode_idx(self.offset);
62
63        // Pre-compute flag_idx
64        let needs_setup = self.expr.needs_setup();
65        let mut flag_idx = self.expr.num_flags() as u8;
66        if needs_setup {
67            // Find which opcode this is in our local_opcode_idx list
68            if let Some(opcode_position) = self
69                .local_opcode_idx
70                .iter()
71                .position(|&idx| idx == local_opcode)
72            {
73                // If this is NOT the last opcode (setup), get the corresponding flag_idx
74                if opcode_position < self.opcode_flag_idx.len() {
75                    flag_idx = self.opcode_flag_idx[opcode_position] as u8;
76                }
77            }
78        }
79
80        let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
81        *data = EcAddNePreCompute {
82            expr: &self.expr,
83            rs_addrs,
84            a: a as u8,
85            flag_idx,
86        };
87
88        let local_opcode = opcode.local_opcode_idx(self.offset);
89        let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize;
90
91        Ok(is_setup)
92    }
93}
94
95macro_rules! dispatch {
96    ($execute_impl:ident, $pre_compute:ident, $is_setup:ident) => {
97        if let Some(field_type) = {
98            let modulus = &$pre_compute.expr.builder.prime;
99            get_field_type(modulus)
100        } {
101            match ($is_setup, field_type) {
102                (true, FieldType::K256Coordinate) => Ok($execute_impl::<
103                    _,
104                    _,
105                    BLOCKS,
106                    BLOCK_SIZE,
107                    { FieldType::K256Coordinate as u8 },
108                    true,
109                >),
110                (true, FieldType::P256Coordinate) => Ok($execute_impl::<
111                    _,
112                    _,
113                    BLOCKS,
114                    BLOCK_SIZE,
115                    { FieldType::P256Coordinate as u8 },
116                    true,
117                >),
118                (true, FieldType::BN254Coordinate) => Ok($execute_impl::<
119                    _,
120                    _,
121                    BLOCKS,
122                    BLOCK_SIZE,
123                    { FieldType::BN254Coordinate as u8 },
124                    true,
125                >),
126                (true, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
127                    _,
128                    _,
129                    BLOCKS,
130                    BLOCK_SIZE,
131                    { FieldType::BLS12_381Coordinate as u8 },
132                    true,
133                >),
134                (false, FieldType::K256Coordinate) => Ok($execute_impl::<
135                    _,
136                    _,
137                    BLOCKS,
138                    BLOCK_SIZE,
139                    { FieldType::K256Coordinate as u8 },
140                    false,
141                >),
142                (false, FieldType::P256Coordinate) => Ok($execute_impl::<
143                    _,
144                    _,
145                    BLOCKS,
146                    BLOCK_SIZE,
147                    { FieldType::P256Coordinate as u8 },
148                    false,
149                >),
150                (false, FieldType::BN254Coordinate) => Ok($execute_impl::<
151                    _,
152                    _,
153                    BLOCKS,
154                    BLOCK_SIZE,
155                    { FieldType::BN254Coordinate as u8 },
156                    false,
157                >),
158                (false, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
159                    _,
160                    _,
161                    BLOCKS,
162                    BLOCK_SIZE,
163                    { FieldType::BLS12_381Coordinate as u8 },
164                    false,
165                >),
166                _ => panic!("Unsupported field type"),
167            }
168        } else if $is_setup {
169            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
170        } else {
171            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
172        }
173    };
174}
175impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> InterpreterExecutor<F>
176    for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
177{
178    #[inline(always)]
179    fn pre_compute_size(&self) -> usize {
180        std::mem::size_of::<EcAddNePreCompute>()
181    }
182
183    #[cfg(not(feature = "tco"))]
184    fn pre_compute<Ctx>(
185        &self,
186        pc: u32,
187        inst: &Instruction<F>,
188        data: &mut [u8],
189    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
190    where
191        Ctx: ExecutionCtxTrait,
192    {
193        let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
194        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
195
196        dispatch!(execute_e1_handler, pre_compute, is_setup)
197    }
198
199    #[cfg(feature = "tco")]
200    fn handler<Ctx>(
201        &self,
202        pc: u32,
203        inst: &Instruction<F>,
204        data: &mut [u8],
205    ) -> Result<Handler<F, Ctx>, StaticProgramError>
206    where
207        Ctx: ExecutionCtxTrait,
208    {
209        let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
210        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
211
212        dispatch!(execute_e1_handler, pre_compute, is_setup)
213    }
214}
215
216#[cfg(feature = "aot")]
217impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> AotExecutor<F>
218    for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
219{
220}
221
222impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> InterpreterMeteredExecutor<F>
223    for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
224{
225    #[inline(always)]
226    fn metered_pre_compute_size(&self) -> usize {
227        std::mem::size_of::<E2PreCompute<EcAddNePreCompute>>()
228    }
229
230    #[cfg(not(feature = "tco"))]
231    fn metered_pre_compute<Ctx>(
232        &self,
233        chip_idx: usize,
234        pc: u32,
235        inst: &Instruction<F>,
236        data: &mut [u8],
237    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
238    where
239        Ctx: MeteredExecutionCtxTrait,
240    {
241        let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
242        pre_compute.chip_idx = chip_idx as u32;
243
244        let pre_compute_pure = &mut pre_compute.data;
245        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
246        dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
247    }
248
249    #[cfg(feature = "tco")]
250    fn metered_handler<Ctx>(
251        &self,
252        chip_idx: usize,
253        pc: u32,
254        inst: &Instruction<F>,
255        data: &mut [u8],
256    ) -> Result<Handler<F, Ctx>, StaticProgramError>
257    where
258        Ctx: MeteredExecutionCtxTrait,
259    {
260        let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
261        pre_compute.chip_idx = chip_idx as u32;
262
263        let pre_compute_pure = &mut pre_compute.data;
264        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
265        dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
266    }
267}
268#[cfg(feature = "aot")]
269impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> AotMeteredExecutor<F>
270    for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
271{
272}
273
274#[inline(always)]
275unsafe fn execute_e12_impl<
276    F: PrimeField32,
277    CTX: ExecutionCtxTrait,
278    const BLOCKS: usize,
279    const BLOCK_SIZE: usize,
280    const FIELD_TYPE: u8,
281    const IS_SETUP: bool,
282>(
283    pre_compute: &EcAddNePreCompute,
284    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
285) -> Result<(), ExecutionError> {
286    let pc = exec_state.pc();
287    // Read register values
288    let rs_vals = pre_compute
289        .rs_addrs
290        .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
291
292    // Read memory values for both points
293    let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
294        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
295        from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
296    });
297
298    if IS_SETUP {
299        let input_prime = BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened());
300        if input_prime != pre_compute.expr.prime {
301            let err = ExecutionError::Fail {
302                pc,
303                msg: "EcAddNe: mismatched prime",
304            };
305            return Err(err);
306        }
307    }
308
309    let output_data = if FIELD_TYPE == u8::MAX || IS_SETUP {
310        let read_data: DynArray<u8> = read_data.into();
311        run_field_expression_precomputed::<true>(
312            pre_compute.expr,
313            pre_compute.flag_idx as usize,
314            &read_data.0,
315        )
316        .into()
317    } else {
318        ec_add_ne::<FIELD_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
319    };
320
321    let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
322    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
323
324    // Write output data to memory
325    for (i, block) in output_data.into_iter().enumerate() {
326        exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
327    }
328
329    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
330
331    Ok(())
332}
333
334#[create_handler]
335#[inline(always)]
336unsafe fn execute_e1_impl<
337    F: PrimeField32,
338    CTX: ExecutionCtxTrait,
339    const BLOCKS: usize,
340    const BLOCK_SIZE: usize,
341    const FIELD_TYPE: u8,
342    const IS_SETUP: bool,
343>(
344    pre_compute: *const u8,
345    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
346) -> Result<(), ExecutionError> {
347    let pre_compute: &EcAddNePreCompute =
348        std::slice::from_raw_parts(pre_compute, size_of::<EcAddNePreCompute>()).borrow();
349    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(pre_compute, exec_state)
350}
351
352#[create_handler]
353#[inline(always)]
354unsafe fn execute_e2_impl<
355    F: PrimeField32,
356    CTX: MeteredExecutionCtxTrait,
357    const BLOCKS: usize,
358    const BLOCK_SIZE: usize,
359    const FIELD_TYPE: u8,
360    const IS_SETUP: bool,
361>(
362    pre_compute: *const u8,
363    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
364) -> Result<(), ExecutionError> {
365    let e2_pre_compute: &E2PreCompute<EcAddNePreCompute> =
366        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<EcAddNePreCompute>>())
367            .borrow();
368    exec_state
369        .ctx
370        .on_height_change(e2_pre_compute.chip_idx as usize, 1);
371    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(
372        &e2_pre_compute.data,
373        exec_state,
374    )
375}