openvm_pairing_guest/halo2curves_shims/bls12_381/
final_exp.rs

1use halo2curves_axiom::bls12_381::{Fq, Fq12, Fq2};
2use num_bigint::BigUint;
3use openvm_ecc_guest::{
4    algebra::{ExpBytes, Field},
5    AffinePoint,
6};
7
8use super::{Bls12_381, FINAL_EXP_FACTOR, LAMBDA, POLY_FACTOR};
9use crate::pairing::{FinalExp, MultiMillerLoop};
10
11// The paper only describes the implementation for Bn254, so we use the gnark implementation for
12// Bls12_381.
13#[allow(non_snake_case)]
14impl FinalExp for Bls12_381 {
15    type Fp = Fq;
16    type Fp2 = Fq2;
17    type Fp12 = Fq12;
18
19    // Adapted from the gnark implementation:
20    // https://github.com/Consensys/gnark/blob/af754dd1c47a92be375930ae1abfbd134c5310d8/std/algebra/emulated/fields_bls12381/e12_pairing.go#L394C1-L395C1
21    fn assert_final_exp_is_one(
22        f: &Self::Fp12,
23        P: &[AffinePoint<Self::Fp>],
24        Q: &[AffinePoint<Self::Fp2>],
25    ) {
26        let (c, s) = Self::final_exp_hint(f);
27
28        // The gnark implementation checks that f * s = c^{q - x} where x is the curve seed.
29        // We check an equivalent condition: f * c^x * c^-q * s = 1.
30        // This is because we can compute f * c^x by embedding the c^x computation in the miller
31        // loop.
32
33        // Since the Bls12_381 curve has a negative seed, the miller loop for Bls12_381 is computed
34        // as f_{Miller,x,Q}(P) = conjugate( f_{Miller,-x,Q}(P) * c^{-x} ).
35        // We will pass in the conjugate inverse of c into the miller loop so that we compute
36        // fc = f_{Miller,x,Q}(P)
37        //    = conjugate( f_{Miller,-x,Q}(P) * c'^{-x} )  (where c' is the conjugate inverse of c)
38        //    = f_{Miller,x,Q}(P) * c^x
39        let c_conj_inv = c.conjugate().invert().unwrap();
40        let c_inv = c.invert().unwrap();
41        let c_q_inv = c_inv.frobenius_map();
42        let fc = Self::multi_miller_loop_embedded_exp(P, Q, Some(c_conj_inv));
43
44        assert_eq!(fc * c_q_inv * s, Fq12::ONE);
45    }
46
47    // Adapted from the gnark implementation:
48    // https://github.com/Consensys/gnark/blob/af754dd1c47a92be375930ae1abfbd134c5310d8/std/algebra/emulated/fields_bls12381/hints.go#L273
49    // returns c (residueWitness) and s (scalingFactor)
50    // The Gnark implementation is based on https://eprint.iacr.org/2024/640.pdf
51    fn final_exp_hint(f: &Self::Fp12) -> (Self::Fp12, Self::Fp12) {
52        // 1. get p-th root inverse
53        let mut exp = FINAL_EXP_FACTOR.clone() * BigUint::from(27u32);
54        let mut root = f.exp_bytes(true, &exp.to_bytes_be());
55        let root_pth_inv: Fq12;
56        if root == Fq12::ONE {
57            root_pth_inv = Fq12::ONE;
58        } else {
59            let exp_inv = exp.modinv(&POLY_FACTOR.clone()).unwrap();
60            exp = exp_inv % POLY_FACTOR.clone();
61            root_pth_inv = root.exp_bytes(false, &exp.to_bytes_be());
62        }
63
64        // 2.1. get order of 3rd primitive root
65        let three = BigUint::from(3u32);
66        let mut order_3rd_power: u32 = 0;
67        exp = POLY_FACTOR.clone() * FINAL_EXP_FACTOR.clone();
68
69        root = f.exp_bytes(true, &exp.to_bytes_be());
70        let three_be = three.to_bytes_be();
71        // NOTE[yj]: we can probably remove this first check as an optimization since we initizlize
72        // order_3rd_power to 0
73        if root == Fq12::ONE {
74            order_3rd_power = 0;
75        }
76        root = root.exp_bytes(true, &three_be);
77        if root == Fq12::ONE {
78            order_3rd_power = 1;
79        }
80        root = root.exp_bytes(true, &three_be);
81        if root == Fq12::ONE {
82            order_3rd_power = 2;
83        }
84        root = root.exp_bytes(true, &three_be);
85        if root == Fq12::ONE {
86            order_3rd_power = 3;
87        }
88
89        // 2.2. get 27th root inverse
90        let root_27th_inv: Fq12;
91        if order_3rd_power == 0 {
92            root_27th_inv = Fq12::ONE;
93        } else {
94            let order_3rd = three.pow(order_3rd_power);
95            exp = POLY_FACTOR.clone() * FINAL_EXP_FACTOR.clone();
96            root = f.exp_bytes(true, &exp.to_bytes_be());
97            let exp_inv = exp.modinv(&order_3rd).unwrap();
98            exp = exp_inv % order_3rd;
99            root_27th_inv = root.exp_bytes(false, &exp.to_bytes_be());
100        }
101
102        // 2.3. shift the Miller loop result so that millerLoop * scalingFactor
103        // is of order finalExpFactor
104        let s = root_pth_inv * root_27th_inv;
105        let f = f * s;
106
107        // 3. get the witness residue
108        // lambda = q - u, the optimal exponent
109        exp = LAMBDA.clone().modinv(&FINAL_EXP_FACTOR.clone()).unwrap();
110        let c = f.exp_bytes(true, &exp.to_bytes_be());
111
112        (c, s)
113    }
114}