1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode};
8use openvm_circuit::{
9 arch::*,
10 system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_instructions::{
14 instruction::Instruction,
15 program::DEFAULT_PC_STEP,
16 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
17};
18use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
19use openvm_stark_backend::p3_field::PrimeField32;
20
21use super::FieldExprVecHeapExecutor;
22use crate::fields::{
23 field_operation, fp2_operation, get_field_type, get_fp2_field_type, FieldType, Operation,
24};
25
26macro_rules! generate_field_dispatch {
27 (
28 $field_type:expr,
29 $op:expr,
30 $blocks:expr,
31 $block_size:expr,
32 $execute_fn:ident,
33 [$(($curve:ident, $operation:ident)),* $(,)?]
34 ) => {
35 match ($field_type, $op) {
36 $(
37 (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
38 _,
39 _,
40 $blocks,
41 $block_size,
42 false,
43 { FieldType::$curve as u8 },
44 { Operation::$operation as u8 },
45 >),
46 )*
47 }
48 };
49}
50
51macro_rules! generate_fp2_dispatch {
52 (
53 $field_type:expr,
54 $op:expr,
55 $blocks:expr,
56 $block_size:expr,
57 $execute_fn:ident,
58 [$(($curve:ident, $operation:ident)),* $(,)?]
59 ) => {
60 match ($field_type, $op) {
61 $(
62 (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
63 _,
64 _,
65 $blocks,
66 $block_size,
67 true,
68 { FieldType::$curve as u8 },
69 { Operation::$operation as u8 },
70 >),
71 )*
72 _ => panic!("Unsupported fp2 field")
73 }
74 };
75}
76
77macro_rules! dispatch {
78 ($execute_impl:ident,$execute_generic_impl:ident,$execute_setup_impl:ident,$pre_compute:ident,$op:ident) => {
79 if let Some(op) = $op {
80 let modulus = &$pre_compute.expr.prime;
81 if IS_FP2 {
82 if let Some(field_type) = get_fp2_field_type(modulus) {
83 generate_fp2_dispatch!(
84 field_type,
85 op,
86 BLOCKS,
87 BLOCK_SIZE,
88 $execute_impl,
89 [
90 (BN254Coordinate, Add),
91 (BN254Coordinate, Sub),
92 (BN254Coordinate, Mul),
93 (BN254Coordinate, Div),
94 (BLS12_381Coordinate, Add),
95 (BLS12_381Coordinate, Sub),
96 (BLS12_381Coordinate, Mul),
97 (BLS12_381Coordinate, Div),
98 ]
99 )
100 } else {
101 Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
102 }
103 } else if let Some(field_type) = get_field_type(modulus) {
104 generate_field_dispatch!(
105 field_type,
106 op,
107 BLOCKS,
108 BLOCK_SIZE,
109 $execute_impl,
110 [
111 (K256Coordinate, Add),
112 (K256Coordinate, Sub),
113 (K256Coordinate, Mul),
114 (K256Coordinate, Div),
115 (K256Scalar, Add),
116 (K256Scalar, Sub),
117 (K256Scalar, Mul),
118 (K256Scalar, Div),
119 (P256Coordinate, Add),
120 (P256Coordinate, Sub),
121 (P256Coordinate, Mul),
122 (P256Coordinate, Div),
123 (P256Scalar, Add),
124 (P256Scalar, Sub),
125 (P256Scalar, Mul),
126 (P256Scalar, Div),
127 (BN254Coordinate, Add),
128 (BN254Coordinate, Sub),
129 (BN254Coordinate, Mul),
130 (BN254Coordinate, Div),
131 (BN254Scalar, Add),
132 (BN254Scalar, Sub),
133 (BN254Scalar, Mul),
134 (BN254Scalar, Div),
135 (BLS12_381Coordinate, Add),
136 (BLS12_381Coordinate, Sub),
137 (BLS12_381Coordinate, Mul),
138 (BLS12_381Coordinate, Div),
139 (BLS12_381Scalar, Add),
140 (BLS12_381Scalar, Sub),
141 (BLS12_381Scalar, Mul),
142 (BLS12_381Scalar, Div),
143 ]
144 )
145 } else {
146 Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
147 }
148 } else {
149 Ok($execute_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
150 }
151 };
152}
153
154#[derive(AlignedBytesBorrow, Clone)]
155#[repr(C)]
156struct FieldExpressionPreCompute<'a> {
157 expr: &'a FieldExpr,
158 rs_addrs: [u8; 2],
159 a: u8,
160 flag_idx: u8,
161}
162
163impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
164 FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
165{
166 fn pre_compute_impl<F: PrimeField32>(
167 &'a self,
168 pc: u32,
169 inst: &Instruction<F>,
170 data: &mut FieldExpressionPreCompute<'a>,
171 ) -> Result<Option<Operation>, StaticProgramError> {
172 let Instruction {
173 opcode,
174 a,
175 b,
176 c,
177 d,
178 e,
179 ..
180 } = inst;
181
182 let a = a.as_canonical_u32();
183 let b = b.as_canonical_u32();
184 let c = c.as_canonical_u32();
185 let d = d.as_canonical_u32();
186 let e = e.as_canonical_u32();
187 if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
188 return Err(StaticProgramError::InvalidInstruction(pc));
189 }
190
191 let local_opcode = opcode.local_opcode_idx(self.0.offset);
192
193 let needs_setup = self.0.expr.needs_setup();
194 let mut flag_idx = self.0.expr.num_flags() as u8;
195 if needs_setup {
196 if let Some(opcode_position) = self
197 .0
198 .local_opcode_idx
199 .iter()
200 .position(|&idx| idx == local_opcode)
201 {
202 if opcode_position < self.0.opcode_flag_idx.len() {
203 flag_idx = self.0.opcode_flag_idx[opcode_position] as u8;
204 }
205 }
206 }
207
208 let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
209 *data = FieldExpressionPreCompute {
210 a: a as u8,
211 rs_addrs,
212 expr: &self.0.expr,
213 flag_idx,
214 };
215
216 if IS_FP2 {
217 let is_setup = local_opcode == Fp2Opcode::SETUP_ADDSUB as usize
218 || local_opcode == Fp2Opcode::SETUP_MULDIV as usize;
219
220 let op = if is_setup {
221 None
222 } else {
223 match local_opcode {
224 x if x == Fp2Opcode::ADD as usize => Some(Operation::Add),
225 x if x == Fp2Opcode::SUB as usize => Some(Operation::Sub),
226 x if x == Fp2Opcode::MUL as usize => Some(Operation::Mul),
227 x if x == Fp2Opcode::DIV as usize => Some(Operation::Div),
228 _ => unreachable!(),
229 }
230 };
231
232 Ok(op)
233 } else {
234 let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize
235 || local_opcode == Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize;
236
237 let op = if is_setup {
238 None
239 } else {
240 match local_opcode {
241 x if x == Rv32ModularArithmeticOpcode::ADD as usize => Some(Operation::Add),
242 x if x == Rv32ModularArithmeticOpcode::SUB as usize => Some(Operation::Sub),
243 x if x == Rv32ModularArithmeticOpcode::MUL as usize => Some(Operation::Mul),
244 x if x == Rv32ModularArithmeticOpcode::DIV as usize => Some(Operation::Div),
245 _ => unreachable!(),
246 }
247 };
248
249 Ok(op)
250 }
251 }
252}
253
254impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool> Executor<F>
255 for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
256{
257 #[inline(always)]
258 fn pre_compute_size(&self) -> usize {
259 std::mem::size_of::<FieldExpressionPreCompute>()
260 }
261
262 fn pre_compute<Ctx>(
263 &self,
264 pc: u32,
265 inst: &Instruction<F>,
266 data: &mut [u8],
267 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
268 where
269 Ctx: ExecutionCtxTrait,
270 {
271 let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
272 let op = self.pre_compute_impl(pc, inst, pre_compute)?;
273
274 dispatch!(
275 execute_e1_impl,
276 execute_e1_generic_impl,
277 execute_e1_setup_impl,
278 pre_compute,
279 op
280 )
281 }
282
283 #[cfg(feature = "tco")]
284 fn handler<Ctx>(
285 &self,
286 pc: u32,
287 inst: &Instruction<F>,
288 data: &mut [u8],
289 ) -> Result<Handler<F, Ctx>, StaticProgramError>
290 where
291 Ctx: ExecutionCtxTrait,
292 {
293 let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
294 let op = self.pre_compute_impl(pc, inst, pre_compute)?;
295
296 dispatch!(
297 execute_e1_tco_handler,
298 execute_e1_generic_tco_handler,
299 execute_e1_setup_tco_handler,
300 pre_compute,
301 op
302 )
303 }
304}
305
306impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
307 MeteredExecutor<F> for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
308{
309 #[inline(always)]
310 fn metered_pre_compute_size(&self) -> usize {
311 std::mem::size_of::<E2PreCompute<FieldExpressionPreCompute>>()
312 }
313
314 fn metered_pre_compute<Ctx>(
315 &self,
316 chip_idx: usize,
317 pc: u32,
318 inst: &Instruction<F>,
319 data: &mut [u8],
320 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
321 where
322 Ctx: MeteredExecutionCtxTrait,
323 {
324 let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
325 pre_compute.chip_idx = chip_idx as u32;
326
327 let pre_compute_pure = &mut pre_compute.data;
328 let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
329
330 dispatch!(
331 execute_e2_impl,
332 execute_e2_generic_impl,
333 execute_e2_setup_impl,
334 pre_compute_pure,
335 op
336 )
337 }
338
339 #[cfg(feature = "tco")]
340 fn metered_handler<Ctx>(
341 &self,
342 chip_idx: usize,
343 pc: u32,
344 inst: &Instruction<F>,
345 data: &mut [u8],
346 ) -> Result<Handler<F, Ctx>, StaticProgramError>
347 where
348 Ctx: MeteredExecutionCtxTrait,
349 {
350 let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
351 pre_compute.chip_idx = chip_idx as u32;
352
353 let pre_compute_pure = &mut pre_compute.data;
354 let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
355
356 dispatch!(
357 execute_e2_tco_handler,
358 execute_e2_generic_tco_handler,
359 execute_e2_setup_tco_handler,
360 pre_compute_pure,
361 op
362 )
363 }
364}
365unsafe fn execute_e12_impl<
366 F: PrimeField32,
367 CTX: ExecutionCtxTrait,
368 const BLOCKS: usize,
369 const BLOCK_SIZE: usize,
370 const IS_FP2: bool,
371 const FIELD_TYPE: u8,
372 const OP: u8,
373>(
374 pre_compute: &FieldExpressionPreCompute,
375 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
376) {
377 let rs_vals = pre_compute
378 .rs_addrs
379 .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
380
381 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
382 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
383 from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
384 });
385
386 let output_data = if IS_FP2 {
387 fp2_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
388 } else {
389 field_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
390 };
391
392 let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
393 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
394
395 for (i, block) in output_data.into_iter().enumerate() {
396 vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
397 }
398
399 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
400 vm_state.instret += 1;
401}
402
403unsafe fn execute_e12_generic_impl<
404 F: PrimeField32,
405 CTX: ExecutionCtxTrait,
406 const BLOCKS: usize,
407 const BLOCK_SIZE: usize,
408>(
409 pre_compute: &FieldExpressionPreCompute,
410 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
411) {
412 let rs_vals = pre_compute
413 .rs_addrs
414 .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
415
416 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
417 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
418 from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
419 });
420 let read_data_dyn: DynArray<u8> = read_data.into();
421
422 let writes = run_field_expression_precomputed::<true>(
423 pre_compute.expr,
424 pre_compute.flag_idx as usize,
425 &read_data_dyn.0,
426 );
427
428 let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
429 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
430
431 let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
432 for (i, block) in data.into_iter().enumerate() {
433 vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
434 }
435
436 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
437 vm_state.instret += 1;
438}
439
440unsafe fn execute_e12_setup_impl<
441 F: PrimeField32,
442 CTX: ExecutionCtxTrait,
443 const BLOCKS: usize,
444 const BLOCK_SIZE: usize,
445 const IS_FP2: bool,
446>(
447 pre_compute: &FieldExpressionPreCompute,
448 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
449) {
450 let rs_vals = pre_compute
452 .rs_addrs
453 .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
454 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
455 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
456 from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
457 });
458
459 let input_prime = if IS_FP2 {
461 BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened())
462 } else {
463 BigUint::from_bytes_le(read_data[0].as_flattened())
464 };
465
466 if input_prime != pre_compute.expr.prime {
467 vm_state.exit_code = Err(ExecutionError::Fail {
468 pc: vm_state.pc,
469 msg: "ModularSetup: mismatched prime",
470 });
471 return;
472 }
473
474 let read_data_dyn: DynArray<u8> = read_data.into();
475
476 let writes = run_field_expression_precomputed::<true>(
477 pre_compute.expr,
478 pre_compute.flag_idx as usize,
479 &read_data_dyn.0,
480 );
481
482 let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
483 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
484
485 let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
486 for (i, block) in data.into_iter().enumerate() {
487 vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
488 }
489
490 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
491 vm_state.instret += 1;
492}
493
494#[create_tco_handler]
495unsafe fn execute_e1_setup_impl<
496 F: PrimeField32,
497 CTX: ExecutionCtxTrait,
498 const BLOCKS: usize,
499 const BLOCK_SIZE: usize,
500 const IS_FP2: bool,
501>(
502 pre_compute: &[u8],
503 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
504) {
505 let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
506 execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(pre_compute, vm_state);
507}
508
509#[create_tco_handler]
510unsafe fn execute_e2_setup_impl<
511 F: PrimeField32,
512 CTX: MeteredExecutionCtxTrait,
513 const BLOCKS: usize,
514 const BLOCK_SIZE: usize,
515 const IS_FP2: bool,
516>(
517 pre_compute: &[u8],
518 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
519) {
520 let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
521 vm_state
522 .ctx
523 .on_height_change(pre_compute.chip_idx as usize, 1);
524 execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(&pre_compute.data, vm_state);
525}
526
527#[create_tco_handler]
528unsafe fn execute_e1_impl<
529 F: PrimeField32,
530 CTX: ExecutionCtxTrait,
531 const BLOCKS: usize,
532 const BLOCK_SIZE: usize,
533 const IS_FP2: bool,
534 const FIELD_TYPE: u8,
535 const OP: u8,
536>(
537 pre_compute: &[u8],
538 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
539) {
540 let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
541 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(pre_compute, vm_state);
542}
543
544#[create_tco_handler]
545unsafe fn execute_e2_impl<
546 F: PrimeField32,
547 CTX: MeteredExecutionCtxTrait,
548 const BLOCKS: usize,
549 const BLOCK_SIZE: usize,
550 const IS_FP2: bool,
551 const FIELD_TYPE: u8,
552 const OP: u8,
553>(
554 pre_compute: &[u8],
555 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
556) {
557 let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
558 vm_state
559 .ctx
560 .on_height_change(pre_compute.chip_idx as usize, 1);
561 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(
562 &pre_compute.data,
563 vm_state,
564 );
565}
566
567#[create_tco_handler]
568unsafe fn execute_e1_generic_impl<
569 F: PrimeField32,
570 CTX: ExecutionCtxTrait,
571 const BLOCKS: usize,
572 const BLOCK_SIZE: usize,
573 const IS_FP2: bool,
574>(
575 pre_compute: &[u8],
576 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
577) {
578 let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
579 execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(pre_compute, vm_state);
580}
581
582#[create_tco_handler]
583unsafe fn execute_e2_generic_impl<
584 F: PrimeField32,
585 CTX: MeteredExecutionCtxTrait,
586 const BLOCKS: usize,
587 const BLOCK_SIZE: usize,
588 const IS_FP2: bool,
589>(
590 pre_compute: &[u8],
591 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
592) {
593 let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
594 vm_state
595 .ctx
596 .on_height_change(pre_compute.chip_idx as usize, 1);
597 execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(&pre_compute.data, vm_state);
598}