1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4 ops::{Add, Mul, Sub},
5};
6
7use itertools::izip;
8use openvm_circuit::arch::{
9 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
10 VmCoreAir, VmCoreChip,
11};
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_instructions::{instruction::Instruction, LocalOpcode};
14use openvm_native_compiler::FieldExtensionOpcode::{self, *};
15use openvm_stark_backend::{
16 interaction::InteractionBuilder,
17 p3_air::BaseAir,
18 p3_field::{Field, FieldAlgebra, PrimeField32},
19 rap::BaseAirWithPublicValues,
20};
21use serde::{Deserialize, Serialize};
22
23pub const BETA: usize = 11;
24pub const EXT_DEG: usize = 4;
25
26#[repr(C)]
27#[derive(AlignedBorrow)]
28pub struct FieldExtensionCoreCols<T> {
29 pub x: [T; EXT_DEG],
30 pub y: [T; EXT_DEG],
31 pub z: [T; EXT_DEG],
32
33 pub is_add: T,
34 pub is_sub: T,
35 pub is_mul: T,
36 pub is_div: T,
37 pub divisor_inv: [T; EXT_DEG],
39}
40
41#[derive(Copy, Clone, Debug)]
42pub struct FieldExtensionCoreAir {}
43
44impl<F: Field> BaseAir<F> for FieldExtensionCoreAir {
45 fn width(&self) -> usize {
46 FieldExtensionCoreCols::<F>::width()
47 }
48}
49
50impl<F: Field> BaseAirWithPublicValues<F> for FieldExtensionCoreAir {}
51
52impl<AB, I> VmCoreAir<AB, I> for FieldExtensionCoreAir
53where
54 AB: InteractionBuilder,
55 I: VmAdapterInterface<AB::Expr>,
56 I::Reads: From<[[AB::Expr; EXT_DEG]; 2]>,
57 I::Writes: From<[[AB::Expr; EXT_DEG]; 1]>,
58 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
59{
60 fn eval(
61 &self,
62 builder: &mut AB,
63 local_core: &[AB::Var],
64 _from_pc: AB::Var,
65 ) -> AdapterAirContext<AB::Expr, I> {
66 let cols: &FieldExtensionCoreCols<_> = local_core.borrow();
67
68 let flags = [cols.is_add, cols.is_sub, cols.is_mul, cols.is_div];
69 let opcodes = [FE4ADD, FE4SUB, BBE4MUL, BBE4DIV];
70 let results = [
71 FieldExtension::add(cols.y, cols.z),
72 FieldExtension::subtract(cols.y, cols.z),
73 FieldExtension::multiply(cols.y, cols.z),
74 FieldExtension::multiply(cols.y, cols.divisor_inv),
75 ];
76
77 let mut is_valid = AB::Expr::ZERO;
85 let mut expected_opcode = AB::Expr::ZERO;
86 let mut expected_result = [
87 AB::Expr::ZERO,
88 AB::Expr::ZERO,
89 AB::Expr::ZERO,
90 AB::Expr::ZERO,
91 ];
92 for (flag, opcode, result) in izip!(flags, opcodes, results) {
93 builder.assert_bool(flag);
94
95 is_valid += flag.into();
96 expected_opcode += flag * AB::F::from_canonical_usize(opcode.local_usize());
97
98 for (j, result_part) in result.into_iter().enumerate() {
99 expected_result[j] += flag * result_part;
100 }
101 }
102
103 for (x_j, expected_result_j) in izip!(cols.x, expected_result) {
104 builder.assert_eq(x_j, expected_result_j);
105 }
106 builder.assert_bool(is_valid.clone());
107
108 let z_times_z_inv = FieldExtension::multiply(cols.z, cols.divisor_inv);
110 for (i, prod_i) in z_times_z_inv.into_iter().enumerate() {
111 if i == 0 {
112 builder.assert_eq(cols.is_div, prod_i);
113 } else {
114 builder.assert_zero(prod_i);
115 }
116 }
117
118 AdapterAirContext {
119 to_pc: None,
120 reads: [cols.y.map(Into::into), cols.z.map(Into::into)].into(),
121 writes: [cols.x.map(Into::into)].into(),
122 instruction: MinimalInstruction {
123 is_valid,
124 opcode: VmCoreAir::<AB, I>::expr_to_global_expr(self, expected_opcode),
125 }
126 .into(),
127 }
128 }
129
130 fn start_offset(&self) -> usize {
131 FieldExtensionOpcode::CLASS_OFFSET
132 }
133}
134
135#[repr(C)]
136#[derive(Debug, Serialize, Deserialize)]
137pub struct FieldExtensionRecord<F> {
138 pub opcode: FieldExtensionOpcode,
139 pub x: [F; EXT_DEG],
140 pub y: [F; EXT_DEG],
141 pub z: [F; EXT_DEG],
142}
143
144pub struct FieldExtensionCoreChip {
145 pub air: FieldExtensionCoreAir,
146}
147
148impl FieldExtensionCoreChip {
149 pub fn new() -> Self {
150 Self {
151 air: FieldExtensionCoreAir {},
152 }
153 }
154}
155
156impl Default for FieldExtensionCoreChip {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for FieldExtensionCoreChip
163where
164 I::Reads: Into<[[F; EXT_DEG]; 2]>,
165 I::Writes: From<[[F; EXT_DEG]; 1]>,
166{
167 type Record = FieldExtensionRecord<F>;
168 type Air = FieldExtensionCoreAir;
169
170 #[allow(clippy::type_complexity)]
171 fn execute_instruction(
172 &self,
173 instruction: &Instruction<F>,
174 _from_pc: u32,
175 reads: I::Reads,
176 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
177 let Instruction { opcode, .. } = instruction;
178 let local_opcode_idx = opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET);
179
180 let data: [[F; EXT_DEG]; 2] = reads.into();
181 let y: [F; EXT_DEG] = data[0];
182 let z: [F; EXT_DEG] = data[1];
183
184 let x = FieldExtension::solve(FieldExtensionOpcode::from_usize(local_opcode_idx), y, z)
185 .unwrap();
186
187 let output = AdapterRuntimeContext {
188 to_pc: None,
189 writes: [x].into(),
190 };
191
192 let record = Self::Record {
193 opcode: FieldExtensionOpcode::from_usize(local_opcode_idx),
194 x,
195 y,
196 z,
197 };
198
199 Ok((output, record))
200 }
201
202 fn get_opcode_name(&self, opcode: usize) -> String {
203 format!(
204 "{:?}",
205 FieldExtensionOpcode::from_usize(opcode - FieldExtensionOpcode::CLASS_OFFSET)
206 )
207 }
208
209 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
210 let FieldExtensionRecord { opcode, x, y, z } = record;
211 let cols: &mut FieldExtensionCoreCols<_> = row_slice.borrow_mut();
212 cols.x = x;
213 cols.y = y;
214 cols.z = z;
215 cols.is_add = F::from_bool(opcode == FieldExtensionOpcode::FE4ADD);
216 cols.is_sub = F::from_bool(opcode == FieldExtensionOpcode::FE4SUB);
217 cols.is_mul = F::from_bool(opcode == FieldExtensionOpcode::BBE4MUL);
218 cols.is_div = F::from_bool(opcode == FieldExtensionOpcode::BBE4DIV);
219 cols.divisor_inv = if opcode == FieldExtensionOpcode::BBE4DIV {
220 FieldExtension::invert(z)
221 } else {
222 [F::ZERO; EXT_DEG]
223 };
224 }
225
226 fn air(&self) -> &Self::Air {
227 &self.air
228 }
229}
230
231pub struct FieldExtension;
232impl FieldExtension {
233 pub(super) fn solve<F: Field>(
234 opcode: FieldExtensionOpcode,
235 x: [F; EXT_DEG],
236 y: [F; EXT_DEG],
237 ) -> Option<[F; EXT_DEG]> {
238 match opcode {
239 FieldExtensionOpcode::FE4ADD => Some(Self::add(x, y)),
240 FieldExtensionOpcode::FE4SUB => Some(Self::subtract(x, y)),
241 FieldExtensionOpcode::BBE4MUL => Some(Self::multiply(x, y)),
242 FieldExtensionOpcode::BBE4DIV => Some(Self::divide(x, y)),
243 }
244 }
245
246 pub(crate) fn add<V, E>(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG]
247 where
248 V: Copy,
249 V: Add<V, Output = E>,
250 {
251 array::from_fn(|i| x[i] + y[i])
252 }
253
254 pub(crate) fn subtract<V, E>(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG]
255 where
256 V: Copy,
257 V: Sub<V, Output = E>,
258 {
259 array::from_fn(|i| x[i] - y[i])
260 }
261
262 pub(crate) fn multiply<V, E>(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG]
263 where
264 E: FieldAlgebra,
265 V: Copy,
266 V: Mul<V, Output = E>,
267 E: Mul<V, Output = E>,
268 V: Add<V, Output = E>,
269 E: Add<V, Output = E>,
270 {
271 let [x0, x1, x2, x3] = x;
272 let [y0, y1, y2, y3] = y;
273 [
274 x0 * y0 + (x1 * y3 + x2 * y2 + x3 * y1) * E::from_canonical_usize(BETA),
275 x0 * y1 + x1 * y0 + (x2 * y3 + x3 * y2) * E::from_canonical_usize(BETA),
276 x0 * y2 + x1 * y1 + x2 * y0 + (x3 * y3) * E::from_canonical_usize(BETA),
277 x0 * y3 + x1 * y2 + x2 * y1 + x3 * y0,
278 ]
279 }
280
281 pub(crate) fn divide<F: Field>(x: [F; EXT_DEG], y: [F; EXT_DEG]) -> [F; EXT_DEG] {
282 Self::multiply(x, Self::invert(y))
283 }
284
285 pub(crate) fn invert<F: Field>(a: [F; EXT_DEG]) -> [F; EXT_DEG] {
286 let [a0, a1, a2, a3] = a;
294
295 let beta = F::from_canonical_usize(BETA);
296
297 let mut b0 = a0 * a0 - beta * (F::TWO * a1 * a3 - a2 * a2);
298 let mut b2 = F::TWO * a0 * a2 - a1 * a1 - beta * a3 * a3;
299
300 let c = b0 * b0 - beta * b2 * b2;
301 let inv_c = c.inverse();
302
303 b0 *= inv_c;
304 b2 *= inv_c;
305
306 [
307 a0 * b0 - a2 * b2 * beta,
308 -a1 * b0 + a3 * b2 * beta,
309 -a0 * b2 + a2 * b0,
310 a1 * b2 - a3 * b0,
311 ]
312 }
313}