halo2curves/ff_ext/
jacobi.rs

1use core::cmp::PartialEq;
2use std::ops::{Add, Mul, Neg, Shr, Sub};
3
4/// Big signed (64 * L)-bit integer type, whose variables store
5/// numbers in the two's complement code as arrays of 64-bit chunks.
6/// The ordering of the chunks in these arrays is little-endian.
7/// The arithmetic operations for this type are wrapping ones
8#[derive(Clone)]
9pub struct LInt<const L: usize>([u64; L]);
10
11impl<const L: usize> LInt<L> {
12    /// Representation of -1
13    pub const MINUS_ONE: Self = Self([u64::MAX; L]);
14
15    /// Representation of 0
16    pub const ZERO: Self = Self([0; L]);
17
18    /// Representation of 1
19    pub const ONE: Self = {
20        let mut data = [0; L];
21        data[0] = 1;
22        Self(data)
23    };
24
25    /// Returns the number, which is stored as the specified
26    /// sequence padded with zeros to length L. If the input
27    /// sequence is longer than L, the method panics
28    pub fn new(data: &[u64]) -> Self {
29        let mut number = Self::ZERO;
30        number.0[..data.len()].copy_from_slice(data);
31        number
32    }
33
34    /// Returns "true" iff the current number is negative
35    #[inline]
36    pub fn is_negative(&self) -> bool {
37        self.0[L - 1] > (u64::MAX >> 1)
38    }
39
40    /// Returns a tuple representing the sum of the first two arguments and the
41    /// bit described by the third argument. The first element of the tuple
42    /// is this sum modulo 2^64, the second one indicates whether the sum is
43    /// no less than 2^64
44    #[inline]
45    fn sum(first: u64, second: u64, carry: bool) -> (u64, bool) {
46        // The implementation is inspired with the "carrying_add" function from this
47        // source: <https://github.com/rust-lang/rust/blob/master/library/core/src/num/uint_macros.rs>
48        let (second, carry) = second.overflowing_add(carry as u64);
49        let (first, high) = first.overflowing_add(second);
50        (first, carry || high)
51    }
52
53    /// Returns "(low, high)", where "high * 2^64 + low = first * second + carry
54    /// + summand"
55    #[inline]
56    fn prodsum(first: u64, second: u64, summand: u64, carry: u64) -> (u64, u64) {
57        let all = (first as u128) * (second as u128) + (carry as u128) + (summand as u128);
58        (all as u64, (all >> u64::BITS) as u64)
59    }
60}
61
62impl<const L: usize> PartialEq for LInt<L> {
63    fn eq(&self, other: &Self) -> bool {
64        self.0 == other.0
65    }
66}
67
68impl<const L: usize> Shr<u32> for &LInt<L> {
69    type Output = LInt<L>;
70    /// Returns the result of applying the arithmetic right shift to the current
71    /// number. The specified bit quantity the number is shifted by must lie
72    /// in {1, 2, ..., 63}. For the quantities outside of the range, the
73    /// behavior of the method is undefined
74    fn shr(self, bits: u32) -> Self::Output {
75        debug_assert!(
76            (bits > 0) && (bits < 64),
77            "Cannot shift by 0 or more than 63 bits!"
78        );
79        let (mut data, right) = ([0; L], u64::BITS - bits);
80
81        for (i, d) in data.iter_mut().enumerate().take(L - 1) {
82            *d = (self.0[i] >> bits) | (self.0[i + 1] << right);
83        }
84        data[L - 1] = self.0[L - 1] >> bits;
85        if self.is_negative() {
86            data[L - 1] |= u64::MAX << right;
87        }
88        LInt::<L>(data)
89    }
90}
91
92impl<const L: usize> Shr<u32> for LInt<L> {
93    type Output = LInt<L>;
94    fn shr(self, bits: u32) -> Self::Output {
95        &self >> bits
96    }
97}
98
99impl<const L: usize> Add for &LInt<L> {
100    type Output = LInt<L>;
101    fn add(self, other: Self) -> Self::Output {
102        let (mut data, mut carry) = ([0; L], false);
103        for (i, d) in data.iter_mut().enumerate().take(L) {
104            (*d, carry) = Self::Output::sum(self.0[i], other.0[i], carry);
105        }
106        LInt::<L>(data)
107    }
108}
109
110impl<const L: usize> Add<&LInt<L>> for LInt<L> {
111    type Output = LInt<L>;
112    fn add(self, other: &Self) -> Self::Output {
113        &self + other
114    }
115}
116
117impl<const L: usize> Add for LInt<L> {
118    type Output = LInt<L>;
119    fn add(self, other: Self) -> Self::Output {
120        &self + &other
121    }
122}
123
124impl<const L: usize> Sub for &LInt<L> {
125    type Output = LInt<L>;
126    fn sub(self, other: Self) -> Self::Output {
127        // For the two's complement code the additive negation is the result of
128        // adding 1 to the bitwise inverted argument's representation. Thus, for
129        // any encoded integers x and y we have x - y = x + !y + 1, where "!" is
130        // the bitwise inversion and addition is done according to the rules of
131        // the code. The algorithm below uses this formula and is the modified
132        // addition algorithm, where the carry flag is initialized with "true"
133        // and the chunks of the second argument are bitwise inverted
134        let (mut data, mut carry) = ([0; L], true);
135        for (i, d) in data.iter_mut().enumerate().take(L) {
136            (*d, carry) = Self::Output::sum(self.0[i], !other.0[i], carry);
137        }
138        LInt::<L>(data)
139    }
140}
141
142impl<const L: usize> Sub<&LInt<L>> for LInt<L> {
143    type Output = LInt<L>;
144    fn sub(self, other: &Self) -> Self::Output {
145        &self - other
146    }
147}
148
149impl<const L: usize> Sub for LInt<L> {
150    type Output = LInt<L>;
151    fn sub(self, other: Self) -> Self::Output {
152        &self - &other
153    }
154}
155
156impl<const L: usize> Neg for &LInt<L> {
157    type Output = LInt<L>;
158    fn neg(self) -> Self::Output {
159        // For the two's complement code the additive negation is the result
160        // of adding 1 to the bitwise inverted argument's representation
161        let (mut data, mut carry) = ([0; L], true);
162        for (i, d) in data.iter_mut().enumerate().take(L) {
163            (*d, carry) = (!self.0[i]).overflowing_add(carry as u64);
164        }
165        LInt::<L>(data)
166    }
167}
168
169impl<const L: usize> Neg for LInt<L> {
170    type Output = LInt<L>;
171    fn neg(self) -> Self::Output {
172        -&self
173    }
174}
175
176impl<const L: usize> Mul for &LInt<L> {
177    type Output = LInt<L>;
178    fn mul(self, other: Self) -> Self::Output {
179        let mut data = [0; L];
180        for i in 0..L {
181            let mut carry = 0;
182            for k in 0..(L - i) {
183                (data[i + k], carry) =
184                    Self::Output::prodsum(self.0[i], other.0[k], data[i + k], carry);
185            }
186        }
187        LInt::<L>(data)
188    }
189}
190
191impl<const L: usize> Mul<&LInt<L>> for LInt<L> {
192    type Output = LInt<L>;
193    fn mul(self, other: &Self) -> Self::Output {
194        &self * other
195    }
196}
197
198impl<const L: usize> Mul for LInt<L> {
199    type Output = LInt<L>;
200    fn mul(self, other: Self) -> Self::Output {
201        &self * &other
202    }
203}
204
205impl<const L: usize> Mul<i64> for &LInt<L> {
206    type Output = LInt<L>;
207    fn mul(self, other: i64) -> Self::Output {
208        let mut data = [0; L];
209        // If the short multiplicand is non-negative, the standard multiplication
210        // algorithm is performed. Otherwise, the product of the additively negated
211        // multiplicands is found as follows. Since for the two's complement code
212        // the additive negation is the result of adding 1 to the bitwise inverted
213        // argument's representation, for any encoded integers x and y we have
214        // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y),  where "!" is
215        // the bitwise inversion and arithmetic operations are performed according
216        // to the rules of the code. If the short multiplicand is negative, the
217        // algorithm below uses this formula by substituting the short multiplicand
218        // for y and becomes the modified standard multiplication algorithm, where
219        // the carry variable is being initialized with the additively negated short
220        // multiplicand and the chunks of the long multiplicand are bitwise inverted
221        let (other, mut carry, mask) = if other < 0 {
222            (-other as u64, -other as u64, u64::MAX)
223        } else {
224            (other as u64, 0, 0)
225        };
226        for (i, d) in data.iter_mut().enumerate().take(L) {
227            (*d, carry) = Self::Output::prodsum(self.0[i] ^ mask, other, 0, carry);
228        }
229        LInt::<L>(data)
230    }
231}
232
233impl<const L: usize> Mul<i64> for LInt<L> {
234    type Output = LInt<L>;
235    fn mul(self, other: i64) -> Self::Output {
236        &self * other
237    }
238}
239
240impl<const L: usize> Mul<&LInt<L>> for i64 {
241    type Output = LInt<L>;
242    fn mul(self, other: &LInt<L>) -> Self::Output {
243        other * self
244    }
245}
246
247impl<const L: usize> Mul<LInt<L>> for i64 {
248    type Output = LInt<L>;
249    fn mul(self, other: LInt<L>) -> Self::Output {
250        other * self
251    }
252}
253
254/// Returns the "approximations" of the arguments and the flag indicating
255/// whether both arguments are equal to their "approximations". Both the
256/// arguments must be non-negative, and at least one of them must be non-zero.
257/// For an incorrect input, the behavior of the function is undefined. These
258/// "approximations" are defined in the following way. Let n be the bit length
259/// of the largest argument without leading zeros. For n > 64 the
260/// "approximation" of the argument, which equals v, is (v div 2 ^ (n - 32)) * 2
261/// ^ 32 + (v mod 2 ^ 32), i.e. it retains the high and low bits of the n-bit
262/// representation of v. If n does not exceed 64, an argument
263/// and its "approximation" are equal. These "approximations" are defined
264/// slightly differently from the ones in the Pornin's method for modular
265/// inversion: instead of taking the 33 high and 31 low bits of the n-bit
266/// representation of an argument, the 32 high and 32 low bits are taken
267fn approximate<const L: usize>(x: &LInt<L>, y: &LInt<L>) -> (u64, u64, bool) {
268    debug_assert!(
269        !(x.is_negative() || y.is_negative()),
270        "Both the arguments must be non-negative!"
271    );
272    debug_assert!(
273        (*x != LInt::ZERO) || (*y != LInt::ZERO),
274        "At least one argument must be non-zero!"
275    );
276    let mut i = L - 1;
277    while (x.0[i] == 0) && (y.0[i] == 0) {
278        i -= 1;
279    }
280    if i == 0 {
281        return (x.0[0], y.0[0], true);
282    }
283    let mut h = (x.0[i], y.0[i]);
284    let z = h.0.leading_zeros().min(h.1.leading_zeros());
285    h = (h.0 << z, h.1 << z);
286    if z > 32 {
287        h.0 |= x.0[i - 1] >> z;
288        h.1 |= y.0[i - 1] >> z;
289    }
290    let h = (h.0 & u64::MAX << 32, h.1 & u64::MAX << 32);
291    let l = (x.0[0] & u64::MAX >> 32, y.0[0] & u64::MAX >> 32);
292    (h.0 | l.0, h.1 | l.1, false)
293}
294
295/// Returns the Jacobi symbol ("n" / "d") multiplied by either 1 or -1.
296/// The later multiplicand is -1 iff the second-lowest bit of "t" is 1.
297/// The value of "d" must be odd in accordance with the Jacobi symbol
298/// definition. For even values of "d", the behavior is not defined.
299/// The implementation is based on the binary Euclidean algorithm
300fn jacobinary(mut n: u64, mut d: u64, mut t: u64) -> i64 {
301    debug_assert!(d & 1 > 0, "The second argument must be odd!");
302    while n != 0 {
303        if n & 1 > 0 {
304            if n < d {
305                (n, d) = (d, n);
306                t ^= n & d;
307            }
308            n = (n - d) >> 1;
309            t ^= d ^ d >> 1;
310        } else {
311            let z = n.trailing_zeros();
312            t ^= (d ^ d >> 1) & (z << 1) as u64;
313            n >>= z;
314        }
315    }
316    (d == 1) as i64 * (1 - (t & 2) as i64)
317}
318
319/// Returns the Jacobi symbol ("n" / "d") computed by means of the modification
320/// of the Pornin's method for modular inversion. The arguments are unsigned
321/// big integers in the form of arrays of 64-bit chunks, the ordering of which
322/// is little-endian. The value of "d" must be odd in accordance with the Jacobi
323/// symbol definition. Both the arguments must be less than 2 ^ (64 * L - 31).
324/// For an incorrect input, the behavior of the function is undefined. The
325/// method differs from the Pornin's method for modular inversion in absence of
326/// the parts, which are not necessary to compute the greatest common divisor of
327/// arguments, presence of the parts used to compute the Jacobi symbol, which
328/// are based on the properties of the modified Jacobi symbol (x / |y|)
329/// described by M. Hamburg, and some original optimizations. Only these
330/// differences have been commented; the aforesaid Pornin's method and the used
331/// ideas of M. Hamburg were given here:
332/// - T. Pornin, "Optimized Binary GCD for Modular Inversion",
333/// <https://eprint.iacr.org/2020/972.pdf>
334/// - M. Hamburg, "Computing the Jacobi symbol using Bernstein-Yang",
335/// <https://eprint.iacr.org/2021/1271.pdf>
336pub fn jacobi<const L: usize>(n: &[u64], d: &[u64]) -> i64 {
337    // Instead of the variable "j" taking the values from {-1, 1} and satisfying
338    // at the end of the outer loop iteration the equation J = "j" * ("n" / |"d"|)
339    // for the modified Jacobi symbol ("n" / |"d"|) and the sought Jacobi symbol J,
340    // we store the sign bit of "j" in the second-lowest bit of "t" for optimization
341    // purposes. This approach was influenced by the paper by M. Hamburg
342    let (mut n, mut d, mut t) = (LInt::<L>::new(n), LInt::<L>::new(d), 0u64);
343    debug_assert!(d.0[0] & 1 > 0, "The second argument must be odd!");
344    debug_assert!(
345        n.0[L - 1].leading_zeros().min(d.0[L - 1].leading_zeros()) >= 31,
346        "Both the arguments must be less than 2 ^ (64 * L - 31)!"
347    );
348    loop {
349        // The inner loop performs 30 iterations instead of 31 ones in the
350        // aforementioned Pornin's method, and the "approximations" of "n" and
351        // "d" retain 32 of the lowest bits instead of 31 in that method. These
352        // modifications allow the values of the "approximation" variables to be
353        // equal modulo 8 to the corresponding "precise" variables' values,
354        // which would have been computed, if the "precise" variables
355        // had been updated in the inner loop along with the "approximations". This
356        // equality modulo 8 is used to update the second-lowest bit of "t" in
357        // accordance with the properties of the modified Jacobi symbol (x /
358        // |y|). The admissibility of these modifications has been proven using
359        // the appropriately modified Pornin's theorems
360        let (mut u, mut v, mut i) = ((1i64, 0i64), (0i64, 1i64), 30);
361        let (mut a, mut b, precise) = approximate(&n, &d);
362        // When each "approximation" variable has the same value as the corresponding
363        // "precise" one, the computation is accomplished using the
364        // short-arithmetic method of the Jacobi symbol calculation by means of
365        // the binary Euclidean algorithm. This approach aims at avoiding the
366        // parts of the final computations, which are related to long arithmetic
367        if precise {
368            return jacobinary(a, b, t);
369        }
370        while i > 0 {
371            if a & 1 > 0 {
372                if a < b {
373                    (a, b, u, v) = (b, a, v, u);
374                    // In both the aforesaid Pornin's method and its modification "n" and "d"
375                    // could not become negative simultaneously even if they were updated after
376                    // each iteration of the inner loop. Also at this point they both have odd
377                    // values. Therefore, the quadratic reciprocity law for the modified Jacobi
378                    // symbol (x / |y|) can be used. According to it, if both x and y are odd
379                    // numbers, among which there is a positive one, then for x = y = 3 (mod 4)
380                    // we have (x / |y|) = -(y / |x|) and for either x or y equal 1 modulo 4
381                    // the symbols (x / |y|) and (y / |x|) are equal
382                    t ^= a & b;
383                }
384                a = (a - b) >> 1;
385                u = (u.0 - v.0, u.1 - v.1);
386                v = (v.0 << 1, v.1 << 1);
387                // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5}
388                t ^= b ^ b >> 1;
389                i -= 1;
390            } else {
391                // Performing the batch of sequential iterations, which divide "a" by 2
392                let z = i.min(a.trailing_zeros());
393                // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5}. However,
394                // we do not need its value for a batch with an even number of divisions by 2
395                t ^= (b ^ b >> 1) & (z << 1) as u64;
396                v = (v.0 << z, v.1 << z);
397                a >>= z;
398                i -= z;
399            }
400        }
401        (n, d) = ((&n * u.0 + &d * u.1) >> 30, (&n * v.0 + &d * v.1) >> 30);
402
403        // This fragment is present to guarantee the correct behavior of the function
404        // in the case of arguments, whose greatest common divisor is no less than 2^64
405        if n == LInt::ZERO {
406            // In both the aforesaid Pornin's method and its modification the pair of the
407            // values of "n" and "d" after the divergence point contains a
408            // positive number and a negative one. Since the value of "n" is 0,
409            // the divergence point has not been reached by the inner loop this
410            // time, so there is no need to check whether "d" is equal to -1
411            return (d == LInt::ONE) as i64 * (1 - (t & 2) as i64);
412        }
413
414        if n.is_negative() {
415            // Since in both the aforesaid Pornin's method and its modification "d" is
416            // always odd and cannot become negative simultaneously with "n",
417            // the value of "d" is positive. The modified Jacobi symbol (-1 /
418            // |y|) for a positive y is -1, iff y mod 4 = 3
419            t ^= d.0[0];
420            n = -n;
421        } else if d.is_negative() {
422            // The modified Jacobi symbols (x / |y|) and (x / |-y|) are equal, so "t" is not
423            // updated
424            d = -d;
425        }
426    }
427}