1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_circuit::fields::{get_field_type, FieldType};
8use openvm_circuit::{
9 arch::*,
10 system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::{
15 instruction::Instruction,
16 program::DEFAULT_PC_STEP,
17 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
18};
19use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
20use openvm_stark_backend::p3_field::PrimeField32;
21
22use super::EcAddNeExecutor;
23use crate::weierstrass_chip::curves::ec_add_ne;
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct EcAddNePreCompute<'a> {
28 expr: &'a FieldExpr,
29 rs_addrs: [u8; 2],
30 a: u8,
31 flag_idx: u8,
32}
33
34impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcAddNeExecutor<BLOCKS, BLOCK_SIZE> {
35 fn pre_compute_impl<F: PrimeField32>(
36 &'a self,
37 pc: u32,
38 inst: &Instruction<F>,
39 data: &mut EcAddNePreCompute<'a>,
40 ) -> Result<bool, StaticProgramError> {
41 let Instruction {
42 opcode,
43 a,
44 b,
45 c,
46 d,
47 e,
48 ..
49 } = inst;
50
51 let a = a.as_canonical_u32();
53 let b = b.as_canonical_u32();
54 let c = c.as_canonical_u32();
55 let d = d.as_canonical_u32();
56 let e = e.as_canonical_u32();
57 if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
58 return Err(StaticProgramError::InvalidInstruction(pc));
59 }
60
61 let local_opcode = opcode.local_opcode_idx(self.offset);
62
63 let needs_setup = self.expr.needs_setup();
65 let mut flag_idx = self.expr.num_flags() as u8;
66 if needs_setup {
67 if let Some(opcode_position) = self
69 .local_opcode_idx
70 .iter()
71 .position(|&idx| idx == local_opcode)
72 {
73 if opcode_position < self.opcode_flag_idx.len() {
75 flag_idx = self.opcode_flag_idx[opcode_position] as u8;
76 }
77 }
78 }
79
80 let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
81 *data = EcAddNePreCompute {
82 expr: &self.expr,
83 rs_addrs,
84 a: a as u8,
85 flag_idx,
86 };
87
88 let local_opcode = opcode.local_opcode_idx(self.offset);
89 let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize;
90
91 Ok(is_setup)
92 }
93}
94
95macro_rules! dispatch {
96 ($execute_impl:ident, $pre_compute:ident, $is_setup:ident) => {
97 if let Some(field_type) = {
98 let modulus = &$pre_compute.expr.builder.prime;
99 get_field_type(modulus)
100 } {
101 match ($is_setup, field_type) {
102 (true, FieldType::K256Coordinate) => Ok($execute_impl::<
103 _,
104 _,
105 BLOCKS,
106 BLOCK_SIZE,
107 { FieldType::K256Coordinate as u8 },
108 true,
109 >),
110 (true, FieldType::P256Coordinate) => Ok($execute_impl::<
111 _,
112 _,
113 BLOCKS,
114 BLOCK_SIZE,
115 { FieldType::P256Coordinate as u8 },
116 true,
117 >),
118 (true, FieldType::BN254Coordinate) => Ok($execute_impl::<
119 _,
120 _,
121 BLOCKS,
122 BLOCK_SIZE,
123 { FieldType::BN254Coordinate as u8 },
124 true,
125 >),
126 (true, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
127 _,
128 _,
129 BLOCKS,
130 BLOCK_SIZE,
131 { FieldType::BLS12_381Coordinate as u8 },
132 true,
133 >),
134 (false, FieldType::K256Coordinate) => Ok($execute_impl::<
135 _,
136 _,
137 BLOCKS,
138 BLOCK_SIZE,
139 { FieldType::K256Coordinate as u8 },
140 false,
141 >),
142 (false, FieldType::P256Coordinate) => Ok($execute_impl::<
143 _,
144 _,
145 BLOCKS,
146 BLOCK_SIZE,
147 { FieldType::P256Coordinate as u8 },
148 false,
149 >),
150 (false, FieldType::BN254Coordinate) => Ok($execute_impl::<
151 _,
152 _,
153 BLOCKS,
154 BLOCK_SIZE,
155 { FieldType::BN254Coordinate as u8 },
156 false,
157 >),
158 (false, FieldType::BLS12_381Coordinate) => Ok($execute_impl::<
159 _,
160 _,
161 BLOCKS,
162 BLOCK_SIZE,
163 { FieldType::BLS12_381Coordinate as u8 },
164 false,
165 >),
166 _ => panic!("Unsupported field type"),
167 }
168 } else if $is_setup {
169 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
170 } else {
171 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
172 }
173 };
174}
175impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> Executor<F>
176 for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
177{
178 #[inline(always)]
179 fn pre_compute_size(&self) -> usize {
180 std::mem::size_of::<EcAddNePreCompute>()
181 }
182
183 fn pre_compute<Ctx>(
184 &self,
185 pc: u32,
186 inst: &Instruction<F>,
187 data: &mut [u8],
188 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
189 where
190 Ctx: ExecutionCtxTrait,
191 {
192 let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
193 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
194
195 dispatch!(execute_e1_impl, pre_compute, is_setup)
196 }
197
198 #[cfg(feature = "tco")]
199 fn handler<Ctx>(
200 &self,
201 pc: u32,
202 inst: &Instruction<F>,
203 data: &mut [u8],
204 ) -> Result<Handler<F, Ctx>, StaticProgramError>
205 where
206 Ctx: ExecutionCtxTrait,
207 {
208 let pre_compute: &mut EcAddNePreCompute = data.borrow_mut();
209 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
210
211 dispatch!(execute_e1_tco_handler, pre_compute, is_setup)
212 }
213}
214
215impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> MeteredExecutor<F>
216 for EcAddNeExecutor<BLOCKS, BLOCK_SIZE>
217{
218 #[inline(always)]
219 fn metered_pre_compute_size(&self) -> usize {
220 std::mem::size_of::<E2PreCompute<EcAddNePreCompute>>()
221 }
222
223 fn metered_pre_compute<Ctx>(
224 &self,
225 chip_idx: usize,
226 pc: u32,
227 inst: &Instruction<F>,
228 data: &mut [u8],
229 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
230 where
231 Ctx: MeteredExecutionCtxTrait,
232 {
233 let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
234 pre_compute.chip_idx = chip_idx as u32;
235
236 let pre_compute_pure = &mut pre_compute.data;
237 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
238 dispatch!(execute_e2_impl, pre_compute_pure, is_setup)
239 }
240
241 #[cfg(feature = "tco")]
242 fn metered_handler<Ctx>(
243 &self,
244 chip_idx: usize,
245 pc: u32,
246 inst: &Instruction<F>,
247 data: &mut [u8],
248 ) -> Result<Handler<F, Ctx>, StaticProgramError>
249 where
250 Ctx: MeteredExecutionCtxTrait,
251 {
252 let pre_compute: &mut E2PreCompute<EcAddNePreCompute> = data.borrow_mut();
253 pre_compute.chip_idx = chip_idx as u32;
254
255 let pre_compute_pure = &mut pre_compute.data;
256 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
257 dispatch!(execute_e2_tco_handler, pre_compute_pure, is_setup)
258 }
259}
260
261unsafe fn execute_e12_impl<
262 F: PrimeField32,
263 CTX: ExecutionCtxTrait,
264 const BLOCKS: usize,
265 const BLOCK_SIZE: usize,
266 const FIELD_TYPE: u8,
267 const IS_SETUP: bool,
268>(
269 pre_compute: &EcAddNePreCompute,
270 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
271) {
272 let rs_vals = pre_compute
274 .rs_addrs
275 .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
276
277 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
279 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
280 from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
281 });
282
283 if IS_SETUP {
284 let input_prime = BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened());
285 if input_prime != pre_compute.expr.prime {
286 vm_state.exit_code = Err(ExecutionError::Fail {
287 pc: vm_state.pc,
288 msg: "EcAddNe: mismatched prime",
289 });
290 return;
291 }
292 }
293
294 let output_data = if FIELD_TYPE == u8::MAX || IS_SETUP {
295 let read_data: DynArray<u8> = read_data.into();
296 run_field_expression_precomputed::<true>(
297 pre_compute.expr,
298 pre_compute.flag_idx as usize,
299 &read_data.0,
300 )
301 .into()
302 } else {
303 ec_add_ne::<FIELD_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
304 };
305
306 let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
307 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
308
309 for (i, block) in output_data.into_iter().enumerate() {
311 vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
312 }
313
314 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
315 vm_state.instret += 1;
316}
317
318#[create_tco_handler]
319unsafe fn execute_e1_impl<
320 F: PrimeField32,
321 CTX: ExecutionCtxTrait,
322 const BLOCKS: usize,
323 const BLOCK_SIZE: usize,
324 const FIELD_TYPE: u8,
325 const IS_SETUP: bool,
326>(
327 pre_compute: &[u8],
328 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
329) {
330 let pre_compute: &EcAddNePreCompute = pre_compute.borrow();
331 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(pre_compute, vm_state);
332}
333
334#[create_tco_handler]
335unsafe fn execute_e2_impl<
336 F: PrimeField32,
337 CTX: MeteredExecutionCtxTrait,
338 const BLOCKS: usize,
339 const BLOCK_SIZE: usize,
340 const FIELD_TYPE: u8,
341 const IS_SETUP: bool,
342>(
343 pre_compute: &[u8],
344 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
345) {
346 let e2_pre_compute: &E2PreCompute<EcAddNePreCompute> = pre_compute.borrow();
347 vm_state
348 .ctx
349 .on_height_change(e2_pre_compute.chip_idx as usize, 1);
350 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(
351 &e2_pre_compute.data,
352 vm_state,
353 );
354}