1use core::cmp::PartialEq;
2use std::ops::{Add, Mul, Neg, Sub};
3
4#[derive(Clone)]
9struct CInt<const B: usize, const L: usize>(pub [u64; L]);
10
11impl<const B: usize, const L: usize> CInt<B, L> {
12 pub const MASK: u64 = u64::MAX >> (64 - B);
14
15 pub const MINUS_ONE: Self = Self([Self::MASK; L]);
17
18 pub const ZERO: Self = Self([0; L]);
20
21 pub const ONE: Self = {
23 let mut data = [0; L];
24 data[0] = 1;
25 Self(data)
26 };
27
28 pub fn shift(&self) -> Self {
31 let mut data = [0; L];
32 if self.is_negative() {
33 data[L - 1] = Self::MASK;
34 }
35 data[..L - 1].copy_from_slice(&self.0[1..]);
36 Self(data)
37 }
38
39 pub fn lowest(&self) -> u64 {
41 self.0[0]
42 }
43
44 pub fn is_negative(&self) -> bool {
46 self.0[L - 1] > (Self::MASK >> 1)
47 }
48}
49
50impl<const B: usize, const L: usize> PartialEq for CInt<B, L> {
51 fn eq(&self, other: &Self) -> bool {
52 self.0 == other.0
53 }
54}
55
56impl<const B: usize, const L: usize> Add for &CInt<B, L> {
57 type Output = CInt<B, L>;
58 fn add(self, other: Self) -> Self::Output {
59 let (mut data, mut carry) = ([0; L], 0);
60 for (i, d) in data.iter_mut().enumerate().take(L) {
61 let sum = self.0[i] + other.0[i] + carry;
62 *d = sum & CInt::<B, L>::MASK;
63 carry = sum >> B;
64 }
65 CInt::<B, L>(data)
66 }
67}
68
69impl<const B: usize, const L: usize> Add<&CInt<B, L>> for CInt<B, L> {
70 type Output = CInt<B, L>;
71 fn add(self, other: &Self) -> Self::Output {
72 &self + other
73 }
74}
75
76impl<const B: usize, const L: usize> Add for CInt<B, L> {
77 type Output = CInt<B, L>;
78 fn add(self, other: Self) -> Self::Output {
79 &self + &other
80 }
81}
82
83impl<const B: usize, const L: usize> Sub for &CInt<B, L> {
84 type Output = CInt<B, L>;
85 fn sub(self, other: Self) -> Self::Output {
86 let (mut data, mut carry) = ([0; L], 1);
94 for (i, d) in data.iter_mut().enumerate().take(L) {
95 let sum = self.0[i] + (other.0[i] ^ CInt::<B, L>::MASK) + carry;
96 *d = sum & CInt::<B, L>::MASK;
97 carry = sum >> B;
98 }
99 CInt::<B, L>(data)
100 }
101}
102
103impl<const B: usize, const L: usize> Sub<&CInt<B, L>> for CInt<B, L> {
104 type Output = CInt<B, L>;
105 fn sub(self, other: &Self) -> Self::Output {
106 &self - other
107 }
108}
109
110impl<const B: usize, const L: usize> Sub for CInt<B, L> {
111 type Output = CInt<B, L>;
112 fn sub(self, other: Self) -> Self::Output {
113 &self - &other
114 }
115}
116
117impl<const B: usize, const L: usize> Neg for &CInt<B, L> {
118 type Output = CInt<B, L>;
119 fn neg(self) -> Self::Output {
120 let (mut data, mut carry) = ([0; L], 1);
123 for (i, d) in data.iter_mut().enumerate().take(L) {
124 let sum = (self.0[i] ^ CInt::<B, L>::MASK) + carry;
125 *d = sum & CInt::<B, L>::MASK;
126 carry = sum >> B;
127 }
128 CInt::<B, L>(data)
129 }
130}
131
132impl<const B: usize, const L: usize> Neg for CInt<B, L> {
133 type Output = CInt<B, L>;
134 fn neg(self) -> Self::Output {
135 -&self
136 }
137}
138
139impl<const B: usize, const L: usize> Mul for &CInt<B, L> {
140 type Output = CInt<B, L>;
141 fn mul(self, other: Self) -> Self::Output {
142 let mut data = [0; L];
143 for i in 0..L {
144 let mut carry = 0;
145 for k in 0..(L - i) {
146 let sum = (data[i + k] as u128)
147 + (carry as u128)
148 + (self.0[i] as u128) * (other.0[k] as u128);
149 data[i + k] = sum as u64 & CInt::<B, L>::MASK;
150 carry = (sum >> B) as u64;
151 }
152 }
153 CInt::<B, L>(data)
154 }
155}
156
157impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for CInt<B, L> {
158 type Output = CInt<B, L>;
159 fn mul(self, other: &Self) -> Self::Output {
160 &self * other
161 }
162}
163
164impl<const B: usize, const L: usize> Mul for CInt<B, L> {
165 type Output = CInt<B, L>;
166 fn mul(self, other: Self) -> Self::Output {
167 &self * &other
168 }
169}
170
171impl<const B: usize, const L: usize> Mul<i64> for &CInt<B, L> {
172 type Output = CInt<B, L>;
173 fn mul(self, other: i64) -> Self::Output {
174 let mut data = [0; L];
175 let (other, mut carry, mask) = if other < 0 {
188 (-other, -other as u64, CInt::<B, L>::MASK)
189 } else {
190 (other, 0, 0)
191 };
192 for (i, d) in data.iter_mut().enumerate().take(L) {
193 let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128);
194 *d = sum as u64 & CInt::<B, L>::MASK;
195 carry = (sum >> B) as u64;
196 }
197 CInt::<B, L>(data)
198 }
199}
200
201impl<const B: usize, const L: usize> Mul<i64> for CInt<B, L> {
202 type Output = CInt<B, L>;
203 fn mul(self, other: i64) -> Self::Output {
204 &self * other
205 }
206}
207
208impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for i64 {
209 type Output = CInt<B, L>;
210 fn mul(self, other: &CInt<B, L>) -> Self::Output {
211 other * self
212 }
213}
214
215impl<const B: usize, const L: usize> Mul<CInt<B, L>> for i64 {
216 type Output = CInt<B, L>;
217 fn mul(self, other: CInt<B, L>) -> Self::Output {
218 other * self
219 }
220}
221
222pub struct BYInverter<const L: usize> {
246 modulus: CInt<62, L>,
248
249 adjuster: CInt<62, L>,
251
252 inverse: i64,
254}
255
256type Matrix = [[i64; 2]; 2];
258
259impl<const L: usize> BYInverter<L> {
260 fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) {
264 let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128);
265 let mut t: Matrix = [[1, 0], [0, 1]];
266
267 loop {
268 let zeros = steps.min(g.trailing_zeros() as i64);
269 (steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros);
270 t[0] = [t[0][0] << zeros, t[0][1] << zeros];
271
272 if steps == 0 {
273 break;
274 }
275 if delta > 0 {
276 (delta, f, g) = (-delta, g as i64, -f as i128);
277 (t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]);
278 }
279
280 let mask = (1 << steps.min(1 - delta).min(5)) - 1;
284 let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask;
285
286 t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]];
287 g += w as i128 * f as i128;
288 }
289
290 (delta, t)
291 }
292
293 fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
296 (
297 (t[0][0] * &f + t[0][1] * &g).shift(),
298 (t[1][0] * &f + t[1][1] * &g).shift(),
299 )
300 }
301
302 fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
307 let mask = CInt::<62, L>::MASK as i64;
308 let mut md = t[0][0] * d.is_negative() as i64 + t[0][1] * e.is_negative() as i64;
309 let mut me = t[1][0] * d.is_negative() as i64 + t[1][1] * e.is_negative() as i64;
310
311 let cd = t[0][0]
312 .wrapping_mul(d.lowest() as i64)
313 .wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64))
314 & mask;
315 let ce = t[1][0]
316 .wrapping_mul(d.lowest() as i64)
317 .wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64))
318 & mask;
319
320 md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask;
321 me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask;
322
323 let cd = t[0][0] * &d + t[0][1] * &e + md * &self.modulus;
324 let ce = t[1][0] * &d + t[1][1] * &e + me * &self.modulus;
325
326 (cd.shift(), ce.shift())
327 }
328
329 fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> {
333 if value.is_negative() {
334 value = value + &self.modulus;
335 }
336
337 if negate {
338 value = -value;
339 }
340
341 if value.is_negative() {
342 value = value + &self.modulus;
343 }
344
345 value
346 }
347
348 const fn convert<const I: usize, const O: usize, const S: usize>(input: &[u64]) -> [u64; S] {
352 const fn min(a: usize, b: usize) -> usize {
354 if a > b {
355 b
356 } else {
357 a
358 }
359 }
360
361 let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0);
362
363 while bits < total {
364 let (i, o) = (bits % I, bits % O);
365 output[bits / O] |= (input[bits / I] >> i) << o;
366 bits += min(I - i, O - o);
367 }
368
369 let mask = u64::MAX >> (64 - O);
370 let mut filled = total / O + if total % O > 0 { 1 } else { 0 };
371
372 while filled > 0 {
373 filled -= 1;
374 output[filled] &= mask;
375 }
376
377 output
378 }
379
380 const fn inv(value: u64) -> i64 {
386 let x = value.wrapping_mul(3) ^ 2;
387 let y = 1u64.wrapping_sub(x.wrapping_mul(value));
388 let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
389 let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
390 let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
391 (x.wrapping_mul(y.wrapping_add(1)) & CInt::<62, L>::MASK) as i64
392 }
393
394 pub const fn new(modulus: &[u64], adjuster: &[u64]) -> Self {
396 Self {
397 modulus: CInt::<62, L>(Self::convert::<64, 62, L>(modulus)),
398 adjuster: CInt::<62, L>(Self::convert::<64, 62, L>(adjuster)),
399 inverse: Self::inv(modulus[0]),
400 }
401 }
402
403 pub fn invert<const S: usize>(&self, value: &[u64]) -> Option<[u64; S]> {
406 let (mut d, mut e) = (CInt::ZERO, self.adjuster.clone());
407 let mut g = CInt::<62, L>(Self::convert::<64, 62, L>(value));
408 let (mut delta, mut f) = (1, self.modulus.clone());
409 let mut matrix;
410 while g != CInt::ZERO {
411 (delta, matrix) = Self::jump(&f, &g, delta);
412 (f, g) = Self::fg(f, g, matrix);
413 (d, e) = self.de(d, e, matrix);
414 }
415 let antiunit = f == CInt::MINUS_ONE;
419 if (f != CInt::ONE) && !antiunit {
420 return None;
421 }
422 Some(Self::convert::<62, 64, S>(&self.norm(d, antiunit).0))
423 }
424}