1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4 mem::size_of,
5};
6
7use num_bigint::BigUint;
8use openvm_circuit::{
9 arch::*,
10 system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
14use openvm_instructions::{
15 instruction::Instruction,
16 program::DEFAULT_PC_STEP,
17 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
18};
19use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
20use openvm_stark_backend::p3_field::PrimeField32;
21
22use super::EcDoubleExecutor;
23use crate::weierstrass_chip::curves::{ec_double, get_curve_type, CurveType};
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct EcDoublePreCompute<'a> {
28 expr: &'a FieldExpr,
29 rs_addrs: [u8; 1],
30 a: u8,
31 flag_idx: u8,
32}
33
34impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcDoubleExecutor<BLOCKS, BLOCK_SIZE> {
35 fn pre_compute_impl<F: PrimeField32>(
36 &'a self,
37 pc: u32,
38 inst: &Instruction<F>,
39 data: &mut EcDoublePreCompute<'a>,
40 ) -> Result<bool, StaticProgramError> {
41 let Instruction {
42 opcode, a, b, d, e, ..
43 } = inst;
44
45 let a = a.as_canonical_u32();
47 let b = b.as_canonical_u32();
48 let d = d.as_canonical_u32();
49 let e = e.as_canonical_u32();
50 if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
51 return Err(StaticProgramError::InvalidInstruction(pc));
52 }
53
54 let local_opcode = opcode.local_opcode_idx(self.offset);
55
56 let needs_setup = self.expr.needs_setup();
58 let mut flag_idx = self.expr.num_flags() as u8;
59 if needs_setup {
60 if let Some(opcode_position) = self
62 .local_opcode_idx
63 .iter()
64 .position(|&idx| idx == local_opcode)
65 {
66 if opcode_position < self.opcode_flag_idx.len() {
68 flag_idx = self.opcode_flag_idx[opcode_position] as u8;
69 }
70 }
71 }
72
73 let rs_addrs = [b as u8];
74 *data = EcDoublePreCompute {
75 expr: &self.expr,
76 rs_addrs,
77 a: a as u8,
78 flag_idx,
79 };
80
81 let local_opcode = opcode.local_opcode_idx(self.offset);
82 let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize;
83
84 Ok(is_setup)
85 }
86}
87
88macro_rules! dispatch {
89 ($execute_impl:ident,$pre_compute:ident,$is_setup:ident) => {
90 if let Some(curve_type) = {
91 let modulus = &$pre_compute.expr.builder.prime;
92 let a_coeff = &$pre_compute.expr.setup_values[0];
93 get_curve_type(modulus, a_coeff)
94 } {
95 match ($is_setup, curve_type) {
96 (true, CurveType::K256) => {
97 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, true>)
98 }
99 (true, CurveType::P256) => {
100 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, true>)
101 }
102 (true, CurveType::BN254) => {
103 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, true>)
104 }
105 (true, CurveType::BLS12_381) => Ok($execute_impl::<
106 _,
107 _,
108 BLOCKS,
109 BLOCK_SIZE,
110 { CurveType::BLS12_381 as u8 },
111 true,
112 >),
113 (false, CurveType::K256) => {
114 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, false>)
115 }
116 (false, CurveType::P256) => {
117 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, false>)
118 }
119 (false, CurveType::BN254) => {
120 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, false>)
121 }
122 (false, CurveType::BLS12_381) => Ok($execute_impl::<
123 _,
124 _,
125 BLOCKS,
126 BLOCK_SIZE,
127 { CurveType::BLS12_381 as u8 },
128 false,
129 >),
130 }
131 } else if $is_setup {
132 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
133 } else {
134 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
135 }
136 };
137}
138
139impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> Executor<F>
140 for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
141{
142 #[inline(always)]
143 fn pre_compute_size(&self) -> usize {
144 size_of::<EcDoublePreCompute>()
145 }
146
147 #[cfg(not(feature = "tco"))]
148 fn pre_compute<Ctx>(
149 &self,
150 pc: u32,
151 inst: &Instruction<F>,
152 data: &mut [u8],
153 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
154 where
155 Ctx: ExecutionCtxTrait,
156 {
157 let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
158 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
159
160 dispatch!(execute_e1_handler, pre_compute, is_setup)
161 }
162
163 #[cfg(feature = "tco")]
164 fn handler<Ctx>(
165 &self,
166 pc: u32,
167 inst: &Instruction<F>,
168 data: &mut [u8],
169 ) -> Result<Handler<F, Ctx>, StaticProgramError>
170 where
171 Ctx: ExecutionCtxTrait,
172 {
173 let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
174 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
175
176 dispatch!(execute_e1_handler, pre_compute, is_setup)
177 }
178}
179
180impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> MeteredExecutor<F>
181 for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
182{
183 #[inline(always)]
184 fn metered_pre_compute_size(&self) -> usize {
185 size_of::<E2PreCompute<EcDoublePreCompute>>()
186 }
187
188 #[cfg(not(feature = "tco"))]
189 fn metered_pre_compute<Ctx>(
190 &self,
191 chip_idx: usize,
192 pc: u32,
193 inst: &Instruction<F>,
194 data: &mut [u8],
195 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
196 where
197 Ctx: MeteredExecutionCtxTrait,
198 {
199 let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
200 pre_compute.chip_idx = chip_idx as u32;
201 let pre_compute_pure = &mut pre_compute.data;
202 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
203
204 dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
205 }
206
207 #[cfg(feature = "tco")]
208 fn metered_handler<Ctx>(
209 &self,
210 chip_idx: usize,
211 pc: u32,
212 inst: &Instruction<F>,
213 data: &mut [u8],
214 ) -> Result<Handler<F, Ctx>, StaticProgramError>
215 where
216 Ctx: MeteredExecutionCtxTrait,
217 {
218 let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
219 pre_compute.chip_idx = chip_idx as u32;
220 let pre_compute_pure = &mut pre_compute.data;
221 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
222
223 dispatch!(execute_e2_handler, pre_compute_pure, is_setup)
224 }
225}
226
227#[inline(always)]
228unsafe fn execute_e12_impl<
229 F: PrimeField32,
230 CTX: ExecutionCtxTrait,
231 const BLOCKS: usize,
232 const BLOCK_SIZE: usize,
233 const CURVE_TYPE: u8,
234 const IS_SETUP: bool,
235>(
236 pre_compute: &EcDoublePreCompute,
237 instret: &mut u64,
238 pc: &mut u32,
239 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
240) -> Result<(), ExecutionError> {
241 let rs_vals = pre_compute
243 .rs_addrs
244 .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
245
246 let read_data: [[u8; BLOCK_SIZE]; BLOCKS] = {
248 let address = rs_vals[0];
249 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
250 from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
251 };
252
253 if IS_SETUP {
254 let input_prime = BigUint::from_bytes_le(read_data[..BLOCKS / 2].as_flattened());
255
256 if input_prime != pre_compute.expr.builder.prime {
257 let err = ExecutionError::Fail {
258 pc: *pc,
259 msg: "EcDouble: mismatched prime",
260 };
261 return Err(err);
262 }
263
264 let input_a = BigUint::from_bytes_le(read_data[BLOCKS / 2..].as_flattened());
266 let coeff_a = &pre_compute.expr.setup_values[0];
267 if input_a != *coeff_a {
268 let err = ExecutionError::Fail {
269 pc: *pc,
270 msg: "EcDouble: mismatched coeff_a",
271 };
272 return Err(err);
273 }
274 }
275
276 let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP {
277 let read_data: DynArray<u8> = read_data.into();
278 run_field_expression_precomputed::<true>(
279 pre_compute.expr,
280 pre_compute.flag_idx as usize,
281 &read_data.0,
282 )
283 .into()
284 } else {
285 ec_double::<CURVE_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
286 };
287
288 let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
289 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
290
291 for (i, block) in output_data.into_iter().enumerate() {
293 exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
294 }
295
296 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
297 *instret += 1;
298
299 Ok(())
300}
301
302#[create_handler]
303#[inline(always)]
304unsafe fn execute_e1_impl<
305 F: PrimeField32,
306 CTX: ExecutionCtxTrait,
307 const BLOCKS: usize,
308 const BLOCK_SIZE: usize,
309 const CURVE_TYPE: u8,
310 const IS_SETUP: bool,
311>(
312 pre_compute: &[u8],
313 instret: &mut u64,
314 pc: &mut u32,
315 _instret_end: u64,
316 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
317) -> Result<(), ExecutionError> {
318 let pre_compute: &EcDoublePreCompute = pre_compute.borrow();
319 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(
320 pre_compute,
321 instret,
322 pc,
323 exec_state,
324 )
325}
326
327#[create_handler]
328#[inline(always)]
329unsafe fn execute_e2_impl<
330 F: PrimeField32,
331 CTX: MeteredExecutionCtxTrait,
332 const BLOCKS: usize,
333 const BLOCK_SIZE: usize,
334 const CURVE_TYPE: u8,
335 const IS_SETUP: bool,
336>(
337 pre_compute: &[u8],
338 instret: &mut u64,
339 pc: &mut u32,
340 _arg: u64,
341 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
342) -> Result<(), ExecutionError> {
343 let e2_pre_compute: &E2PreCompute<EcDoublePreCompute> = pre_compute.borrow();
344 exec_state
345 .ctx
346 .on_height_change(e2_pre_compute.chip_idx as usize, 1);
347 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(
348 &e2_pre_compute.data,
349 instret,
350 pc,
351 exec_state,
352 )
353}