openvm_native_circuit/field_extension/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4    ops::{Add, Mul, Sub},
5};
6
7use itertools::izip;
8use openvm_circuit::arch::{
9    AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
10    VmCoreAir, VmCoreChip,
11};
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_instructions::{instruction::Instruction, LocalOpcode};
14use openvm_native_compiler::FieldExtensionOpcode::{self, *};
15use openvm_stark_backend::{
16    interaction::InteractionBuilder,
17    p3_air::BaseAir,
18    p3_field::{Field, FieldAlgebra, PrimeField32},
19    rap::BaseAirWithPublicValues,
20};
21use serde::{Deserialize, Serialize};
22
23pub const BETA: usize = 11;
24pub const EXT_DEG: usize = 4;
25
26#[repr(C)]
27#[derive(AlignedBorrow)]
28pub struct FieldExtensionCoreCols<T> {
29    pub x: [T; EXT_DEG],
30    pub y: [T; EXT_DEG],
31    pub z: [T; EXT_DEG],
32
33    pub is_add: T,
34    pub is_sub: T,
35    pub is_mul: T,
36    pub is_div: T,
37    /// `divisor_inv` is y.inverse() when opcode is FDIV and zero otherwise.
38    pub divisor_inv: [T; EXT_DEG],
39}
40
41#[derive(Copy, Clone, Debug)]
42pub struct FieldExtensionCoreAir {}
43
44impl<F: Field> BaseAir<F> for FieldExtensionCoreAir {
45    fn width(&self) -> usize {
46        FieldExtensionCoreCols::<F>::width()
47    }
48}
49
50impl<F: Field> BaseAirWithPublicValues<F> for FieldExtensionCoreAir {}
51
52impl<AB, I> VmCoreAir<AB, I> for FieldExtensionCoreAir
53where
54    AB: InteractionBuilder,
55    I: VmAdapterInterface<AB::Expr>,
56    I::Reads: From<[[AB::Expr; EXT_DEG]; 2]>,
57    I::Writes: From<[[AB::Expr; EXT_DEG]; 1]>,
58    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
59{
60    fn eval(
61        &self,
62        builder: &mut AB,
63        local_core: &[AB::Var],
64        _from_pc: AB::Var,
65    ) -> AdapterAirContext<AB::Expr, I> {
66        let cols: &FieldExtensionCoreCols<_> = local_core.borrow();
67
68        let flags = [cols.is_add, cols.is_sub, cols.is_mul, cols.is_div];
69        let opcodes = [FE4ADD, FE4SUB, BBE4MUL, BBE4DIV];
70        let results = [
71            FieldExtension::add(cols.y, cols.z),
72            FieldExtension::subtract(cols.y, cols.z),
73            FieldExtension::multiply(cols.y, cols.z),
74            FieldExtension::multiply(cols.y, cols.divisor_inv),
75        ];
76
77        // Imposing the following constraints:
78        // - Each flag in `flags` is a boolean.
79        // - Exactly one flag in `flags` is true.
80        // - The inner product of the `flags` and `opcodes` equals `io.opcode`.
81        // - The inner product of the `flags` and `results[:,j]` equals `io.z[j]` for each `j`.
82        // - If `is_div` is true, then `aux.divisor_inv` correctly represents the inverse of `io.y`.
83
84        let mut is_valid = AB::Expr::ZERO;
85        let mut expected_opcode = AB::Expr::ZERO;
86        let mut expected_result = [
87            AB::Expr::ZERO,
88            AB::Expr::ZERO,
89            AB::Expr::ZERO,
90            AB::Expr::ZERO,
91        ];
92        for (flag, opcode, result) in izip!(flags, opcodes, results) {
93            builder.assert_bool(flag);
94
95            is_valid += flag.into();
96            expected_opcode += flag * AB::F::from_canonical_usize(opcode.local_usize());
97
98            for (j, result_part) in result.into_iter().enumerate() {
99                expected_result[j] += flag * result_part;
100            }
101        }
102
103        for (x_j, expected_result_j) in izip!(cols.x, expected_result) {
104            builder.assert_eq(x_j, expected_result_j);
105        }
106        builder.assert_bool(is_valid.clone());
107
108        // constrain aux.divisor_inv: z * z^(-1) = 1
109        let z_times_z_inv = FieldExtension::multiply(cols.z, cols.divisor_inv);
110        for (i, prod_i) in z_times_z_inv.into_iter().enumerate() {
111            if i == 0 {
112                builder.assert_eq(cols.is_div, prod_i);
113            } else {
114                builder.assert_zero(prod_i);
115            }
116        }
117
118        AdapterAirContext {
119            to_pc: None,
120            reads: [cols.y.map(Into::into), cols.z.map(Into::into)].into(),
121            writes: [cols.x.map(Into::into)].into(),
122            instruction: MinimalInstruction {
123                is_valid,
124                opcode: VmCoreAir::<AB, I>::expr_to_global_expr(self, expected_opcode),
125            }
126            .into(),
127        }
128    }
129
130    fn start_offset(&self) -> usize {
131        FieldExtensionOpcode::CLASS_OFFSET
132    }
133}
134
135#[repr(C)]
136#[derive(Debug, Serialize, Deserialize)]
137pub struct FieldExtensionRecord<F> {
138    pub opcode: FieldExtensionOpcode,
139    pub x: [F; EXT_DEG],
140    pub y: [F; EXT_DEG],
141    pub z: [F; EXT_DEG],
142}
143
144pub struct FieldExtensionCoreChip {
145    pub air: FieldExtensionCoreAir,
146}
147
148impl FieldExtensionCoreChip {
149    pub fn new() -> Self {
150        Self {
151            air: FieldExtensionCoreAir {},
152        }
153    }
154}
155
156impl Default for FieldExtensionCoreChip {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for FieldExtensionCoreChip
163where
164    I::Reads: Into<[[F; EXT_DEG]; 2]>,
165    I::Writes: From<[[F; EXT_DEG]; 1]>,
166{
167    type Record = FieldExtensionRecord<F>;
168    type Air = FieldExtensionCoreAir;
169
170    #[allow(clippy::type_complexity)]
171    fn execute_instruction(
172        &self,
173        instruction: &Instruction<F>,
174        _from_pc: u32,
175        reads: I::Reads,
176    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
177        let Instruction { opcode, .. } = instruction;
178        let local_opcode_idx = opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET);
179
180        let data: [[F; EXT_DEG]; 2] = reads.into();
181        let y: [F; EXT_DEG] = data[0];
182        let z: [F; EXT_DEG] = data[1];
183
184        let x = FieldExtension::solve(FieldExtensionOpcode::from_usize(local_opcode_idx), y, z)
185            .unwrap();
186
187        let output = AdapterRuntimeContext {
188            to_pc: None,
189            writes: [x].into(),
190        };
191
192        let record = Self::Record {
193            opcode: FieldExtensionOpcode::from_usize(local_opcode_idx),
194            x,
195            y,
196            z,
197        };
198
199        Ok((output, record))
200    }
201
202    fn get_opcode_name(&self, opcode: usize) -> String {
203        format!(
204            "{:?}",
205            FieldExtensionOpcode::from_usize(opcode - FieldExtensionOpcode::CLASS_OFFSET)
206        )
207    }
208
209    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
210        let FieldExtensionRecord { opcode, x, y, z } = record;
211        let cols: &mut FieldExtensionCoreCols<_> = row_slice.borrow_mut();
212        cols.x = x;
213        cols.y = y;
214        cols.z = z;
215        cols.is_add = F::from_bool(opcode == FieldExtensionOpcode::FE4ADD);
216        cols.is_sub = F::from_bool(opcode == FieldExtensionOpcode::FE4SUB);
217        cols.is_mul = F::from_bool(opcode == FieldExtensionOpcode::BBE4MUL);
218        cols.is_div = F::from_bool(opcode == FieldExtensionOpcode::BBE4DIV);
219        cols.divisor_inv = if opcode == FieldExtensionOpcode::BBE4DIV {
220            FieldExtension::invert(z)
221        } else {
222            [F::ZERO; EXT_DEG]
223        };
224    }
225
226    fn air(&self) -> &Self::Air {
227        &self.air
228    }
229}
230
231pub struct FieldExtension;
232impl FieldExtension {
233    pub(super) fn solve<F: Field>(
234        opcode: FieldExtensionOpcode,
235        x: [F; EXT_DEG],
236        y: [F; EXT_DEG],
237    ) -> Option<[F; EXT_DEG]> {
238        match opcode {
239            FieldExtensionOpcode::FE4ADD => Some(Self::add(x, y)),
240            FieldExtensionOpcode::FE4SUB => Some(Self::subtract(x, y)),
241            FieldExtensionOpcode::BBE4MUL => Some(Self::multiply(x, y)),
242            FieldExtensionOpcode::BBE4DIV => Some(Self::divide(x, y)),
243        }
244    }
245
246    pub(crate) fn add<V, E>(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG]
247    where
248        V: Copy,
249        V: Add<V, Output = E>,
250    {
251        array::from_fn(|i| x[i] + y[i])
252    }
253
254    pub(crate) fn subtract<V, E>(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG]
255    where
256        V: Copy,
257        V: Sub<V, Output = E>,
258    {
259        array::from_fn(|i| x[i] - y[i])
260    }
261
262    pub(crate) fn multiply<V, E>(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG]
263    where
264        E: FieldAlgebra,
265        V: Copy,
266        V: Mul<V, Output = E>,
267        E: Mul<V, Output = E>,
268        V: Add<V, Output = E>,
269        E: Add<V, Output = E>,
270    {
271        let [x0, x1, x2, x3] = x;
272        let [y0, y1, y2, y3] = y;
273        [
274            x0 * y0 + (x1 * y3 + x2 * y2 + x3 * y1) * E::from_canonical_usize(BETA),
275            x0 * y1 + x1 * y0 + (x2 * y3 + x3 * y2) * E::from_canonical_usize(BETA),
276            x0 * y2 + x1 * y1 + x2 * y0 + (x3 * y3) * E::from_canonical_usize(BETA),
277            x0 * y3 + x1 * y2 + x2 * y1 + x3 * y0,
278        ]
279    }
280
281    pub(crate) fn divide<F: Field>(x: [F; EXT_DEG], y: [F; EXT_DEG]) -> [F; EXT_DEG] {
282        Self::multiply(x, Self::invert(y))
283    }
284
285    pub(crate) fn invert<F: Field>(a: [F; EXT_DEG]) -> [F; EXT_DEG] {
286        // Let a = (a0, a1, a2, a3) represent the element we want to invert.
287        // Define a' = (a0, -a1, a2, -a3).  By construction, the product b = a * a' will have zero
288        // degree-1 and degree-3 coefficients.
289        // Let b = (b0, 0, b2, 0) and define b' = (b0, 0, -b2, 0).
290        // Note that c = b * b' = b0^2 - BETA * b2^2, which is an element of the base field.
291        // Therefore, the inverse of a is 1 / a = a' / (a * a') = a' * b' / (b * b') = a' * b' / c.
292
293        let [a0, a1, a2, a3] = a;
294
295        let beta = F::from_canonical_usize(BETA);
296
297        let mut b0 = a0 * a0 - beta * (F::TWO * a1 * a3 - a2 * a2);
298        let mut b2 = F::TWO * a0 * a2 - a1 * a1 - beta * a3 * a3;
299
300        let c = b0 * b0 - beta * b2 * b2;
301        let inv_c = c.inverse();
302
303        b0 *= inv_c;
304        b2 *= inv_c;
305
306        [
307            a0 * b0 - a2 * b2 * beta,
308            -a1 * b0 + a3 * b2 * beta,
309            -a0 * b2 + a2 * b0,
310            a1 * b2 - a3 * b0,
311        ]
312    }
313}