1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode};
8use openvm_circuit::{
9 arch::*,
10 system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_instructions::{
14 instruction::Instruction,
15 program::DEFAULT_PC_STEP,
16 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
17};
18use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
19use openvm_stark_backend::p3_field::PrimeField32;
20
21use super::FieldExprVecHeapExecutor;
22use crate::fields::{
23 field_operation, fp2_operation, get_field_type, get_fp2_field_type, FieldType, Operation,
24};
25
26macro_rules! generate_field_dispatch {
27 (
28 $field_type:expr,
29 $op:expr,
30 $blocks:expr,
31 $block_size:expr,
32 $execute_fn:ident,
33 [$(($curve:ident, $operation:ident)),* $(,)?]
34 ) => {
35 match ($field_type, $op) {
36 $(
37 (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
38 _,
39 _,
40 $blocks,
41 $block_size,
42 false,
43 { FieldType::$curve as u8 },
44 { Operation::$operation as u8 },
45 >),
46 )*
47 }
48 };
49}
50
51macro_rules! generate_fp2_dispatch {
52 (
53 $field_type:expr,
54 $op:expr,
55 $blocks:expr,
56 $block_size:expr,
57 $execute_fn:ident,
58 [$(($curve:ident, $operation:ident)),* $(,)?]
59 ) => {
60 match ($field_type, $op) {
61 $(
62 (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
63 _,
64 _,
65 $blocks,
66 $block_size,
67 true,
68 { FieldType::$curve as u8 },
69 { Operation::$operation as u8 },
70 >),
71 )*
72 _ => panic!("Unsupported fp2 field")
73 }
74 };
75}
76
77macro_rules! dispatch {
78 ($execute_impl:ident,$execute_generic_impl:ident,$execute_setup_impl:ident,$pre_compute:ident,$op:ident) => {
79 if let Some(op) = $op {
80 let modulus = &$pre_compute.expr.prime;
81 if IS_FP2 {
82 if let Some(field_type) = get_fp2_field_type(modulus) {
83 generate_fp2_dispatch!(
84 field_type,
85 op,
86 BLOCKS,
87 BLOCK_SIZE,
88 $execute_impl,
89 [
90 (BN254Coordinate, Add),
91 (BN254Coordinate, Sub),
92 (BN254Coordinate, Mul),
93 (BN254Coordinate, Div),
94 (BLS12_381Coordinate, Add),
95 (BLS12_381Coordinate, Sub),
96 (BLS12_381Coordinate, Mul),
97 (BLS12_381Coordinate, Div),
98 ]
99 )
100 } else {
101 Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
102 }
103 } else if let Some(field_type) = get_field_type(modulus) {
104 generate_field_dispatch!(
105 field_type,
106 op,
107 BLOCKS,
108 BLOCK_SIZE,
109 $execute_impl,
110 [
111 (K256Coordinate, Add),
112 (K256Coordinate, Sub),
113 (K256Coordinate, Mul),
114 (K256Coordinate, Div),
115 (K256Scalar, Add),
116 (K256Scalar, Sub),
117 (K256Scalar, Mul),
118 (K256Scalar, Div),
119 (P256Coordinate, Add),
120 (P256Coordinate, Sub),
121 (P256Coordinate, Mul),
122 (P256Coordinate, Div),
123 (P256Scalar, Add),
124 (P256Scalar, Sub),
125 (P256Scalar, Mul),
126 (P256Scalar, Div),
127 (BN254Coordinate, Add),
128 (BN254Coordinate, Sub),
129 (BN254Coordinate, Mul),
130 (BN254Coordinate, Div),
131 (BN254Scalar, Add),
132 (BN254Scalar, Sub),
133 (BN254Scalar, Mul),
134 (BN254Scalar, Div),
135 (BLS12_381Coordinate, Add),
136 (BLS12_381Coordinate, Sub),
137 (BLS12_381Coordinate, Mul),
138 (BLS12_381Coordinate, Div),
139 (BLS12_381Scalar, Add),
140 (BLS12_381Scalar, Sub),
141 (BLS12_381Scalar, Mul),
142 (BLS12_381Scalar, Div),
143 ]
144 )
145 } else {
146 Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
147 }
148 } else {
149 Ok($execute_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
150 }
151 };
152}
153
154#[derive(AlignedBytesBorrow, Clone)]
155#[repr(C)]
156struct FieldExpressionPreCompute<'a> {
157 expr: &'a FieldExpr,
158 rs_addrs: [u8; 2],
159 a: u8,
160 flag_idx: u8,
161}
162
163impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
164 FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
165{
166 fn pre_compute_impl<F: PrimeField32>(
167 &'a self,
168 pc: u32,
169 inst: &Instruction<F>,
170 data: &mut FieldExpressionPreCompute<'a>,
171 ) -> Result<Option<Operation>, StaticProgramError> {
172 let Instruction {
173 opcode,
174 a,
175 b,
176 c,
177 d,
178 e,
179 ..
180 } = inst;
181
182 let a = a.as_canonical_u32();
183 let b = b.as_canonical_u32();
184 let c = c.as_canonical_u32();
185 let d = d.as_canonical_u32();
186 let e = e.as_canonical_u32();
187 if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
188 return Err(StaticProgramError::InvalidInstruction(pc));
189 }
190
191 let local_opcode = opcode.local_opcode_idx(self.0.offset);
192
193 let needs_setup = self.0.expr.needs_setup();
194 let mut flag_idx = self.0.expr.num_flags() as u8;
195 if needs_setup {
196 if let Some(opcode_position) = self
197 .0
198 .local_opcode_idx
199 .iter()
200 .position(|&idx| idx == local_opcode)
201 {
202 if opcode_position < self.0.opcode_flag_idx.len() {
203 flag_idx = self.0.opcode_flag_idx[opcode_position] as u8;
204 }
205 }
206 }
207
208 let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
209 *data = FieldExpressionPreCompute {
210 a: a as u8,
211 rs_addrs,
212 expr: &self.0.expr,
213 flag_idx,
214 };
215
216 if IS_FP2 {
217 let is_setup = local_opcode == Fp2Opcode::SETUP_ADDSUB as usize
218 || local_opcode == Fp2Opcode::SETUP_MULDIV as usize;
219
220 let op = if is_setup {
221 None
222 } else {
223 match local_opcode {
224 x if x == Fp2Opcode::ADD as usize => Some(Operation::Add),
225 x if x == Fp2Opcode::SUB as usize => Some(Operation::Sub),
226 x if x == Fp2Opcode::MUL as usize => Some(Operation::Mul),
227 x if x == Fp2Opcode::DIV as usize => Some(Operation::Div),
228 _ => unreachable!(),
229 }
230 };
231
232 Ok(op)
233 } else {
234 let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize
235 || local_opcode == Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize;
236
237 let op = if is_setup {
238 None
239 } else {
240 match local_opcode {
241 x if x == Rv32ModularArithmeticOpcode::ADD as usize => Some(Operation::Add),
242 x if x == Rv32ModularArithmeticOpcode::SUB as usize => Some(Operation::Sub),
243 x if x == Rv32ModularArithmeticOpcode::MUL as usize => Some(Operation::Mul),
244 x if x == Rv32ModularArithmeticOpcode::DIV as usize => Some(Operation::Div),
245 _ => unreachable!(),
246 }
247 };
248
249 Ok(op)
250 }
251 }
252}
253
254impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool> Executor<F>
255 for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
256{
257 #[inline(always)]
258 fn pre_compute_size(&self) -> usize {
259 std::mem::size_of::<FieldExpressionPreCompute>()
260 }
261
262 #[cfg(not(feature = "tco"))]
263 fn pre_compute<Ctx>(
264 &self,
265 pc: u32,
266 inst: &Instruction<F>,
267 data: &mut [u8],
268 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
269 where
270 Ctx: ExecutionCtxTrait,
271 {
272 let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
273 let op = self.pre_compute_impl(pc, inst, pre_compute)?;
274
275 dispatch!(
276 execute_e1_handler,
277 execute_e1_generic_handler,
278 execute_e1_setup_handler,
279 pre_compute,
280 op
281 )
282 }
283
284 #[cfg(feature = "tco")]
285 fn handler<Ctx>(
286 &self,
287 pc: u32,
288 inst: &Instruction<F>,
289 data: &mut [u8],
290 ) -> Result<Handler<F, Ctx>, StaticProgramError>
291 where
292 Ctx: ExecutionCtxTrait,
293 {
294 let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
295 let op = self.pre_compute_impl(pc, inst, pre_compute)?;
296
297 dispatch!(
298 execute_e1_handler,
299 execute_e1_generic_handler,
300 execute_e1_setup_handler,
301 pre_compute,
302 op
303 )
304 }
305}
306
307impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
308 MeteredExecutor<F> for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
309{
310 #[inline(always)]
311 fn metered_pre_compute_size(&self) -> usize {
312 std::mem::size_of::<E2PreCompute<FieldExpressionPreCompute>>()
313 }
314
315 #[cfg(not(feature = "tco"))]
316 fn metered_pre_compute<Ctx>(
317 &self,
318 chip_idx: usize,
319 pc: u32,
320 inst: &Instruction<F>,
321 data: &mut [u8],
322 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
323 where
324 Ctx: MeteredExecutionCtxTrait,
325 {
326 let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
327 pre_compute.chip_idx = chip_idx as u32;
328
329 let pre_compute_pure = &mut pre_compute.data;
330 let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
331
332 dispatch!(
333 execute_e2_handler,
334 execute_e2_generic_handler,
335 execute_e2_setup_handler,
336 pre_compute_pure,
337 op
338 )
339 }
340
341 #[cfg(feature = "tco")]
342 fn metered_handler<Ctx>(
343 &self,
344 chip_idx: usize,
345 pc: u32,
346 inst: &Instruction<F>,
347 data: &mut [u8],
348 ) -> Result<Handler<F, Ctx>, StaticProgramError>
349 where
350 Ctx: MeteredExecutionCtxTrait,
351 {
352 let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
353 pre_compute.chip_idx = chip_idx as u32;
354
355 let pre_compute_pure = &mut pre_compute.data;
356 let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
357
358 dispatch!(
359 execute_e2_handler,
360 execute_e2_generic_handler,
361 execute_e2_setup_handler,
362 pre_compute_pure,
363 op
364 )
365 }
366}
367
368#[inline(always)]
369unsafe fn execute_e12_impl<
370 F: PrimeField32,
371 CTX: ExecutionCtxTrait,
372 const BLOCKS: usize,
373 const BLOCK_SIZE: usize,
374 const IS_FP2: bool,
375 const FIELD_TYPE: u8,
376 const OP: u8,
377>(
378 pre_compute: &FieldExpressionPreCompute,
379 instret: &mut u64,
380 pc: &mut u32,
381 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
382) {
383 let rs_vals = pre_compute
384 .rs_addrs
385 .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
386
387 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
388 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
389 from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
390 });
391
392 let output_data = if IS_FP2 {
393 fp2_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
394 } else {
395 field_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
396 };
397
398 let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
399 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
400
401 for (i, block) in output_data.into_iter().enumerate() {
402 exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
403 }
404
405 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
406 *instret += 1;
407}
408
409#[inline(always)]
410unsafe fn execute_e12_generic_impl<
411 F: PrimeField32,
412 CTX: ExecutionCtxTrait,
413 const BLOCKS: usize,
414 const BLOCK_SIZE: usize,
415>(
416 pre_compute: &FieldExpressionPreCompute,
417 instret: &mut u64,
418 pc: &mut u32,
419 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
420) {
421 let rs_vals = pre_compute
422 .rs_addrs
423 .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
424
425 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
426 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
427 from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
428 });
429 let read_data_dyn: DynArray<u8> = read_data.into();
430
431 let writes = run_field_expression_precomputed::<true>(
432 pre_compute.expr,
433 pre_compute.flag_idx as usize,
434 &read_data_dyn.0,
435 );
436
437 let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
438 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
439
440 let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
441 for (i, block) in data.into_iter().enumerate() {
442 exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
443 }
444
445 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
446 *instret += 1;
447}
448
449#[inline(always)]
450unsafe fn execute_e12_setup_impl<
451 F: PrimeField32,
452 CTX: ExecutionCtxTrait,
453 const BLOCKS: usize,
454 const BLOCK_SIZE: usize,
455 const IS_FP2: bool,
456>(
457 pre_compute: &FieldExpressionPreCompute,
458 instret: &mut u64,
459 pc: &mut u32,
460 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
461) -> Result<(), ExecutionError> {
462 let rs_vals = pre_compute
464 .rs_addrs
465 .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
466 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
467 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
468 from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
469 });
470
471 let input_prime = if IS_FP2 {
473 BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened())
474 } else {
475 BigUint::from_bytes_le(read_data[0].as_flattened())
476 };
477
478 if input_prime != pre_compute.expr.prime {
479 let err = ExecutionError::Fail {
480 pc: *pc,
481 msg: "ModularSetup: mismatched prime",
482 };
483 return Err(err);
484 }
485
486 let read_data_dyn: DynArray<u8> = read_data.into();
487
488 let writes = run_field_expression_precomputed::<true>(
489 pre_compute.expr,
490 pre_compute.flag_idx as usize,
491 &read_data_dyn.0,
492 );
493
494 let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
495 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
496
497 let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
498 for (i, block) in data.into_iter().enumerate() {
499 exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
500 }
501
502 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
503 *instret += 1;
504
505 Ok(())
506}
507
508#[create_handler]
509#[inline(always)]
510unsafe fn execute_e1_setup_impl<
511 F: PrimeField32,
512 CTX: ExecutionCtxTrait,
513 const BLOCKS: usize,
514 const BLOCK_SIZE: usize,
515 const IS_FP2: bool,
516>(
517 pre_compute: &[u8],
518 instret: &mut u64,
519 pc: &mut u32,
520 _instret_end: u64,
521 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
522) -> Result<(), ExecutionError> {
523 let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
524 execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(pre_compute, instret, pc, exec_state)
525}
526
527#[create_handler]
528#[inline(always)]
529unsafe fn execute_e2_setup_impl<
530 F: PrimeField32,
531 CTX: MeteredExecutionCtxTrait,
532 const BLOCKS: usize,
533 const BLOCK_SIZE: usize,
534 const IS_FP2: bool,
535>(
536 pre_compute: &[u8],
537 instret: &mut u64,
538 pc: &mut u32,
539 _arg: u64,
540 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
541) -> Result<(), ExecutionError> {
542 let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
543 exec_state
544 .ctx
545 .on_height_change(pre_compute.chip_idx as usize, 1);
546 execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(
547 &pre_compute.data,
548 instret,
549 pc,
550 exec_state,
551 )
552}
553
554#[create_handler]
555#[inline(always)]
556unsafe fn execute_e1_impl<
557 F: PrimeField32,
558 CTX: ExecutionCtxTrait,
559 const BLOCKS: usize,
560 const BLOCK_SIZE: usize,
561 const IS_FP2: bool,
562 const FIELD_TYPE: u8,
563 const OP: u8,
564>(
565 pre_compute: &[u8],
566 instret: &mut u64,
567 pc: &mut u32,
568 _instret_end: u64,
569 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
570) {
571 let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
572 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(
573 pre_compute,
574 instret,
575 pc,
576 exec_state,
577 );
578}
579
580#[create_handler]
581#[inline(always)]
582unsafe fn execute_e2_impl<
583 F: PrimeField32,
584 CTX: MeteredExecutionCtxTrait,
585 const BLOCKS: usize,
586 const BLOCK_SIZE: usize,
587 const IS_FP2: bool,
588 const FIELD_TYPE: u8,
589 const OP: u8,
590>(
591 pre_compute: &[u8],
592 instret: &mut u64,
593 pc: &mut u32,
594 _arg: u64,
595 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
596) {
597 let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
598 exec_state
599 .ctx
600 .on_height_change(pre_compute.chip_idx as usize, 1);
601 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(
602 &pre_compute.data,
603 instret,
604 pc,
605 exec_state,
606 );
607}
608
609#[create_handler]
610#[inline(always)]
611unsafe fn execute_e1_generic_impl<
612 F: PrimeField32,
613 CTX: ExecutionCtxTrait,
614 const BLOCKS: usize,
615 const BLOCK_SIZE: usize,
616 const IS_FP2: bool,
617>(
618 pre_compute: &[u8],
619 instret: &mut u64,
620 pc: &mut u32,
621 _instret_end: u64,
622 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
623) {
624 let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
625 execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(pre_compute, instret, pc, exec_state);
626}
627
628#[create_handler]
629#[inline(always)]
630unsafe fn execute_e2_generic_impl<
631 F: PrimeField32,
632 CTX: MeteredExecutionCtxTrait,
633 const BLOCKS: usize,
634 const BLOCK_SIZE: usize,
635 const IS_FP2: bool,
636>(
637 pre_compute: &[u8],
638 instret: &mut u64,
639 pc: &mut u32,
640 _arg: u64,
641 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
642) {
643 let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
644 exec_state
645 .ctx
646 .on_height_change(pre_compute.chip_idx as usize, 1);
647 execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(
648 &pre_compute.data,
649 instret,
650 pc,
651 exec_state,
652 );
653}