openvm_algebra_circuit/
fp2.rs

1use std::{cell::RefCell, rc::Rc};
2
3use openvm_mod_circuit_builder::{ExprBuilder, FieldVariable, SymbolicExpr};
4
5/// Quadratic field extension of `Fp` defined by `Fp2 = Fp[u]/(1 + u^2)`. Assumes that `-1` is not a quadratic residue in `Fp`, which is equivalent to `p` being congruent to `3 (mod 4)`.
6/// Extends Mod Builder to work with Fp2 variables.
7#[derive(Clone)]
8pub struct Fp2 {
9    pub c0: FieldVariable,
10    pub c1: FieldVariable,
11}
12
13impl Fp2 {
14    pub fn new(builder: Rc<RefCell<ExprBuilder>>) -> Self {
15        let c0 = ExprBuilder::new_input(builder.clone());
16        let c1 = ExprBuilder::new_input(builder.clone());
17        Fp2 { c0, c1 }
18    }
19
20    pub fn new_var(builder: Rc<RefCell<ExprBuilder>>) -> ((usize, usize), Fp2) {
21        let (c0_idx, c0) = builder.borrow_mut().new_var();
22        let (c1_idx, c1) = builder.borrow_mut().new_var();
23        let fp2 = Fp2 {
24            c0: FieldVariable::from_var(builder.clone(), c0),
25            c1: FieldVariable::from_var(builder.clone(), c1),
26        };
27        ((c0_idx, c1_idx), fp2)
28    }
29
30    pub fn save(&mut self) -> [usize; 2] {
31        let c0_idx = self.c0.save();
32        let c1_idx = self.c1.save();
33        [c0_idx, c1_idx]
34    }
35
36    pub fn save_output(&mut self) {
37        self.c0.save_output();
38        self.c1.save_output();
39    }
40
41    pub fn add(&mut self, other: &mut Fp2) -> Fp2 {
42        Fp2 {
43            c0: &mut self.c0 + &mut other.c0,
44            c1: &mut self.c1 + &mut other.c1,
45        }
46    }
47
48    pub fn sub(&mut self, other: &mut Fp2) -> Fp2 {
49        Fp2 {
50            c0: &mut self.c0 - &mut other.c0,
51            c1: &mut self.c1 - &mut other.c1,
52        }
53    }
54
55    pub fn mul(&mut self, other: &mut Fp2) -> Fp2 {
56        let c0 = &mut self.c0 * &mut other.c0 - &mut self.c1 * &mut other.c1;
57        let c1 = &mut self.c0 * &mut other.c1 + &mut self.c1 * &mut other.c0;
58        Fp2 { c0, c1 }
59    }
60
61    pub fn square(&mut self) -> Fp2 {
62        let c0 = self.c0.square() - self.c1.square();
63        let c1 = (&mut self.c0 * &mut self.c1).int_mul(2);
64        Fp2 { c0, c1 }
65    }
66
67    pub fn div(&mut self, other: &mut Fp2) -> Fp2 {
68        let builder = self.c0.builder.borrow();
69        let prime = builder.prime.clone();
70        let limb_bits = builder.limb_bits;
71        let num_limbs = builder.num_limbs;
72        let proper_max = builder.proper_max().clone();
73        drop(builder);
74
75        // These are dummy variables, will be replaced later so the index within it doesn't matter.
76        // We use these to check if we need to save self/other first.
77        let fake_z0 = SymbolicExpr::Var(0);
78        let fake_z1 = SymbolicExpr::Var(1);
79
80        // Compute should not be affected by whether auto save is triggered.
81        // So we must do compute first.
82        // Compute z0
83        let compute_denom = &other.c0.expr * &other.c0.expr + &other.c1.expr * &other.c1.expr;
84        let compute_z0_nom = &self.c0.expr * &other.c0.expr + &self.c1.expr * &other.c1.expr;
85        let compute_z0 = &compute_z0_nom / &compute_denom;
86        // Compute z1
87        let compute_z1_nom = &self.c1.expr * &other.c0.expr - &self.c0.expr * &other.c1.expr;
88        let compute_z1 = &compute_z1_nom / &compute_denom;
89
90        // We will constrain
91        //  (1) x0 = y0*z0 - y1*z1 and
92        //  (2) x1 = y1*z0 + y0*z1
93        // which implies z0 and z1 are computed as above.
94        // Observe (1)*y0 + (2)*y1 yields x0*y0 + x1*y1 = z0(y0^2 + y1^2) and so z0 = (x0*y0 + x1*y1) / (y0^2 + y1^2) as needed.
95        // Observe (1)*(-y1) + (2)*y0 yields x1*y0 - x0*y1 = z1(y0^2 + y1^2) and so z1 = (x1*y0 - x0*y1) / (y0^2 + y1^2) as needed.
96
97        // Constraint 1: x0 = y0*z0 - y1*z1
98        let constraint1 = &self.c0.expr - &other.c0.expr * &fake_z0 + &other.c1.expr * &fake_z1;
99        let carry_bits =
100            constraint1.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
101        if carry_bits > self.c0.max_carry_bits {
102            self.save();
103        }
104        let constraint1 = &self.c0.expr - &other.c0.expr * &fake_z0 + &other.c1.expr * &fake_z1;
105        let carry_bits =
106            constraint1.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
107        if carry_bits > self.c0.max_carry_bits {
108            other.save();
109        }
110
111        // Constraint 2: x1 = y1*z0 + y0*z1
112        let constraint2 = &self.c1.expr - &other.c1.expr * &fake_z0 - &other.c0.expr * &fake_z1;
113        let carry_bits =
114            constraint2.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
115        if carry_bits > self.c0.max_carry_bits {
116            self.save();
117        }
118        let constraint2 = &self.c1.expr - &other.c1.expr * &fake_z0 - &other.c0.expr * &fake_z1;
119        let carry_bits =
120            constraint2.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
121        if carry_bits > self.c0.max_carry_bits {
122            other.save();
123        }
124
125        let mut builder = self.c0.builder.borrow_mut();
126        let (z0_idx, z0) = builder.new_var();
127        let (z1_idx, z1) = builder.new_var();
128        let constraint1 = &self.c0.expr - &other.c0.expr * &z0 + &other.c1.expr * &z1;
129        let constraint2 = &self.c1.expr - &other.c1.expr * &z0 - &other.c0.expr * &z1;
130        builder.set_compute(z0_idx, compute_z0);
131        builder.set_compute(z1_idx, compute_z1);
132        builder.set_constraint(z0_idx, constraint1);
133        builder.set_constraint(z1_idx, constraint2);
134        drop(builder);
135
136        let z0_var = FieldVariable::from_var(self.c0.builder.clone(), z0);
137        let z1_var = FieldVariable::from_var(self.c0.builder.clone(), z1);
138        Fp2 {
139            c0: z0_var,
140            c1: z1_var,
141        }
142    }
143
144    pub fn scalar_mul(&mut self, fp: &mut FieldVariable) -> Fp2 {
145        Fp2 {
146            c0: &mut self.c0 * fp,
147            c1: &mut self.c1 * fp,
148        }
149    }
150
151    pub fn int_add(&mut self, c: [isize; 2]) -> Fp2 {
152        Fp2 {
153            c0: self.c0.int_add(c[0]),
154            c1: self.c1.int_add(c[1]),
155        }
156    }
157
158    // c is like a Fp2, but with both c0 and c1 being very small numbers.
159    pub fn int_mul(&mut self, c: [isize; 2]) -> Fp2 {
160        Fp2 {
161            c0: self.c0.int_mul(c[0]) - self.c1.int_mul(c[1]),
162            c1: self.c0.int_mul(c[1]) + self.c1.int_mul(c[0]),
163        }
164    }
165
166    pub fn neg(&mut self) -> Fp2 {
167        self.int_mul([-1, 0])
168    }
169
170    pub fn select(flag_id: usize, a: &Fp2, b: &Fp2) -> Fp2 {
171        Fp2 {
172            c0: FieldVariable::select(flag_id, &a.c0, &b.c0),
173            c1: FieldVariable::select(flag_id, &a.c1, &b.c1),
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use halo2curves_axiom::bn256::Fq2;
181    use num_bigint::BigUint;
182    use openvm_circuit_primitives::TraceSubRowGenerator;
183    use openvm_mod_circuit_builder::{test_utils::*, FieldExpr, FieldExprCols};
184    use openvm_pairing_guest::bn254::BN254_MODULUS;
185    use openvm_stark_backend::{
186        p3_air::BaseAir, p3_field::FieldAlgebra, p3_matrix::dense::RowMajorMatrix,
187    };
188    use openvm_stark_sdk::{
189        any_rap_arc_vec, config::baby_bear_blake3::BabyBearBlake3Engine, engine::StarkFriEngine,
190        p3_baby_bear::BabyBear,
191    };
192
193    use super::Fp2;
194
195    fn two_fp2_input(x: &Fq2, y: &Fq2) -> Vec<BigUint> {
196        vec![
197            bn254_fq_to_biguint(x.c0),
198            bn254_fq_to_biguint(x.c1),
199            bn254_fq_to_biguint(y.c0),
200            bn254_fq_to_biguint(y.c1),
201        ]
202    }
203
204    fn test_fp2(
205        fp2_fn: impl Fn(&mut Fp2, &mut Fp2) -> Fp2,
206        fq2_fn: impl Fn(&Fq2, &Fq2) -> Fq2,
207        save_result: bool,
208    ) {
209        let prime = BN254_MODULUS.clone();
210        let (range_checker, builder) = setup(&prime);
211
212        let mut x_fp2 = Fp2::new(builder.clone());
213        let mut y_fp2 = Fp2::new(builder.clone());
214        let mut r = fp2_fn(&mut x_fp2, &mut y_fp2);
215        if save_result {
216            r.save();
217        }
218
219        let builder = builder.borrow().clone();
220        let air = FieldExpr::new(builder, range_checker.bus(), false);
221        let width = BaseAir::<BabyBear>::width(&air);
222
223        let x_fp2 = bn254_fq2_random(1);
224        let y_fp2 = bn254_fq2_random(5);
225        let r_fp2 = fq2_fn(&x_fp2, &y_fp2);
226        let inputs = two_fp2_input(&x_fp2, &y_fp2);
227
228        let mut row = BabyBear::zero_vec(width);
229        air.generate_subrow((&range_checker, inputs, vec![]), &mut row);
230        let FieldExprCols { vars, .. } = air.load_vars(&row);
231        let trace = RowMajorMatrix::new(row, width);
232        let range_trace = range_checker.generate_trace();
233        assert_eq!(vars.len(), 2);
234        let r_c0 = evaluate_biguint(&vars[0], LIMB_BITS);
235        let r_c1 = evaluate_biguint(&vars[1], LIMB_BITS);
236        let expected_c0 = bn254_fq_to_biguint(r_fp2.c0);
237        let expected_c1 = bn254_fq_to_biguint(r_fp2.c1);
238        assert_eq!(r_c0, expected_c0);
239        assert_eq!(r_c1, expected_c1);
240
241        BabyBearBlake3Engine::run_simple_test_no_pis_fast(
242            any_rap_arc_vec![air, range_checker.air],
243            vec![trace, range_trace],
244        )
245        .expect("Verification failed");
246    }
247
248    #[test]
249    fn test_fp2_add() {
250        test_fp2(Fp2::add, |x, y| x + y, true);
251    }
252
253    #[test]
254    fn test_fp2_sub() {
255        test_fp2(Fp2::sub, |x, y| x - y, true);
256    }
257
258    #[test]
259    fn test_fp2_mul() {
260        test_fp2(Fp2::mul, |x, y| x * y, true);
261    }
262
263    #[test]
264    fn test_fp2_div() {
265        test_fp2(Fp2::div, |x, y| x * y.invert().unwrap(), false);
266    }
267
268    #[test]
269    fn test_fp2_div2() {
270        let prime = BN254_MODULUS.clone();
271        let (range_checker, builder) = setup(&prime);
272
273        let mut x_fp2 = Fp2::new(builder.clone());
274        let mut y_fp2 = Fp2::new(builder.clone());
275        let mut z_fp2 = Fp2::new(builder.clone());
276        let mut xy = x_fp2.mul(&mut y_fp2);
277        let _r = xy.div(&mut z_fp2);
278        // no need to save as div auto save.
279
280        let builder = builder.borrow().clone();
281        let air = FieldExpr::new(builder, range_checker.bus(), false);
282        let width = BaseAir::<BabyBear>::width(&air);
283
284        let x_fp2 = bn254_fq2_random(5);
285        let y_fp2 = bn254_fq2_random(15);
286        let z_fp2 = bn254_fq2_random(95);
287        let r_fp2 = z_fp2.invert().unwrap() * x_fp2 * y_fp2;
288        let inputs = vec![
289            bn254_fq_to_biguint(x_fp2.c0),
290            bn254_fq_to_biguint(x_fp2.c1),
291            bn254_fq_to_biguint(y_fp2.c0),
292            bn254_fq_to_biguint(y_fp2.c1),
293            bn254_fq_to_biguint(z_fp2.c0),
294            bn254_fq_to_biguint(z_fp2.c1),
295        ];
296        let mut row = BabyBear::zero_vec(width);
297        air.generate_subrow((&range_checker, inputs, vec![]), &mut row);
298        let FieldExprCols { vars, .. } = air.load_vars(&row);
299        let trace = RowMajorMatrix::new(row, width);
300        let range_trace = range_checker.generate_trace();
301        assert_eq!(vars.len(), 2);
302        let r_c0 = evaluate_biguint(&vars[0], LIMB_BITS);
303        let r_c1 = evaluate_biguint(&vars[1], LIMB_BITS);
304        let expected_c0 = bn254_fq_to_biguint(r_fp2.c0);
305        let expected_c1 = bn254_fq_to_biguint(r_fp2.c1);
306        assert_eq!(r_c0, expected_c0);
307        assert_eq!(r_c1, expected_c1);
308
309        BabyBearBlake3Engine::run_simple_test_no_pis_fast(
310            any_rap_arc_vec![air, range_checker.air],
311            vec![trace, range_trace],
312        )
313        .expect("Verification failed");
314    }
315}