1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_circuit::{
8 arch::*,
9 system::memory::{online::GuestMemory, POINTER_MAX_BITS},
10};
11use openvm_circuit_primitives::AlignedBytesBorrow;
12use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
13use openvm_instructions::{
14 instruction::Instruction,
15 program::DEFAULT_PC_STEP,
16 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
17};
18use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
19use openvm_stark_backend::p3_field::PrimeField32;
20
21use super::EcDoubleExecutor;
22use crate::weierstrass_chip::curves::{ec_double, get_curve_type, CurveType};
23
24#[derive(AlignedBytesBorrow, Clone)]
25#[repr(C)]
26struct EcDoublePreCompute<'a> {
27 expr: &'a FieldExpr,
28 rs_addrs: [u8; 1],
29 a: u8,
30 flag_idx: u8,
31}
32
33impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcDoubleExecutor<BLOCKS, BLOCK_SIZE> {
34 fn pre_compute_impl<F: PrimeField32>(
35 &'a self,
36 pc: u32,
37 inst: &Instruction<F>,
38 data: &mut EcDoublePreCompute<'a>,
39 ) -> Result<bool, StaticProgramError> {
40 let Instruction {
41 opcode, a, b, d, e, ..
42 } = inst;
43
44 let a = a.as_canonical_u32();
46 let b = b.as_canonical_u32();
47 let d = d.as_canonical_u32();
48 let e = e.as_canonical_u32();
49 if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
50 return Err(StaticProgramError::InvalidInstruction(pc));
51 }
52
53 let local_opcode = opcode.local_opcode_idx(self.offset);
54
55 let needs_setup = self.expr.needs_setup();
57 let mut flag_idx = self.expr.num_flags() as u8;
58 if needs_setup {
59 if let Some(opcode_position) = self
61 .local_opcode_idx
62 .iter()
63 .position(|&idx| idx == local_opcode)
64 {
65 if opcode_position < self.opcode_flag_idx.len() {
67 flag_idx = self.opcode_flag_idx[opcode_position] as u8;
68 }
69 }
70 }
71
72 let rs_addrs = [b as u8];
73 *data = EcDoublePreCompute {
74 expr: &self.expr,
75 rs_addrs,
76 a: a as u8,
77 flag_idx,
78 };
79
80 let local_opcode = opcode.local_opcode_idx(self.offset);
81 let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize;
82
83 Ok(is_setup)
84 }
85}
86
87macro_rules! dispatch {
88 ($execute_impl:ident,$pre_compute:ident,$is_setup:ident) => {
89 if let Some(curve_type) = {
90 let modulus = &$pre_compute.expr.builder.prime;
91 let a_coeff = &$pre_compute.expr.setup_values[0];
92 get_curve_type(modulus, a_coeff)
93 } {
94 match ($is_setup, curve_type) {
95 (true, CurveType::K256) => {
96 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, true>)
97 }
98 (true, CurveType::P256) => {
99 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, true>)
100 }
101 (true, CurveType::BN254) => {
102 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, true>)
103 }
104 (true, CurveType::BLS12_381) => Ok($execute_impl::<
105 _,
106 _,
107 BLOCKS,
108 BLOCK_SIZE,
109 { CurveType::BLS12_381 as u8 },
110 true,
111 >),
112 (false, CurveType::K256) => {
113 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, false>)
114 }
115 (false, CurveType::P256) => {
116 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, false>)
117 }
118 (false, CurveType::BN254) => {
119 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, false>)
120 }
121 (false, CurveType::BLS12_381) => Ok($execute_impl::<
122 _,
123 _,
124 BLOCKS,
125 BLOCK_SIZE,
126 { CurveType::BLS12_381 as u8 },
127 false,
128 >),
129 }
130 } else if $is_setup {
131 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>)
132 } else {
133 Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>)
134 }
135 };
136}
137
138impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> Executor<F>
139 for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
140{
141 #[inline(always)]
142 fn pre_compute_size(&self) -> usize {
143 std::mem::size_of::<EcDoublePreCompute>()
144 }
145
146 fn pre_compute<Ctx>(
147 &self,
148 pc: u32,
149 inst: &Instruction<F>,
150 data: &mut [u8],
151 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
152 where
153 Ctx: ExecutionCtxTrait,
154 {
155 let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
156 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
157
158 dispatch!(execute_e1_impl, pre_compute, is_setup)
159 }
160
161 #[cfg(feature = "tco")]
162 fn handler<Ctx>(
163 &self,
164 pc: u32,
165 inst: &Instruction<F>,
166 data: &mut [u8],
167 ) -> Result<Handler<F, Ctx>, StaticProgramError>
168 where
169 Ctx: ExecutionCtxTrait,
170 {
171 let pre_compute: &mut EcDoublePreCompute = data.borrow_mut();
172 let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
173
174 dispatch!(execute_e1_tco_handler, pre_compute, is_setup)
175 }
176}
177
178impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize> MeteredExecutor<F>
179 for EcDoubleExecutor<BLOCKS, BLOCK_SIZE>
180{
181 #[inline(always)]
182 fn metered_pre_compute_size(&self) -> usize {
183 std::mem::size_of::<E2PreCompute<EcDoublePreCompute>>()
184 }
185
186 fn metered_pre_compute<Ctx>(
187 &self,
188 chip_idx: usize,
189 pc: u32,
190 inst: &Instruction<F>,
191 data: &mut [u8],
192 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
193 where
194 Ctx: MeteredExecutionCtxTrait,
195 {
196 let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
197 pre_compute.chip_idx = chip_idx as u32;
198 let pre_compute_pure = &mut pre_compute.data;
199 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
200
201 dispatch!(execute_e2_impl, pre_compute_pure, is_setup)
202 }
203
204 #[cfg(feature = "tco")]
205 fn metered_handler<Ctx>(
206 &self,
207 chip_idx: usize,
208 pc: u32,
209 inst: &Instruction<F>,
210 data: &mut [u8],
211 ) -> Result<Handler<F, Ctx>, StaticProgramError>
212 where
213 Ctx: MeteredExecutionCtxTrait,
214 {
215 let pre_compute: &mut E2PreCompute<EcDoublePreCompute> = data.borrow_mut();
216 pre_compute.chip_idx = chip_idx as u32;
217 let pre_compute_pure = &mut pre_compute.data;
218 let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
219
220 dispatch!(execute_e2_tco_handler, pre_compute_pure, is_setup)
221 }
222}
223
224unsafe fn execute_e12_impl<
225 F: PrimeField32,
226 CTX: ExecutionCtxTrait,
227 const BLOCKS: usize,
228 const BLOCK_SIZE: usize,
229 const CURVE_TYPE: u8,
230 const IS_SETUP: bool,
231>(
232 pre_compute: &EcDoublePreCompute,
233 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
234) {
235 let rs_vals = pre_compute
237 .rs_addrs
238 .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
239
240 let read_data: [[u8; BLOCK_SIZE]; BLOCKS] = {
242 let address = rs_vals[0];
243 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
244 from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
245 };
246
247 if IS_SETUP {
248 let input_prime = BigUint::from_bytes_le(read_data[..BLOCKS / 2].as_flattened());
249
250 if input_prime != pre_compute.expr.builder.prime {
251 vm_state.exit_code = Err(ExecutionError::Fail {
252 pc: vm_state.pc,
253 msg: "EcDouble: mismatched prime",
254 });
255 return;
256 }
257
258 let input_a = BigUint::from_bytes_le(read_data[BLOCKS / 2..].as_flattened());
260 let coeff_a = &pre_compute.expr.setup_values[0];
261 if input_a != *coeff_a {
262 vm_state.exit_code = Err(ExecutionError::Fail {
263 pc: vm_state.pc,
264 msg: "EcDouble: mismatched coeff_a",
265 });
266 return;
267 }
268 }
269
270 let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP {
271 let read_data: DynArray<u8> = read_data.into();
272 run_field_expression_precomputed::<true>(
273 pre_compute.expr,
274 pre_compute.flag_idx as usize,
275 &read_data.0,
276 )
277 .into()
278 } else {
279 ec_double::<CURVE_TYPE, BLOCKS, BLOCK_SIZE>(read_data)
280 };
281
282 let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
283 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
284
285 for (i, block) in output_data.into_iter().enumerate() {
287 vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
288 }
289
290 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
291 vm_state.instret += 1;
292}
293
294#[create_tco_handler]
295unsafe fn execute_e1_impl<
296 F: PrimeField32,
297 CTX: ExecutionCtxTrait,
298 const BLOCKS: usize,
299 const BLOCK_SIZE: usize,
300 const CURVE_TYPE: u8,
301 const IS_SETUP: bool,
302>(
303 pre_compute: &[u8],
304 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
305) {
306 let pre_compute: &EcDoublePreCompute = pre_compute.borrow();
307 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(pre_compute, vm_state);
308}
309
310#[create_tco_handler]
311unsafe fn execute_e2_impl<
312 F: PrimeField32,
313 CTX: MeteredExecutionCtxTrait,
314 const BLOCKS: usize,
315 const BLOCK_SIZE: usize,
316 const CURVE_TYPE: u8,
317 const IS_SETUP: bool,
318>(
319 pre_compute: &[u8],
320 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
321) {
322 let e2_pre_compute: &E2PreCompute<EcDoublePreCompute> = pre_compute.borrow();
323 vm_state
324 .ctx
325 .on_height_change(e2_pre_compute.chip_idx as usize, 1);
326 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(
327 &e2_pre_compute.data,
328 vm_state,
329 );
330}