halo2curves_axiom/ff_ext/
inverse.rs

1use core::cmp::PartialEq;
2use std::ops::{Add, Mul, Neg, Sub};
3
4/// Big signed (B * L)-bit integer type, whose variables store
5/// numbers in the two's complement code as arrays of B-bit chunks.
6/// The ordering of the chunks in these arrays is little-endian.
7/// The arithmetic operations for this type are wrapping ones.
8#[derive(Clone)]
9struct CInt<const B: usize, const L: usize>(pub [u64; L]);
10
11impl<const B: usize, const L: usize> CInt<B, L> {
12    /// Mask, in which the B lowest bits are 1 and only they
13    pub const MASK: u64 = u64::MAX >> (64 - B);
14
15    /// Representation of -1
16    pub const MINUS_ONE: Self = Self([Self::MASK; L]);
17
18    /// Representation of 0
19    pub const ZERO: Self = Self([0; L]);
20
21    /// Representation of 1
22    pub const ONE: Self = {
23        let mut data = [0; L];
24        data[0] = 1;
25        Self(data)
26    };
27
28    /// Returns the result of applying B-bit right
29    /// arithmetical shift to the current number
30    pub fn shift(&self) -> Self {
31        let mut data = [0; L];
32        if self.is_negative() {
33            data[L - 1] = Self::MASK;
34        }
35        data[..L - 1].copy_from_slice(&self.0[1..]);
36        Self(data)
37    }
38
39    /// Returns the lowest B bits of the current number
40    pub fn lowest(&self) -> u64 {
41        self.0[0]
42    }
43
44    /// Returns "true" iff the current number is negative
45    pub fn is_negative(&self) -> bool {
46        self.0[L - 1] > (Self::MASK >> 1)
47    }
48}
49
50impl<const B: usize, const L: usize> PartialEq for CInt<B, L> {
51    fn eq(&self, other: &Self) -> bool {
52        self.0 == other.0
53    }
54}
55
56impl<const B: usize, const L: usize> Add for &CInt<B, L> {
57    type Output = CInt<B, L>;
58    fn add(self, other: Self) -> Self::Output {
59        let (mut data, mut carry) = ([0; L], 0);
60        for (i, d) in data.iter_mut().enumerate().take(L) {
61            let sum = self.0[i] + other.0[i] + carry;
62            *d = sum & CInt::<B, L>::MASK;
63            carry = sum >> B;
64        }
65        CInt::<B, L>(data)
66    }
67}
68
69impl<const B: usize, const L: usize> Add<&CInt<B, L>> for CInt<B, L> {
70    type Output = CInt<B, L>;
71    fn add(self, other: &Self) -> Self::Output {
72        &self + other
73    }
74}
75
76impl<const B: usize, const L: usize> Add for CInt<B, L> {
77    type Output = CInt<B, L>;
78    fn add(self, other: Self) -> Self::Output {
79        &self + &other
80    }
81}
82
83impl<const B: usize, const L: usize> Sub for &CInt<B, L> {
84    type Output = CInt<B, L>;
85    fn sub(self, other: Self) -> Self::Output {
86        // For the two's complement code the additive negation is the result of
87        // adding 1 to the bitwise inverted argument's representation. Thus, for
88        // any encoded integers x and y we have x - y = x + !y + 1, where "!" is
89        // the bitwise inversion and addition is done according to the rules of
90        // the code. The algorithm below uses this formula and is the modified
91        // addition algorithm, where the carry flag is initialized with 1 and
92        // the chunks of the second argument are bitwise inverted
93        let (mut data, mut carry) = ([0; L], 1);
94        for (i, d) in data.iter_mut().enumerate().take(L) {
95            let sum = self.0[i] + (other.0[i] ^ CInt::<B, L>::MASK) + carry;
96            *d = sum & CInt::<B, L>::MASK;
97            carry = sum >> B;
98        }
99        CInt::<B, L>(data)
100    }
101}
102
103impl<const B: usize, const L: usize> Sub<&CInt<B, L>> for CInt<B, L> {
104    type Output = CInt<B, L>;
105    fn sub(self, other: &Self) -> Self::Output {
106        &self - other
107    }
108}
109
110impl<const B: usize, const L: usize> Sub for CInt<B, L> {
111    type Output = CInt<B, L>;
112    fn sub(self, other: Self) -> Self::Output {
113        &self - &other
114    }
115}
116
117impl<const B: usize, const L: usize> Neg for &CInt<B, L> {
118    type Output = CInt<B, L>;
119    fn neg(self) -> Self::Output {
120        // For the two's complement code the additive negation is the result
121        // of adding 1 to the bitwise inverted argument's representation
122        let (mut data, mut carry) = ([0; L], 1);
123        for (i, d) in data.iter_mut().enumerate().take(L) {
124            let sum = (self.0[i] ^ CInt::<B, L>::MASK) + carry;
125            *d = sum & CInt::<B, L>::MASK;
126            carry = sum >> B;
127        }
128        CInt::<B, L>(data)
129    }
130}
131
132impl<const B: usize, const L: usize> Neg for CInt<B, L> {
133    type Output = CInt<B, L>;
134    fn neg(self) -> Self::Output {
135        -&self
136    }
137}
138
139impl<const B: usize, const L: usize> Mul for &CInt<B, L> {
140    type Output = CInt<B, L>;
141    fn mul(self, other: Self) -> Self::Output {
142        let mut data = [0; L];
143        for i in 0..L {
144            let mut carry = 0;
145            for k in 0..(L - i) {
146                let sum = (data[i + k] as u128)
147                    + (carry as u128)
148                    + (self.0[i] as u128) * (other.0[k] as u128);
149                data[i + k] = sum as u64 & CInt::<B, L>::MASK;
150                carry = (sum >> B) as u64;
151            }
152        }
153        CInt::<B, L>(data)
154    }
155}
156
157impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for CInt<B, L> {
158    type Output = CInt<B, L>;
159    fn mul(self, other: &Self) -> Self::Output {
160        &self * other
161    }
162}
163
164impl<const B: usize, const L: usize> Mul for CInt<B, L> {
165    type Output = CInt<B, L>;
166    fn mul(self, other: Self) -> Self::Output {
167        &self * &other
168    }
169}
170
171impl<const B: usize, const L: usize> Mul<i64> for &CInt<B, L> {
172    type Output = CInt<B, L>;
173    fn mul(self, other: i64) -> Self::Output {
174        let mut data = [0; L];
175        // If the short multiplicand is non-negative, the standard multiplication
176        // algorithm is performed. Otherwise, the product of the additively negated
177        // multiplicands is found as follows. Since for the two's complement code
178        // the additive negation is the result of adding 1 to the bitwise inverted
179        // argument's representation, for any encoded integers x and y we have
180        // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y),  where "!" is
181        // the bitwise inversion and arithmetic operations are performed according
182        // to the rules of the code. If the short multiplicand is negative, the
183        // algorithm below uses this formula by substituting the short multiplicand
184        // for y and turns into the modified standard multiplication algorithm,
185        // where the carry flag is initialized with the additively negated short
186        // multiplicand and the chunks of the long multiplicand are bitwise inverted
187        let (other, mut carry, mask) = if other < 0 {
188            (-other, -other as u64, CInt::<B, L>::MASK)
189        } else {
190            (other, 0, 0)
191        };
192        for (i, d) in data.iter_mut().enumerate().take(L) {
193            let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128);
194            *d = sum as u64 & CInt::<B, L>::MASK;
195            carry = (sum >> B) as u64;
196        }
197        CInt::<B, L>(data)
198    }
199}
200
201impl<const B: usize, const L: usize> Mul<i64> for CInt<B, L> {
202    type Output = CInt<B, L>;
203    fn mul(self, other: i64) -> Self::Output {
204        &self * other
205    }
206}
207
208impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for i64 {
209    type Output = CInt<B, L>;
210    fn mul(self, other: &CInt<B, L>) -> Self::Output {
211        other * self
212    }
213}
214
215impl<const B: usize, const L: usize> Mul<CInt<B, L>> for i64 {
216    type Output = CInt<B, L>;
217    fn mul(self, other: CInt<B, L>) -> Self::Output {
218        other * self
219    }
220}
221
222/// Type of the modular multiplicative inverter based on the Bernstein-Yang method.
223/// The inverter can be created for a specified modulus M and adjusting parameter A
224/// to compute the adjusted multiplicative inverses of positive integers, i.e. for
225/// computing (1 / x) * A (mod M) for a positive integer x.
226///
227/// The adjusting parameter allows computing the multiplicative inverses in the case of
228/// using the Montgomery representation for the input or the expected output. If R is
229/// the Montgomery factor, the multiplicative inverses in the appropriate representation
230/// can be computed provided that the value of A is chosen as follows:
231/// - A = 1, if both the input and the expected output are in the standard form
232/// - A = R^2 mod M, if both the input and the expected output are in the Montgomery form
233/// - A = R mod M, if either the input or the expected output is in the Montgomery form,
234/// but not both of them
235///
236/// The public methods of this type receive and return unsigned big integers as arrays of
237/// 64-bit chunks, the ordering of which is little-endian. Both the modulus and the integer
238/// to be inverted should not exceed 2 ^ (62 * L - 64)
239///
240/// For better understanding the implementation, the following resources are recommended:
241/// - D. Bernstein, B.-Y. Yang, "Fast constant-time gcd computation and modular inversion",
242/// <https://gcd.cr.yp.to/safegcd-20190413.pdf>
243/// - P. Wuille, "The safegcd implementation in libsecp256k1 explained",
244/// <https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md>
245pub struct BYInverter<const L: usize> {
246    /// Modulus
247    modulus: CInt<62, L>,
248
249    /// Adjusting parameter
250    adjuster: CInt<62, L>,
251
252    /// Multiplicative inverse of the modulus modulo 2^62
253    inverse: i64,
254}
255
256/// Type of the Bernstein-Yang transition matrix multiplied by 2^62
257type Matrix = [[i64; 2]; 2];
258
259impl<const L: usize> BYInverter<L> {
260    /// Returns the Bernstein-Yang transition matrix multiplied by 2^62 and the new value
261    /// of the delta variable for the 62 basic steps of the Bernstein-Yang method, which
262    /// are to be performed sequentially for specified initial values of f, g and delta
263    fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) {
264        let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128);
265        let mut t: Matrix = [[1, 0], [0, 1]];
266
267        loop {
268            let zeros = steps.min(g.trailing_zeros() as i64);
269            (steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros);
270            t[0] = [t[0][0] << zeros, t[0][1] << zeros];
271
272            if steps == 0 {
273                break;
274            }
275            if delta > 0 {
276                (delta, f, g) = (-delta, g as i64, -f as i128);
277                (t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]);
278            }
279
280            // The formula (3 * x) xor 28 = -1 / x (mod 32) for an odd integer x
281            // in the two's complement code has been derived from the formula
282            // (3 * x) xor 2 = 1 / x (mod 32) attributed to Peter Montgomery
283            let mask = (1 << steps.min(1 - delta).min(5)) - 1;
284            let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask;
285
286            t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]];
287            g += w as i128 * f as i128;
288        }
289
290        (delta, t)
291    }
292
293    /// Returns the updated values of the variables f and g for specified initial ones and Bernstein-Yang transition
294    /// matrix multiplied by 2^62. The returned vector is "matrix * (f, g)' / 2^62", where "'" is the transpose operator
295    fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
296        (
297            (t[0][0] * &f + t[0][1] * &g).shift(),
298            (t[1][0] * &f + t[1][1] * &g).shift(),
299        )
300    }
301
302    /// Returns the updated values of the variables d and e for specified initial ones and Bernstein-Yang transition
303    /// matrix multiplied by 2^62. The returned vector is congruent modulo M to "matrix * (d, e)' / 2^62 (mod M)",
304    /// where M is the modulus the inverter was created for and "'" stands for the transpose operator. Both the input
305    /// and output values lie in the interval (-2 * M, M)
306    fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
307        let mask = CInt::<62, L>::MASK as i64;
308        let mut md = t[0][0] * d.is_negative() as i64 + t[0][1] * e.is_negative() as i64;
309        let mut me = t[1][0] * d.is_negative() as i64 + t[1][1] * e.is_negative() as i64;
310
311        let cd = t[0][0]
312            .wrapping_mul(d.lowest() as i64)
313            .wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64))
314            & mask;
315        let ce = t[1][0]
316            .wrapping_mul(d.lowest() as i64)
317            .wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64))
318            & mask;
319
320        md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask;
321        me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask;
322
323        let cd = t[0][0] * &d + t[0][1] * &e + md * &self.modulus;
324        let ce = t[1][0] * &d + t[1][1] * &e + me * &self.modulus;
325
326        (cd.shift(), ce.shift())
327    }
328
329    /// Returns either "value (mod M)" or "-value (mod M)", where M is the modulus the
330    /// inverter was created for, depending on "negate", which determines the presence
331    /// of "-" in the used formula. The input integer lies in the interval (-2 * M, M)
332    fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> {
333        if value.is_negative() {
334            value = value + &self.modulus;
335        }
336
337        if negate {
338            value = -value;
339        }
340
341        if value.is_negative() {
342            value = value + &self.modulus;
343        }
344
345        value
346    }
347
348    /// Returns a big unsigned integer as an array of O-bit chunks, which is equal modulo
349    /// 2 ^ (O * S) to the input big unsigned integer stored as an array of I-bit chunks.
350    /// The ordering of the chunks in these arrays is little-endian
351    const fn convert<const I: usize, const O: usize, const S: usize>(input: &[u64]) -> [u64; S] {
352        // This function is defined because the method "min" of the usize type is not constant
353        const fn min(a: usize, b: usize) -> usize {
354            if a > b {
355                b
356            } else {
357                a
358            }
359        }
360
361        let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0);
362
363        while bits < total {
364            let (i, o) = (bits % I, bits % O);
365            output[bits / O] |= (input[bits / I] >> i) << o;
366            bits += min(I - i, O - o);
367        }
368
369        let mask = u64::MAX >> (64 - O);
370        let mut filled = total / O + if total % O > 0 { 1 } else { 0 };
371
372        while filled > 0 {
373            filled -= 1;
374            output[filled] &= mask;
375        }
376
377        output
378    }
379
380    /// Returns the multiplicative inverse of the argument modulo 2^62. The implementation is based
381    /// on the Hurchalla's method for computing the multiplicative inverse modulo a power of two.
382    /// For better understanding the implementation, the following paper is recommended:
383    /// J. Hurchalla, "An Improved Integer Multiplicative Inverse (modulo 2^w)",
384    /// https://arxiv.org/pdf/2204.04342.pdf
385    const fn inv(value: u64) -> i64 {
386        let x = value.wrapping_mul(3) ^ 2;
387        let y = 1u64.wrapping_sub(x.wrapping_mul(value));
388        let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
389        let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
390        let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
391        (x.wrapping_mul(y.wrapping_add(1)) & CInt::<62, L>::MASK) as i64
392    }
393
394    /// Creates the inverter for specified modulus and adjusting parameter
395    pub const fn new(modulus: &[u64], adjuster: &[u64]) -> Self {
396        Self {
397            modulus: CInt::<62, L>(Self::convert::<64, 62, L>(modulus)),
398            adjuster: CInt::<62, L>(Self::convert::<64, 62, L>(adjuster)),
399            inverse: Self::inv(modulus[0]),
400        }
401    }
402
403    /// Returns either the adjusted modular multiplicative inverse for the argument or None
404    /// depending on invertibility of the argument, i.e. its coprimality with the modulus
405    pub fn invert<const S: usize>(&self, value: &[u64]) -> Option<[u64; S]> {
406        let (mut d, mut e) = (CInt::ZERO, self.adjuster.clone());
407        let mut g = CInt::<62, L>(Self::convert::<64, 62, L>(value));
408        let (mut delta, mut f) = (1, self.modulus.clone());
409        let mut matrix;
410        while g != CInt::ZERO {
411            (delta, matrix) = Self::jump(&f, &g, delta);
412            (f, g) = Self::fg(f, g, matrix);
413            (d, e) = self.de(d, e, matrix);
414        }
415        // At this point the absolute value of "f" equals the greatest common divisor
416        // of the integer to be inverted and the modulus the inverter was created for.
417        // Thus, if "f" is neither 1 nor -1, then the sought inverse does not exist
418        let antiunit = f == CInt::MINUS_ONE;
419        if (f != CInt::ONE) && !antiunit {
420            return None;
421        }
422        Some(Self::convert::<62, 64, S>(&self.norm(d, antiunit).0))
423    }
424}