1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_circuit::fields::{get_field_type, FieldType};
8use openvm_circuit::{
9 arch::*,
10 system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::{
15 instruction::Instruction,
16 program::DEFAULT_PC_STEP,
17 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
18};
19use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
20use openvm_stark_backend::p3_field::PrimeField32;
21
22use super::EcAddNeExecutor;
23use crate::weierstrass_chip::curves::ec_add_ne;
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct EcAddNePreCompute<'a> {
28 expr: &'a FieldExpr,
29 rs_addrs: [u8; 2],
30 a: u8,
31 flag_idx: u8,
32}
33
34impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcAddNeExecutor<BLOCKS, BLOCK_SIZE> {
35 fn pre_compute_impl<F: PrimeField32>(
36 &'a self,
37 pc: u32,
38 inst: &Instruction<F>,
39 data: &mut EcAddNePreCompute<'a>,
40 ) -> Result<bool, StaticProgramError> {
41 let Instruction {
42 opcode,
43 a,
44 b,
45 c,
46 d,
47 e,
48 ..
49 } = inst;
50
51 let a = a.as_canonical_u32();
53 let b = b.as_canonical_u32();
54 let c = c.as_canonical_u32();
55 let d = d.as_canonical_u32();
56 let e = e.as_canonical_u32();
57 if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
58 return Err(StaticProgramError::InvalidInstruction(pc));
59 }
60
61 let local_opcode = opcode.local_opcode_idx(self.offset);
62
63 let needs_setup = self.expr.needs_setup();
65 let mut flag_idx = self.expr.num_flags() as u8;
66 if needs_setup {
67 if let Some(opcode_position) = self
69 .local_opcode_idx
70 .iter()
71 .position(|&idx| idx == local_opcode)
72 {
73 if opcode_position < self.opcode_flag_idx.len() {
75 flag_idx = self.opcode_flag_idx[opcode_position] as u8;
76 }
77 }
78 }
79
80 let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
81 *data = EcAddNePreCompute {
82 expr: &self.expr,
83 rs_addrs,
84 a: a as u8,
85 flag_idx,
86 };
87
88 let local_opcode = opcode.local_opcode_idx(self.offset);
89 let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize;
90
91 Ok(is_setup)
92 }
93}
94
95macro_rules! dispatch {
96 ($execute_impl:ident, $pre_compute:ident, $is_setup:ident) => {
97 if let Some(field_type) = {
98 let modulus = &$pre_compute.expr.builder.prime;
99 get_field_type(modulus)
100 } {
101 match ($is_setup, field_type) {
102 (true, FieldType::K256Coordinate) => Ok($execute_impl::<
103 _,
104 _,
105 BLOCKS,
106 BLOCK_SIZE,
107 { FieldType::K256Coordinate as u8 },
108 true,
109 >),
110 (true, FieldType::P256Coordinate) => Ok($execute_impl::<
111 _,
112 _,
113 BLOCKS,
114 BLOCK_SIZE,
115 { FieldType::P256Coordinate as u8 },
116 true,
117 >),
118 (true, FieldType::BN254Coordinate) => Ok($execute_impl::<
119 _,
120 _,
121 BLOCKS,
122 BLOCK_SIZE,
123 { FieldType::BN254Coordinate as u8 },
124 true,
125 >),
126 (true, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
127 _,
128 _,
129 BLOCKS,
130 BLOCK_SIZE,
131 { FieldType::BLS12_381Coordinate as u8 },
132 true,
133 >),
134 (false, FieldType::K256Coordinate) => Ok($execute_impl::<
135 _,
136 _,
137 BLOCKS,
138 BLOCK_SIZE,
139 { FieldType::K256Coordinate as u8 },
140 false,
141 >),
142 (false, FieldType::P256Coordinate) => Ok($execute_impl::<
143 _,
144 _,
145 BLOCKS,
146 BLOCK_SIZE,
147 { FieldType::P256Coordinate as u8 },
148 false,
149 >),
150 (false, FieldType::BN254Coordinate) => Ok($execute_impl::<
151 _,
152 _,
153 BLOCKS,
154 BLOCK_SIZE,
155 { FieldType::BN254Coordinate as u8 },
156 false,
157 >),
158 (false, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
159 _,
160 _,
161 BLOCKS,
162 BLOCK_SIZE,
163 { FieldType::BLS12_381Coordinate as u8 },
164 false,
165 >),
166 _ => panic!("Unsupported field type"),
167 }
168 } else if $is_setup {
169 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
170 } else {
171 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
172 }
173 };
174}
175impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> Executor<F>
176 for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
177{
178 #[inline(always)]
179 fn pre_compute_size(&self) -> usize {
180 std::mem::size_of::<EcAddNePreCompute>()
181 }
182
183 #[cfg(not(feature = "tco"))]
184 fn pre_compute<Ctx>(
185 &self,
186 pc: u32,
187 inst: &Instruction<F>,
188 data: &mut [u8],
189 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
190 where
191 Ctx: ExecutionCtxTrait,
192 {
193 let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
194 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
195
196 dispatch!(execute_e1_handler, pre_compute, is_setup)
197 }
198
199 #[cfg(feature = "tco")]
200 fn handler<Ctx>(
201 &self,
202 pc: u32,
203 inst: &Instruction<F>,
204 data: &mut [u8],
205 ) -> Result<Handler<F, Ctx>, StaticProgramError>
206 where
207 Ctx: ExecutionCtxTrait,
208 {
209 let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
210 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
211
212 dispatch!(execute_e1_handler, pre_compute, is_setup)
213 }
214}
215
216impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> MeteredExecutor<F>
217 for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
218{
219 #[inline(always)]
220 fn metered_pre_compute_size(&self) -> usize {
221 std::mem::size_of::<E2PreCompute<EcAddNePreCompute>>()
222 }
223
224 #[cfg(not(feature = "tco"))]
225 fn metered_pre_compute<Ctx>(
226 &self,
227 chip_idx: usize,
228 pc: u32,
229 inst: &Instruction<F>,
230 data: &mut [u8],
231 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
232 where
233 Ctx: MeteredExecutionCtxTrait,
234 {
235 let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
236 pre_compute.chip_idx = chip_idx as u32;
237
238 let pre_compute_pure = &mut pre_compute.data;
239 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
240 dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
241 }
242
243 #[cfg(feature = "tco")]
244 fn metered_handler<Ctx>(
245 &self,
246 chip_idx: usize,
247 pc: u32,
248 inst: &Instruction<F>,
249 data: &mut [u8],
250 ) -> Result<Handler<F, Ctx>, StaticProgramError>
251 where
252 Ctx: MeteredExecutionCtxTrait,
253 {
254 let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
255 pre_compute.chip_idx = chip_idx as u32;
256
257 let pre_compute_pure = &mut pre_compute.data;
258 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
259 dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
260 }
261}
262
263#[inline(always)]
264unsafe fn execute_e12_impl<
265 F: PrimeField32,
266 CTX: ExecutionCtxTrait,
267 const BLOCKS: usize,
268 const BLOCK_SIZE: usize,
269 const FIELD_TYPE: u8,
270 const IS_SETUP: bool,
271>(
272 pre_compute: &EcAddNePreCompute,
273 instret: &mut u64,
274 pc: &mut u32,
275 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
276) -> Result<(), ExecutionError> {
277 let rs_vals = pre_compute
279 .rs_addrs
280 .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
281
282 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
284 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
285 from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
286 });
287
288 if IS_SETUP {
289 let input_prime = BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened());
290 if input_prime != pre_compute.expr.prime {
291 let err = ExecutionError::Fail {
292 pc: *pc,
293 msg: "EcAddNe: mismatched prime",
294 };
295 return Err(err);
296 }
297 }
298
299 let output_data = if FIELD_TYPE == u8::MAX || IS_SETUP {
300 let read_data: DynArray<u8> = read_data.into();
301 run_field_expression_precomputed::<true>(
302 pre_compute.expr,
303 pre_compute.flag_idx as usize,
304 &read_data.0,
305 )
306 .into()
307 } else {
308 ec_add_ne::<FIELD_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
309 };
310
311 let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
312 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
313
314 for (i, block) in output_data.into_iter().enumerate() {
316 exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
317 }
318
319 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
320 *instret += 1;
321
322 Ok(())
323}
324
325#[create_handler]
326#[inline(always)]
327unsafe fn execute_e1_impl<
328 F: PrimeField32,
329 CTX: ExecutionCtxTrait,
330 const BLOCKS: usize,
331 const BLOCK_SIZE: usize,
332 const FIELD_TYPE: u8,
333 const IS_SETUP: bool,
334>(
335 pre_compute: &[u8],
336 instret: &mut u64,
337 pc: &mut u32,
338 _instret_end: u64,
339 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
340) -> Result<(), ExecutionError> {
341 let pre_compute: &EcAddNePreCompute = pre_compute.borrow();
342 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(
343 pre_compute,
344 instret,
345 pc,
346 exec_state,
347 )
348}
349
350#[create_handler]
351#[inline(always)]
352unsafe fn execute_e2_impl<
353 F: PrimeField32,
354 CTX: MeteredExecutionCtxTrait,
355 const BLOCKS: usize,
356 const BLOCK_SIZE: usize,
357 const FIELD_TYPE: u8,
358 const IS_SETUP: bool,
359>(
360 pre_compute: &[u8],
361 instret: &mut u64,
362 pc: &mut u32,
363 _arg: u64,
364 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
365) -> Result<(), ExecutionError> {
366 let e2_pre_compute: &E2PreCompute<EcAddNePreCompute> = pre_compute.borrow();
367 exec_state
368 .ctx
369 .on_height_change(e2_pre_compute.chip_idx as usize, 1);
370 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(
371 &e2_pre_compute.data,
372 instret,
373 pc,
374 exec_state,
375 )
376}