1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode};
8use openvm_circuit::{
9 arch::*,
10 system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_instructions::{
14 instruction::Instruction,
15 program::DEFAULT_PC_STEP,
16 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
17};
18use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
19use openvm_stark_backend::p3_field::PrimeField32;
20
21use super::FieldExprVecHeapExecutor;
22use crate::fields::{
23 field_operation, fp2_operation, get_field_type, get_fp2_field_type, FieldType, Operation,
24};
25
26macro_rules! generate_field_dispatch {
27 (
28 $field_type:expr,
29 $op:expr,
30 $blocks:expr,
31 $block_size:expr,
32 $execute_fn:ident,
33 [$(($curve:ident, $operation:ident)),* $(,)?]
34 ) => {
35 match ($field_type, $op) {
36 $(
37 (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
38 _,
39 _,
40 $blocks,
41 $block_size,
42 false,
43 { FieldType::$curve as u8 },
44 { Operation::$operation as u8 },
45 >),
46 )*
47 }
48 };
49}
50
51macro_rules! generate_fp2_dispatch {
52 (
53 $field_type:expr,
54 $op:expr,
55 $blocks:expr,
56 $block_size:expr,
57 $execute_fn:ident,
58 [$(($curve:ident, $operation:ident)),* $(,)?]
59 ) => {
60 match ($field_type, $op) {
61 $(
62 (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
63 _,
64 _,
65 $blocks,
66 $block_size,
67 true,
68 { FieldType::$curve as u8 },
69 { Operation::$operation as u8 },
70 >),
71 )*
72 _ => panic!("Unsupported fp2 field")
73 }
74 };
75}
76
77macro_rules! dispatch {
78 ($execute_impl:ident,$execute_generic_impl:ident,$execute_setup_impl:ident,$pre_compute:ident,$op:ident) => {
79 if let Some(op) = $op {
80 let modulus = &$pre_compute.expr.prime;
81 if IS_FP2 {
82 if let Some(field_type) = get_fp2_field_type(modulus) {
83 generate_fp2_dispatch!(
84 field_type,
85 op,
86 BLOCKS,
87 BLOCK_SIZE,
88 $execute_impl,
89 [
90 (BN254Coordinate, Add),
91 (BN254Coordinate, Sub),
92 (BN254Coordinate, Mul),
93 (BN254Coordinate, Div),
94 (BLS12_381Coordinate, Add),
95 (BLS12_381Coordinate, Sub),
96 (BLS12_381Coordinate, Mul),
97 (BLS12_381Coordinate, Div),
98 ]
99 )
100 } else {
101 Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
102 }
103 } else if let Some(field_type) = get_field_type(modulus) {
104 generate_field_dispatch!(
105 field_type,
106 op,
107 BLOCKS,
108 BLOCK_SIZE,
109 $execute_impl,
110 [
111 (K256Coordinate, Add),
112 (K256Coordinate, Sub),
113 (K256Coordinate, Mul),
114 (K256Coordinate, Div),
115 (K256Scalar, Add),
116 (K256Scalar, Sub),
117 (K256Scalar, Mul),
118 (K256Scalar, Div),
119 (P256Coordinate, Add),
120 (P256Coordinate, Sub),
121 (P256Coordinate, Mul),
122 (P256Coordinate, Div),
123 (P256Scalar, Add),
124 (P256Scalar, Sub),
125 (P256Scalar, Mul),
126 (P256Scalar, Div),
127 (BN254Coordinate, Add),
128 (BN254Coordinate, Sub),
129 (BN254Coordinate, Mul),
130 (BN254Coordinate, Div),
131 (BN254Scalar, Add),
132 (BN254Scalar, Sub),
133 (BN254Scalar, Mul),
134 (BN254Scalar, Div),
135 (BLS12_381Coordinate, Add),
136 (BLS12_381Coordinate, Sub),
137 (BLS12_381Coordinate, Mul),
138 (BLS12_381Coordinate, Div),
139 (BLS12_381Scalar, Add),
140 (BLS12_381Scalar, Sub),
141 (BLS12_381Scalar, Mul),
142 (BLS12_381Scalar, Div),
143 ]
144 )
145 } else {
146 Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
147 }
148 } else {
149 Ok($execute_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
150 }
151 };
152}
153
154#[derive(AlignedBytesBorrow, Clone)]
155#[repr(C)]
156struct FieldExpressionPreCompute<'a> {
157 expr: &'a FieldExpr,
158 rs_addrs: [u8; 2],
159 a: u8,
160 flag_idx: u8,
161}
162
163impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
164 FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
165{
166 fn pre_compute_impl<F: PrimeField32>(
167 &'a self,
168 pc: u32,
169 inst: &Instruction<F>,
170 data: &mut FieldExpressionPreCompute<'a>,
171 ) -> Result<Option<Operation>, StaticProgramError> {
172 let Instruction {
173 opcode,
174 a,
175 b,
176 c,
177 d,
178 e,
179 ..
180 } = inst;
181
182 let a = a.as_canonical_u32();
183 let b = b.as_canonical_u32();
184 let c = c.as_canonical_u32();
185 let d = d.as_canonical_u32();
186 let e = e.as_canonical_u32();
187 if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
188 return Err(StaticProgramError::InvalidInstruction(pc));
189 }
190
191 let local_opcode = opcode.local_opcode_idx(self.0.offset);
192
193 let needs_setup = self.0.expr.needs_setup();
194 let mut flag_idx = self.0.expr.num_flags() as u8;
195 if needs_setup {
196 if let Some(opcode_position) = self
197 .0
198 .local_opcode_idx
199 .iter()
200 .position(|&idx| idx == local_opcode)
201 {
202 if opcode_position < self.0.opcode_flag_idx.len() {
203 flag_idx = self.0.opcode_flag_idx[opcode_position] as u8;
204 }
205 }
206 }
207
208 let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
209 *data = FieldExpressionPreCompute {
210 a: a as u8,
211 rs_addrs,
212 expr: &self.0.expr,
213 flag_idx,
214 };
215
216 if IS_FP2 {
217 let is_setup = local_opcode == Fp2Opcode::SETUP_ADDSUB as usize
218 || local_opcode == Fp2Opcode::SETUP_MULDIV as usize;
219
220 let op = if is_setup {
221 None
222 } else {
223 match local_opcode {
224 x if x == Fp2Opcode::ADD as usize => Some(Operation::Add),
225 x if x == Fp2Opcode::SUB as usize => Some(Operation::Sub),
226 x if x == Fp2Opcode::MUL as usize => Some(Operation::Mul),
227 x if x == Fp2Opcode::DIV as usize => Some(Operation::Div),
228 _ => unreachable!(),
229 }
230 };
231
232 Ok(op)
233 } else {
234 let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize
235 || local_opcode == Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize;
236
237 let op = if is_setup {
238 None
239 } else {
240 match local_opcode {
241 x if x == Rv32ModularArithmeticOpcode::ADD as usize => Some(Operation::Add),
242 x if x == Rv32ModularArithmeticOpcode::SUB as usize => Some(Operation::Sub),
243 x if x == Rv32ModularArithmeticOpcode::MUL as usize => Some(Operation::Mul),
244 x if x == Rv32ModularArithmeticOpcode::DIV as usize => Some(Operation::Div),
245 _ => unreachable!(),
246 }
247 };
248
249 Ok(op)
250 }
251 }
252}
253
254impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
255 InterpreterExecutor<F> for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
256{
257 #[inline(always)]
258 fn pre_compute_size(&self) -> usize {
259 std::mem::size_of::<FieldExpressionPreCompute>()
260 }
261
262 #[cfg(not(feature = "tco"))]
263 fn pre_compute<Ctx>(
264 &self,
265 pc: u32,
266 inst: &Instruction<F>,
267 data: &mut [u8],
268 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
269 where
270 Ctx: ExecutionCtxTrait,
271 {
272 let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
273 let op = self.pre_compute_impl(pc, inst, pre_compute)?;
274
275 dispatch!(
276 execute_e1_handler,
277 execute_e1_generic_handler,
278 execute_e1_setup_handler,
279 pre_compute,
280 op
281 )
282 }
283
284 #[cfg(feature = "tco")]
285 fn handler<Ctx>(
286 &self,
287 pc: u32,
288 inst: &Instruction<F>,
289 data: &mut [u8],
290 ) -> Result<Handler<F, Ctx>, StaticProgramError>
291 where
292 Ctx: ExecutionCtxTrait,
293 {
294 let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
295 let op = self.pre_compute_impl(pc, inst, pre_compute)?;
296
297 dispatch!(
298 execute_e1_handler,
299 execute_e1_generic_handler,
300 execute_e1_setup_handler,
301 pre_compute,
302 op
303 )
304 }
305}
306
307#[cfg(feature = "aot")]
308impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
309 AotExecutor<F> for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
310{
311}
312
313impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
314 InterpreterMeteredExecutor<F> for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
315{
316 #[inline(always)]
317 fn metered_pre_compute_size(&self) -> usize {
318 std::mem::size_of::<E2PreCompute<FieldExpressionPreCompute>>()
319 }
320
321 #[cfg(not(feature = "tco"))]
322 fn metered_pre_compute<Ctx>(
323 &self,
324 chip_idx: usize,
325 pc: u32,
326 inst: &Instruction<F>,
327 data: &mut [u8],
328 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
329 where
330 Ctx: MeteredExecutionCtxTrait,
331 {
332 let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
333 pre_compute.chip_idx = chip_idx as u32;
334
335 let pre_compute_pure = &mut pre_compute.data;
336 let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
337
338 dispatch!(
339 execute_e2_handler,
340 execute_e2_generic_handler,
341 execute_e2_setup_handler,
342 pre_compute_pure,
343 op
344 )
345 }
346
347 #[cfg(feature = "tco")]
348 fn metered_handler<Ctx>(
349 &self,
350 chip_idx: usize,
351 pc: u32,
352 inst: &Instruction<F>,
353 data: &mut [u8],
354 ) -> Result<Handler<F, Ctx>, StaticProgramError>
355 where
356 Ctx: MeteredExecutionCtxTrait,
357 {
358 let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
359 pre_compute.chip_idx = chip_idx as u32;
360
361 let pre_compute_pure = &mut pre_compute.data;
362 let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
363
364 dispatch!(
365 execute_e2_handler,
366 execute_e2_generic_handler,
367 execute_e2_setup_handler,
368 pre_compute_pure,
369 op
370 )
371 }
372}
373
374#[cfg(feature = "aot")]
375impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
376 AotMeteredExecutor<F> for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
377{
378}
379
380#[inline(always)]
381unsafe fn execute_e12_impl<
382 F: PrimeField32,
383 CTX: ExecutionCtxTrait,
384 const BLOCKS: usize,
385 const BLOCK_SIZE: usize,
386 const IS_FP2: bool,
387 const FIELD_TYPE: u8,
388 const OP: u8,
389>(
390 pre_compute: &FieldExpressionPreCompute,
391 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
392) {
393 let rs_vals = pre_compute
394 .rs_addrs
395 .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
396
397 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
398 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
399 from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
400 });
401
402 let output_data = if IS_FP2 {
403 fp2_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
404 } else {
405 field_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
406 };
407
408 let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
409 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
410
411 for (i, block) in output_data.into_iter().enumerate() {
412 exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
413 }
414
415 let pc = exec_state.pc();
416 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
417}
418
419#[inline(always)]
420unsafe fn execute_e12_generic_impl<
421 F: PrimeField32,
422 CTX: ExecutionCtxTrait,
423 const BLOCKS: usize,
424 const BLOCK_SIZE: usize,
425>(
426 pre_compute: &FieldExpressionPreCompute,
427 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
428) {
429 let rs_vals = pre_compute
430 .rs_addrs
431 .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
432
433 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
434 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
435 from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
436 });
437 let read_data_dyn: DynArray<u8> = read_data.into();
438
439 let writes = run_field_expression_precomputed::<true>(
440 pre_compute.expr,
441 pre_compute.flag_idx as usize,
442 &read_data_dyn.0,
443 );
444
445 let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
446 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
447
448 let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
449 for (i, block) in data.into_iter().enumerate() {
450 exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
451 }
452
453 let pc = exec_state.pc();
454 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
455}
456
457#[inline(always)]
458unsafe fn execute_e12_setup_impl<
459 F: PrimeField32,
460 CTX: ExecutionCtxTrait,
461 const BLOCKS: usize,
462 const BLOCK_SIZE: usize,
463 const IS_FP2: bool,
464>(
465 pre_compute: &FieldExpressionPreCompute,
466 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
467) -> Result<(), ExecutionError> {
468 let pc = exec_state.pc();
469 let rs_vals = pre_compute
471 .rs_addrs
472 .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
473 let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
474 debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
475 from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
476 });
477
478 let input_prime = if IS_FP2 {
480 BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened())
481 } else {
482 BigUint::from_bytes_le(read_data[0].as_flattened())
483 };
484
485 if input_prime != pre_compute.expr.prime {
486 let err = ExecutionError::Fail {
487 pc,
488 msg: "ModularSetup: mismatched prime",
489 };
490 return Err(err);
491 }
492
493 let read_data_dyn: DynArray<u8> = read_data.into();
494
495 let writes = run_field_expression_precomputed::<true>(
496 pre_compute.expr,
497 pre_compute.flag_idx as usize,
498 &read_data_dyn.0,
499 );
500
501 let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
502 debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
503
504 let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
505 for (i, block) in data.into_iter().enumerate() {
506 exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
507 }
508
509 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
510
511 Ok(())
512}
513
514#[create_handler]
515#[inline(always)]
516unsafe fn execute_e1_setup_impl<
517 F: PrimeField32,
518 CTX: ExecutionCtxTrait,
519 const BLOCKS: usize,
520 const BLOCK_SIZE: usize,
521 const IS_FP2: bool,
522>(
523 pre_compute: *const u8,
524 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
525) -> Result<(), ExecutionError> {
526 let pre_compute: &FieldExpressionPreCompute =
527 std::slice::from_raw_parts(pre_compute, size_of::<FieldExpressionPreCompute>()).borrow();
528 execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(pre_compute, exec_state)
529}
530
531#[create_handler]
532#[inline(always)]
533unsafe fn execute_e2_setup_impl<
534 F: PrimeField32,
535 CTX: MeteredExecutionCtxTrait,
536 const BLOCKS: usize,
537 const BLOCK_SIZE: usize,
538 const IS_FP2: bool,
539>(
540 pre_compute: *const u8,
541 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
542) -> Result<(), ExecutionError> {
543 let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = std::slice::from_raw_parts(
544 pre_compute,
545 size_of::<E2PreCompute<FieldExpressionPreCompute>>(),
546 )
547 .borrow();
548 exec_state
549 .ctx
550 .on_height_change(pre_compute.chip_idx as usize, 1);
551 execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(&pre_compute.data, exec_state)
552}
553
554#[create_handler]
555#[inline(always)]
556unsafe fn execute_e1_impl<
557 F: PrimeField32,
558 CTX: ExecutionCtxTrait,
559 const BLOCKS: usize,
560 const BLOCK_SIZE: usize,
561 const IS_FP2: bool,
562 const FIELD_TYPE: u8,
563 const OP: u8,
564>(
565 pre_compute: *const u8,
566 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
567) {
568 let pre_compute: &FieldExpressionPreCompute =
569 std::slice::from_raw_parts(pre_compute, size_of::<FieldExpressionPreCompute>()).borrow();
570 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(pre_compute, exec_state);
571}
572
573#[create_handler]
574#[inline(always)]
575unsafe fn execute_e2_impl<
576 F: PrimeField32,
577 CTX: MeteredExecutionCtxTrait,
578 const BLOCKS: usize,
579 const BLOCK_SIZE: usize,
580 const IS_FP2: bool,
581 const FIELD_TYPE: u8,
582 const OP: u8,
583>(
584 pre_compute: *const u8,
585 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
586) {
587 let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = std::slice::from_raw_parts(
588 pre_compute,
589 size_of::<E2PreCompute<FieldExpressionPreCompute>>(),
590 )
591 .borrow();
592 exec_state
593 .ctx
594 .on_height_change(pre_compute.chip_idx as usize, 1);
595 execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(
596 &pre_compute.data,
597 exec_state,
598 );
599}
600
601#[create_handler]
602#[inline(always)]
603unsafe fn execute_e1_generic_impl<
604 F: PrimeField32,
605 CTX: ExecutionCtxTrait,
606 const BLOCKS: usize,
607 const BLOCK_SIZE: usize,
608 const IS_FP2: bool,
609>(
610 pre_compute: *const u8,
611 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
612) {
613 let pre_compute: &FieldExpressionPreCompute =
614 std::slice::from_raw_parts(pre_compute, size_of::<FieldExpressionPreCompute>()).borrow();
615 execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(pre_compute, exec_state);
616}
617
618#[create_handler]
619#[inline(always)]
620unsafe fn execute_e2_generic_impl<
621 F: PrimeField32,
622 CTX: MeteredExecutionCtxTrait,
623 const BLOCKS: usize,
624 const BLOCK_SIZE: usize,
625 const IS_FP2: bool,
626>(
627 pre_compute: *const u8,
628 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
629) {
630 let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = std::slice::from_raw_parts(
631 pre_compute,
632 size_of::<E2PreCompute<FieldExpressionPreCompute>>(),
633 )
634 .borrow();
635 exec_state
636 .ctx
637 .on_height_change(pre_compute.chip_idx as usize, 1);
638 execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(&pre_compute.data, exec_state);
639}