1use std::{cell::RefCell, rc::Rc};
2
3use openvm_mod_circuit_builder::{ExprBuilder, FieldVariable, SymbolicExpr};
4
5#[derive(Clone)]
8pub struct Fp2 {
9 pub c0: FieldVariable,
10 pub c1: FieldVariable,
11}
12
13impl Fp2 {
14 pub fn new(builder: Rc<RefCell<ExprBuilder>>) -> Self {
15 let c0 = ExprBuilder::new_input(builder.clone());
16 let c1 = ExprBuilder::new_input(builder.clone());
17 Fp2 { c0, c1 }
18 }
19
20 pub fn new_var(builder: Rc<RefCell<ExprBuilder>>) -> ((usize, usize), Fp2) {
21 let (c0_idx, c0) = builder.borrow_mut().new_var();
22 let (c1_idx, c1) = builder.borrow_mut().new_var();
23 let fp2 = Fp2 {
24 c0: FieldVariable::from_var(builder.clone(), c0),
25 c1: FieldVariable::from_var(builder.clone(), c1),
26 };
27 ((c0_idx, c1_idx), fp2)
28 }
29
30 pub fn save(&mut self) -> [usize; 2] {
31 let c0_idx = self.c0.save();
32 let c1_idx = self.c1.save();
33 [c0_idx, c1_idx]
34 }
35
36 pub fn save_output(&mut self) {
37 self.c0.save_output();
38 self.c1.save_output();
39 }
40
41 pub fn add(&mut self, other: &mut Fp2) -> Fp2 {
42 Fp2 {
43 c0: &mut self.c0 + &mut other.c0,
44 c1: &mut self.c1 + &mut other.c1,
45 }
46 }
47
48 pub fn sub(&mut self, other: &mut Fp2) -> Fp2 {
49 Fp2 {
50 c0: &mut self.c0 - &mut other.c0,
51 c1: &mut self.c1 - &mut other.c1,
52 }
53 }
54
55 pub fn mul(&mut self, other: &mut Fp2) -> Fp2 {
56 let c0 = &mut self.c0 * &mut other.c0 - &mut self.c1 * &mut other.c1;
57 let c1 = &mut self.c0 * &mut other.c1 + &mut self.c1 * &mut other.c0;
58 Fp2 { c0, c1 }
59 }
60
61 pub fn square(&mut self) -> Fp2 {
62 let c0 = self.c0.square() - self.c1.square();
63 let c1 = (&mut self.c0 * &mut self.c1).int_mul(2);
64 Fp2 { c0, c1 }
65 }
66
67 pub fn div(&mut self, other: &mut Fp2) -> Fp2 {
68 let builder = self.c0.builder.borrow();
69 let prime = builder.prime.clone();
70 let limb_bits = builder.limb_bits;
71 let num_limbs = builder.num_limbs;
72 let proper_max = builder.proper_max().clone();
73 drop(builder);
74
75 let fake_z0 = SymbolicExpr::Var(0);
78 let fake_z1 = SymbolicExpr::Var(1);
79
80 let compute_denom = &other.c0.expr * &other.c0.expr + &other.c1.expr * &other.c1.expr;
84 let compute_z0_nom = &self.c0.expr * &other.c0.expr + &self.c1.expr * &other.c1.expr;
85 let compute_z0 = &compute_z0_nom / &compute_denom;
86 let compute_z1_nom = &self.c1.expr * &other.c0.expr - &self.c0.expr * &other.c1.expr;
88 let compute_z1 = &compute_z1_nom / &compute_denom;
89
90 let constraint1 = &self.c0.expr - &other.c0.expr * &fake_z0 + &other.c1.expr * &fake_z1;
99 let carry_bits =
100 constraint1.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
101 if carry_bits > self.c0.max_carry_bits {
102 self.save();
103 }
104 let constraint1 = &self.c0.expr - &other.c0.expr * &fake_z0 + &other.c1.expr * &fake_z1;
105 let carry_bits =
106 constraint1.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
107 if carry_bits > self.c0.max_carry_bits {
108 other.save();
109 }
110
111 let constraint2 = &self.c1.expr - &other.c1.expr * &fake_z0 - &other.c0.expr * &fake_z1;
113 let carry_bits =
114 constraint2.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
115 if carry_bits > self.c0.max_carry_bits {
116 self.save();
117 }
118 let constraint2 = &self.c1.expr - &other.c1.expr * &fake_z0 - &other.c0.expr * &fake_z1;
119 let carry_bits =
120 constraint2.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
121 if carry_bits > self.c0.max_carry_bits {
122 other.save();
123 }
124
125 let mut builder = self.c0.builder.borrow_mut();
126 let (z0_idx, z0) = builder.new_var();
127 let (z1_idx, z1) = builder.new_var();
128 let constraint1 = &self.c0.expr - &other.c0.expr * &z0 + &other.c1.expr * &z1;
129 let constraint2 = &self.c1.expr - &other.c1.expr * &z0 - &other.c0.expr * &z1;
130 builder.set_compute(z0_idx, compute_z0);
131 builder.set_compute(z1_idx, compute_z1);
132 builder.set_constraint(z0_idx, constraint1);
133 builder.set_constraint(z1_idx, constraint2);
134 drop(builder);
135
136 let z0_var = FieldVariable::from_var(self.c0.builder.clone(), z0);
137 let z1_var = FieldVariable::from_var(self.c0.builder.clone(), z1);
138 Fp2 {
139 c0: z0_var,
140 c1: z1_var,
141 }
142 }
143
144 pub fn scalar_mul(&mut self, fp: &mut FieldVariable) -> Fp2 {
145 Fp2 {
146 c0: &mut self.c0 * fp,
147 c1: &mut self.c1 * fp,
148 }
149 }
150
151 pub fn int_add(&mut self, c: [isize; 2]) -> Fp2 {
152 Fp2 {
153 c0: self.c0.int_add(c[0]),
154 c1: self.c1.int_add(c[1]),
155 }
156 }
157
158 pub fn int_mul(&mut self, c: [isize; 2]) -> Fp2 {
160 Fp2 {
161 c0: self.c0.int_mul(c[0]) - self.c1.int_mul(c[1]),
162 c1: self.c0.int_mul(c[1]) + self.c1.int_mul(c[0]),
163 }
164 }
165
166 pub fn neg(&mut self) -> Fp2 {
167 self.int_mul([-1, 0])
168 }
169
170 pub fn select(flag_id: usize, a: &Fp2, b: &Fp2) -> Fp2 {
171 Fp2 {
172 c0: FieldVariable::select(flag_id, &a.c0, &b.c0),
173 c1: FieldVariable::select(flag_id, &a.c1, &b.c1),
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use halo2curves_axiom::bn256::Fq2;
181 use num_bigint::BigUint;
182 use openvm_circuit_primitives::TraceSubRowGenerator;
183 use openvm_mod_circuit_builder::{test_utils::*, FieldExpr, FieldExprCols};
184 use openvm_pairing_guest::bn254::BN254_MODULUS;
185 use openvm_stark_backend::{
186 p3_air::BaseAir, p3_field::FieldAlgebra, p3_matrix::dense::RowMajorMatrix,
187 };
188 use openvm_stark_sdk::{
189 any_rap_arc_vec, config::baby_bear_blake3::BabyBearBlake3Engine, engine::StarkFriEngine,
190 p3_baby_bear::BabyBear,
191 };
192
193 use super::Fp2;
194
195 fn two_fp2_input(x: &Fq2, y: &Fq2) -> Vec<BigUint> {
196 vec![
197 bn254_fq_to_biguint(x.c0),
198 bn254_fq_to_biguint(x.c1),
199 bn254_fq_to_biguint(y.c0),
200 bn254_fq_to_biguint(y.c1),
201 ]
202 }
203
204 fn test_fp2(
205 fp2_fn: impl Fn(&mut Fp2, &mut Fp2) -> Fp2,
206 fq2_fn: impl Fn(&Fq2, &Fq2) -> Fq2,
207 save_result: bool,
208 ) {
209 let prime = BN254_MODULUS.clone();
210 let (range_checker, builder) = setup(&prime);
211
212 let mut x_fp2 = Fp2::new(builder.clone());
213 let mut y_fp2 = Fp2::new(builder.clone());
214 let mut r = fp2_fn(&mut x_fp2, &mut y_fp2);
215 if save_result {
216 r.save();
217 }
218
219 let builder = builder.borrow().clone();
220 let air = FieldExpr::new(builder, range_checker.bus(), false);
221 let width = BaseAir::<BabyBear>::width(&air);
222
223 let x_fp2 = bn254_fq2_random(1);
224 let y_fp2 = bn254_fq2_random(5);
225 let r_fp2 = fq2_fn(&x_fp2, &y_fp2);
226 let inputs = two_fp2_input(&x_fp2, &y_fp2);
227
228 let mut row = BabyBear::zero_vec(width);
229 air.generate_subrow((&range_checker, inputs, vec![]), &mut row);
230 let FieldExprCols { vars, .. } = air.load_vars(&row);
231 let trace = RowMajorMatrix::new(row, width);
232 let range_trace = range_checker.generate_trace();
233 assert_eq!(vars.len(), 2);
234 let r_c0 = evaluate_biguint(&vars[0], LIMB_BITS);
235 let r_c1 = evaluate_biguint(&vars[1], LIMB_BITS);
236 let expected_c0 = bn254_fq_to_biguint(r_fp2.c0);
237 let expected_c1 = bn254_fq_to_biguint(r_fp2.c1);
238 assert_eq!(r_c0, expected_c0);
239 assert_eq!(r_c1, expected_c1);
240
241 BabyBearBlake3Engine::run_simple_test_no_pis_fast(
242 any_rap_arc_vec![air, range_checker.air],
243 vec![trace, range_trace],
244 )
245 .expect("Verification failed");
246 }
247
248 #[test]
249 fn test_fp2_add() {
250 test_fp2(Fp2::add, |x, y| x + y, true);
251 }
252
253 #[test]
254 fn test_fp2_sub() {
255 test_fp2(Fp2::sub, |x, y| x - y, true);
256 }
257
258 #[test]
259 fn test_fp2_mul() {
260 test_fp2(Fp2::mul, |x, y| x * y, true);
261 }
262
263 #[test]
264 fn test_fp2_div() {
265 test_fp2(Fp2::div, |x, y| x * y.invert().unwrap(), false);
266 }
267
268 #[test]
269 fn test_fp2_div2() {
270 let prime = BN254_MODULUS.clone();
271 let (range_checker, builder) = setup(&prime);
272
273 let mut x_fp2 = Fp2::new(builder.clone());
274 let mut y_fp2 = Fp2::new(builder.clone());
275 let mut z_fp2 = Fp2::new(builder.clone());
276 let mut xy = x_fp2.mul(&mut y_fp2);
277 let _r = xy.div(&mut z_fp2);
278 let builder = builder.borrow().clone();
281 let air = FieldExpr::new(builder, range_checker.bus(), false);
282 let width = BaseAir::<BabyBear>::width(&air);
283
284 let x_fp2 = bn254_fq2_random(5);
285 let y_fp2 = bn254_fq2_random(15);
286 let z_fp2 = bn254_fq2_random(95);
287 let r_fp2 = z_fp2.invert().unwrap() * x_fp2 * y_fp2;
288 let inputs = vec![
289 bn254_fq_to_biguint(x_fp2.c0),
290 bn254_fq_to_biguint(x_fp2.c1),
291 bn254_fq_to_biguint(y_fp2.c0),
292 bn254_fq_to_biguint(y_fp2.c1),
293 bn254_fq_to_biguint(z_fp2.c0),
294 bn254_fq_to_biguint(z_fp2.c1),
295 ];
296 let mut row = BabyBear::zero_vec(width);
297 air.generate_subrow((&range_checker, inputs, vec![]), &mut row);
298 let FieldExprCols { vars, .. } = air.load_vars(&row);
299 let trace = RowMajorMatrix::new(row, width);
300 let range_trace = range_checker.generate_trace();
301 assert_eq!(vars.len(), 2);
302 let r_c0 = evaluate_biguint(&vars[0], LIMB_BITS);
303 let r_c1 = evaluate_biguint(&vars[1], LIMB_BITS);
304 let expected_c0 = bn254_fq_to_biguint(r_fp2.c0);
305 let expected_c1 = bn254_fq_to_biguint(r_fp2.c1);
306 assert_eq!(r_c0, expected_c0);
307 assert_eq!(r_c1, expected_c1);
308
309 BabyBearBlake3Engine::run_simple_test_no_pis_fast(
310 any_rap_arc_vec![air, range_checker.air],
311 vec![trace, range_trace],
312 )
313 .expect("Verification failed");
314 }
315}