halo2curves/ff_ext/jacobi.rs
1use core::cmp::PartialEq;
2use std::ops::{Add, Mul, Neg, Shr, Sub};
3
4/// Big signed (64 * L)-bit integer type, whose variables store
5/// numbers in the two's complement code as arrays of 64-bit chunks.
6/// The ordering of the chunks in these arrays is little-endian.
7/// The arithmetic operations for this type are wrapping ones
8#[derive(Clone)]
9pub struct LInt<const L: usize>([u64; L]);
10
11impl<const L: usize> LInt<L> {
12 /// Representation of -1
13 pub const MINUS_ONE: Self = Self([u64::MAX; L]);
14
15 /// Representation of 0
16 pub const ZERO: Self = Self([0; L]);
17
18 /// Representation of 1
19 pub const ONE: Self = {
20 let mut data = [0; L];
21 data[0] = 1;
22 Self(data)
23 };
24
25 /// Returns the number, which is stored as the specified
26 /// sequence padded with zeros to length L. If the input
27 /// sequence is longer than L, the method panics
28 pub fn new(data: &[u64]) -> Self {
29 let mut number = Self::ZERO;
30 number.0[..data.len()].copy_from_slice(data);
31 number
32 }
33
34 /// Returns "true" iff the current number is negative
35 #[inline]
36 pub fn is_negative(&self) -> bool {
37 self.0[L - 1] > (u64::MAX >> 1)
38 }
39
40 /// Returns a tuple representing the sum of the first two arguments and the
41 /// bit described by the third argument. The first element of the tuple
42 /// is this sum modulo 2^64, the second one indicates whether the sum is
43 /// no less than 2^64
44 #[inline]
45 fn sum(first: u64, second: u64, carry: bool) -> (u64, bool) {
46 // The implementation is inspired with the "carrying_add" function from this
47 // source: <https://github.com/rust-lang/rust/blob/master/library/core/src/num/uint_macros.rs>
48 let (second, carry) = second.overflowing_add(carry as u64);
49 let (first, high) = first.overflowing_add(second);
50 (first, carry || high)
51 }
52
53 /// Returns "(low, high)", where "high * 2^64 + low = first * second + carry
54 /// + summand"
55 #[inline]
56 fn prodsum(first: u64, second: u64, summand: u64, carry: u64) -> (u64, u64) {
57 let all = (first as u128) * (second as u128) + (carry as u128) + (summand as u128);
58 (all as u64, (all >> u64::BITS) as u64)
59 }
60}
61
62impl<const L: usize> PartialEq for LInt<L> {
63 fn eq(&self, other: &Self) -> bool {
64 self.0 == other.0
65 }
66}
67
68impl<const L: usize> Shr<u32> for &LInt<L> {
69 type Output = LInt<L>;
70 /// Returns the result of applying the arithmetic right shift to the current
71 /// number. The specified bit quantity the number is shifted by must lie
72 /// in {1, 2, ..., 63}. For the quantities outside of the range, the
73 /// behavior of the method is undefined
74 fn shr(self, bits: u32) -> Self::Output {
75 debug_assert!(
76 (bits > 0) && (bits < 64),
77 "Cannot shift by 0 or more than 63 bits!"
78 );
79 let (mut data, right) = ([0; L], u64::BITS - bits);
80
81 for (i, d) in data.iter_mut().enumerate().take(L - 1) {
82 *d = (self.0[i] >> bits) | (self.0[i + 1] << right);
83 }
84 data[L - 1] = self.0[L - 1] >> bits;
85 if self.is_negative() {
86 data[L - 1] |= u64::MAX << right;
87 }
88 LInt::<L>(data)
89 }
90}
91
92impl<const L: usize> Shr<u32> for LInt<L> {
93 type Output = LInt<L>;
94 fn shr(self, bits: u32) -> Self::Output {
95 &self >> bits
96 }
97}
98
99impl<const L: usize> Add for &LInt<L> {
100 type Output = LInt<L>;
101 fn add(self, other: Self) -> Self::Output {
102 let (mut data, mut carry) = ([0; L], false);
103 for (i, d) in data.iter_mut().enumerate().take(L) {
104 (*d, carry) = Self::Output::sum(self.0[i], other.0[i], carry);
105 }
106 LInt::<L>(data)
107 }
108}
109
110impl<const L: usize> Add<&LInt<L>> for LInt<L> {
111 type Output = LInt<L>;
112 fn add(self, other: &Self) -> Self::Output {
113 &self + other
114 }
115}
116
117impl<const L: usize> Add for LInt<L> {
118 type Output = LInt<L>;
119 fn add(self, other: Self) -> Self::Output {
120 &self + &other
121 }
122}
123
124impl<const L: usize> Sub for &LInt<L> {
125 type Output = LInt<L>;
126 fn sub(self, other: Self) -> Self::Output {
127 // For the two's complement code the additive negation is the result of
128 // adding 1 to the bitwise inverted argument's representation. Thus, for
129 // any encoded integers x and y we have x - y = x + !y + 1, where "!" is
130 // the bitwise inversion and addition is done according to the rules of
131 // the code. The algorithm below uses this formula and is the modified
132 // addition algorithm, where the carry flag is initialized with "true"
133 // and the chunks of the second argument are bitwise inverted
134 let (mut data, mut carry) = ([0; L], true);
135 for (i, d) in data.iter_mut().enumerate().take(L) {
136 (*d, carry) = Self::Output::sum(self.0[i], !other.0[i], carry);
137 }
138 LInt::<L>(data)
139 }
140}
141
142impl<const L: usize> Sub<&LInt<L>> for LInt<L> {
143 type Output = LInt<L>;
144 fn sub(self, other: &Self) -> Self::Output {
145 &self - other
146 }
147}
148
149impl<const L: usize> Sub for LInt<L> {
150 type Output = LInt<L>;
151 fn sub(self, other: Self) -> Self::Output {
152 &self - &other
153 }
154}
155
156impl<const L: usize> Neg for &LInt<L> {
157 type Output = LInt<L>;
158 fn neg(self) -> Self::Output {
159 // For the two's complement code the additive negation is the result
160 // of adding 1 to the bitwise inverted argument's representation
161 let (mut data, mut carry) = ([0; L], true);
162 for (i, d) in data.iter_mut().enumerate().take(L) {
163 (*d, carry) = (!self.0[i]).overflowing_add(carry as u64);
164 }
165 LInt::<L>(data)
166 }
167}
168
169impl<const L: usize> Neg for LInt<L> {
170 type Output = LInt<L>;
171 fn neg(self) -> Self::Output {
172 -&self
173 }
174}
175
176impl<const L: usize> Mul for &LInt<L> {
177 type Output = LInt<L>;
178 fn mul(self, other: Self) -> Self::Output {
179 let mut data = [0; L];
180 for i in 0..L {
181 let mut carry = 0;
182 for k in 0..(L - i) {
183 (data[i + k], carry) =
184 Self::Output::prodsum(self.0[i], other.0[k], data[i + k], carry);
185 }
186 }
187 LInt::<L>(data)
188 }
189}
190
191impl<const L: usize> Mul<&LInt<L>> for LInt<L> {
192 type Output = LInt<L>;
193 fn mul(self, other: &Self) -> Self::Output {
194 &self * other
195 }
196}
197
198impl<const L: usize> Mul for LInt<L> {
199 type Output = LInt<L>;
200 fn mul(self, other: Self) -> Self::Output {
201 &self * &other
202 }
203}
204
205impl<const L: usize> Mul<i64> for &LInt<L> {
206 type Output = LInt<L>;
207 fn mul(self, other: i64) -> Self::Output {
208 let mut data = [0; L];
209 // If the short multiplicand is non-negative, the standard multiplication
210 // algorithm is performed. Otherwise, the product of the additively negated
211 // multiplicands is found as follows. Since for the two's complement code
212 // the additive negation is the result of adding 1 to the bitwise inverted
213 // argument's representation, for any encoded integers x and y we have
214 // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y), where "!" is
215 // the bitwise inversion and arithmetic operations are performed according
216 // to the rules of the code. If the short multiplicand is negative, the
217 // algorithm below uses this formula by substituting the short multiplicand
218 // for y and becomes the modified standard multiplication algorithm, where
219 // the carry variable is being initialized with the additively negated short
220 // multiplicand and the chunks of the long multiplicand are bitwise inverted
221 let (other, mut carry, mask) = if other < 0 {
222 (-other as u64, -other as u64, u64::MAX)
223 } else {
224 (other as u64, 0, 0)
225 };
226 for (i, d) in data.iter_mut().enumerate().take(L) {
227 (*d, carry) = Self::Output::prodsum(self.0[i] ^ mask, other, 0, carry);
228 }
229 LInt::<L>(data)
230 }
231}
232
233impl<const L: usize> Mul<i64> for LInt<L> {
234 type Output = LInt<L>;
235 fn mul(self, other: i64) -> Self::Output {
236 &self * other
237 }
238}
239
240impl<const L: usize> Mul<&LInt<L>> for i64 {
241 type Output = LInt<L>;
242 fn mul(self, other: &LInt<L>) -> Self::Output {
243 other * self
244 }
245}
246
247impl<const L: usize> Mul<LInt<L>> for i64 {
248 type Output = LInt<L>;
249 fn mul(self, other: LInt<L>) -> Self::Output {
250 other * self
251 }
252}
253
254/// Returns the "approximations" of the arguments and the flag indicating
255/// whether both arguments are equal to their "approximations". Both the
256/// arguments must be non-negative, and at least one of them must be non-zero.
257/// For an incorrect input, the behavior of the function is undefined. These
258/// "approximations" are defined in the following way. Let n be the bit length
259/// of the largest argument without leading zeros. For n > 64 the
260/// "approximation" of the argument, which equals v, is (v div 2 ^ (n - 32)) * 2
261/// ^ 32 + (v mod 2 ^ 32), i.e. it retains the high and low bits of the n-bit
262/// representation of v. If n does not exceed 64, an argument
263/// and its "approximation" are equal. These "approximations" are defined
264/// slightly differently from the ones in the Pornin's method for modular
265/// inversion: instead of taking the 33 high and 31 low bits of the n-bit
266/// representation of an argument, the 32 high and 32 low bits are taken
267fn approximate<const L: usize>(x: &LInt<L>, y: &LInt<L>) -> (u64, u64, bool) {
268 debug_assert!(
269 !(x.is_negative() || y.is_negative()),
270 "Both the arguments must be non-negative!"
271 );
272 debug_assert!(
273 (*x != LInt::ZERO) || (*y != LInt::ZERO),
274 "At least one argument must be non-zero!"
275 );
276 let mut i = L - 1;
277 while (x.0[i] == 0) && (y.0[i] == 0) {
278 i -= 1;
279 }
280 if i == 0 {
281 return (x.0[0], y.0[0], true);
282 }
283 let mut h = (x.0[i], y.0[i]);
284 let z = h.0.leading_zeros().min(h.1.leading_zeros());
285 h = (h.0 << z, h.1 << z);
286 if z > 32 {
287 h.0 |= x.0[i - 1] >> z;
288 h.1 |= y.0[i - 1] >> z;
289 }
290 let h = (h.0 & u64::MAX << 32, h.1 & u64::MAX << 32);
291 let l = (x.0[0] & u64::MAX >> 32, y.0[0] & u64::MAX >> 32);
292 (h.0 | l.0, h.1 | l.1, false)
293}
294
295/// Returns the Jacobi symbol ("n" / "d") multiplied by either 1 or -1.
296/// The later multiplicand is -1 iff the second-lowest bit of "t" is 1.
297/// The value of "d" must be odd in accordance with the Jacobi symbol
298/// definition. For even values of "d", the behavior is not defined.
299/// The implementation is based on the binary Euclidean algorithm
300fn jacobinary(mut n: u64, mut d: u64, mut t: u64) -> i64 {
301 debug_assert!(d & 1 > 0, "The second argument must be odd!");
302 while n != 0 {
303 if n & 1 > 0 {
304 if n < d {
305 (n, d) = (d, n);
306 t ^= n & d;
307 }
308 n = (n - d) >> 1;
309 t ^= d ^ d >> 1;
310 } else {
311 let z = n.trailing_zeros();
312 t ^= (d ^ d >> 1) & (z << 1) as u64;
313 n >>= z;
314 }
315 }
316 (d == 1) as i64 * (1 - (t & 2) as i64)
317}
318
319/// Returns the Jacobi symbol ("n" / "d") computed by means of the modification
320/// of the Pornin's method for modular inversion. The arguments are unsigned
321/// big integers in the form of arrays of 64-bit chunks, the ordering of which
322/// is little-endian. The value of "d" must be odd in accordance with the Jacobi
323/// symbol definition. Both the arguments must be less than 2 ^ (64 * L - 31).
324/// For an incorrect input, the behavior of the function is undefined. The
325/// method differs from the Pornin's method for modular inversion in absence of
326/// the parts, which are not necessary to compute the greatest common divisor of
327/// arguments, presence of the parts used to compute the Jacobi symbol, which
328/// are based on the properties of the modified Jacobi symbol (x / |y|)
329/// described by M. Hamburg, and some original optimizations. Only these
330/// differences have been commented; the aforesaid Pornin's method and the used
331/// ideas of M. Hamburg were given here:
332/// - T. Pornin, "Optimized Binary GCD for Modular Inversion",
333/// <https://eprint.iacr.org/2020/972.pdf>
334/// - M. Hamburg, "Computing the Jacobi symbol using Bernstein-Yang",
335/// <https://eprint.iacr.org/2021/1271.pdf>
336pub fn jacobi<const L: usize>(n: &[u64], d: &[u64]) -> i64 {
337 // Instead of the variable "j" taking the values from {-1, 1} and satisfying
338 // at the end of the outer loop iteration the equation J = "j" * ("n" / |"d"|)
339 // for the modified Jacobi symbol ("n" / |"d"|) and the sought Jacobi symbol J,
340 // we store the sign bit of "j" in the second-lowest bit of "t" for optimization
341 // purposes. This approach was influenced by the paper by M. Hamburg
342 let (mut n, mut d, mut t) = (LInt::<L>::new(n), LInt::<L>::new(d), 0u64);
343 debug_assert!(d.0[0] & 1 > 0, "The second argument must be odd!");
344 debug_assert!(
345 n.0[L - 1].leading_zeros().min(d.0[L - 1].leading_zeros()) >= 31,
346 "Both the arguments must be less than 2 ^ (64 * L - 31)!"
347 );
348 loop {
349 // The inner loop performs 30 iterations instead of 31 ones in the
350 // aforementioned Pornin's method, and the "approximations" of "n" and
351 // "d" retain 32 of the lowest bits instead of 31 in that method. These
352 // modifications allow the values of the "approximation" variables to be
353 // equal modulo 8 to the corresponding "precise" variables' values,
354 // which would have been computed, if the "precise" variables
355 // had been updated in the inner loop along with the "approximations". This
356 // equality modulo 8 is used to update the second-lowest bit of "t" in
357 // accordance with the properties of the modified Jacobi symbol (x /
358 // |y|). The admissibility of these modifications has been proven using
359 // the appropriately modified Pornin's theorems
360 let (mut u, mut v, mut i) = ((1i64, 0i64), (0i64, 1i64), 30);
361 let (mut a, mut b, precise) = approximate(&n, &d);
362 // When each "approximation" variable has the same value as the corresponding
363 // "precise" one, the computation is accomplished using the
364 // short-arithmetic method of the Jacobi symbol calculation by means of
365 // the binary Euclidean algorithm. This approach aims at avoiding the
366 // parts of the final computations, which are related to long arithmetic
367 if precise {
368 return jacobinary(a, b, t);
369 }
370 while i > 0 {
371 if a & 1 > 0 {
372 if a < b {
373 (a, b, u, v) = (b, a, v, u);
374 // In both the aforesaid Pornin's method and its modification "n" and "d"
375 // could not become negative simultaneously even if they were updated after
376 // each iteration of the inner loop. Also at this point they both have odd
377 // values. Therefore, the quadratic reciprocity law for the modified Jacobi
378 // symbol (x / |y|) can be used. According to it, if both x and y are odd
379 // numbers, among which there is a positive one, then for x = y = 3 (mod 4)
380 // we have (x / |y|) = -(y / |x|) and for either x or y equal 1 modulo 4
381 // the symbols (x / |y|) and (y / |x|) are equal
382 t ^= a & b;
383 }
384 a = (a - b) >> 1;
385 u = (u.0 - v.0, u.1 - v.1);
386 v = (v.0 << 1, v.1 << 1);
387 // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5}
388 t ^= b ^ b >> 1;
389 i -= 1;
390 } else {
391 // Performing the batch of sequential iterations, which divide "a" by 2
392 let z = i.min(a.trailing_zeros());
393 // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5}. However,
394 // we do not need its value for a batch with an even number of divisions by 2
395 t ^= (b ^ b >> 1) & (z << 1) as u64;
396 v = (v.0 << z, v.1 << z);
397 a >>= z;
398 i -= z;
399 }
400 }
401 (n, d) = ((&n * u.0 + &d * u.1) >> 30, (&n * v.0 + &d * v.1) >> 30);
402
403 // This fragment is present to guarantee the correct behavior of the function
404 // in the case of arguments, whose greatest common divisor is no less than 2^64
405 if n == LInt::ZERO {
406 // In both the aforesaid Pornin's method and its modification the pair of the
407 // values of "n" and "d" after the divergence point contains a
408 // positive number and a negative one. Since the value of "n" is 0,
409 // the divergence point has not been reached by the inner loop this
410 // time, so there is no need to check whether "d" is equal to -1
411 return (d == LInt::ONE) as i64 * (1 - (t & 2) as i64);
412 }
413
414 if n.is_negative() {
415 // Since in both the aforesaid Pornin's method and its modification "d" is
416 // always odd and cannot become negative simultaneously with "n",
417 // the value of "d" is positive. The modified Jacobi symbol (-1 /
418 // |y|) for a positive y is -1, iff y mod 4 = 3
419 t ^= d.0[0];
420 n = -n;
421 } else if d.is_negative() {
422 // The modified Jacobi symbols (x / |y|) and (x / |-y|) are equal, so "t" is not
423 // updated
424 d = -d;
425 }
426 }
427}