openvm_algebra_circuit/
execution.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode};
8use openvm_circuit::{
9    arch::*,
10    system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11};
12use openvm_circuit_primitives::AlignedBytesBorrow;
13use openvm_instructions::{
14    instruction::Instruction,
15    program::DEFAULT_PC_STEP,
16    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
17};
18use openvm_mod_circuit_builder::{run_field_expression_precomputed, FieldExpr};
19use openvm_stark_backend::p3_field::PrimeField32;
20
21use super::FieldExprVecHeapExecutor;
22use crate::fields::{
23    field_operation, fp2_operation, get_field_type, get_fp2_field_type, FieldType, Operation,
24};
25
26macro_rules! generate_field_dispatch {
27    (
28        $field_type:expr,
29        $op:expr,
30        $blocks:expr,
31        $block_size:expr,
32        $execute_fn:ident,
33        [$(($curve:ident, $operation:ident)),* $(,)?]
34    ) => {
35        match ($field_type, $op) {
36            $(
37                (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
38                    _,
39                    _,
40                    $blocks,
41                    $block_size,
42                    false,
43                    { FieldType::$curve as u8 },
44                    { Operation::$operation as u8 },
45                >),
46            )*
47        }
48    };
49}
50
51macro_rules! generate_fp2_dispatch {
52    (
53        $field_type:expr,
54        $op:expr,
55        $blocks:expr,
56        $block_size:expr,
57        $execute_fn:ident,
58        [$(($curve:ident, $operation:ident)),* $(,)?]
59    ) => {
60        match ($field_type, $op) {
61            $(
62                (FieldType::$curve, Operation::$operation) => Ok($execute_fn::<
63                    _,
64                    _,
65                    $blocks,
66                    $block_size,
67                    true,
68                    { FieldType::$curve as u8 },
69                    { Operation::$operation as u8 },
70                >),
71            )*
72            _ => panic!("Unsupported fp2 field")
73        }
74    };
75}
76
77macro_rules! dispatch {
78    ($execute_impl:ident,$execute_generic_impl:ident,$execute_setup_impl:ident,$pre_compute:ident,$op:ident) => {
79        if let Some(op) = $op {
80            let modulus = &$pre_compute.expr.prime;
81            if IS_FP2 {
82                if let Some(field_type) = get_fp2_field_type(modulus) {
83                    generate_fp2_dispatch!(
84                        field_type,
85                        op,
86                        BLOCKS,
87                        BLOCK_SIZE,
88                        $execute_impl,
89                        [
90                            (BN254Coordinate, Add),
91                            (BN254Coordinate, Sub),
92                            (BN254Coordinate, Mul),
93                            (BN254Coordinate, Div),
94                            (BLS12_381Coordinate, Add),
95                            (BLS12_381Coordinate, Sub),
96                            (BLS12_381Coordinate, Mul),
97                            (BLS12_381Coordinate, Div),
98                        ]
99                    )
100                } else {
101                    Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
102                }
103            } else if let Some(field_type) = get_field_type(modulus) {
104                generate_field_dispatch!(
105                    field_type,
106                    op,
107                    BLOCKS,
108                    BLOCK_SIZE,
109                    $execute_impl,
110                    [
111                        (K256Coordinate, Add),
112                        (K256Coordinate, Sub),
113                        (K256Coordinate, Mul),
114                        (K256Coordinate, Div),
115                        (K256Scalar, Add),
116                        (K256Scalar, Sub),
117                        (K256Scalar, Mul),
118                        (K256Scalar, Div),
119                        (P256Coordinate, Add),
120                        (P256Coordinate, Sub),
121                        (P256Coordinate, Mul),
122                        (P256Coordinate, Div),
123                        (P256Scalar, Add),
124                        (P256Scalar, Sub),
125                        (P256Scalar, Mul),
126                        (P256Scalar, Div),
127                        (BN254Coordinate, Add),
128                        (BN254Coordinate, Sub),
129                        (BN254Coordinate, Mul),
130                        (BN254Coordinate, Div),
131                        (BN254Scalar, Add),
132                        (BN254Scalar, Sub),
133                        (BN254Scalar, Mul),
134                        (BN254Scalar, Div),
135                        (BLS12_381Coordinate, Add),
136                        (BLS12_381Coordinate, Sub),
137                        (BLS12_381Coordinate, Mul),
138                        (BLS12_381Coordinate, Div),
139                        (BLS12_381Scalar, Add),
140                        (BLS12_381Scalar, Sub),
141                        (BLS12_381Scalar, Mul),
142                        (BLS12_381Scalar, Div),
143                    ]
144                )
145            } else {
146                Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
147            }
148        } else {
149            Ok($execute_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
150        }
151    };
152}
153
154#[derive(AlignedBytesBorrow, Clone)]
155#[repr(C)]
156struct FieldExpressionPreCompute<'a> {
157    expr: &'a FieldExpr,
158    rs_addrs: [u8; 2],
159    a: u8,
160    flag_idx: u8,
161}
162
163impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
164    FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
165{
166    fn pre_compute_impl<F: PrimeField32>(
167        &'a self,
168        pc: u32,
169        inst: &Instruction<F>,
170        data: &mut FieldExpressionPreCompute<'a>,
171    ) -> Result<Option<Operation>, StaticProgramError> {
172        let Instruction {
173            opcode,
174            a,
175            b,
176            c,
177            d,
178            e,
179            ..
180        } = inst;
181
182        let a = a.as_canonical_u32();
183        let b = b.as_canonical_u32();
184        let c = c.as_canonical_u32();
185        let d = d.as_canonical_u32();
186        let e = e.as_canonical_u32();
187        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
188            return Err(StaticProgramError::InvalidInstruction(pc));
189        }
190
191        let local_opcode = opcode.local_opcode_idx(self.0.offset);
192
193        let needs_setup = self.0.expr.needs_setup();
194        let mut flag_idx = self.0.expr.num_flags() as u8;
195        if needs_setup {
196            if let Some(opcode_position) = self
197                .0
198                .local_opcode_idx
199                .iter()
200                .position(|&idx| idx == local_opcode)
201            {
202                if opcode_position < self.0.opcode_flag_idx.len() {
203                    flag_idx = self.0.opcode_flag_idx[opcode_position] as u8;
204                }
205            }
206        }
207
208        let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
209        *data = FieldExpressionPreCompute {
210            a: a as u8,
211            rs_addrs,
212            expr: &self.0.expr,
213            flag_idx,
214        };
215
216        if IS_FP2 {
217            let is_setup = local_opcode == Fp2Opcode::SETUP_ADDSUB as usize
218                || local_opcode == Fp2Opcode::SETUP_MULDIV as usize;
219
220            let op = if is_setup {
221                None
222            } else {
223                match local_opcode {
224                    x if x == Fp2Opcode::ADD as usize => Some(Operation::Add),
225                    x if x == Fp2Opcode::SUB as usize => Some(Operation::Sub),
226                    x if x == Fp2Opcode::MUL as usize => Some(Operation::Mul),
227                    x if x == Fp2Opcode::DIV as usize => Some(Operation::Div),
228                    _ => unreachable!(),
229                }
230            };
231
232            Ok(op)
233        } else {
234            let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize
235                || local_opcode == Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize;
236
237            let op = if is_setup {
238                None
239            } else {
240                match local_opcode {
241                    x if x == Rv32ModularArithmeticOpcode::ADD as usize => Some(Operation::Add),
242                    x if x == Rv32ModularArithmeticOpcode::SUB as usize => Some(Operation::Sub),
243                    x if x == Rv32ModularArithmeticOpcode::MUL as usize => Some(Operation::Mul),
244                    x if x == Rv32ModularArithmeticOpcode::DIV as usize => Some(Operation::Div),
245                    _ => unreachable!(),
246                }
247            };
248
249            Ok(op)
250        }
251    }
252}
253
254impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool> Executor<F>
255    for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
256{
257    #[inline(always)]
258    fn pre_compute_size(&self) -> usize {
259        std::mem::size_of::<FieldExpressionPreCompute>()
260    }
261
262    fn pre_compute<Ctx>(
263        &self,
264        pc: u32,
265        inst: &Instruction<F>,
266        data: &mut [u8],
267    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
268    where
269        Ctx: ExecutionCtxTrait,
270    {
271        let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
272        let op = self.pre_compute_impl(pc, inst, pre_compute)?;
273
274        dispatch!(
275            execute_e1_impl,
276            execute_e1_generic_impl,
277            execute_e1_setup_impl,
278            pre_compute,
279            op
280        )
281    }
282
283    #[cfg(feature = "tco")]
284    fn handler<Ctx>(
285        &self,
286        pc: u32,
287        inst: &Instruction<F>,
288        data: &mut [u8],
289    ) -> Result<Handler<F, Ctx>, StaticProgramError>
290    where
291        Ctx: ExecutionCtxTrait,
292    {
293        let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
294        let op = self.pre_compute_impl(pc, inst, pre_compute)?;
295
296        dispatch!(
297            execute_e1_tco_handler,
298            execute_e1_generic_tco_handler,
299            execute_e1_setup_tco_handler,
300            pre_compute,
301            op
302        )
303    }
304}
305
306impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
307    MeteredExecutor<F> for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
308{
309    #[inline(always)]
310    fn metered_pre_compute_size(&self) -> usize {
311        std::mem::size_of::<E2PreCompute<FieldExpressionPreCompute>>()
312    }
313
314    fn metered_pre_compute<Ctx>(
315        &self,
316        chip_idx: usize,
317        pc: u32,
318        inst: &Instruction<F>,
319        data: &mut [u8],
320    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
321    where
322        Ctx: MeteredExecutionCtxTrait,
323    {
324        let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
325        pre_compute.chip_idx = chip_idx as u32;
326
327        let pre_compute_pure = &mut pre_compute.data;
328        let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
329
330        dispatch!(
331            execute_e2_impl,
332            execute_e2_generic_impl,
333            execute_e2_setup_impl,
334            pre_compute_pure,
335            op
336        )
337    }
338
339    #[cfg(feature = "tco")]
340    fn metered_handler<Ctx>(
341        &self,
342        chip_idx: usize,
343        pc: u32,
344        inst: &Instruction<F>,
345        data: &mut [u8],
346    ) -> Result<Handler<F, Ctx>, StaticProgramError>
347    where
348        Ctx: MeteredExecutionCtxTrait,
349    {
350        let pre_compute: &mut E2PreCompute<FieldExpressionPreCompute> = data.borrow_mut();
351        pre_compute.chip_idx = chip_idx as u32;
352
353        let pre_compute_pure = &mut pre_compute.data;
354        let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?;
355
356        dispatch!(
357            execute_e2_tco_handler,
358            execute_e2_generic_tco_handler,
359            execute_e2_setup_tco_handler,
360            pre_compute_pure,
361            op
362        )
363    }
364}
365unsafe fn execute_e12_impl<
366    F: PrimeField32,
367    CTX: ExecutionCtxTrait,
368    const BLOCKS: usize,
369    const BLOCK_SIZE: usize,
370    const IS_FP2: bool,
371    const FIELD_TYPE: u8,
372    const OP: u8,
373>(
374    pre_compute: &FieldExpressionPreCompute,
375    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
376) {
377    let rs_vals = pre_compute
378        .rs_addrs
379        .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
380
381    let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
382        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
383        from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
384    });
385
386    let output_data = if IS_FP2 {
387        fp2_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
388    } else {
389        field_operation::<FIELD_TYPE, BLOCKS, BLOCK_SIZE, OP>(read_data)
390    };
391
392    let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
393    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
394
395    for (i, block) in output_data.into_iter().enumerate() {
396        vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
397    }
398
399    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
400    vm_state.instret += 1;
401}
402
403unsafe fn execute_e12_generic_impl<
404    F: PrimeField32,
405    CTX: ExecutionCtxTrait,
406    const BLOCKS: usize,
407    const BLOCK_SIZE: usize,
408>(
409    pre_compute: &FieldExpressionPreCompute,
410    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
411) {
412    let rs_vals = pre_compute
413        .rs_addrs
414        .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
415
416    let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
417        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
418        from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
419    });
420    let read_data_dyn: DynArray<u8> = read_data.into();
421
422    let writes = run_field_expression_precomputed::<true>(
423        pre_compute.expr,
424        pre_compute.flag_idx as usize,
425        &read_data_dyn.0,
426    );
427
428    let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
429    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
430
431    let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
432    for (i, block) in data.into_iter().enumerate() {
433        vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
434    }
435
436    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
437    vm_state.instret += 1;
438}
439
440unsafe fn execute_e12_setup_impl<
441    F: PrimeField32,
442    CTX: ExecutionCtxTrait,
443    const BLOCKS: usize,
444    const BLOCK_SIZE: usize,
445    const IS_FP2: bool,
446>(
447    pre_compute: &FieldExpressionPreCompute,
448    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
449) {
450    // Read the first input (which should be the prime)
451    let rs_vals = pre_compute
452        .rs_addrs
453        .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
454    let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| {
455        debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
456        from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32))
457    });
458
459    // Extract first field element as the prime
460    let input_prime = if IS_FP2 {
461        BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened())
462    } else {
463        BigUint::from_bytes_le(read_data[0].as_flattened())
464    };
465
466    if input_prime != pre_compute.expr.prime {
467        vm_state.exit_code = Err(ExecutionError::Fail {
468            pc: vm_state.pc,
469            msg: "ModularSetup: mismatched prime",
470        });
471        return;
472    }
473
474    let read_data_dyn: DynArray<u8> = read_data.into();
475
476    let writes = run_field_expression_precomputed::<true>(
477        pre_compute.expr,
478        pre_compute.flag_idx as usize,
479        &read_data_dyn.0,
480    );
481
482    let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32));
483    debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS));
484
485    let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into();
486    for (i, block) in data.into_iter().enumerate() {
487        vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block);
488    }
489
490    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
491    vm_state.instret += 1;
492}
493
494#[create_tco_handler]
495unsafe fn execute_e1_setup_impl<
496    F: PrimeField32,
497    CTX: ExecutionCtxTrait,
498    const BLOCKS: usize,
499    const BLOCK_SIZE: usize,
500    const IS_FP2: bool,
501>(
502    pre_compute: &[u8],
503    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
504) {
505    let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
506    execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(pre_compute, vm_state);
507}
508
509#[create_tco_handler]
510unsafe fn execute_e2_setup_impl<
511    F: PrimeField32,
512    CTX: MeteredExecutionCtxTrait,
513    const BLOCKS: usize,
514    const BLOCK_SIZE: usize,
515    const IS_FP2: bool,
516>(
517    pre_compute: &[u8],
518    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
519) {
520    let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
521    vm_state
522        .ctx
523        .on_height_change(pre_compute.chip_idx as usize, 1);
524    execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(&pre_compute.data, vm_state);
525}
526
527#[create_tco_handler]
528unsafe fn execute_e1_impl<
529    F: PrimeField32,
530    CTX: ExecutionCtxTrait,
531    const BLOCKS: usize,
532    const BLOCK_SIZE: usize,
533    const IS_FP2: bool,
534    const FIELD_TYPE: u8,
535    const OP: u8,
536>(
537    pre_compute: &[u8],
538    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
539) {
540    let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
541    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(pre_compute, vm_state);
542}
543
544#[create_tco_handler]
545unsafe fn execute_e2_impl<
546    F: PrimeField32,
547    CTX: MeteredExecutionCtxTrait,
548    const BLOCKS: usize,
549    const BLOCK_SIZE: usize,
550    const IS_FP2: bool,
551    const FIELD_TYPE: u8,
552    const OP: u8,
553>(
554    pre_compute: &[u8],
555    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
556) {
557    let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
558    vm_state
559        .ctx
560        .on_height_change(pre_compute.chip_idx as usize, 1);
561    execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(
562        &pre_compute.data,
563        vm_state,
564    );
565}
566
567#[create_tco_handler]
568unsafe fn execute_e1_generic_impl<
569    F: PrimeField32,
570    CTX: ExecutionCtxTrait,
571    const BLOCKS: usize,
572    const BLOCK_SIZE: usize,
573    const IS_FP2: bool,
574>(
575    pre_compute: &[u8],
576    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
577) {
578    let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow();
579    execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(pre_compute, vm_state);
580}
581
582#[create_tco_handler]
583unsafe fn execute_e2_generic_impl<
584    F: PrimeField32,
585    CTX: MeteredExecutionCtxTrait,
586    const BLOCKS: usize,
587    const BLOCK_SIZE: usize,
588    const IS_FP2: bool,
589>(
590    pre_compute: &[u8],
591    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
592) {
593    let pre_compute: &E2PreCompute<FieldExpressionPreCompute> = pre_compute.borrow();
594    vm_state
595        .ctx
596        .on_height_change(pre_compute.chip_idx as usize, 1);
597    execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(&pre_compute.data, vm_state);
598}