halo2curves_axiom/ff_ext/jacobi.rs
1use core::cmp::PartialEq;
2use std::ops::{Add, Mul, Neg, Shr, Sub};
3
4/// Big signed (64 * L)-bit integer type, whose variables store
5/// numbers in the two's complement code as arrays of 64-bit chunks.
6/// The ordering of the chunks in these arrays is little-endian.
7/// The arithmetic operations for this type are wrapping ones
8#[derive(Clone)]
9pub struct LInt<const L: usize>([u64; L]);
10
11impl<const L: usize> LInt<L> {
12 /// Representation of -1
13 pub const MINUS_ONE: Self = Self([u64::MAX; L]);
14
15 /// Representation of 0
16 pub const ZERO: Self = Self([0; L]);
17
18 /// Representation of 1
19 pub const ONE: Self = {
20 let mut data = [0; L];
21 data[0] = 1;
22 Self(data)
23 };
24
25 /// Returns the number, which is stored as the specified
26 /// sequence padded with zeros to length L. If the input
27 /// sequence is longer than L, the method panics
28 pub fn new(data: &[u64]) -> Self {
29 let mut number = Self::ZERO;
30 number.0[..data.len()].copy_from_slice(data);
31 number
32 }
33
34 /// Returns "true" iff the current number is negative
35 #[inline]
36 pub fn is_negative(&self) -> bool {
37 self.0[L - 1] > (u64::MAX >> 1)
38 }
39
40 /// Returns a tuple representing the sum of the first two arguments and the bit
41 /// described by the third argument. The first element of the tuple is this sum
42 /// modulo 2^64, the second one indicates whether the sum is no less than 2^64
43 #[inline]
44 fn sum(first: u64, second: u64, carry: bool) -> (u64, bool) {
45 // The implementation is inspired with the "carrying_add" function from this source:
46 // https://github.com/rust-lang/rust/blob/master/library/core/src/num/uint_macros.rs
47 let (second, carry) = second.overflowing_add(carry as u64);
48 let (first, high) = first.overflowing_add(second);
49 (first, carry || high)
50 }
51
52 /// Returns "(low, high)", where "high * 2^64 + low = first * second + carry + summand"
53 #[inline]
54 fn prodsum(first: u64, second: u64, summand: u64, carry: u64) -> (u64, u64) {
55 let all = (first as u128) * (second as u128) + (carry as u128) + (summand as u128);
56 (all as u64, (all >> u64::BITS) as u64)
57 }
58}
59
60impl<const L: usize> PartialEq for LInt<L> {
61 fn eq(&self, other: &Self) -> bool {
62 self.0 == other.0
63 }
64}
65
66impl<const L: usize> Shr<u32> for &LInt<L> {
67 type Output = LInt<L>;
68 /// Returns the result of applying the arithmetic right shift to the current number.
69 /// The specified bit quantity the number is shifted by must lie in {1, 2, ..., 63}.
70 /// For the quantities outside of the range, the behavior of the method is undefined
71 fn shr(self, bits: u32) -> Self::Output {
72 debug_assert!(
73 (bits > 0) && (bits < 64),
74 "Cannot shift by 0 or more than 63 bits!"
75 );
76 let (mut data, right) = ([0; L], u64::BITS - bits);
77
78 for (i, d) in data.iter_mut().enumerate().take(L - 1) {
79 *d = (self.0[i] >> bits) | (self.0[i + 1] << right);
80 }
81 data[L - 1] = self.0[L - 1] >> bits;
82 if self.is_negative() {
83 data[L - 1] |= u64::MAX << right;
84 }
85 LInt::<L>(data)
86 }
87}
88
89impl<const L: usize> Shr<u32> for LInt<L> {
90 type Output = LInt<L>;
91 fn shr(self, bits: u32) -> Self::Output {
92 &self >> bits
93 }
94}
95
96impl<const L: usize> Add for &LInt<L> {
97 type Output = LInt<L>;
98 fn add(self, other: Self) -> Self::Output {
99 let (mut data, mut carry) = ([0; L], false);
100 for (i, d) in data.iter_mut().enumerate().take(L) {
101 (*d, carry) = Self::Output::sum(self.0[i], other.0[i], carry);
102 }
103 LInt::<L>(data)
104 }
105}
106
107impl<const L: usize> Add<&LInt<L>> for LInt<L> {
108 type Output = LInt<L>;
109 fn add(self, other: &Self) -> Self::Output {
110 &self + other
111 }
112}
113
114impl<const L: usize> Add for LInt<L> {
115 type Output = LInt<L>;
116 fn add(self, other: Self) -> Self::Output {
117 &self + &other
118 }
119}
120
121impl<const L: usize> Sub for &LInt<L> {
122 type Output = LInt<L>;
123 fn sub(self, other: Self) -> Self::Output {
124 // For the two's complement code the additive negation is the result of
125 // adding 1 to the bitwise inverted argument's representation. Thus, for
126 // any encoded integers x and y we have x - y = x + !y + 1, where "!" is
127 // the bitwise inversion and addition is done according to the rules of
128 // the code. The algorithm below uses this formula and is the modified
129 // addition algorithm, where the carry flag is initialized with "true"
130 // and the chunks of the second argument are bitwise inverted
131 let (mut data, mut carry) = ([0; L], true);
132 for (i, d) in data.iter_mut().enumerate().take(L) {
133 (*d, carry) = Self::Output::sum(self.0[i], !other.0[i], carry);
134 }
135 LInt::<L>(data)
136 }
137}
138
139impl<const L: usize> Sub<&LInt<L>> for LInt<L> {
140 type Output = LInt<L>;
141 fn sub(self, other: &Self) -> Self::Output {
142 &self - other
143 }
144}
145
146impl<const L: usize> Sub for LInt<L> {
147 type Output = LInt<L>;
148 fn sub(self, other: Self) -> Self::Output {
149 &self - &other
150 }
151}
152
153impl<const L: usize> Neg for &LInt<L> {
154 type Output = LInt<L>;
155 fn neg(self) -> Self::Output {
156 // For the two's complement code the additive negation is the result
157 // of adding 1 to the bitwise inverted argument's representation
158 let (mut data, mut carry) = ([0; L], true);
159 for (i, d) in data.iter_mut().enumerate().take(L) {
160 (*d, carry) = (!self.0[i]).overflowing_add(carry as u64);
161 }
162 LInt::<L>(data)
163 }
164}
165
166impl<const L: usize> Neg for LInt<L> {
167 type Output = LInt<L>;
168 fn neg(self) -> Self::Output {
169 -&self
170 }
171}
172
173impl<const L: usize> Mul for &LInt<L> {
174 type Output = LInt<L>;
175 fn mul(self, other: Self) -> Self::Output {
176 let mut data = [0; L];
177 for i in 0..L {
178 let mut carry = 0;
179 for k in 0..(L - i) {
180 (data[i + k], carry) =
181 Self::Output::prodsum(self.0[i], other.0[k], data[i + k], carry);
182 }
183 }
184 LInt::<L>(data)
185 }
186}
187
188impl<const L: usize> Mul<&LInt<L>> for LInt<L> {
189 type Output = LInt<L>;
190 fn mul(self, other: &Self) -> Self::Output {
191 &self * other
192 }
193}
194
195impl<const L: usize> Mul for LInt<L> {
196 type Output = LInt<L>;
197 fn mul(self, other: Self) -> Self::Output {
198 &self * &other
199 }
200}
201
202impl<const L: usize> Mul<i64> for &LInt<L> {
203 type Output = LInt<L>;
204 fn mul(self, other: i64) -> Self::Output {
205 let mut data = [0; L];
206 // If the short multiplicand is non-negative, the standard multiplication
207 // algorithm is performed. Otherwise, the product of the additively negated
208 // multiplicands is found as follows. Since for the two's complement code
209 // the additive negation is the result of adding 1 to the bitwise inverted
210 // argument's representation, for any encoded integers x and y we have
211 // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y), where "!" is
212 // the bitwise inversion and arithmetic operations are performed according
213 // to the rules of the code. If the short multiplicand is negative, the
214 // algorithm below uses this formula by substituting the short multiplicand
215 // for y and becomes the modified standard multiplication algorithm, where
216 // the carry variable is being initialized with the additively negated short
217 // multiplicand and the chunks of the long multiplicand are bitwise inverted
218 let (other, mut carry, mask) = if other < 0 {
219 (-other as u64, -other as u64, u64::MAX)
220 } else {
221 (other as u64, 0, 0)
222 };
223 for (i, d) in data.iter_mut().enumerate().take(L) {
224 (*d, carry) = Self::Output::prodsum(self.0[i] ^ mask, other, 0, carry);
225 }
226 LInt::<L>(data)
227 }
228}
229
230impl<const L: usize> Mul<i64> for LInt<L> {
231 type Output = LInt<L>;
232 fn mul(self, other: i64) -> Self::Output {
233 &self * other
234 }
235}
236
237impl<const L: usize> Mul<&LInt<L>> for i64 {
238 type Output = LInt<L>;
239 fn mul(self, other: &LInt<L>) -> Self::Output {
240 other * self
241 }
242}
243
244impl<const L: usize> Mul<LInt<L>> for i64 {
245 type Output = LInt<L>;
246 fn mul(self, other: LInt<L>) -> Self::Output {
247 other * self
248 }
249}
250
251/// Returns the "approximations" of the arguments and the flag indicating whether
252/// both arguments are equal to their "approximations". Both the arguments must be
253/// non-negative, and at least one of them must be non-zero. For an incorrect input,
254/// the behavior of the function is undefined. These "approximations" are defined
255/// in the following way. Let n be the bit length of the largest argument without
256/// leading zeros. For n > 64 the "approximation" of the argument, which equals v,
257/// is (v div 2 ^ (n - 32)) * 2 ^ 32 + (v mod 2 ^ 32), i.e. it retains the high and
258/// low bits of the n-bit representation of v. If n does not exceed 64, an argument
259/// and its "approximation" are equal. These "approximations" are defined slightly
260/// differently from the ones in the Pornin's method for modular inversion: instead
261/// of taking the 33 high and 31 low bits of the n-bit representation of an argument,
262/// the 32 high and 32 low bits are taken
263fn approximate<const L: usize>(x: &LInt<L>, y: &LInt<L>) -> (u64, u64, bool) {
264 debug_assert!(
265 !(x.is_negative() || y.is_negative()),
266 "Both the arguments must be non-negative!"
267 );
268 debug_assert!(
269 (*x != LInt::ZERO) || (*y != LInt::ZERO),
270 "At least one argument must be non-zero!"
271 );
272 let mut i = L - 1;
273 while (x.0[i] == 0) && (y.0[i] == 0) {
274 i -= 1;
275 }
276 if i == 0 {
277 return (x.0[0], y.0[0], true);
278 }
279 let mut h = (x.0[i], y.0[i]);
280 let z = h.0.leading_zeros().min(h.1.leading_zeros());
281 h = (h.0 << z, h.1 << z);
282 if z > 32 {
283 h.0 |= x.0[i - 1] >> z;
284 h.1 |= y.0[i - 1] >> z;
285 }
286 let h = (h.0 & u64::MAX << 32, h.1 & u64::MAX << 32);
287 let l = (x.0[0] & u64::MAX >> 32, y.0[0] & u64::MAX >> 32);
288 (h.0 | l.0, h.1 | l.1, false)
289}
290
291/// Returns the Jacobi symbol ("n" / "d") multiplied by either 1 or -1.
292/// The later multiplicand is -1 iff the second-lowest bit of "t" is 1.
293/// The value of "d" must be odd in accordance with the Jacobi symbol
294/// definition. For even values of "d", the behavior is not defined.
295/// The implementation is based on the binary Euclidean algorithm
296fn jacobinary(mut n: u64, mut d: u64, mut t: u64) -> i64 {
297 debug_assert!(d & 1 > 0, "The second argument must be odd!");
298 while n != 0 {
299 if n & 1 > 0 {
300 if n < d {
301 (n, d) = (d, n);
302 t ^= n & d;
303 }
304 n = (n - d) >> 1;
305 t ^= d ^ d >> 1;
306 } else {
307 let z = n.trailing_zeros();
308 t ^= (d ^ d >> 1) & (z << 1) as u64;
309 n >>= z;
310 }
311 }
312 (d == 1) as i64 * (1 - (t & 2) as i64)
313}
314
315/// Returns the Jacobi symbol ("n" / "d") computed by means of the modification
316/// of the the Pornin's method for modular inversion. The arguments are unsigned
317/// big integers in the form of arrays of 64-bit chunks, the ordering of which
318/// is little-endian. The value of "d" must be odd in accordance with the Jacobi
319/// symbol definition. Both the arguments must be less than 2 ^ (64 * L - 31).
320/// For an incorrect input, the behavior of the function is undefined. The method
321/// differs from the Pornin's method for modular inversion in absence of the parts,
322/// which are not necessary to compute the greatest common divisor of arguments,
323/// presence of the parts used to compute the Jacobi symbol, which are based on
324/// the properties of the modified Jacobi symbol (x / |y|) described by M. Hamburg,
325/// and some original optimizations. Only these differences have been commented;
326/// the aforesaid Pornin's method and the used ideas of M. Hamburg were given here:
327/// - T. Pornin, "Optimized Binary GCD for Modular Inversion",
328/// <https://eprint.iacr.org/2020/972.pdf>
329/// - M. Hamburg, "Computing the Jacobi symbol using Bernstein-Yang",
330/// <https://eprint.iacr.org/2021/1271.pdf>
331pub fn jacobi<const L: usize>(n: &[u64], d: &[u64]) -> i64 {
332 // Instead of the variable "j" taking the values from {-1, 1} and satysfying
333 // at the end of the outer loop iteration the equation J = "j" * ("n" / |"d"|)
334 // for the modified Jacobi symbol ("n" / |"d"|) and the sought Jacobi symbol J,
335 // we store the sign bit of "j" in the second-lowest bit of "t" for optimization
336 // purposes. This approach was influenced by the paper by M. Hamburg
337 let (mut n, mut d, mut t) = (LInt::<L>::new(n), LInt::<L>::new(d), 0u64);
338 debug_assert!(d.0[0] & 1 > 0, "The second argument must be odd!");
339 debug_assert!(
340 n.0[L - 1].leading_zeros().min(d.0[L - 1].leading_zeros()) >= 31,
341 "Both the arguments must be less than 2 ^ (64 * L - 31)!"
342 );
343 loop {
344 // The inner loop performs 30 iterations instead of 31 ones in the aforementioned
345 // Pornin's method, and the "approximations" of "n" and "d" retain 32 of the lowest
346 // bits instead of 31 in that method. These modifications allow the values of the
347 // "approximation" variables to be equal modulo 8 to the corresponding "precise"
348 // variables' values, which would have been computed, if the "precise" variables
349 // had been updated in the inner loop along with the "approximations". This equality
350 // modulo 8 is used to update the second-lowest bit of "t" in accordance with the
351 // properties of the modified Jacobi symbol (x / |y|). The admissibility of these
352 // modifications has been proven using the appropriately modified Pornin's theorems
353 let (mut u, mut v, mut i) = ((1i64, 0i64), (0i64, 1i64), 30);
354 let (mut a, mut b, precise) = approximate(&n, &d);
355 // When each "approximation" variable has the same value as the corresponding "precise"
356 // one, the computation is accomplished using the short-arithmetic method of the Jacobi
357 // symbol calculation by means of the binary Euclidean algorithm. This approach aims at
358 // avoiding the parts of the final computations, which are related to long arithmetics
359 if precise {
360 return jacobinary(a, b, t);
361 }
362 while i > 0 {
363 if a & 1 > 0 {
364 if a < b {
365 (a, b, u, v) = (b, a, v, u);
366 // In both the aforesaid Pornin's method and its modification "n" and "d"
367 // could not become negative simultaneously even if they were updated after
368 // each iteration of the inner loop. Also at this point they both have odd
369 // values. Therefore, the quadratic reciprocity law for the modified Jacobi
370 // symbol (x / |y|) can be used. According to it, if both x and y are odd
371 // numbers, among which there is a positive one, then for x = y = 3 (mod 4)
372 // we have (x / |y|) = -(y / |x|) and for either x or y equal 1 modulo 4
373 // the symbols (x / |y|) and (y / |x|) are equal
374 t ^= a & b;
375 }
376 a = (a - b) >> 1;
377 u = (u.0 - v.0, u.1 - v.1);
378 v = (v.0 << 1, v.1 << 1);
379 // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5}
380 t ^= b ^ b >> 1;
381 i -= 1;
382 } else {
383 // Performing the batch of sequential iterations, which divide "a" by 2
384 let z = i.min(a.trailing_zeros());
385 // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5}. However,
386 // we do not need its value for a batch with an even number of divisions by 2
387 t ^= (b ^ b >> 1) & (z << 1) as u64;
388 v = (v.0 << z, v.1 << z);
389 a >>= z;
390 i -= z;
391 }
392 }
393 (n, d) = ((&n * u.0 + &d * u.1) >> 30, (&n * v.0 + &d * v.1) >> 30);
394
395 // This fragment is present to guarantee the correct behavior of the function
396 // in the case of arguments, whose greatest common divisor is no less than 2^64
397 if n == LInt::ZERO {
398 // In both the aforesaid Pornin's method and its modification the pair of the values
399 // of "n" and "d" after the divergence point contains a positive number and a negative
400 // one. Since the value of "n" is 0, the divergence point has not been reached by the
401 // inner loop this time, so there is no need to check whether "d" is equal to -1
402 return (d == LInt::ONE) as i64 * (1 - (t & 2) as i64);
403 }
404
405 if n.is_negative() {
406 // Since in both the aforesaid Pornin's method and its modification "d" is always odd
407 // and cannot become negative simultaneously with "n", the value of "d" is positive.
408 // The modified Jacobi symbol (-1 / |y|) for a positive y is -1, iff y mod 4 = 3
409 t ^= d.0[0];
410 n = -n;
411 } else if d.is_negative() {
412 // The modified Jacobi symbols (x / |y|) and (x / |-y|) are equal, so "t" is not updated
413 d = -d;
414 }
415 }
416}