halo2curves/ff_ext/
inverse.rs

1use core::cmp::PartialEq;
2use std::ops::{Add, Mul, Neg, Sub};
3
4/// Big signed (B * L)-bit integer type, whose variables store
5/// numbers in the two's complement code as arrays of B-bit chunks.
6/// The ordering of the chunks in these arrays is little-endian.
7/// The arithmetic operations for this type are wrapping ones.
8#[derive(Clone)]
9struct CInt<const B: usize, const L: usize>(pub [u64; L]);
10
11impl<const B: usize, const L: usize> CInt<B, L> {
12    /// Mask, in which the B lowest bits are 1 and only they
13    pub const MASK: u64 = u64::MAX >> (64 - B);
14
15    /// Representation of -1
16    pub const MINUS_ONE: Self = Self([Self::MASK; L]);
17
18    /// Representation of 0
19    pub const ZERO: Self = Self([0; L]);
20
21    /// Representation of 1
22    pub const ONE: Self = {
23        let mut data = [0; L];
24        data[0] = 1;
25        Self(data)
26    };
27
28    /// Returns the result of applying B-bit right
29    /// arithmetical shift to the current number
30    pub fn shift(&self) -> Self {
31        let mut data = [0; L];
32        if self.is_negative() {
33            data[L - 1] = Self::MASK;
34        }
35        data[..L - 1].copy_from_slice(&self.0[1..]);
36        Self(data)
37    }
38
39    /// Returns the lowest B bits of the current number
40    pub fn lowest(&self) -> u64 {
41        self.0[0]
42    }
43
44    /// Returns "true" iff the current number is negative
45    pub fn is_negative(&self) -> bool {
46        self.0[L - 1] > (Self::MASK >> 1)
47    }
48}
49
50impl<const B: usize, const L: usize> PartialEq for CInt<B, L> {
51    fn eq(&self, other: &Self) -> bool {
52        self.0 == other.0
53    }
54}
55
56impl<const B: usize, const L: usize> Add for &CInt<B, L> {
57    type Output = CInt<B, L>;
58    fn add(self, other: Self) -> Self::Output {
59        let (mut data, mut carry) = ([0; L], 0);
60        for (i, d) in data.iter_mut().enumerate().take(L) {
61            let sum = self.0[i] + other.0[i] + carry;
62            *d = sum & CInt::<B, L>::MASK;
63            carry = sum >> B;
64        }
65        CInt::<B, L>(data)
66    }
67}
68
69impl<const B: usize, const L: usize> Add<&CInt<B, L>> for CInt<B, L> {
70    type Output = CInt<B, L>;
71    fn add(self, other: &Self) -> Self::Output {
72        &self + other
73    }
74}
75
76impl<const B: usize, const L: usize> Add for CInt<B, L> {
77    type Output = CInt<B, L>;
78    fn add(self, other: Self) -> Self::Output {
79        &self + &other
80    }
81}
82
83impl<const B: usize, const L: usize> Sub for &CInt<B, L> {
84    type Output = CInt<B, L>;
85    fn sub(self, other: Self) -> Self::Output {
86        // For the two's complement code the additive negation is the result of
87        // adding 1 to the bitwise inverted argument's representation. Thus, for
88        // any encoded integers x and y we have x - y = x + !y + 1, where "!" is
89        // the bitwise inversion and addition is done according to the rules of
90        // the code. The algorithm below uses this formula and is the modified
91        // addition algorithm, where the carry flag is initialized with 1 and
92        // the chunks of the second argument are bitwise inverted
93        let (mut data, mut carry) = ([0; L], 1);
94        for (i, d) in data.iter_mut().enumerate().take(L) {
95            let sum = self.0[i] + (other.0[i] ^ CInt::<B, L>::MASK) + carry;
96            *d = sum & CInt::<B, L>::MASK;
97            carry = sum >> B;
98        }
99        CInt::<B, L>(data)
100    }
101}
102
103impl<const B: usize, const L: usize> Sub<&CInt<B, L>> for CInt<B, L> {
104    type Output = CInt<B, L>;
105    fn sub(self, other: &Self) -> Self::Output {
106        &self - other
107    }
108}
109
110impl<const B: usize, const L: usize> Sub for CInt<B, L> {
111    type Output = CInt<B, L>;
112    fn sub(self, other: Self) -> Self::Output {
113        &self - &other
114    }
115}
116
117impl<const B: usize, const L: usize> Neg for &CInt<B, L> {
118    type Output = CInt<B, L>;
119    fn neg(self) -> Self::Output {
120        // For the two's complement code the additive negation is the result
121        // of adding 1 to the bitwise inverted argument's representation
122        let (mut data, mut carry) = ([0; L], 1);
123        for (i, d) in data.iter_mut().enumerate().take(L) {
124            let sum = (self.0[i] ^ CInt::<B, L>::MASK) + carry;
125            *d = sum & CInt::<B, L>::MASK;
126            carry = sum >> B;
127        }
128        CInt::<B, L>(data)
129    }
130}
131
132impl<const B: usize, const L: usize> Neg for CInt<B, L> {
133    type Output = CInt<B, L>;
134    fn neg(self) -> Self::Output {
135        -&self
136    }
137}
138
139impl<const B: usize, const L: usize> Mul for &CInt<B, L> {
140    type Output = CInt<B, L>;
141    fn mul(self, other: Self) -> Self::Output {
142        let mut data = [0; L];
143        for i in 0..L {
144            let mut carry = 0;
145            for k in 0..(L - i) {
146                let sum = (data[i + k] as u128)
147                    + (carry as u128)
148                    + (self.0[i] as u128) * (other.0[k] as u128);
149                data[i + k] = sum as u64 & CInt::<B, L>::MASK;
150                carry = (sum >> B) as u64;
151            }
152        }
153        CInt::<B, L>(data)
154    }
155}
156
157impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for CInt<B, L> {
158    type Output = CInt<B, L>;
159    fn mul(self, other: &Self) -> Self::Output {
160        &self * other
161    }
162}
163
164impl<const B: usize, const L: usize> Mul for CInt<B, L> {
165    type Output = CInt<B, L>;
166    fn mul(self, other: Self) -> Self::Output {
167        &self * &other
168    }
169}
170
171impl<const B: usize, const L: usize> Mul<i64> for &CInt<B, L> {
172    type Output = CInt<B, L>;
173    fn mul(self, other: i64) -> Self::Output {
174        let mut data = [0; L];
175        // If the short multiplicand is non-negative, the standard multiplication
176        // algorithm is performed. Otherwise, the product of the additively negated
177        // multiplicands is found as follows. Since for the two's complement code
178        // the additive negation is the result of adding 1 to the bitwise inverted
179        // argument's representation, for any encoded integers x and y we have
180        // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y),  where "!" is
181        // the bitwise inversion and arithmetic operations are performed according
182        // to the rules of the code. If the short multiplicand is negative, the
183        // algorithm below uses this formula by substituting the short multiplicand
184        // for y and turns into the modified standard multiplication algorithm,
185        // where the carry flag is initialized with the additively negated short
186        // multiplicand and the chunks of the long multiplicand are bitwise inverted
187        let (other, mut carry, mask) = if other < 0 {
188            (-other, -other as u64, CInt::<B, L>::MASK)
189        } else {
190            (other, 0, 0)
191        };
192        for (i, d) in data.iter_mut().enumerate().take(L) {
193            let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128);
194            *d = sum as u64 & CInt::<B, L>::MASK;
195            carry = (sum >> B) as u64;
196        }
197        CInt::<B, L>(data)
198    }
199}
200
201impl<const B: usize, const L: usize> Mul<i64> for CInt<B, L> {
202    type Output = CInt<B, L>;
203    fn mul(self, other: i64) -> Self::Output {
204        &self * other
205    }
206}
207
208impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for i64 {
209    type Output = CInt<B, L>;
210    fn mul(self, other: &CInt<B, L>) -> Self::Output {
211        other * self
212    }
213}
214
215impl<const B: usize, const L: usize> Mul<CInt<B, L>> for i64 {
216    type Output = CInt<B, L>;
217    fn mul(self, other: CInt<B, L>) -> Self::Output {
218        other * self
219    }
220}
221
222/// Type of the modular multiplicative inverter based on the Bernstein-Yang
223/// method. The inverter can be created for a specified modulus M and adjusting
224/// parameter A to compute the adjusted multiplicative inverses of positive
225/// integers, i.e. for computing (1 / x) * A (mod M) for a positive integer x.
226///
227/// The adjusting parameter allows computing the multiplicative inverses in the
228/// case of using the Montgomery representation for the input or the expected
229/// output. If R is the Montgomery factor, the multiplicative inverses in the
230/// appropriate representation can be computed provided that the value of A is
231/// chosen as follows:
232/// - A = 1, if both the input and the expected output are in the standard form
233/// - A = R^2 mod M, if both the input and the expected output are in the
234///   Montgomery form
235/// - A = R mod M, if either the input or the expected output is in the
236///   Montgomery form,
237/// but not both of them
238///
239/// The public methods of this type receive and return unsigned big integers as
240/// arrays of 64-bit chunks, the ordering of which is little-endian. Both the
241/// modulus and the integer to be inverted should not exceed 2 ^ (62 * L - 64)
242///
243/// For better understanding the implementation, the following resources are
244/// recommended:
245/// - D. Bernstein, B.-Y. Yang, "Fast constant-time gcd computation and modular
246///   inversion",
247/// <https://gcd.cr.yp.to/safegcd-20190413.pdf>
248/// - P. Wuille, "The safegcd implementation in libsecp256k1 explained",
249/// <https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md>
250pub struct BYInverter<const L: usize> {
251    /// Modulus
252    modulus: CInt<62, L>,
253
254    /// Adjusting parameter
255    adjuster: CInt<62, L>,
256
257    /// Multiplicative inverse of the modulus modulo 2^62
258    inverse: i64,
259}
260
261/// Type of the Bernstein-Yang transition matrix multiplied by 2^62
262type Matrix = [[i64; 2]; 2];
263
264impl<const L: usize> BYInverter<L> {
265    /// Returns the Bernstein-Yang transition matrix multiplied by 2^62 and the
266    /// new value of the delta variable for the 62 basic steps of the
267    /// Bernstein-Yang method, which are to be performed sequentially for
268    /// specified initial values of f, g and delta
269    fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) {
270        let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128);
271        let mut t: Matrix = [[1, 0], [0, 1]];
272
273        loop {
274            let zeros = steps.min(g.trailing_zeros() as i64);
275            (steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros);
276            t[0] = [t[0][0] << zeros, t[0][1] << zeros];
277
278            if steps == 0 {
279                break;
280            }
281            if delta > 0 {
282                (delta, f, g) = (-delta, g as i64, -f as i128);
283                (t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]);
284            }
285
286            // The formula (3 * x) xor 28 = -1 / x (mod 32) for an odd integer x
287            // in the two's complement code has been derived from the formula
288            // (3 * x) xor 2 = 1 / x (mod 32) attributed to Peter Montgomery
289            let mask = (1 << steps.min(1 - delta).min(5)) - 1;
290            let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask;
291
292            t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]];
293            g += w as i128 * f as i128;
294        }
295
296        (delta, t)
297    }
298
299    /// Returns the updated values of the variables f and g for specified
300    /// initial ones and Bernstein-Yang transition matrix multiplied by
301    /// 2^62. The returned vector is "matrix * (f, g)' / 2^62", where "'" is the
302    /// transpose operator
303    fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
304        (
305            (t[0][0] * &f + t[0][1] * &g).shift(),
306            (t[1][0] * &f + t[1][1] * &g).shift(),
307        )
308    }
309
310    /// Returns the updated values of the variables d and e for specified
311    /// initial ones and Bernstein-Yang transition matrix multiplied by
312    /// 2^62. The returned vector is congruent modulo M to "matrix * (d, e)' /
313    /// 2^62 (mod M)", where M is the modulus the inverter was created for
314    /// and "'" stands for the transpose operator. Both the input and output
315    /// values lie in the interval (-2 * M, M)
316    fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
317        let mask = CInt::<62, L>::MASK as i64;
318        let mut md = t[0][0] * d.is_negative() as i64 + t[0][1] * e.is_negative() as i64;
319        let mut me = t[1][0] * d.is_negative() as i64 + t[1][1] * e.is_negative() as i64;
320
321        let cd = t[0][0]
322            .wrapping_mul(d.lowest() as i64)
323            .wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64))
324            & mask;
325        let ce = t[1][0]
326            .wrapping_mul(d.lowest() as i64)
327            .wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64))
328            & mask;
329
330        md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask;
331        me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask;
332
333        let cd = t[0][0] * &d + t[0][1] * &e + md * &self.modulus;
334        let ce = t[1][0] * &d + t[1][1] * &e + me * &self.modulus;
335
336        (cd.shift(), ce.shift())
337    }
338
339    /// Returns either "value (mod M)" or "-value (mod M)", where M is the
340    /// modulus the inverter was created for, depending on "negate", which
341    /// determines the presence of "-" in the used formula. The input
342    /// integer lies in the interval (-2 * M, M)
343    fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> {
344        if value.is_negative() {
345            value = value + &self.modulus;
346        }
347
348        if negate {
349            value = -value;
350        }
351
352        if value.is_negative() {
353            value = value + &self.modulus;
354        }
355
356        value
357    }
358
359    /// Returns a big unsigned integer as an array of O-bit chunks, which is
360    /// equal modulo 2 ^ (O * S) to the input big unsigned integer stored as
361    /// an array of I-bit chunks. The ordering of the chunks in these arrays
362    /// is little-endian
363    const fn convert<const I: usize, const O: usize, const S: usize>(input: &[u64]) -> [u64; S] {
364        // This function is defined because the method "min" of the usize type is not
365        // constant
366        const fn min(a: usize, b: usize) -> usize {
367            if a > b {
368                b
369            } else {
370                a
371            }
372        }
373
374        let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0);
375
376        while bits < total {
377            let (i, o) = (bits % I, bits % O);
378            output[bits / O] |= (input[bits / I] >> i) << o;
379            bits += min(I - i, O - o);
380        }
381
382        let mask = u64::MAX >> (64 - O);
383        let mut filled = total / O + if total % O > 0 { 1 } else { 0 };
384
385        while filled > 0 {
386            filled -= 1;
387            output[filled] &= mask;
388        }
389
390        output
391    }
392
393    /// Returns the multiplicative inverse of the argument modulo 2^62. The
394    /// implementation is based on the Hurchalla's method for computing the
395    /// multiplicative inverse modulo a power of two. For better
396    /// understanding the implementation, the following paper is recommended:
397    /// J. Hurchalla, "An Improved Integer Multiplicative Inverse (modulo 2^w)",
398    /// <https://arxiv.org/pdf/2204.04342.pdf>
399    const fn inv(value: u64) -> i64 {
400        let x = value.wrapping_mul(3) ^ 2;
401        let y = 1u64.wrapping_sub(x.wrapping_mul(value));
402        let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
403        let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
404        let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
405        (x.wrapping_mul(y.wrapping_add(1)) & CInt::<62, L>::MASK) as i64
406    }
407
408    /// Creates the inverter for specified modulus and adjusting parameter
409    pub const fn new(modulus: &[u64], adjuster: &[u64]) -> Self {
410        Self {
411            modulus: CInt::<62, L>(Self::convert::<64, 62, L>(modulus)),
412            adjuster: CInt::<62, L>(Self::convert::<64, 62, L>(adjuster)),
413            inverse: Self::inv(modulus[0]),
414        }
415    }
416
417    /// Returns either the adjusted modular multiplicative inverse for the
418    /// argument or None depending on invertibility of the argument, i.e.
419    /// its coprimality with the modulus
420    pub fn invert<const S: usize>(&self, value: &[u64]) -> Option<[u64; S]> {
421        let (mut d, mut e) = (CInt::ZERO, self.adjuster.clone());
422        let mut g = CInt::<62, L>(Self::convert::<64, 62, L>(value));
423        let (mut delta, mut f) = (1, self.modulus.clone());
424        let mut matrix;
425        while g != CInt::ZERO {
426            (delta, matrix) = Self::jump(&f, &g, delta);
427            (f, g) = Self::fg(f, g, matrix);
428            (d, e) = self.de(d, e, matrix);
429        }
430        // At this point the absolute value of "f" equals the greatest common divisor
431        // of the integer to be inverted and the modulus the inverter was created for.
432        // Thus, if "f" is neither 1 nor -1, then the sought inverse does not exist
433        let antiunit = f == CInt::MINUS_ONE;
434        if (f != CInt::ONE) && !antiunit {
435            return None;
436        }
437        Some(Self::convert::<62, 64, S>(&self.norm(d, antiunit).0))
438    }
439}