1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4 mem::size_of,
5};
6
7use num_bigint::BigUint;
8use openvm_circuit::{
9 arch::*,
10 system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::{
15 instruction::Instruction,
16 program::DEFAULT_PC_STEP,
17 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
18};
19use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
20use openvm_stark_backend::p3_field::PrimeField32;
21
22use super::EcDoubleExecutor;
23use crate::weierstrass_chip::curves::{ec_double, get_curve_type, CurveType};
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct EcDoublePreCompute<'a> {
28 expr: &'a FieldExpr,
29 rs_addrs: [u8; 1],
30 a: u8,
31 flag_idx: u8,
32}
33
34impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcDoubleExecutor<BLOCKS, BLOCK_SIZE> {
35 fn pre_compute_impl<F: PrimeField32>(
36 &'a self,
37 pc: u32,
38 inst: &Instruction<F>,
39 data: &mut EcDoublePreCompute<'a>,
40 ) -> Result<bool, StaticProgramError> {
41 let Instruction {
42 opcode, a, b, d, e, ..
43 } = inst;
44
45 let a = a.as_canonical_u32();
47 let b = b.as_canonical_u32();
48 let d = d.as_canonical_u32();
49 let e = e.as_canonical_u32();
50 if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
51 return Err(StaticProgramError::InvalidInstruction(pc));
52 }
53
54 let local_opcode = opcode.local_opcode_idx(self.offset);
55
56 let needs_setup = self.expr.needs_setup();
58 let mut flag_idx = self.expr.num_flags() as u8;
59 if needs_setup {
60 if let Some(opcode_position) = self
62 .local_opcode_idx
63 .iter()
64 .position(|&idx| idx == local_opcode)
65 {
66 if opcode_position < self.opcode_flag_idx.len() {
68 flag_idx = self.opcode_flag_idx[opcode_position] as u8;
69 }
70 }
71 }
72
73 let rs_addrs = [b as u8];
74 *data = EcDoublePreCompute {
75 expr: &self.expr,
76 rs_addrs,
77 a: a as u8,
78 flag_idx,
79 };
80
81 let local_opcode = opcode.local_opcode_idx(self.offset);
82 let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize;
83
84 Ok(is_setup)
85 }
86}
87
88macro_rules! dispatch {
89 ($execute_impl:ident,$pre_compute:ident,$is_setup:ident) => {
90 if let Some(curve_type) = {
91 let modulus = &$pre_compute.expr.builder.prime;
92 let a_coeff = &$pre_compute.expr.setup_values[0];
93 get_curve_type(modulus, a_coeff)
94 } {
95 match ($is_setup, curve_type) {
96 (true, CurveType::K256) => {
97 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, true>)
98 }
99 (true, CurveType::P256) => {
100 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, true>)
101 }
102 (true, CurveType::BN254) => {
103 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, true>)
104 }
105 (true, CurveType::BLS12_381) => Ok($execute_impl::<
106 _,
107 _,
108 BLOCKS,
109 BLOCK_SIZE,
110 { CurveType::BLS12_381 as u8 },
111 true,
112 >),
113 (false, CurveType::K256) => {
114 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, false>)
115 }
116 (false, CurveType::P256) => {
117 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, false>)
118 }
119 (false, CurveType::BN254) => {
120 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, false>)
121 }
122 (false, CurveType::BLS12_381) => Ok($execute_impl::<
123 _,
124 _,
125 BLOCKS,
126 BLOCK_SIZE,
127 { CurveType::BLS12_381 as u8 },
128 false,
129 >),
130 }
131 } else if $is_setup {
132 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
133 } else {
134 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
135 }
136 };
137}
138
139impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> InterpreterExecutor<F>
140 for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
141{
142 #[inline(always)]
143 fn pre_compute_size(&self) -> usize {
144 size_of::<EcDoublePreCompute>()
145 }
146
147 #[cfg(not(feature = "tco"))]
148 fn pre_compute<Ctx>(
149 &self,
150 pc: u32,
151 inst: &Instruction<F>,
152 data: &mut [u8],
153 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
154 where
155 Ctx: ExecutionCtxTrait,
156 {
157 let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
158 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
159
160 dispatch!(execute_e1_handler, pre_compute, is_setup)
161 }
162
163 #[cfg(feature = "tco")]
164 fn handler<Ctx>(
165 &self,
166 pc: u32,
167 inst: &Instruction<F>,
168 data: &mut [u8],
169 ) -> Result<Handler<F, Ctx>, StaticProgramError>
170 where
171 Ctx: ExecutionCtxTrait,
172 {
173 let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
174 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
175
176 dispatch!(execute_e1_handler, pre_compute, is_setup)
177 }
178}
179
180#[cfg(feature = "aot")]
181impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> AotExecutor<F>
182 for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
183{
184}
185
186impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> InterpreterMeteredExecutor<F>
187 for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
188{
189 #[inline(always)]
190 fn metered_pre_compute_size(&self) -> usize {
191 size_of::<E2PreCompute<EcDoublePreCompute>>()
192 }
193
194 #[cfg(not(feature = "tco"))]
195 fn metered_pre_compute<Ctx>(
196 &self,
197 chip_idx: usize,
198 pc: u32,
199 inst: &Instruction<F>,
200 data: &mut [u8],
201 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
202 where
203 Ctx: MeteredExecutionCtxTrait,
204 {
205 let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
206 pre_compute.chip_idx = chip_idx as u32;
207 let pre_compute_pure = &mut pre_compute.data;
208 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
209
210 dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
211 }
212
213 #[cfg(feature = "tco")]
214 fn metered_handler<Ctx>(
215 &self,
216 chip_idx: usize,
217 pc: u32,
218 inst: &Instruction<F>,
219 data: &mut [u8],
220 ) -> Result<Handler<F, Ctx>, StaticProgramError>
221 where
222 Ctx: MeteredExecutionCtxTrait,
223 {
224 let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
225 pre_compute.chip_idx = chip_idx as u32;
226 let pre_compute_pure = &mut pre_compute.data;
227 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
228
229 dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
230 }
231}
232
233#[cfg(feature = "aot")]
234impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> AotMeteredExecutor<F>
235 for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
236{
237}
238
239#[inline(always)]
240unsafe fn execute_e12_impl<
241 F: PrimeField32,
242 CTX: ExecutionCtxTrait,
243 const BLOCKS: usize,
244 const BLOCK_SIZE: usize,
245 const CURVE_TYPE: u8,
246 const IS_SETUP: bool,
247>(
248 pre_compute: &EcDoublePreCompute,
249 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
250) -> Result<(), ExecutionError> {
251 let pc = exec_state.pc();
252 let rs_vals = pre_compute
254 .rs_addrs
255 .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
256
257 let read_data: [[u8; BLOCK_SIZE]; BLOCKS] = {
259 let address = rs_vals[0];
260 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
261 from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
262 };
263
264 if IS_SETUP {
265 let input_prime = BigUint::from_bytes_le(read_data[..BLOCKS / 2].as_flattened());
266
267 if input_prime != pre_compute.expr.builder.prime {
268 let err = ExecutionError::Fail {
269 pc,
270 msg: "EcDouble: mismatched prime",
271 };
272 return Err(err);
273 }
274
275 let input_a = BigUint::from_bytes_le(read_data[BLOCKS / 2..].as_flattened());
277 let coeff_a = &pre_compute.expr.setup_values[0];
278 if input_a != *coeff_a {
279 let err = ExecutionError::Fail {
280 pc,
281 msg: "EcDouble: mismatched coeff_a",
282 };
283 return Err(err);
284 }
285 }
286
287 let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP {
288 let read_data: DynArray<u8> = read_data.into();
289 run_field_expression_precomputed::<true>(
290 pre_compute.expr,
291 pre_compute.flag_idx as usize,
292 &read_data.0,
293 )
294 .into()
295 } else {
296 ec_double::<CURVE_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
297 };
298
299 let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
300 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
301
302 for (i, block) in output_data.into_iter().enumerate() {
304 exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
305 }
306
307 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
308
309 Ok(())
310}
311
312#[create_handler]
313#[inline(always)]
314unsafe fn execute_e1_impl<
315 F: PrimeField32,
316 CTX: ExecutionCtxTrait,
317 const BLOCKS: usize,
318 const BLOCK_SIZE: usize,
319 const CURVE_TYPE: u8,
320 const IS_SETUP: bool,
321>(
322 pre_compute: *const u8,
323 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
324) -> Result<(), ExecutionError> {
325 let pre_compute: &EcDoublePreCompute =
326 std::slice::from_raw_parts(pre_compute, size_of::<EcDoublePreCompute>()).borrow();
327 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(pre_compute, exec_state)
328}
329
330#[create_handler]
331#[inline(always)]
332unsafe fn execute_e2_impl<
333 F: PrimeField32,
334 CTX: MeteredExecutionCtxTrait,
335 const BLOCKS: usize,
336 const BLOCK_SIZE: usize,
337 const CURVE_TYPE: u8,
338 const IS_SETUP: bool,
339>(
340 pre_compute: *const u8,
341 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
342) -> Result<(), ExecutionError> {
343 let e2_pre_compute: &E2PreCompute<EcDoublePreCompute> =
344 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<EcDoublePreCompute>>())
345 .borrow();
346 exec_state
347 .ctx
348 .on_height_change(e2_pre_compute.chip_idx as usize, 1);
349 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(
350 &e2_pre_compute.data,
351 exec_state,
352 )
353}