openvm_algebra_circuit/
fp2.rs

1use std::{cell::RefCell, rc::Rc};
2
3use openvm_mod_circuit_builder::{ExprBuilder, FieldVariable, SymbolicExpr};
4
5/// Quadratic field extension of `Fp` defined by `Fp2 = Fp[u]/(1 + u^2)`. Assumes that `-1` is not a
6/// quadratic residue in `Fp`, which is equivalent to `p` being congruent to `3 (mod 4)`.
7/// Extends Mod Builder to work with Fp2 variables.
8#[derive(Clone)]
9pub struct Fp2 {
10    pub c0: FieldVariable,
11    pub c1: FieldVariable,
12}
13
14impl Fp2 {
15    pub fn new(builder: Rc<RefCell<ExprBuilder>>) -> Self {
16        let c0 = ExprBuilder::new_input(builder.clone());
17        let c1 = ExprBuilder::new_input(builder.clone());
18        Fp2 { c0, c1 }
19    }
20
21    pub fn new_var(builder: Rc<RefCell<ExprBuilder>>) -> ((usize, usize), Fp2) {
22        let (c0_idx, c0) = builder.borrow_mut().new_var();
23        let (c1_idx, c1) = builder.borrow_mut().new_var();
24        let fp2 = Fp2 {
25            c0: FieldVariable::from_var(builder.clone(), c0),
26            c1: FieldVariable::from_var(builder.clone(), c1),
27        };
28        ((c0_idx, c1_idx), fp2)
29    }
30
31    pub fn save(&mut self) -> [usize; 2] {
32        let c0_idx = self.c0.save();
33        let c1_idx = self.c1.save();
34        [c0_idx, c1_idx]
35    }
36
37    pub fn save_output(&mut self) {
38        self.c0.save_output();
39        self.c1.save_output();
40    }
41
42    pub fn add(&mut self, other: &mut Fp2) -> Fp2 {
43        Fp2 {
44            c0: &mut self.c0 + &mut other.c0,
45            c1: &mut self.c1 + &mut other.c1,
46        }
47    }
48
49    pub fn sub(&mut self, other: &mut Fp2) -> Fp2 {
50        Fp2 {
51            c0: &mut self.c0 - &mut other.c0,
52            c1: &mut self.c1 - &mut other.c1,
53        }
54    }
55
56    pub fn mul(&mut self, other: &mut Fp2) -> Fp2 {
57        let c0 = &mut self.c0 * &mut other.c0 - &mut self.c1 * &mut other.c1;
58        let c1 = &mut self.c0 * &mut other.c1 + &mut self.c1 * &mut other.c0;
59        Fp2 { c0, c1 }
60    }
61
62    pub fn square(&mut self) -> Fp2 {
63        let c0 = self.c0.square() - self.c1.square();
64        let c1 = (&mut self.c0 * &mut self.c1).int_mul(2);
65        Fp2 { c0, c1 }
66    }
67
68    pub fn div(&mut self, other: &mut Fp2) -> Fp2 {
69        let builder = self.c0.builder.borrow();
70        let prime = builder.prime.clone();
71        let limb_bits = builder.limb_bits;
72        let num_limbs = builder.num_limbs;
73        let proper_max = builder.proper_max().clone();
74        drop(builder);
75
76        // These are dummy variables, will be replaced later so the index within it doesn't matter.
77        // We use these to check if we need to save self/other first.
78        let fake_z0 = SymbolicExpr::Var(0);
79        let fake_z1 = SymbolicExpr::Var(1);
80
81        // Compute should not be affected by whether auto save is triggered.
82        // So we must do compute first.
83        // Compute z0
84        let compute_denom = &other.c0.expr * &other.c0.expr + &other.c1.expr * &other.c1.expr;
85        let compute_z0_nom = &self.c0.expr * &other.c0.expr + &self.c1.expr * &other.c1.expr;
86        let compute_z0 = &compute_z0_nom / &compute_denom;
87        // Compute z1
88        let compute_z1_nom = &self.c1.expr * &other.c0.expr - &self.c0.expr * &other.c1.expr;
89        let compute_z1 = &compute_z1_nom / &compute_denom;
90
91        // We will constrain
92        //  (1) x0 = y0*z0 - y1*z1 and
93        //  (2) x1 = y1*z0 + y0*z1
94        // which implies z0 and z1 are computed as above.
95        // Observe (1)*y0 + (2)*y1 yields x0*y0 + x1*y1 = z0(y0^2 + y1^2) and so z0 = (x0*y0 +
96        // x1*y1) / (y0^2 + y1^2) as needed. Observe (1)*(-y1) + (2)*y0 yields x1*y0 - x0*y1
97        // = z1(y0^2 + y1^2) and so z1 = (x1*y0 - x0*y1) / (y0^2 + y1^2) as needed.
98
99        // Constraint 1: x0 = y0*z0 - y1*z1
100        let constraint1 = &self.c0.expr - &other.c0.expr * &fake_z0 + &other.c1.expr * &fake_z1;
101        let carry_bits =
102            constraint1.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
103        if carry_bits > self.c0.max_carry_bits {
104            self.save();
105        }
106        let constraint1 = &self.c0.expr - &other.c0.expr * &fake_z0 + &other.c1.expr * &fake_z1;
107        let carry_bits =
108            constraint1.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
109        if carry_bits > self.c0.max_carry_bits {
110            other.save();
111        }
112
113        // Constraint 2: x1 = y1*z0 + y0*z1
114        let constraint2 = &self.c1.expr - &other.c1.expr * &fake_z0 - &other.c0.expr * &fake_z1;
115        let carry_bits =
116            constraint2.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
117        if carry_bits > self.c0.max_carry_bits {
118            self.save();
119        }
120        let constraint2 = &self.c1.expr - &other.c1.expr * &fake_z0 - &other.c0.expr * &fake_z1;
121        let carry_bits =
122            constraint2.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
123        if carry_bits > self.c0.max_carry_bits {
124            other.save();
125        }
126
127        let mut builder = self.c0.builder.borrow_mut();
128        let (z0_idx, z0) = builder.new_var();
129        let (z1_idx, z1) = builder.new_var();
130        let constraint1 = &self.c0.expr - &other.c0.expr * &z0 + &other.c1.expr * &z1;
131        let constraint2 = &self.c1.expr - &other.c1.expr * &z0 - &other.c0.expr * &z1;
132        builder.set_compute(z0_idx, compute_z0);
133        builder.set_compute(z1_idx, compute_z1);
134        builder.set_constraint(z0_idx, constraint1);
135        builder.set_constraint(z1_idx, constraint2);
136        drop(builder);
137
138        let z0_var = FieldVariable::from_var(self.c0.builder.clone(), z0);
139        let z1_var = FieldVariable::from_var(self.c0.builder.clone(), z1);
140        Fp2 {
141            c0: z0_var,
142            c1: z1_var,
143        }
144    }
145
146    pub fn scalar_mul(&mut self, fp: &mut FieldVariable) -> Fp2 {
147        Fp2 {
148            c0: &mut self.c0 * fp,
149            c1: &mut self.c1 * fp,
150        }
151    }
152
153    pub fn int_add(&mut self, c: [isize; 2]) -> Fp2 {
154        Fp2 {
155            c0: self.c0.int_add(c[0]),
156            c1: self.c1.int_add(c[1]),
157        }
158    }
159
160    // c is like a Fp2, but with both c0 and c1 being very small numbers.
161    pub fn int_mul(&mut self, c: [isize; 2]) -> Fp2 {
162        Fp2 {
163            c0: self.c0.int_mul(c[0]) - self.c1.int_mul(c[1]),
164            c1: self.c0.int_mul(c[1]) + self.c1.int_mul(c[0]),
165        }
166    }
167
168    pub fn neg(&mut self) -> Fp2 {
169        self.int_mul([-1, 0])
170    }
171
172    pub fn select(flag_id: usize, a: &Fp2, b: &Fp2) -> Fp2 {
173        Fp2 {
174            c0: FieldVariable::select(flag_id, &a.c0, &b.c0),
175            c1: FieldVariable::select(flag_id, &a.c1, &b.c1),
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use halo2curves_axiom::bn256::Fq2;
183    use num_bigint::BigUint;
184    use openvm_circuit_primitives::TraceSubRowGenerator;
185    use openvm_mod_circuit_builder::{test_utils::*, FieldExpr, FieldExprCols};
186    use openvm_pairing_guest::bn254::BN254_MODULUS;
187    use openvm_stark_backend::{
188        p3_air::BaseAir, p3_field::FieldAlgebra, p3_matrix::dense::RowMajorMatrix,
189    };
190    use openvm_stark_sdk::{
191        any_rap_arc_vec, config::baby_bear_blake3::BabyBearBlake3Engine, engine::StarkFriEngine,
192        p3_baby_bear::BabyBear,
193    };
194
195    use super::Fp2;
196
197    fn two_fp2_input(x: &Fq2, y: &Fq2) -> Vec<BigUint> {
198        vec![
199            bn254_fq_to_biguint(x.c0),
200            bn254_fq_to_biguint(x.c1),
201            bn254_fq_to_biguint(y.c0),
202            bn254_fq_to_biguint(y.c1),
203        ]
204    }
205
206    fn test_fp2(
207        fp2_fn: impl Fn(&mut Fp2, &mut Fp2) -> Fp2,
208        fq2_fn: impl Fn(&Fq2, &Fq2) -> Fq2,
209        save_result: bool,
210    ) {
211        let prime = BN254_MODULUS.clone();
212        let (range_checker, builder) = setup(&prime);
213
214        let mut x_fp2 = Fp2::new(builder.clone());
215        let mut y_fp2 = Fp2::new(builder.clone());
216        let mut r = fp2_fn(&mut x_fp2, &mut y_fp2);
217        if save_result {
218            r.save();
219        }
220
221        let builder = builder.borrow().clone();
222        let air = FieldExpr::new(builder, range_checker.bus(), false);
223        let width = BaseAir::<BabyBear>::width(&air);
224
225        let x_fp2 = bn254_fq2_random(1);
226        let y_fp2 = bn254_fq2_random(5);
227        let r_fp2 = fq2_fn(&x_fp2, &y_fp2);
228        let inputs = two_fp2_input(&x_fp2, &y_fp2);
229
230        let mut row = BabyBear::zero_vec(width);
231        air.generate_subrow((&range_checker, inputs, vec![]), &mut row);
232        let FieldExprCols { vars, .. } = air.load_vars(&row);
233        let trace = RowMajorMatrix::new(row, width);
234        let range_trace = range_checker.generate_trace();
235        assert_eq!(vars.len(), 2);
236        let r_c0 = evaluate_biguint(&vars[0], LIMB_BITS);
237        let r_c1 = evaluate_biguint(&vars[1], LIMB_BITS);
238        let expected_c0 = bn254_fq_to_biguint(r_fp2.c0);
239        let expected_c1 = bn254_fq_to_biguint(r_fp2.c1);
240        assert_eq!(r_c0, expected_c0);
241        assert_eq!(r_c1, expected_c1);
242
243        BabyBearBlake3Engine::run_simple_test_no_pis_fast(
244            any_rap_arc_vec![air, range_checker.air],
245            vec![trace, range_trace],
246        )
247        .expect("Verification failed");
248    }
249
250    #[test]
251    fn test_fp2_add() {
252        test_fp2(Fp2::add, |x, y| x + y, true);
253    }
254
255    #[test]
256    fn test_fp2_sub() {
257        test_fp2(Fp2::sub, |x, y| x - y, true);
258    }
259
260    #[test]
261    fn test_fp2_mul() {
262        test_fp2(Fp2::mul, |x, y| x * y, true);
263    }
264
265    #[test]
266    fn test_fp2_div() {
267        test_fp2(Fp2::div, |x, y| x * y.invert().unwrap(), false);
268    }
269
270    #[test]
271    fn test_fp2_div2() {
272        let prime = BN254_MODULUS.clone();
273        let (range_checker, builder) = setup(&prime);
274
275        let mut x_fp2 = Fp2::new(builder.clone());
276        let mut y_fp2 = Fp2::new(builder.clone());
277        let mut z_fp2 = Fp2::new(builder.clone());
278        let mut xy = x_fp2.mul(&mut y_fp2);
279        let _r = xy.div(&mut z_fp2);
280        // no need to save as div auto save.
281
282        let builder = builder.borrow().clone();
283        let air = FieldExpr::new(builder, range_checker.bus(), false);
284        let width = BaseAir::<BabyBear>::width(&air);
285
286        let x_fp2 = bn254_fq2_random(5);
287        let y_fp2 = bn254_fq2_random(15);
288        let z_fp2 = bn254_fq2_random(95);
289        let r_fp2 = z_fp2.invert().unwrap() * x_fp2 * y_fp2;
290        let inputs = vec![
291            bn254_fq_to_biguint(x_fp2.c0),
292            bn254_fq_to_biguint(x_fp2.c1),
293            bn254_fq_to_biguint(y_fp2.c0),
294            bn254_fq_to_biguint(y_fp2.c1),
295            bn254_fq_to_biguint(z_fp2.c0),
296            bn254_fq_to_biguint(z_fp2.c1),
297        ];
298        let mut row = BabyBear::zero_vec(width);
299        air.generate_subrow((&range_checker, inputs, vec![]), &mut row);
300        let FieldExprCols { vars, .. } = air.load_vars(&row);
301        let trace = RowMajorMatrix::new(row, width);
302        let range_trace = range_checker.generate_trace();
303        assert_eq!(vars.len(), 2);
304        let r_c0 = evaluate_biguint(&vars[0], LIMB_BITS);
305        let r_c1 = evaluate_biguint(&vars[1], LIMB_BITS);
306        let expected_c0 = bn254_fq_to_biguint(r_fp2.c0);
307        let expected_c1 = bn254_fq_to_biguint(r_fp2.c1);
308        assert_eq!(r_c0, expected_c0);
309        assert_eq!(r_c1, expected_c1);
310
311        BabyBearBlake3Engine::run_simple_test_no_pis_fast(
312            any_rap_arc_vec![air, range_checker.air],
313            vec![trace, range_trace],
314        )
315        .expect("Verification failed");
316    }
317}