halo2curves_axiom/ff_ext/
jacobi.rs

1use core::cmp::PartialEq;
2use std::ops::{Add, Mul, Neg, Shr, Sub};
3
4/// Big signed (64 * L)-bit integer type, whose variables store
5/// numbers in the two's complement code as arrays of 64-bit chunks.
6/// The ordering of the chunks in these arrays is little-endian.
7/// The arithmetic operations for this type are wrapping ones
8#[derive(Clone)]
9pub struct LInt<const L: usize>([u64; L]);
10
11impl<const L: usize> LInt<L> {
12    /// Representation of -1
13    pub const MINUS_ONE: Self = Self([u64::MAX; L]);
14
15    /// Representation of 0
16    pub const ZERO: Self = Self([0; L]);
17
18    /// Representation of 1
19    pub const ONE: Self = {
20        let mut data = [0; L];
21        data[0] = 1;
22        Self(data)
23    };
24
25    /// Returns the number, which is stored as the specified
26    /// sequence padded with zeros to length L. If the input
27    /// sequence is longer than L, the method panics
28    pub fn new(data: &[u64]) -> Self {
29        let mut number = Self::ZERO;
30        number.0[..data.len()].copy_from_slice(data);
31        number
32    }
33
34    /// Returns "true" iff the current number is negative
35    #[inline]
36    pub fn is_negative(&self) -> bool {
37        self.0[L - 1] > (u64::MAX >> 1)
38    }
39
40    /// Returns a tuple representing the sum of the first two arguments and the bit
41    /// described by the third argument. The first element of the tuple is this sum
42    /// modulo 2^64, the second one indicates whether the sum is no less than 2^64
43    #[inline]
44    fn sum(first: u64, second: u64, carry: bool) -> (u64, bool) {
45        // The implementation is inspired with the "carrying_add" function from this source:
46        // https://github.com/rust-lang/rust/blob/master/library/core/src/num/uint_macros.rs
47        let (second, carry) = second.overflowing_add(carry as u64);
48        let (first, high) = first.overflowing_add(second);
49        (first, carry || high)
50    }
51
52    /// Returns "(low, high)", where "high * 2^64 + low = first * second + carry + summand"
53    #[inline]
54    fn prodsum(first: u64, second: u64, summand: u64, carry: u64) -> (u64, u64) {
55        let all = (first as u128) * (second as u128) + (carry as u128) + (summand as u128);
56        (all as u64, (all >> u64::BITS) as u64)
57    }
58}
59
60impl<const L: usize> PartialEq for LInt<L> {
61    fn eq(&self, other: &Self) -> bool {
62        self.0 == other.0
63    }
64}
65
66impl<const L: usize> Shr<u32> for &LInt<L> {
67    type Output = LInt<L>;
68    /// Returns the result of applying the arithmetic right shift to the current number.
69    /// The specified bit quantity the number is shifted by must lie in {1, 2, ..., 63}.
70    /// For the quantities outside of the range, the behavior of the method is undefined
71    fn shr(self, bits: u32) -> Self::Output {
72        debug_assert!(
73            (bits > 0) && (bits < 64),
74            "Cannot shift by 0 or more than 63 bits!"
75        );
76        let (mut data, right) = ([0; L], u64::BITS - bits);
77
78        for (i, d) in data.iter_mut().enumerate().take(L - 1) {
79            *d = (self.0[i] >> bits) | (self.0[i + 1] << right);
80        }
81        data[L - 1] = self.0[L - 1] >> bits;
82        if self.is_negative() {
83            data[L - 1] |= u64::MAX << right;
84        }
85        LInt::<L>(data)
86    }
87}
88
89impl<const L: usize> Shr<u32> for LInt<L> {
90    type Output = LInt<L>;
91    fn shr(self, bits: u32) -> Self::Output {
92        &self >> bits
93    }
94}
95
96impl<const L: usize> Add for &LInt<L> {
97    type Output = LInt<L>;
98    fn add(self, other: Self) -> Self::Output {
99        let (mut data, mut carry) = ([0; L], false);
100        for (i, d) in data.iter_mut().enumerate().take(L) {
101            (*d, carry) = Self::Output::sum(self.0[i], other.0[i], carry);
102        }
103        LInt::<L>(data)
104    }
105}
106
107impl<const L: usize> Add<&LInt<L>> for LInt<L> {
108    type Output = LInt<L>;
109    fn add(self, other: &Self) -> Self::Output {
110        &self + other
111    }
112}
113
114impl<const L: usize> Add for LInt<L> {
115    type Output = LInt<L>;
116    fn add(self, other: Self) -> Self::Output {
117        &self + &other
118    }
119}
120
121impl<const L: usize> Sub for &LInt<L> {
122    type Output = LInt<L>;
123    fn sub(self, other: Self) -> Self::Output {
124        // For the two's complement code the additive negation is the result of
125        // adding 1 to the bitwise inverted argument's representation. Thus, for
126        // any encoded integers x and y we have x - y = x + !y + 1, where "!" is
127        // the bitwise inversion and addition is done according to the rules of
128        // the code. The algorithm below uses this formula and is the modified
129        // addition algorithm, where the carry flag is initialized with "true"
130        // and the chunks of the second argument are bitwise inverted
131        let (mut data, mut carry) = ([0; L], true);
132        for (i, d) in data.iter_mut().enumerate().take(L) {
133            (*d, carry) = Self::Output::sum(self.0[i], !other.0[i], carry);
134        }
135        LInt::<L>(data)
136    }
137}
138
139impl<const L: usize> Sub<&LInt<L>> for LInt<L> {
140    type Output = LInt<L>;
141    fn sub(self, other: &Self) -> Self::Output {
142        &self - other
143    }
144}
145
146impl<const L: usize> Sub for LInt<L> {
147    type Output = LInt<L>;
148    fn sub(self, other: Self) -> Self::Output {
149        &self - &other
150    }
151}
152
153impl<const L: usize> Neg for &LInt<L> {
154    type Output = LInt<L>;
155    fn neg(self) -> Self::Output {
156        // For the two's complement code the additive negation is the result
157        // of adding 1 to the bitwise inverted argument's representation
158        let (mut data, mut carry) = ([0; L], true);
159        for (i, d) in data.iter_mut().enumerate().take(L) {
160            (*d, carry) = (!self.0[i]).overflowing_add(carry as u64);
161        }
162        LInt::<L>(data)
163    }
164}
165
166impl<const L: usize> Neg for LInt<L> {
167    type Output = LInt<L>;
168    fn neg(self) -> Self::Output {
169        -&self
170    }
171}
172
173impl<const L: usize> Mul for &LInt<L> {
174    type Output = LInt<L>;
175    fn mul(self, other: Self) -> Self::Output {
176        let mut data = [0; L];
177        for i in 0..L {
178            let mut carry = 0;
179            for k in 0..(L - i) {
180                (data[i + k], carry) =
181                    Self::Output::prodsum(self.0[i], other.0[k], data[i + k], carry);
182            }
183        }
184        LInt::<L>(data)
185    }
186}
187
188impl<const L: usize> Mul<&LInt<L>> for LInt<L> {
189    type Output = LInt<L>;
190    fn mul(self, other: &Self) -> Self::Output {
191        &self * other
192    }
193}
194
195impl<const L: usize> Mul for LInt<L> {
196    type Output = LInt<L>;
197    fn mul(self, other: Self) -> Self::Output {
198        &self * &other
199    }
200}
201
202impl<const L: usize> Mul<i64> for &LInt<L> {
203    type Output = LInt<L>;
204    fn mul(self, other: i64) -> Self::Output {
205        let mut data = [0; L];
206        // If the short multiplicand is non-negative, the standard multiplication
207        // algorithm is performed. Otherwise, the product of the additively negated
208        // multiplicands is found as follows. Since for the two's complement code
209        // the additive negation is the result of adding 1 to the bitwise inverted
210        // argument's representation, for any encoded integers x and y we have
211        // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y),  where "!" is
212        // the bitwise inversion and arithmetic operations are performed according
213        // to the rules of the code. If the short multiplicand is negative, the
214        // algorithm below uses this formula by substituting the short multiplicand
215        // for y and becomes the modified standard multiplication algorithm, where
216        // the carry variable is being initialized with the additively negated short
217        // multiplicand and the chunks of the long multiplicand are bitwise inverted
218        let (other, mut carry, mask) = if other < 0 {
219            (-other as u64, -other as u64, u64::MAX)
220        } else {
221            (other as u64, 0, 0)
222        };
223        for (i, d) in data.iter_mut().enumerate().take(L) {
224            (*d, carry) = Self::Output::prodsum(self.0[i] ^ mask, other, 0, carry);
225        }
226        LInt::<L>(data)
227    }
228}
229
230impl<const L: usize> Mul<i64> for LInt<L> {
231    type Output = LInt<L>;
232    fn mul(self, other: i64) -> Self::Output {
233        &self * other
234    }
235}
236
237impl<const L: usize> Mul<&LInt<L>> for i64 {
238    type Output = LInt<L>;
239    fn mul(self, other: &LInt<L>) -> Self::Output {
240        other * self
241    }
242}
243
244impl<const L: usize> Mul<LInt<L>> for i64 {
245    type Output = LInt<L>;
246    fn mul(self, other: LInt<L>) -> Self::Output {
247        other * self
248    }
249}
250
251/// Returns the "approximations" of the arguments and the flag indicating whether
252/// both arguments are equal to their "approximations". Both the arguments must be
253/// non-negative, and at least one of them must be non-zero. For an incorrect input,
254/// the behavior of the function is undefined. These "approximations" are defined
255/// in the following way. Let n be the bit length of the largest argument without
256/// leading zeros. For n > 64 the "approximation" of the argument, which equals v,
257/// is (v div 2 ^ (n - 32)) * 2 ^ 32 + (v mod 2 ^ 32), i.e. it retains the high and
258/// low bits of the n-bit representation of v. If n does not exceed 64, an argument
259/// and its "approximation" are equal. These "approximations" are defined slightly
260/// differently from the ones in the Pornin's method for modular inversion: instead
261/// of taking the 33 high and 31 low bits of the n-bit representation of an argument,
262/// the 32 high and 32 low bits are taken
263fn approximate<const L: usize>(x: &LInt<L>, y: &LInt<L>) -> (u64, u64, bool) {
264    debug_assert!(
265        !(x.is_negative() || y.is_negative()),
266        "Both the arguments must be non-negative!"
267    );
268    debug_assert!(
269        (*x != LInt::ZERO) || (*y != LInt::ZERO),
270        "At least one argument must be non-zero!"
271    );
272    let mut i = L - 1;
273    while (x.0[i] == 0) && (y.0[i] == 0) {
274        i -= 1;
275    }
276    if i == 0 {
277        return (x.0[0], y.0[0], true);
278    }
279    let mut h = (x.0[i], y.0[i]);
280    let z = h.0.leading_zeros().min(h.1.leading_zeros());
281    h = (h.0 << z, h.1 << z);
282    if z > 32 {
283        h.0 |= x.0[i - 1] >> z;
284        h.1 |= y.0[i - 1] >> z;
285    }
286    let h = (h.0 & u64::MAX << 32, h.1 & u64::MAX << 32);
287    let l = (x.0[0] & u64::MAX >> 32, y.0[0] & u64::MAX >> 32);
288    (h.0 | l.0, h.1 | l.1, false)
289}
290
291/// Returns the Jacobi symbol ("n" / "d") multiplied by either 1 or -1.
292/// The later multiplicand is -1 iff the second-lowest bit of "t" is 1.
293/// The value of "d" must be odd in accordance with the Jacobi symbol
294/// definition. For even values of "d", the behavior is not defined.
295/// The implementation is based on the binary Euclidean algorithm
296fn jacobinary(mut n: u64, mut d: u64, mut t: u64) -> i64 {
297    debug_assert!(d & 1 > 0, "The second argument must be odd!");
298    while n != 0 {
299        if n & 1 > 0 {
300            if n < d {
301                (n, d) = (d, n);
302                t ^= n & d;
303            }
304            n = (n - d) >> 1;
305            t ^= d ^ d >> 1;
306        } else {
307            let z = n.trailing_zeros();
308            t ^= (d ^ d >> 1) & (z << 1) as u64;
309            n >>= z;
310        }
311    }
312    (d == 1) as i64 * (1 - (t & 2) as i64)
313}
314
315/// Returns the Jacobi symbol ("n" / "d") computed by means of the modification
316/// of the the Pornin's method for modular inversion. The arguments are unsigned
317/// big integers in the form of arrays of 64-bit chunks, the ordering of which
318/// is little-endian. The value of "d" must be odd in accordance with the Jacobi
319/// symbol definition. Both the arguments must be less than 2 ^ (64 * L - 31).
320/// For an incorrect input, the behavior of the function is undefined. The method
321/// differs from the Pornin's method for modular inversion in absence of the parts,
322/// which are not necessary to compute the greatest common divisor of arguments,
323/// presence of the parts used to compute the Jacobi symbol, which are based on
324/// the properties of the modified Jacobi symbol (x / |y|) described by M. Hamburg,
325/// and some original optimizations. Only these differences have been commented;
326/// the aforesaid Pornin's method and the used ideas of M. Hamburg were given here:
327/// - T. Pornin, "Optimized Binary GCD for Modular Inversion",
328/// <https://eprint.iacr.org/2020/972.pdf>
329/// - M. Hamburg, "Computing the Jacobi symbol using Bernstein-Yang",
330/// <https://eprint.iacr.org/2021/1271.pdf>
331pub fn jacobi<const L: usize>(n: &[u64], d: &[u64]) -> i64 {
332    // Instead of the variable "j" taking the values from {-1, 1} and satysfying
333    // at the end of the outer loop iteration the equation J = "j" * ("n" / |"d"|)
334    // for the modified Jacobi symbol ("n" / |"d"|) and the sought Jacobi symbol J,
335    // we store the sign bit of "j" in the second-lowest bit of "t" for optimization
336    // purposes. This approach was influenced by the paper by M. Hamburg
337    let (mut n, mut d, mut t) = (LInt::<L>::new(n), LInt::<L>::new(d), 0u64);
338    debug_assert!(d.0[0] & 1 > 0, "The second argument must be odd!");
339    debug_assert!(
340        n.0[L - 1].leading_zeros().min(d.0[L - 1].leading_zeros()) >= 31,
341        "Both the arguments must be less than 2 ^ (64 * L - 31)!"
342    );
343    loop {
344        // The inner loop performs 30 iterations instead of 31 ones in the aforementioned
345        // Pornin's method, and the "approximations" of "n" and "d" retain 32 of the lowest
346        // bits instead of 31 in that method. These modifications allow the values of the
347        // "approximation" variables to be equal modulo 8 to the corresponding "precise"
348        // variables' values, which would have been computed, if the "precise" variables
349        // had been updated in the inner loop along with the "approximations". This equality
350        // modulo 8 is used to update the second-lowest bit of "t" in accordance with the
351        // properties of the modified Jacobi symbol (x / |y|). The admissibility of these
352        // modifications has been proven using the appropriately modified Pornin's theorems
353        let (mut u, mut v, mut i) = ((1i64, 0i64), (0i64, 1i64), 30);
354        let (mut a, mut b, precise) = approximate(&n, &d);
355        // When each "approximation" variable has the same value as the corresponding "precise"
356        // one, the computation is accomplished using the short-arithmetic method of the Jacobi
357        // symbol calculation by means of the binary Euclidean algorithm. This approach aims at
358        // avoiding the parts of the final computations, which are related to long arithmetics
359        if precise {
360            return jacobinary(a, b, t);
361        }
362        while i > 0 {
363            if a & 1 > 0 {
364                if a < b {
365                    (a, b, u, v) = (b, a, v, u);
366                    // In both the aforesaid Pornin's method and its modification "n" and "d"
367                    // could not become negative simultaneously even if they were updated after
368                    // each iteration of the inner loop. Also at this point they both have odd
369                    // values. Therefore, the quadratic reciprocity law for the modified Jacobi
370                    // symbol (x / |y|) can be used. According to it, if both x and y are odd
371                    // numbers, among which there is a positive one, then for x = y = 3 (mod 4)
372                    // we have (x / |y|) = -(y / |x|) and for either x or y equal 1 modulo 4
373                    // the symbols (x / |y|) and (y / |x|) are equal
374                    t ^= a & b;
375                }
376                a = (a - b) >> 1;
377                u = (u.0 - v.0, u.1 - v.1);
378                v = (v.0 << 1, v.1 << 1);
379                // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5}
380                t ^= b ^ b >> 1;
381                i -= 1;
382            } else {
383                // Performing the batch of sequential iterations, which divide "a" by 2
384                let z = i.min(a.trailing_zeros());
385                // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5}. However,
386                // we do not need its value for a batch with an even number of divisions by 2
387                t ^= (b ^ b >> 1) & (z << 1) as u64;
388                v = (v.0 << z, v.1 << z);
389                a >>= z;
390                i -= z;
391            }
392        }
393        (n, d) = ((&n * u.0 + &d * u.1) >> 30, (&n * v.0 + &d * v.1) >> 30);
394
395        // This fragment is present to guarantee the correct behavior of the function
396        // in the case of arguments, whose greatest common divisor is no less than 2^64
397        if n == LInt::ZERO {
398            // In both the aforesaid Pornin's method and its modification the pair of the values
399            // of "n" and "d" after the divergence point contains a positive number and a negative
400            // one. Since the value of "n" is 0, the divergence point has not been reached by the
401            // inner loop this time, so there is no need to check whether "d" is equal to -1
402            return (d == LInt::ONE) as i64 * (1 - (t & 2) as i64);
403        }
404
405        if n.is_negative() {
406            // Since in both the aforesaid Pornin's method and its modification "d" is always odd
407            // and cannot become negative simultaneously with "n", the value of "d" is positive.
408            // The modified Jacobi symbol (-1 / |y|) for a positive y is -1, iff y mod 4 = 3
409            t ^= d.0[0];
410            n = -n;
411        } else if d.is_negative() {
412            // The modified Jacobi symbols (x / |y|) and (x / |-y|) are equal, so "t" is not updated
413            d = -d;
414        }
415    }
416}