openvm_native_circuit/field_arithmetic/
core.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use itertools::izip;
4use openvm_circuit::arch::{
5    AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
6    VmCoreAir, VmCoreChip,
7};
8use openvm_circuit_primitives_derive::AlignedBorrow;
9use openvm_instructions::{instruction::Instruction, LocalOpcode};
10use openvm_native_compiler::FieldArithmeticOpcode::{self, *};
11use openvm_stark_backend::{
12    interaction::InteractionBuilder,
13    p3_air::BaseAir,
14    p3_field::{Field, FieldAlgebra, PrimeField32},
15    rap::BaseAirWithPublicValues,
16};
17use serde::{Deserialize, Serialize};
18
19#[repr(C)]
20#[derive(AlignedBorrow)]
21pub struct FieldArithmeticCoreCols<T> {
22    pub a: T,
23    pub b: T,
24    pub c: T,
25
26    pub is_add: T,
27    pub is_sub: T,
28    pub is_mul: T,
29    pub is_div: T,
30    /// `divisor_inv` is y.inverse() when opcode is FDIV and zero otherwise.
31    pub divisor_inv: T,
32}
33
34#[derive(Copy, Clone, Debug)]
35pub struct FieldArithmeticCoreAir {}
36
37impl<F: Field> BaseAir<F> for FieldArithmeticCoreAir {
38    fn width(&self) -> usize {
39        FieldArithmeticCoreCols::<F>::width()
40    }
41}
42
43impl<F: Field> BaseAirWithPublicValues<F> for FieldArithmeticCoreAir {}
44
45impl<AB, I> VmCoreAir<AB, I> for FieldArithmeticCoreAir
46where
47    AB: InteractionBuilder,
48    I: VmAdapterInterface<AB::Expr>,
49    I::Reads: From<[[AB::Expr; 1]; 2]>,
50    I::Writes: From<[[AB::Expr; 1]; 1]>,
51    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
52{
53    fn eval(
54        &self,
55        builder: &mut AB,
56        local_core: &[AB::Var],
57        _from_pc: AB::Var,
58    ) -> AdapterAirContext<AB::Expr, I> {
59        let cols: &FieldArithmeticCoreCols<_> = local_core.borrow();
60
61        let a = cols.a;
62        let b = cols.b;
63        let c = cols.c;
64
65        let flags = [cols.is_add, cols.is_sub, cols.is_mul, cols.is_div];
66        let opcodes = [ADD, SUB, MUL, DIV];
67        let results = [b + c, b - c, b * c, b * cols.divisor_inv];
68
69        // Imposing the following constraints:
70        // - Each flag in `flags` is a boolean.
71        // - Exactly one flag in `flags` is true.
72        // - The inner product of the `flags` and `opcodes` equals `io.opcode`.
73        // - The inner product of the `flags` and `results` equals `io.z`.
74        // - If `is_div` is true, then `aux.divisor_inv` correctly represents the multiplicative inverse of `io.y`.
75
76        let mut is_valid = AB::Expr::ZERO;
77        let mut expected_opcode = AB::Expr::ZERO;
78        let mut expected_result = AB::Expr::ZERO;
79        for (flag, opcode, result) in izip!(flags, opcodes, results) {
80            builder.assert_bool(flag);
81
82            is_valid += flag.into();
83            expected_opcode += flag * AB::Expr::from_canonical_u32(opcode as u32);
84            expected_result += flag * result;
85        }
86        builder.assert_eq(a, expected_result);
87        builder.assert_bool(is_valid.clone());
88        builder.assert_eq(cols.is_div, c * cols.divisor_inv);
89
90        AdapterAirContext {
91            to_pc: None,
92            reads: [[cols.b.into()], [cols.c.into()]].into(),
93            writes: [[cols.a.into()]].into(),
94            instruction: MinimalInstruction {
95                is_valid,
96                opcode: VmCoreAir::<AB, I>::expr_to_global_expr(self, expected_opcode),
97            }
98            .into(),
99        }
100    }
101
102    fn start_offset(&self) -> usize {
103        FieldArithmeticOpcode::CLASS_OFFSET
104    }
105}
106
107#[repr(C)]
108#[derive(Debug, Serialize, Deserialize)]
109pub struct FieldArithmeticRecord<F> {
110    pub opcode: FieldArithmeticOpcode,
111    pub a: F,
112    pub b: F,
113    pub c: F,
114}
115
116pub struct FieldArithmeticCoreChip {
117    pub air: FieldArithmeticCoreAir,
118}
119
120impl FieldArithmeticCoreChip {
121    pub fn new() -> Self {
122        Self {
123            air: FieldArithmeticCoreAir {},
124        }
125    }
126}
127
128impl Default for FieldArithmeticCoreChip {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for FieldArithmeticCoreChip
135where
136    I::Reads: Into<[[F; 1]; 2]>,
137    I::Writes: From<[[F; 1]; 1]>,
138{
139    type Record = FieldArithmeticRecord<F>;
140    type Air = FieldArithmeticCoreAir;
141
142    #[allow(clippy::type_complexity)]
143    fn execute_instruction(
144        &self,
145        instruction: &Instruction<F>,
146        _from_pc: u32,
147        reads: I::Reads,
148    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
149        let Instruction { opcode, .. } = instruction;
150        let local_opcode = FieldArithmeticOpcode::from_usize(
151            opcode.local_opcode_idx(FieldArithmeticOpcode::CLASS_OFFSET),
152        );
153
154        let data: [[F; 1]; 2] = reads.into();
155        let b = data[0][0];
156        let c = data[1][0];
157        let a = FieldArithmetic::run_field_arithmetic(local_opcode, b, c).unwrap();
158
159        let output: AdapterRuntimeContext<F, I> = AdapterRuntimeContext {
160            to_pc: None,
161            writes: [[a]].into(),
162        };
163
164        let record = Self::Record {
165            opcode: local_opcode,
166            a,
167            b,
168            c,
169        };
170
171        Ok((output, record))
172    }
173
174    fn get_opcode_name(&self, opcode: usize) -> String {
175        format!(
176            "{:?}",
177            FieldArithmeticOpcode::from_usize(opcode - FieldArithmeticOpcode::CLASS_OFFSET)
178        )
179    }
180
181    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
182        let FieldArithmeticRecord { opcode, a, b, c } = record;
183        let row_slice: &mut FieldArithmeticCoreCols<_> = row_slice.borrow_mut();
184        row_slice.a = a;
185        row_slice.b = b;
186        row_slice.c = c;
187
188        row_slice.is_add = F::from_bool(opcode == FieldArithmeticOpcode::ADD);
189        row_slice.is_sub = F::from_bool(opcode == FieldArithmeticOpcode::SUB);
190        row_slice.is_mul = F::from_bool(opcode == FieldArithmeticOpcode::MUL);
191        row_slice.is_div = F::from_bool(opcode == FieldArithmeticOpcode::DIV);
192        row_slice.divisor_inv = if opcode == FieldArithmeticOpcode::DIV {
193            c.inverse()
194        } else {
195            F::ZERO
196        };
197    }
198
199    fn air(&self) -> &Self::Air {
200        &self.air
201    }
202}
203
204pub struct FieldArithmetic;
205impl FieldArithmetic {
206    pub(super) fn run_field_arithmetic<F: Field>(
207        opcode: FieldArithmeticOpcode,
208        b: F,
209        c: F,
210    ) -> Option<F> {
211        match opcode {
212            FieldArithmeticOpcode::ADD => Some(b + c),
213            FieldArithmeticOpcode::SUB => Some(b - c),
214            FieldArithmeticOpcode::MUL => Some(b * c),
215            FieldArithmeticOpcode::DIV => {
216                if c.is_zero() {
217                    None
218                } else {
219                    Some(b * c.inverse())
220                }
221            }
222        }
223    }
224}