1use std::{cell::RefCell, rc::Rc};
2
3use openvm_mod_circuit_builder::{ExprBuilder, FieldVariable, SymbolicExpr};
4
5#[derive(Clone)]
9pub struct Fp2 {
10 pub c0: FieldVariable,
11 pub c1: FieldVariable,
12}
13
14impl Fp2 {
15 pub fn new(builder: Rc<RefCell<ExprBuilder>>) -> Self {
16 let c0 = ExprBuilder::new_input(builder.clone());
17 let c1 = ExprBuilder::new_input(builder.clone());
18 Fp2 { c0, c1 }
19 }
20
21 pub fn new_var(builder: Rc<RefCell<ExprBuilder>>) -> ((usize, usize), Fp2) {
22 let (c0_idx, c0) = builder.borrow_mut().new_var();
23 let (c1_idx, c1) = builder.borrow_mut().new_var();
24 let fp2 = Fp2 {
25 c0: FieldVariable::from_var(builder.clone(), c0),
26 c1: FieldVariable::from_var(builder.clone(), c1),
27 };
28 ((c0_idx, c1_idx), fp2)
29 }
30
31 pub fn save(&mut self) -> [usize; 2] {
32 let c0_idx = self.c0.save();
33 let c1_idx = self.c1.save();
34 [c0_idx, c1_idx]
35 }
36
37 pub fn save_output(&mut self) {
38 self.c0.save_output();
39 self.c1.save_output();
40 }
41
42 pub fn add(&mut self, other: &mut Fp2) -> Fp2 {
43 Fp2 {
44 c0: &mut self.c0 + &mut other.c0,
45 c1: &mut self.c1 + &mut other.c1,
46 }
47 }
48
49 pub fn sub(&mut self, other: &mut Fp2) -> Fp2 {
50 Fp2 {
51 c0: &mut self.c0 - &mut other.c0,
52 c1: &mut self.c1 - &mut other.c1,
53 }
54 }
55
56 pub fn mul(&mut self, other: &mut Fp2) -> Fp2 {
57 let c0 = &mut self.c0 * &mut other.c0 - &mut self.c1 * &mut other.c1;
58 let c1 = &mut self.c0 * &mut other.c1 + &mut self.c1 * &mut other.c0;
59 Fp2 { c0, c1 }
60 }
61
62 pub fn square(&mut self) -> Fp2 {
63 let c0 = self.c0.square() - self.c1.square();
64 let c1 = (&mut self.c0 * &mut self.c1).int_mul(2);
65 Fp2 { c0, c1 }
66 }
67
68 pub fn div(&mut self, other: &mut Fp2) -> Fp2 {
69 let builder = self.c0.builder.borrow();
70 let prime = builder.prime.clone();
71 let limb_bits = builder.limb_bits;
72 let num_limbs = builder.num_limbs;
73 let proper_max = builder.proper_max().clone();
74 drop(builder);
75
76 let fake_z0 = SymbolicExpr::Var(0);
79 let fake_z1 = SymbolicExpr::Var(1);
80
81 let compute_denom = &other.c0.expr * &other.c0.expr + &other.c1.expr * &other.c1.expr;
85 let compute_z0_nom = &self.c0.expr * &other.c0.expr + &self.c1.expr * &other.c1.expr;
86 let compute_z0 = &compute_z0_nom / &compute_denom;
87 let compute_z1_nom = &self.c1.expr * &other.c0.expr - &self.c0.expr * &other.c1.expr;
89 let compute_z1 = &compute_z1_nom / &compute_denom;
90
91 let constraint1 = &self.c0.expr - &other.c0.expr * &fake_z0 + &other.c1.expr * &fake_z1;
101 let carry_bits =
102 constraint1.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
103 if carry_bits > self.c0.max_carry_bits {
104 self.save();
105 }
106 let constraint1 = &self.c0.expr - &other.c0.expr * &fake_z0 + &other.c1.expr * &fake_z1;
107 let carry_bits =
108 constraint1.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
109 if carry_bits > self.c0.max_carry_bits {
110 other.save();
111 }
112
113 let constraint2 = &self.c1.expr - &other.c1.expr * &fake_z0 - &other.c0.expr * &fake_z1;
115 let carry_bits =
116 constraint2.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
117 if carry_bits > self.c0.max_carry_bits {
118 self.save();
119 }
120 let constraint2 = &self.c1.expr - &other.c1.expr * &fake_z0 - &other.c0.expr * &fake_z1;
121 let carry_bits =
122 constraint2.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
123 if carry_bits > self.c0.max_carry_bits {
124 other.save();
125 }
126
127 let mut builder = self.c0.builder.borrow_mut();
128 let (z0_idx, z0) = builder.new_var();
129 let (z1_idx, z1) = builder.new_var();
130 let constraint1 = &self.c0.expr - &other.c0.expr * &z0 + &other.c1.expr * &z1;
131 let constraint2 = &self.c1.expr - &other.c1.expr * &z0 - &other.c0.expr * &z1;
132 builder.set_compute(z0_idx, compute_z0);
133 builder.set_compute(z1_idx, compute_z1);
134 builder.set_constraint(z0_idx, constraint1);
135 builder.set_constraint(z1_idx, constraint2);
136 drop(builder);
137
138 let z0_var = FieldVariable::from_var(self.c0.builder.clone(), z0);
139 let z1_var = FieldVariable::from_var(self.c0.builder.clone(), z1);
140 Fp2 {
141 c0: z0_var,
142 c1: z1_var,
143 }
144 }
145
146 pub fn scalar_mul(&mut self, fp: &mut FieldVariable) -> Fp2 {
147 Fp2 {
148 c0: &mut self.c0 * fp,
149 c1: &mut self.c1 * fp,
150 }
151 }
152
153 pub fn int_add(&mut self, c: [isize; 2]) -> Fp2 {
154 Fp2 {
155 c0: self.c0.int_add(c[0]),
156 c1: self.c1.int_add(c[1]),
157 }
158 }
159
160 pub fn int_mul(&mut self, c: [isize; 2]) -> Fp2 {
162 Fp2 {
163 c0: self.c0.int_mul(c[0]) - self.c1.int_mul(c[1]),
164 c1: self.c0.int_mul(c[1]) + self.c1.int_mul(c[0]),
165 }
166 }
167
168 pub fn neg(&mut self) -> Fp2 {
169 self.int_mul([-1, 0])
170 }
171
172 pub fn select(flag_id: usize, a: &Fp2, b: &Fp2) -> Fp2 {
173 Fp2 {
174 c0: FieldVariable::select(flag_id, &a.c0, &b.c0),
175 c1: FieldVariable::select(flag_id, &a.c1, &b.c1),
176 }
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use halo2curves_axiom::bn256::Fq2;
183 use num_bigint::BigUint;
184 use openvm_circuit_primitives::TraceSubRowGenerator;
185 use openvm_mod_circuit_builder::{test_utils::*, FieldExpr, FieldExprCols};
186 use openvm_pairing_guest::bn254::BN254_MODULUS;
187 use openvm_stark_backend::{
188 p3_air::BaseAir, p3_field::FieldAlgebra, p3_matrix::dense::RowMajorMatrix,
189 };
190 use openvm_stark_sdk::{
191 any_rap_arc_vec, config::baby_bear_blake3::BabyBearBlake3Engine, engine::StarkFriEngine,
192 p3_baby_bear::BabyBear,
193 };
194
195 use super::Fp2;
196
197 fn two_fp2_input(x: &Fq2, y: &Fq2) -> Vec<BigUint> {
198 vec![
199 bn254_fq_to_biguint(x.c0),
200 bn254_fq_to_biguint(x.c1),
201 bn254_fq_to_biguint(y.c0),
202 bn254_fq_to_biguint(y.c1),
203 ]
204 }
205
206 fn test_fp2(
207 fp2_fn: impl Fn(&mut Fp2, &mut Fp2) -> Fp2,
208 fq2_fn: impl Fn(&Fq2, &Fq2) -> Fq2,
209 save_result: bool,
210 ) {
211 let prime = BN254_MODULUS.clone();
212 let (range_checker, builder) = setup(&prime);
213
214 let mut x_fp2 = Fp2::new(builder.clone());
215 let mut y_fp2 = Fp2::new(builder.clone());
216 let mut r = fp2_fn(&mut x_fp2, &mut y_fp2);
217 if save_result {
218 r.save();
219 }
220
221 let builder = builder.borrow().clone();
222 let air = FieldExpr::new(builder, range_checker.bus(), false);
223 let width = BaseAir::<BabyBear>::width(&air);
224
225 let x_fp2 = bn254_fq2_random(1);
226 let y_fp2 = bn254_fq2_random(5);
227 let r_fp2 = fq2_fn(&x_fp2, &y_fp2);
228 let inputs = two_fp2_input(&x_fp2, &y_fp2);
229
230 let mut row = BabyBear::zero_vec(width);
231 air.generate_subrow((&range_checker, inputs, vec![]), &mut row);
232 let FieldExprCols { vars, .. } = air.load_vars(&row);
233 let trace = RowMajorMatrix::new(row, width);
234 let range_trace = range_checker.generate_trace();
235 assert_eq!(vars.len(), 2);
236 let r_c0 = evaluate_biguint(&vars[0], LIMB_BITS);
237 let r_c1 = evaluate_biguint(&vars[1], LIMB_BITS);
238 let expected_c0 = bn254_fq_to_biguint(r_fp2.c0);
239 let expected_c1 = bn254_fq_to_biguint(r_fp2.c1);
240 assert_eq!(r_c0, expected_c0);
241 assert_eq!(r_c1, expected_c1);
242
243 BabyBearBlake3Engine::run_simple_test_no_pis_fast(
244 any_rap_arc_vec![air, range_checker.air],
245 vec![trace, range_trace],
246 )
247 .expect("Verification failed");
248 }
249
250 #[test]
251 fn test_fp2_add() {
252 test_fp2(Fp2::add, |x, y| x + y, true);
253 }
254
255 #[test]
256 fn test_fp2_sub() {
257 test_fp2(Fp2::sub, |x, y| x - y, true);
258 }
259
260 #[test]
261 fn test_fp2_mul() {
262 test_fp2(Fp2::mul, |x, y| x * y, true);
263 }
264
265 #[test]
266 fn test_fp2_div() {
267 test_fp2(Fp2::div, |x, y| x * y.invert().unwrap(), false);
268 }
269
270 #[test]
271 fn test_fp2_div2() {
272 let prime = BN254_MODULUS.clone();
273 let (range_checker, builder) = setup(&prime);
274
275 let mut x_fp2 = Fp2::new(builder.clone());
276 let mut y_fp2 = Fp2::new(builder.clone());
277 let mut z_fp2 = Fp2::new(builder.clone());
278 let mut xy = x_fp2.mul(&mut y_fp2);
279 let _r = xy.div(&mut z_fp2);
280 let builder = builder.borrow().clone();
283 let air = FieldExpr::new(builder, range_checker.bus(), false);
284 let width = BaseAir::<BabyBear>::width(&air);
285
286 let x_fp2 = bn254_fq2_random(5);
287 let y_fp2 = bn254_fq2_random(15);
288 let z_fp2 = bn254_fq2_random(95);
289 let r_fp2 = z_fp2.invert().unwrap() * x_fp2 * y_fp2;
290 let inputs = vec![
291 bn254_fq_to_biguint(x_fp2.c0),
292 bn254_fq_to_biguint(x_fp2.c1),
293 bn254_fq_to_biguint(y_fp2.c0),
294 bn254_fq_to_biguint(y_fp2.c1),
295 bn254_fq_to_biguint(z_fp2.c0),
296 bn254_fq_to_biguint(z_fp2.c1),
297 ];
298 let mut row = BabyBear::zero_vec(width);
299 air.generate_subrow((&range_checker, inputs, vec![]), &mut row);
300 let FieldExprCols { vars, .. } = air.load_vars(&row);
301 let trace = RowMajorMatrix::new(row, width);
302 let range_trace = range_checker.generate_trace();
303 assert_eq!(vars.len(), 2);
304 let r_c0 = evaluate_biguint(&vars[0], LIMB_BITS);
305 let r_c1 = evaluate_biguint(&vars[1], LIMB_BITS);
306 let expected_c0 = bn254_fq_to_biguint(r_fp2.c0);
307 let expected_c1 = bn254_fq_to_biguint(r_fp2.c1);
308 assert_eq!(r_c0, expected_c0);
309 assert_eq!(r_c1, expected_c1);
310
311 BabyBearBlake3Engine::run_simple_test_no_pis_fast(
312 any_rap_arc_vec![air, range_checker.air],
313 vec![trace, range_trace],
314 )
315 .expect("Verification failed");
316 }
317}