openvm_ecc_circuit/weierstrass_chip/add_ne/
execution.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_circuit::fields::{get_field_type, FieldType};
8use openvm_circuit::{
9    arch::*,
10    system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::{
15    instruction::Instruction,
16    program::DEFAULT_PC_STEP,
17    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
18};
19use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
20use openvm_stark_backend::p3_field::PrimeField32;
21
22use super::EcAddNeExecutor;
23use crate::weierstrass_chip::curves::ec_add_ne;
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct EcAddNePreCompute<'a> {
28    expr: &'a FieldExpr,
29    rs_addrs: [u8; 2],
30    a: u8,
31    flag_idx: u8,
32}
33
34impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcAddNeExecutor<BLOCKS, BLOCK_SIZE> {
35    fn pre_compute_impl<F: PrimeField32>(
36        &'a self,
37        pc: u32,
38        inst: &Instruction<F>,
39        data: &mut EcAddNePreCompute<'a>,
40    ) -> Result<bool, StaticProgramError> {
41        let Instruction {
42            opcode,
43            a,
44            b,
45            c,
46            d,
47            e,
48            ..
49        } = inst;
50
51        // Validate instruction format
52        let a = a.as_canonical_u32();
53        let b = b.as_canonical_u32();
54        let c = c.as_canonical_u32();
55        let d = d.as_canonical_u32();
56        let e = e.as_canonical_u32();
57        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
58            return Err(StaticProgramError::InvalidInstruction(pc));
59        }
60
61        let local_opcode = opcode.local_opcode_idx(self.offset);
62
63        // Pre-compute flag_idx
64        let needs_setup = self.expr.needs_setup();
65        let mut flag_idx = self.expr.num_flags() as u8;
66        if needs_setup {
67            // Find which opcode this is in our local_opcode_idx list
68            if let Some(opcode_position) = self
69                .local_opcode_idx
70                .iter()
71                .position(|&idx| idx == local_opcode)
72            {
73                // If this is NOT the last opcode (setup), get the corresponding flag_idx
74                if opcode_position < self.opcode_flag_idx.len() {
75                    flag_idx = self.opcode_flag_idx[opcode_position] as u8;
76                }
77            }
78        }
79
80        let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
81        *data = EcAddNePreCompute {
82            expr: &self.expr,
83            rs_addrs,
84            a: a as u8,
85            flag_idx,
86        };
87
88        let local_opcode = opcode.local_opcode_idx(self.offset);
89        let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize;
90
91        Ok(is_setup)
92    }
93}
94
95macro_rules! dispatch {
96    ($execute_impl:ident, $pre_compute:ident, $is_setup:ident) => {
97        if let Some(field_type) = {
98            let modulus = &$pre_compute.expr.builder.prime;
99            get_field_type(modulus)
100        } {
101            match ($is_setup, field_type) {
102                (true, FieldType::K256Coordinate) => Ok($execute_impl::<
103                    _,
104                    _,
105                    BLOCKS,
106                    BLOCK_SIZE,
107                    { FieldType::K256Coordinate as u8 },
108                    true,
109                >),
110                (true, FieldType::P256Coordinate) => Ok($execute_impl::<
111                    _,
112                    _,
113                    BLOCKS,
114                    BLOCK_SIZE,
115                    { FieldType::P256Coordinate as u8 },
116                    true,
117                >),
118                (true, FieldType::BN254Coordinate) => Ok($execute_impl::<
119                    _,
120                    _,
121                    BLOCKS,
122                    BLOCK_SIZE,
123                    { FieldType::BN254Coordinate as u8 },
124                    true,
125                >),
126                (true, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
127                    _,
128                    _,
129                    BLOCKS,
130                    BLOCK_SIZE,
131                    { FieldType::BLS12_381Coordinate as u8 },
132                    true,
133                >),
134                (false, FieldType::K256Coordinate) => Ok($execute_impl::<
135                    _,
136                    _,
137                    BLOCKS,
138                    BLOCK_SIZE,
139                    { FieldType::K256Coordinate as u8 },
140                    false,
141                >),
142                (false, FieldType::P256Coordinate) => Ok($execute_impl::<
143                    _,
144                    _,
145                    BLOCKS,
146                    BLOCK_SIZE,
147                    { FieldType::P256Coordinate as u8 },
148                    false,
149                >),
150                (false, FieldType::BN254Coordinate) => Ok($execute_impl::<
151                    _,
152                    _,
153                    BLOCKS,
154                    BLOCK_SIZE,
155                    { FieldType::BN254Coordinate as u8 },
156                    false,
157                >),
158                (false, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
159                    _,
160                    _,
161                    BLOCKS,
162                    BLOCK_SIZE,
163                    { FieldType::BLS12_381Coordinate as u8 },
164                    false,
165                >),
166                _ => panic!("Unsupported field type"),
167            }
168        } else if $is_setup {
169            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
170        } else {
171            Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
172        }
173    };
174}
175impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> Executor<F>
176    for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
177{
178    #[inline(always)]
179    fn pre_compute_size(&self) -> usize {
180        std::mem::size_of::<EcAddNePreCompute>()
181    }
182
183    fn pre_compute<Ctx>(
184        &self,
185        pc: u32,
186        inst: &Instruction<F>,
187        data: &mut [u8],
188    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
189    where
190        Ctx: ExecutionCtxTrait,
191    {
192        let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
193        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
194
195        dispatch!(execute_e1_impl, pre_compute, is_setup)
196    }
197
198    #[cfg(feature = "tco")]
199    fn handler<Ctx>(
200        &self,
201        pc: u32,
202        inst: &Instruction<F>,
203        data: &mut [u8],
204    ) -> Result<Handler<F, Ctx>, StaticProgramError>
205    where
206        Ctx: ExecutionCtxTrait,
207    {
208        let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
209        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
210
211        dispatch!(execute_e1_tco_handler, pre_compute, is_setup)
212    }
213}
214
215impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> MeteredExecutor<F>
216    for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
217{
218    #[inline(always)]
219    fn metered_pre_compute_size(&self) -> usize {
220        std::mem::size_of::<E2PreCompute<EcAddNePreCompute>>()
221    }
222
223    fn metered_pre_compute<Ctx>(
224        &self,
225        chip_idx: usize,
226        pc: u32,
227        inst: &Instruction<F>,
228        data: &mut [u8],
229    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
230    where
231        Ctx: MeteredExecutionCtxTrait,
232    {
233        let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
234        pre_compute.chip_idx = chip_idx as u32;
235
236        let pre_compute_pure = &mut pre_compute.data;
237        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
238        dispatch!(execute_e2_impl, pre_compute_pure, is_setup)
239    }
240
241    #[cfg(feature = "tco")]
242    fn metered_handler<Ctx>(
243        &self,
244        chip_idx: usize,
245        pc: u32,
246        inst: &Instruction<F>,
247        data: &mut [u8],
248    ) -> Result<Handler<F, Ctx>, StaticProgramError>
249    where
250        Ctx: MeteredExecutionCtxTrait,
251    {
252        let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
253        pre_compute.chip_idx = chip_idx as u32;
254
255        let pre_compute_pure = &mut pre_compute.data;
256        let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
257        dispatch!(execute_e2_tco_handler, pre_compute_pure, is_setup)
258    }
259}
260
261unsafe fn execute_e12_impl<
262    F: PrimeField32,
263    CTX: ExecutionCtxTrait,
264    const BLOCKS: usize,
265    const BLOCK_SIZE: usize,
266    const FIELD_TYPE: u8,
267    const IS_SETUP: bool,
268>(
269    pre_compute: &EcAddNePreCompute,
270    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
271) {
272    // Read register values
273    let rs_vals = pre_compute
274        .rs_addrs
275        .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
276
277    // Read memory values for both points
278    let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
279        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
280        from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
281    });
282
283    if IS_SETUP {
284        let input_prime = BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened());
285        if input_prime != pre_compute.expr.prime {
286            vm_state.exit_code = Err(ExecutionError::Fail {
287                pc: vm_state.pc,
288                msg: "EcAddNe: mismatched prime",
289            });
290            return;
291        }
292    }
293
294    let output_data = if FIELD_TYPE == u8::MAX || IS_SETUP {
295        let read_data: DynArray<u8> = read_data.into();
296        run_field_expression_precomputed::<true>(
297            pre_compute.expr,
298            pre_compute.flag_idx as usize,
299            &read_data.0,
300        )
301        .into()
302    } else {
303        ec_add_ne::<FIELD_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
304    };
305
306    let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
307    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
308
309    // Write output data to memory
310    for (i, block) in output_data.into_iter().enumerate() {
311        vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
312    }
313
314    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
315    vm_state.instret += 1;
316}
317
318#[create_tco_handler]
319unsafe fn execute_e1_impl<
320    F: PrimeField32,
321    CTX: ExecutionCtxTrait,
322    const BLOCKS: usize,
323    const BLOCK_SIZE: usize,
324    const FIELD_TYPE: u8,
325    const IS_SETUP: bool,
326>(
327    pre_compute: &[u8],
328    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
329) {
330    let pre_compute: &EcAddNePreCompute = pre_compute.borrow();
331    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(pre_compute, vm_state);
332}
333
334#[create_tco_handler]
335unsafe fn execute_e2_impl<
336    F: PrimeField32,
337    CTX: MeteredExecutionCtxTrait,
338    const BLOCKS: usize,
339    const BLOCK_SIZE: usize,
340    const FIELD_TYPE: u8,
341    const IS_SETUP: bool,
342>(
343    pre_compute: &[u8],
344    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
345) {
346    let e2_pre_compute: &E2PreCompute<EcAddNePreCompute> = pre_compute.borrow();
347    vm_state
348        .ctx
349        .on_height_change(e2_pre_compute.chip_idx as usize, 1);
350    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(
351        &e2_pre_compute.data,
352        vm_state,
353    );
354}