openvm_native_circuit/field_arithmetic/
core.rs
1use std::borrow::{Borrow, BorrowMut};
2
3use itertools::izip;
4use openvm_circuit::arch::{
5 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
6 VmCoreAir, VmCoreChip,
7};
8use openvm_circuit_primitives_derive::AlignedBorrow;
9use openvm_instructions::{instruction::Instruction, LocalOpcode};
10use openvm_native_compiler::FieldArithmeticOpcode::{self, *};
11use openvm_stark_backend::{
12 interaction::InteractionBuilder,
13 p3_air::BaseAir,
14 p3_field::{Field, FieldAlgebra, PrimeField32},
15 rap::BaseAirWithPublicValues,
16};
17use serde::{Deserialize, Serialize};
18
19#[repr(C)]
20#[derive(AlignedBorrow)]
21pub struct FieldArithmeticCoreCols<T> {
22 pub a: T,
23 pub b: T,
24 pub c: T,
25
26 pub is_add: T,
27 pub is_sub: T,
28 pub is_mul: T,
29 pub is_div: T,
30 pub divisor_inv: T,
32}
33
34#[derive(Copy, Clone, Debug)]
35pub struct FieldArithmeticCoreAir {}
36
37impl<F: Field> BaseAir<F> for FieldArithmeticCoreAir {
38 fn width(&self) -> usize {
39 FieldArithmeticCoreCols::<F>::width()
40 }
41}
42
43impl<F: Field> BaseAirWithPublicValues<F> for FieldArithmeticCoreAir {}
44
45impl<AB, I> VmCoreAir<AB, I> for FieldArithmeticCoreAir
46where
47 AB: InteractionBuilder,
48 I: VmAdapterInterface<AB::Expr>,
49 I::Reads: From<[[AB::Expr; 1]; 2]>,
50 I::Writes: From<[[AB::Expr; 1]; 1]>,
51 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
52{
53 fn eval(
54 &self,
55 builder: &mut AB,
56 local_core: &[AB::Var],
57 _from_pc: AB::Var,
58 ) -> AdapterAirContext<AB::Expr, I> {
59 let cols: &FieldArithmeticCoreCols<_> = local_core.borrow();
60
61 let a = cols.a;
62 let b = cols.b;
63 let c = cols.c;
64
65 let flags = [cols.is_add, cols.is_sub, cols.is_mul, cols.is_div];
66 let opcodes = [ADD, SUB, MUL, DIV];
67 let results = [b + c, b - c, b * c, b * cols.divisor_inv];
68
69 let mut is_valid = AB::Expr::ZERO;
77 let mut expected_opcode = AB::Expr::ZERO;
78 let mut expected_result = AB::Expr::ZERO;
79 for (flag, opcode, result) in izip!(flags, opcodes, results) {
80 builder.assert_bool(flag);
81
82 is_valid += flag.into();
83 expected_opcode += flag * AB::Expr::from_canonical_u32(opcode as u32);
84 expected_result += flag * result;
85 }
86 builder.assert_eq(a, expected_result);
87 builder.assert_bool(is_valid.clone());
88 builder.assert_eq(cols.is_div, c * cols.divisor_inv);
89
90 AdapterAirContext {
91 to_pc: None,
92 reads: [[cols.b.into()], [cols.c.into()]].into(),
93 writes: [[cols.a.into()]].into(),
94 instruction: MinimalInstruction {
95 is_valid,
96 opcode: VmCoreAir::<AB, I>::expr_to_global_expr(self, expected_opcode),
97 }
98 .into(),
99 }
100 }
101
102 fn start_offset(&self) -> usize {
103 FieldArithmeticOpcode::CLASS_OFFSET
104 }
105}
106
107#[repr(C)]
108#[derive(Debug, Serialize, Deserialize)]
109pub struct FieldArithmeticRecord<F> {
110 pub opcode: FieldArithmeticOpcode,
111 pub a: F,
112 pub b: F,
113 pub c: F,
114}
115
116pub struct FieldArithmeticCoreChip {
117 pub air: FieldArithmeticCoreAir,
118}
119
120impl FieldArithmeticCoreChip {
121 pub fn new() -> Self {
122 Self {
123 air: FieldArithmeticCoreAir {},
124 }
125 }
126}
127
128impl Default for FieldArithmeticCoreChip {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for FieldArithmeticCoreChip
135where
136 I::Reads: Into<[[F; 1]; 2]>,
137 I::Writes: From<[[F; 1]; 1]>,
138{
139 type Record = FieldArithmeticRecord<F>;
140 type Air = FieldArithmeticCoreAir;
141
142 #[allow(clippy::type_complexity)]
143 fn execute_instruction(
144 &self,
145 instruction: &Instruction<F>,
146 _from_pc: u32,
147 reads: I::Reads,
148 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
149 let Instruction { opcode, .. } = instruction;
150 let local_opcode = FieldArithmeticOpcode::from_usize(
151 opcode.local_opcode_idx(FieldArithmeticOpcode::CLASS_OFFSET),
152 );
153
154 let data: [[F; 1]; 2] = reads.into();
155 let b = data[0][0];
156 let c = data[1][0];
157 let a = FieldArithmetic::run_field_arithmetic(local_opcode, b, c).unwrap();
158
159 let output: AdapterRuntimeContext<F, I> = AdapterRuntimeContext {
160 to_pc: None,
161 writes: [[a]].into(),
162 };
163
164 let record = Self::Record {
165 opcode: local_opcode,
166 a,
167 b,
168 c,
169 };
170
171 Ok((output, record))
172 }
173
174 fn get_opcode_name(&self, opcode: usize) -> String {
175 format!(
176 "{:?}",
177 FieldArithmeticOpcode::from_usize(opcode - FieldArithmeticOpcode::CLASS_OFFSET)
178 )
179 }
180
181 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
182 let FieldArithmeticRecord { opcode, a, b, c } = record;
183 let row_slice: &mut FieldArithmeticCoreCols<_> = row_slice.borrow_mut();
184 row_slice.a = a;
185 row_slice.b = b;
186 row_slice.c = c;
187
188 row_slice.is_add = F::from_bool(opcode == FieldArithmeticOpcode::ADD);
189 row_slice.is_sub = F::from_bool(opcode == FieldArithmeticOpcode::SUB);
190 row_slice.is_mul = F::from_bool(opcode == FieldArithmeticOpcode::MUL);
191 row_slice.is_div = F::from_bool(opcode == FieldArithmeticOpcode::DIV);
192 row_slice.divisor_inv = if opcode == FieldArithmeticOpcode::DIV {
193 c.inverse()
194 } else {
195 F::ZERO
196 };
197 }
198
199 fn air(&self) -> &Self::Air {
200 &self.air
201 }
202}
203
204pub struct FieldArithmetic;
205impl FieldArithmetic {
206 pub(super) fn run_field_arithmetic<F: Field>(
207 opcode: FieldArithmeticOpcode,
208 b: F,
209 c: F,
210 ) -> Option<F> {
211 match opcode {
212 FieldArithmeticOpcode::ADD => Some(b + c),
213 FieldArithmeticOpcode::SUB => Some(b - c),
214 FieldArithmeticOpcode::MUL => Some(b * c),
215 FieldArithmeticOpcode::DIV => {
216 if c.is_zero() {
217 None
218 } else {
219 Some(b * c.inverse())
220 }
221 }
222 }
223 }
224}