1use core::cmp::PartialEq;
2use std::ops::{Add, Mul, Neg, Sub};
3
4#[derive(Clone)]
9struct CInt<const B: usize, const L: usize>(pub [u64; L]);
10
11impl<const B: usize, const L: usize> CInt<B, L> {
12 pub const MASK: u64 = u64::MAX >> (64 - B);
14
15 pub const MINUS_ONE: Self = Self([Self::MASK; L]);
17
18 pub const ZERO: Self = Self([0; L]);
20
21 pub const ONE: Self = {
23 let mut data = [0; L];
24 data[0] = 1;
25 Self(data)
26 };
27
28 pub fn shift(&self) -> Self {
31 let mut data = [0; L];
32 if self.is_negative() {
33 data[L - 1] = Self::MASK;
34 }
35 data[..L - 1].copy_from_slice(&self.0[1..]);
36 Self(data)
37 }
38
39 pub fn lowest(&self) -> u64 {
41 self.0[0]
42 }
43
44 pub fn is_negative(&self) -> bool {
46 self.0[L - 1] > (Self::MASK >> 1)
47 }
48}
49
50impl<const B: usize, const L: usize> PartialEq for CInt<B, L> {
51 fn eq(&self, other: &Self) -> bool {
52 self.0 == other.0
53 }
54}
55
56impl<const B: usize, const L: usize> Add for &CInt<B, L> {
57 type Output = CInt<B, L>;
58 fn add(self, other: Self) -> Self::Output {
59 let (mut data, mut carry) = ([0; L], 0);
60 for (i, d) in data.iter_mut().enumerate().take(L) {
61 let sum = self.0[i] + other.0[i] + carry;
62 *d = sum & CInt::<B, L>::MASK;
63 carry = sum >> B;
64 }
65 CInt::<B, L>(data)
66 }
67}
68
69impl<const B: usize, const L: usize> Add<&CInt<B, L>> for CInt<B, L> {
70 type Output = CInt<B, L>;
71 fn add(self, other: &Self) -> Self::Output {
72 &self + other
73 }
74}
75
76impl<const B: usize, const L: usize> Add for CInt<B, L> {
77 type Output = CInt<B, L>;
78 fn add(self, other: Self) -> Self::Output {
79 &self + &other
80 }
81}
82
83impl<const B: usize, const L: usize> Sub for &CInt<B, L> {
84 type Output = CInt<B, L>;
85 fn sub(self, other: Self) -> Self::Output {
86 let (mut data, mut carry) = ([0; L], 1);
94 for (i, d) in data.iter_mut().enumerate().take(L) {
95 let sum = self.0[i] + (other.0[i] ^ CInt::<B, L>::MASK) + carry;
96 *d = sum & CInt::<B, L>::MASK;
97 carry = sum >> B;
98 }
99 CInt::<B, L>(data)
100 }
101}
102
103impl<const B: usize, const L: usize> Sub<&CInt<B, L>> for CInt<B, L> {
104 type Output = CInt<B, L>;
105 fn sub(self, other: &Self) -> Self::Output {
106 &self - other
107 }
108}
109
110impl<const B: usize, const L: usize> Sub for CInt<B, L> {
111 type Output = CInt<B, L>;
112 fn sub(self, other: Self) -> Self::Output {
113 &self - &other
114 }
115}
116
117impl<const B: usize, const L: usize> Neg for &CInt<B, L> {
118 type Output = CInt<B, L>;
119 fn neg(self) -> Self::Output {
120 let (mut data, mut carry) = ([0; L], 1);
123 for (i, d) in data.iter_mut().enumerate().take(L) {
124 let sum = (self.0[i] ^ CInt::<B, L>::MASK) + carry;
125 *d = sum & CInt::<B, L>::MASK;
126 carry = sum >> B;
127 }
128 CInt::<B, L>(data)
129 }
130}
131
132impl<const B: usize, const L: usize> Neg for CInt<B, L> {
133 type Output = CInt<B, L>;
134 fn neg(self) -> Self::Output {
135 -&self
136 }
137}
138
139impl<const B: usize, const L: usize> Mul for &CInt<B, L> {
140 type Output = CInt<B, L>;
141 fn mul(self, other: Self) -> Self::Output {
142 let mut data = [0; L];
143 for i in 0..L {
144 let mut carry = 0;
145 for k in 0..(L - i) {
146 let sum = (data[i + k] as u128)
147 + (carry as u128)
148 + (self.0[i] as u128) * (other.0[k] as u128);
149 data[i + k] = sum as u64 & CInt::<B, L>::MASK;
150 carry = (sum >> B) as u64;
151 }
152 }
153 CInt::<B, L>(data)
154 }
155}
156
157impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for CInt<B, L> {
158 type Output = CInt<B, L>;
159 fn mul(self, other: &Self) -> Self::Output {
160 &self * other
161 }
162}
163
164impl<const B: usize, const L: usize> Mul for CInt<B, L> {
165 type Output = CInt<B, L>;
166 fn mul(self, other: Self) -> Self::Output {
167 &self * &other
168 }
169}
170
171impl<const B: usize, const L: usize> Mul<i64> for &CInt<B, L> {
172 type Output = CInt<B, L>;
173 fn mul(self, other: i64) -> Self::Output {
174 let mut data = [0; L];
175 let (other, mut carry, mask) = if other < 0 {
188 (-other, -other as u64, CInt::<B, L>::MASK)
189 } else {
190 (other, 0, 0)
191 };
192 for (i, d) in data.iter_mut().enumerate().take(L) {
193 let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128);
194 *d = sum as u64 & CInt::<B, L>::MASK;
195 carry = (sum >> B) as u64;
196 }
197 CInt::<B, L>(data)
198 }
199}
200
201impl<const B: usize, const L: usize> Mul<i64> for CInt<B, L> {
202 type Output = CInt<B, L>;
203 fn mul(self, other: i64) -> Self::Output {
204 &self * other
205 }
206}
207
208impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for i64 {
209 type Output = CInt<B, L>;
210 fn mul(self, other: &CInt<B, L>) -> Self::Output {
211 other * self
212 }
213}
214
215impl<const B: usize, const L: usize> Mul<CInt<B, L>> for i64 {
216 type Output = CInt<B, L>;
217 fn mul(self, other: CInt<B, L>) -> Self::Output {
218 other * self
219 }
220}
221
222pub struct BYInverter<const L: usize> {
251 modulus: CInt<62, L>,
253
254 adjuster: CInt<62, L>,
256
257 inverse: i64,
259}
260
261type Matrix = [[i64; 2]; 2];
263
264impl<const L: usize> BYInverter<L> {
265 fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) {
270 let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128);
271 let mut t: Matrix = [[1, 0], [0, 1]];
272
273 loop {
274 let zeros = steps.min(g.trailing_zeros() as i64);
275 (steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros);
276 t[0] = [t[0][0] << zeros, t[0][1] << zeros];
277
278 if steps == 0 {
279 break;
280 }
281 if delta > 0 {
282 (delta, f, g) = (-delta, g as i64, -f as i128);
283 (t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]);
284 }
285
286 let mask = (1 << steps.min(1 - delta).min(5)) - 1;
290 let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask;
291
292 t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]];
293 g += w as i128 * f as i128;
294 }
295
296 (delta, t)
297 }
298
299 fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
304 (
305 (t[0][0] * &f + t[0][1] * &g).shift(),
306 (t[1][0] * &f + t[1][1] * &g).shift(),
307 )
308 }
309
310 fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
317 let mask = CInt::<62, L>::MASK as i64;
318 let mut md = t[0][0] * d.is_negative() as i64 + t[0][1] * e.is_negative() as i64;
319 let mut me = t[1][0] * d.is_negative() as i64 + t[1][1] * e.is_negative() as i64;
320
321 let cd = t[0][0]
322 .wrapping_mul(d.lowest() as i64)
323 .wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64))
324 & mask;
325 let ce = t[1][0]
326 .wrapping_mul(d.lowest() as i64)
327 .wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64))
328 & mask;
329
330 md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask;
331 me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask;
332
333 let cd = t[0][0] * &d + t[0][1] * &e + md * &self.modulus;
334 let ce = t[1][0] * &d + t[1][1] * &e + me * &self.modulus;
335
336 (cd.shift(), ce.shift())
337 }
338
339 fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> {
344 if value.is_negative() {
345 value = value + &self.modulus;
346 }
347
348 if negate {
349 value = -value;
350 }
351
352 if value.is_negative() {
353 value = value + &self.modulus;
354 }
355
356 value
357 }
358
359 const fn convert<const I: usize, const O: usize, const S: usize>(input: &[u64]) -> [u64; S] {
364 const fn min(a: usize, b: usize) -> usize {
367 if a > b {
368 b
369 } else {
370 a
371 }
372 }
373
374 let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0);
375
376 while bits < total {
377 let (i, o) = (bits % I, bits % O);
378 output[bits / O] |= (input[bits / I] >> i) << o;
379 bits += min(I - i, O - o);
380 }
381
382 let mask = u64::MAX >> (64 - O);
383 let mut filled = total / O + if total % O > 0 { 1 } else { 0 };
384
385 while filled > 0 {
386 filled -= 1;
387 output[filled] &= mask;
388 }
389
390 output
391 }
392
393 const fn inv(value: u64) -> i64 {
400 let x = value.wrapping_mul(3) ^ 2;
401 let y = 1u64.wrapping_sub(x.wrapping_mul(value));
402 let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
403 let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
404 let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
405 (x.wrapping_mul(y.wrapping_add(1)) & CInt::<62, L>::MASK) as i64
406 }
407
408 pub const fn new(modulus: &[u64], adjuster: &[u64]) -> Self {
410 Self {
411 modulus: CInt::<62, L>(Self::convert::<64, 62, L>(modulus)),
412 adjuster: CInt::<62, L>(Self::convert::<64, 62, L>(adjuster)),
413 inverse: Self::inv(modulus[0]),
414 }
415 }
416
417 pub fn invert<const S: usize>(&self, value: &[u64]) -> Option<[u64; S]> {
421 let (mut d, mut e) = (CInt::ZERO, self.adjuster.clone());
422 let mut g = CInt::<62, L>(Self::convert::<64, 62, L>(value));
423 let (mut delta, mut f) = (1, self.modulus.clone());
424 let mut matrix;
425 while g != CInt::ZERO {
426 (delta, matrix) = Self::jump(&f, &g, delta);
427 (f, g) = Self::fg(f, g, matrix);
428 (d, e) = self.de(d, e, matrix);
429 }
430 let antiunit = f == CInt::MINUS_ONE;
434 if (f != CInt::ONE) && !antiunit {
435 return None;
436 }
437 Some(Self::convert::<62, 64, S>(&self.norm(d, antiunit).0))
438 }
439}