openvm_pairing_guest/halo2curves_shims/bls12_381/
final_exp.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use halo2curves_axiom::bls12_381::{Fq, Fq12, Fq2};
use num_bigint::BigUint;
use openvm_ecc_guest::{
    algebra::{ExpBytes, Field},
    AffinePoint,
};

use super::{Bls12_381, FINAL_EXP_FACTOR, LAMBDA, POLY_FACTOR};
use crate::pairing::{FinalExp, MultiMillerLoop};

#[allow(non_snake_case)]
impl FinalExp for Bls12_381 {
    type Fp = Fq;
    type Fp2 = Fq2;
    type Fp12 = Fq12;

    fn assert_final_exp_is_one(
        f: &Self::Fp12,
        P: &[AffinePoint<Self::Fp>],
        Q: &[AffinePoint<Self::Fp2>],
    ) {
        let (c, s) = Self::final_exp_hint(f);

        // f * s = c^{q - x}
        // f * s = c^q * c^-x
        // f * c^x * c^-q * s = 1,
        //   where fc = f * c'^x (embedded Miller loop with c conjugate inverse),
        //   and the curve seed x = -0xd201000000010000
        //   the miller loop computation includes a conjugation at the end because the value of the
        //   seed is negative, so we need to conjugate the miller loop input c as c'. We then substitute
        //   y = -x to get c^-y and finally compute c'^-y as input to the miller loop:
        // f * c'^-y * c^-q * s = 1
        let c_inv = c.invert().unwrap();
        let c_conj_inv = c.conjugate().invert().unwrap();
        let c_q_inv = c_inv.frobenius_map();

        // fc = f_{Miller,x,Q}(P) * c^{x}
        // where
        //   fc = conjugate( f_{Miller,-x,Q}(P) * c'^{-x} ), with c' denoting the conjugate of c
        let fc = Self::multi_miller_loop_embedded_exp(P, Q, Some(c_conj_inv));

        assert_eq!(fc * c_q_inv * s, Fq12::ONE);
    }

    // Adapted from the gnark implementation:
    // https://github.com/Consensys/gnark/blob/af754dd1c47a92be375930ae1abfbd134c5310d8/std/algebra/emulated/fields_bls12381/hints.go#L273
    // returns c (residueWitness) and s (scalingFactor)
    fn final_exp_hint(f: &Self::Fp12) -> (Self::Fp12, Self::Fp12) {
        // 1. get p-th root inverse
        let mut exp = FINAL_EXP_FACTOR.clone() * BigUint::from(27u32);
        let mut root = f.exp_bytes(true, &exp.to_bytes_be());
        let root_pth_inv: Fq12;
        if root == Fq12::ONE {
            root_pth_inv = Fq12::ONE;
        } else {
            let exp_inv = exp.modinv(&POLY_FACTOR.clone()).unwrap();
            exp = exp_inv % POLY_FACTOR.clone();
            root_pth_inv = root.exp_bytes(false, &exp.to_bytes_be());
        }

        // 2.1. get order of 3rd primitive root
        let three = BigUint::from(3u32);
        let mut order_3rd_power: u32 = 0;
        exp = POLY_FACTOR.clone() * FINAL_EXP_FACTOR.clone();

        root = f.exp_bytes(true, &exp.to_bytes_be());
        let three_be = three.to_bytes_be();
        // NOTE[yj]: we can probably remove this first check as an optimization since we initizlize order_3rd_power to 0
        if root == Fq12::ONE {
            order_3rd_power = 0;
        }
        root = root.exp_bytes(true, &three_be);
        if root == Fq12::ONE {
            order_3rd_power = 1;
        }
        root = root.exp_bytes(true, &three_be);
        if root == Fq12::ONE {
            order_3rd_power = 2;
        }
        root = root.exp_bytes(true, &three_be);
        if root == Fq12::ONE {
            order_3rd_power = 3;
        }

        // 2.2. get 27th root inverse
        let root_27th_inv: Fq12;
        if order_3rd_power == 0 {
            root_27th_inv = Fq12::ONE;
        } else {
            let order_3rd = three.pow(order_3rd_power);
            exp = POLY_FACTOR.clone() * FINAL_EXP_FACTOR.clone();
            root = f.exp_bytes(true, &exp.to_bytes_be());
            let exp_inv = exp.modinv(&order_3rd).unwrap();
            exp = exp_inv % order_3rd;
            root_27th_inv = root.exp_bytes(false, &exp.to_bytes_be());
        }

        // 2.3. shift the Miller loop result so that millerLoop * scalingFactor
        // is of order finalExpFactor
        let s = root_pth_inv * root_27th_inv;
        let f = f * s;

        // 3. get the witness residue
        // lambda = q - u, the optimal exponent
        exp = LAMBDA.clone().modinv(&FINAL_EXP_FACTOR.clone()).unwrap();
        let c = f.exp_bytes(true, &exp.to_bytes_be());

        (c, s)
    }
}