openvm_algebra_circuit/
execution.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode};
8use openvm_circuit::{
9    arch::*,
10    system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_instructions::{
14    instruction::Instruction,
15    program::DEFAULT_PC_STEP,
16    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
17};
18use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
19use openvm_stark_backend::p3_field::PrimeField32;
20
21use super::FieldExprVecHeapExecutor;
22use crate::fields::{
23    field_operation, fp2_operation, get_field_type, get_fp2_field_type, FieldType, Operation,
24};
25
26macro_rules! generate_field_dispatch {
27    (
28        $field_type:expr,
29        $op:expr,
30        $blocks:expr,
31        $block_size:expr,
32        $execute_fn:ident,
33        [$(($curve:ident, $operation:ident)),* $(,)?]
34    ) => {
35        match ($field_type, $op) {
36            $(
37                (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
38                    _,
39                    _,
40                    $blocks,
41                    $block_size,
42                    false,
43                    { FieldType::$curve as u8 },
44                    { Operation::$operation as u8 },
45                >),
46            )*
47        }
48    };
49}
50
51macro_rules! generate_fp2_dispatch {
52    (
53        $field_type:expr,
54        $op:expr,
55        $blocks:expr,
56        $block_size:expr,
57        $execute_fn:ident,
58        [$(($curve:ident, $operation:ident)),* $(,)?]
59    ) => {
60        match ($field_type, $op) {
61            $(
62                (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
63                    _,
64                    _,
65                    $blocks,
66                    $block_size,
67                    true,
68                    { FieldType::$curve as u8 },
69                    { Operation::$operation as u8 },
70                >),
71            )*
72            _ => panic!("Unsupported fp2 field")
73        }
74    };
75}
76
77macro_rules! dispatch {
78    ($execute_impl:ident,$execute_generic_impl:ident,$execute_setup_impl:ident,$pre_compute:ident,$op:ident) => {
79        if let Some(op) = $op {
80            let modulus = &$pre_compute.expr.prime;
81            if IS_FP2 {
82                if let Some(field_type) = get_fp2_field_type(modulus) {
83                    generate_fp2_dispatch!(
84                        field_type,
85                        op,
86                        BLOCKS,
87                        BLOCK_SIZE,
88                        $execute_impl,
89                        [
90                            (BN254Coordinate, Add),
91                            (BN254Coordinate, Sub),
92                            (BN254Coordinate, Mul),
93                            (BN254Coordinate, Div),
94                            (BLS12_381Coordinate, Add),
95                            (BLS12_381Coordinate, Sub),
96                            (BLS12_381Coordinate, Mul),
97                            (BLS12_381Coordinate, Div),
98                        ]
99                    )
100                } else {
101                    Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
102                }
103            } else if let Some(field_type) = get_field_type(modulus) {
104                generate_field_dispatch!(
105                    field_type,
106                    op,
107                    BLOCKS,
108                    BLOCK_SIZE,
109                    $execute_impl,
110                    [
111                        (K256Coordinate, Add),
112                        (K256Coordinate, Sub),
113                        (K256Coordinate, Mul),
114                        (K256Coordinate, Div),
115                        (K256Scalar, Add),
116                        (K256Scalar, Sub),
117                        (K256Scalar, Mul),
118                        (K256Scalar, Div),
119                        (P256Coordinate, Add),
120                        (P256Coordinate, Sub),
121                        (P256Coordinate, Mul),
122                        (P256Coordinate, Div),
123                        (P256Scalar, Add),
124                        (P256Scalar, Sub),
125                        (P256Scalar, Mul),
126                        (P256Scalar, Div),
127                        (BN254Coordinate, Add),
128                        (BN254Coordinate, Sub),
129                        (BN254Coordinate, Mul),
130                        (BN254Coordinate, Div),
131                        (BN254Scalar, Add),
132                        (BN254Scalar, Sub),
133                        (BN254Scalar, Mul),
134                        (BN254Scalar, Div),
135                        (BLS12_381Coordinate, Add),
136                        (BLS12_381Coordinate, Sub),
137                        (BLS12_381Coordinate, Mul),
138                        (BLS12_381Coordinate, Div),
139                        (BLS12_381Scalar, Add),
140                        (BLS12_381Scalar, Sub),
141                        (BLS12_381Scalar, Mul),
142                        (BLS12_381Scalar, Div),
143                    ]
144                )
145            } else {
146                Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
147            }
148        } else {
149            Ok($execute_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
150        }
151    };
152}
153
154#[derive(AlignedBytesBorrow, Clone)]
155#[repr(C)]
156struct FieldExpressionPreCompute<'a> {
157    expr: &'a FieldExpr,
158    rs_addrs: [u8; 2],
159    a: u8,
160    flag_idx: u8,
161}
162
163impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
164    FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
165{
166    fn pre_compute_impl<F: PrimeField32>(
167        &'a self,
168        pc: u32,
169        inst: &Instruction<F>,
170        data: &mut FieldExpressionPreCompute<'a>,
171    ) -> Result<Option<Operation>, StaticProgramError> {
172        let Instruction {
173            opcode,
174            a,
175            b,
176            c,
177            d,
178            e,
179            ..
180        } = inst;
181
182        let a = a.as_canonical_u32();
183        let b = b.as_canonical_u32();
184        let c = c.as_canonical_u32();
185        let d = d.as_canonical_u32();
186        let e = e.as_canonical_u32();
187        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
188            return Err(StaticProgramError::InvalidInstruction(pc));
189        }
190
191        let local_opcode = opcode.local_opcode_idx(self.0.offset);
192
193        let needs_setup = self.0.expr.needs_setup();
194        let mut flag_idx = self.0.expr.num_flags() as u8;
195        if needs_setup {
196            if let Some(opcode_position) = self
197                .0
198                .local_opcode_idx
199                .iter()
200                .position(|&idx| idx == local_opcode)
201            {
202                if opcode_position < self.0.opcode_flag_idx.len() {
203                    flag_idx = self.0.opcode_flag_idx[opcode_position] as u8;
204                }
205            }
206        }
207
208        let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
209        *data = FieldExpressionPreCompute {
210            a: a as u8,
211            rs_addrs,
212            expr: &self.0.expr,
213            flag_idx,
214        };
215
216        if IS_FP2 {
217            let is_setup = local_opcode == Fp2Opcode::SETUP_ADDSUB as usize
218                || local_opcode == Fp2Opcode::SETUP_MULDIV as usize;
219
220            let op = if is_setup {
221                None
222            } else {
223                match local_opcode {
224                    x if x == Fp2Opcode::ADD as usize => Some(Operation::Add),
225                    x if x == Fp2Opcode::SUB as usize => Some(Operation::Sub),
226                    x if x == Fp2Opcode::MUL as usize => Some(Operation::Mul),
227                    x if x == Fp2Opcode::DIV as usize => Some(Operation::Div),
228                    _ => unreachable!(),
229                }
230            };
231
232            Ok(op)
233        } else {
234            let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize
235                || local_opcode == Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize;
236
237            let op = if is_setup {
238                None
239            } else {
240                match local_opcode {
241                    x if x == Rv32ModularArithmeticOpcode::ADD as usize => Some(Operation::Add),
242                    x if x == Rv32ModularArithmeticOpcode::SUB as usize => Some(Operation::Sub),
243                    x if x == Rv32ModularArithmeticOpcode::MUL as usize => Some(Operation::Mul),
244                    x if x == Rv32ModularArithmeticOpcode::DIV as usize => Some(Operation::Div),
245                    _ => unreachable!(),
246                }
247            };
248
249            Ok(op)
250        }
251    }
252}
253
254impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool> Executor<F>
255    for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
256{
257    #[inline(always)]
258    fn pre_compute_size(&self) -> usize {
259        std::mem::size_of::<FieldExpressionPreCompute>()
260    }
261
262    #[cfg(not(feature = "tco"))]
263    fn pre_compute<Ctx>(
264        &self,
265        pc: u32,
266        inst: &Instruction<F>,
267        data: &mut [u8],
268    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
269    where
270        Ctx: ExecutionCtxTrait,
271    {
272        let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
273        let op = self.pre_compute_impl(pc, inst, pre_compute)?;
274
275        dispatch!(
276            execute_e1_handler,
277            execute_e1_generic_handler,
278            execute_e1_setup_handler,
279            pre_compute,
280            op
281        )
282    }
283
284    #[cfg(feature = "tco")]
285    fn handler<Ctx>(
286        &self,
287        pc: u32,
288        inst: &Instruction<F>,
289        data: &mut [u8],
290    ) -> Result<Handler<F, Ctx>, StaticProgramError>
291    where
292        Ctx: ExecutionCtxTrait,
293    {
294        let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
295        let op = self.pre_compute_impl(pc, inst, pre_compute)?;
296
297        dispatch!(
298            execute_e1_handler,
299            execute_e1_generic_handler,
300            execute_e1_setup_handler,
301            pre_compute,
302            op
303        )
304    }
305}
306
307impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
308    MeteredExecutor<F> for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
309{
310    #[inline(always)]
311    fn metered_pre_compute_size(&self) -> usize {
312        std::mem::size_of::<E2PreCompute<FieldExpressionPreCompute>>()
313    }
314
315    #[cfg(not(feature = "tco"))]
316    fn metered_pre_compute<Ctx>(
317        &self,
318        chip_idx: usize,
319        pc: u32,
320        inst: &Instruction<F>,
321        data: &mut [u8],
322    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
323    where
324        Ctx: MeteredExecutionCtxTrait,
325    {
326        let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
327        pre_compute.chip_idx = chip_idx as u32;
328
329        let pre_compute_pure = &mut pre_compute.data;
330        let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
331
332        dispatch!(
333            execute_e2_handler,
334            execute_e2_generic_handler,
335            execute_e2_setup_handler,
336            pre_compute_pure,
337            op
338        )
339    }
340
341    #[cfg(feature = "tco")]
342    fn metered_handler<Ctx>(
343        &self,
344        chip_idx: usize,
345        pc: u32,
346        inst: &Instruction<F>,
347        data: &mut [u8],
348    ) -> Result<Handler<F, Ctx>, StaticProgramError>
349    where
350        Ctx: MeteredExecutionCtxTrait,
351    {
352        let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
353        pre_compute.chip_idx = chip_idx as u32;
354
355        let pre_compute_pure = &mut pre_compute.data;
356        let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
357
358        dispatch!(
359            execute_e2_handler,
360            execute_e2_generic_handler,
361            execute_e2_setup_handler,
362            pre_compute_pure,
363            op
364        )
365    }
366}
367
368#[inline(always)]
369unsafe fn execute_e12_impl<
370    F: PrimeField32,
371    CTX: ExecutionCtxTrait,
372    const BLOCKS: usize,
373    const BLOCK_SIZE: usize,
374    const IS_FP2: bool,
375    const FIELD_TYPE: u8,
376    const OP: u8,
377>(
378    pre_compute: &FieldExpressionPreCompute,
379    instret: &mut u64,
380    pc: &mut u32,
381    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
382) {
383    let rs_vals = pre_compute
384        .rs_addrs
385        .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
386
387    let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
388        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
389        from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
390    });
391
392    let output_data = if IS_FP2 {
393        fp2_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
394    } else {
395        field_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
396    };
397
398    let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
399    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
400
401    for (i, block) in output_data.into_iter().enumerate() {
402        exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
403    }
404
405    *pc = pc.wrapping_add(DEFAULT_PC_STEP);
406    *instret += 1;
407}
408
409#[inline(always)]
410unsafe fn execute_e12_generic_impl<
411    F: PrimeField32,
412    CTX: ExecutionCtxTrait,
413    const BLOCKS: usize,
414    const BLOCK_SIZE: usize,
415>(
416    pre_compute: &FieldExpressionPreCompute,
417    instret: &mut u64,
418    pc: &mut u32,
419    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
420) {
421    let rs_vals = pre_compute
422        .rs_addrs
423        .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
424
425    let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
426        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
427        from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
428    });
429    let read_data_dyn: DynArray<u8> = read_data.into();
430
431    let writes = run_field_expression_precomputed::<true>(
432        pre_compute.expr,
433        pre_compute.flag_idx as usize,
434        &read_data_dyn.0,
435    );
436
437    let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
438    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
439
440    let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
441    for (i, block) in data.into_iter().enumerate() {
442        exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
443    }
444
445    *pc = pc.wrapping_add(DEFAULT_PC_STEP);
446    *instret += 1;
447}
448
449#[inline(always)]
450unsafe fn execute_e12_setup_impl<
451    F: PrimeField32,
452    CTX: ExecutionCtxTrait,
453    const BLOCKS: usize,
454    const BLOCK_SIZE: usize,
455    const IS_FP2: bool,
456>(
457    pre_compute: &FieldExpressionPreCompute,
458    instret: &mut u64,
459    pc: &mut u32,
460    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
461) -> Result<(), ExecutionError> {
462    // Read the first input (which should be the prime)
463    let rs_vals = pre_compute
464        .rs_addrs
465        .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
466    let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
467        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
468        from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
469    });
470
471    // Extract first field element as the prime
472    let input_prime = if IS_FP2 {
473        BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened())
474    } else {
475        BigUint::from_bytes_le(read_data[0].as_flattened())
476    };
477
478    if input_prime != pre_compute.expr.prime {
479        let err = ExecutionError::Fail {
480            pc: *pc,
481            msg: "ModularSetup: mismatched prime",
482        };
483        return Err(err);
484    }
485
486    let read_data_dyn: DynArray<u8> = read_data.into();
487
488    let writes = run_field_expression_precomputed::<true>(
489        pre_compute.expr,
490        pre_compute.flag_idx as usize,
491        &read_data_dyn.0,
492    );
493
494    let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
495    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
496
497    let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
498    for (i, block) in data.into_iter().enumerate() {
499        exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
500    }
501
502    *pc = pc.wrapping_add(DEFAULT_PC_STEP);
503    *instret += 1;
504
505    Ok(())
506}
507
508#[create_handler]
509#[inline(always)]
510unsafe fn execute_e1_setup_impl<
511    F: PrimeField32,
512    CTX: ExecutionCtxTrait,
513    const BLOCKS: usize,
514    const BLOCK_SIZE: usize,
515    const IS_FP2: bool,
516>(
517    pre_compute: &[u8],
518    instret: &mut u64,
519    pc: &mut u32,
520    _instret_end: u64,
521    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
522) -> Result<(), ExecutionError> {
523    let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
524    execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(pre_compute, instret, pc, exec_state)
525}
526
527#[create_handler]
528#[inline(always)]
529unsafe fn execute_e2_setup_impl<
530    F: PrimeField32,
531    CTX: MeteredExecutionCtxTrait,
532    const BLOCKS: usize,
533    const BLOCK_SIZE: usize,
534    const IS_FP2: bool,
535>(
536    pre_compute: &[u8],
537    instret: &mut u64,
538    pc: &mut u32,
539    _arg: u64,
540    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
541) -> Result<(), ExecutionError> {
542    let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
543    exec_state
544        .ctx
545        .on_height_change(pre_compute.chip_idx as usize, 1);
546    execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(
547        &pre_compute.data,
548        instret,
549        pc,
550        exec_state,
551    )
552}
553
554#[create_handler]
555#[inline(always)]
556unsafe fn execute_e1_impl<
557    F: PrimeField32,
558    CTX: ExecutionCtxTrait,
559    const BLOCKS: usize,
560    const BLOCK_SIZE: usize,
561    const IS_FP2: bool,
562    const FIELD_TYPE: u8,
563    const OP: u8,
564>(
565    pre_compute: &[u8],
566    instret: &mut u64,
567    pc: &mut u32,
568    _instret_end: u64,
569    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
570) {
571    let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
572    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(
573        pre_compute,
574        instret,
575        pc,
576        exec_state,
577    );
578}
579
580#[create_handler]
581#[inline(always)]
582unsafe fn execute_e2_impl<
583    F: PrimeField32,
584    CTX: MeteredExecutionCtxTrait,
585    const BLOCKS: usize,
586    const BLOCK_SIZE: usize,
587    const IS_FP2: bool,
588    const FIELD_TYPE: u8,
589    const OP: u8,
590>(
591    pre_compute: &[u8],
592    instret: &mut u64,
593    pc: &mut u32,
594    _arg: u64,
595    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
596) {
597    let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
598    exec_state
599        .ctx
600        .on_height_change(pre_compute.chip_idx as usize, 1);
601    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(
602        &pre_compute.data,
603        instret,
604        pc,
605        exec_state,
606    );
607}
608
609#[create_handler]
610#[inline(always)]
611unsafe fn execute_e1_generic_impl<
612    F: PrimeField32,
613    CTX: ExecutionCtxTrait,
614    const BLOCKS: usize,
615    const BLOCK_SIZE: usize,
616    const IS_FP2: bool,
617>(
618    pre_compute: &[u8],
619    instret: &mut u64,
620    pc: &mut u32,
621    _instret_end: u64,
622    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
623) {
624    let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
625    execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(pre_compute, instret, pc, exec_state);
626}
627
628#[create_handler]
629#[inline(always)]
630unsafe fn execute_e2_generic_impl<
631    F: PrimeField32,
632    CTX: MeteredExecutionCtxTrait,
633    const BLOCKS: usize,
634    const BLOCK_SIZE: usize,
635    const IS_FP2: bool,
636>(
637    pre_compute: &[u8],
638    instret: &mut u64,
639    pc: &mut u32,
640    _arg: u64,
641    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
642) {
643    let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
644    exec_state
645        .ctx
646        .on_height_change(pre_compute.chip_idx as usize, 1);
647    execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(
648        &pre_compute.data,
649        instret,
650        pc,
651        exec_state,
652    );
653}