p3_monty_31/mds.rs
1use core::marker::PhantomData;
2
3use p3_mds::karatsuba_convolution::Convolve;
4use p3_mds::util::dot_product;
5use p3_mds::MdsPermutation;
6use p3_symmetric::Permutation;
7
8use crate::{BarrettParameters, MontyField31, MontyParameters};
9
10/// A collection of circulant MDS matrices saved using their left most column.
11pub trait MDSUtils: Clone + Sync {
12 const MATRIX_CIRC_MDS_8_COL: [i64; 8];
13 const MATRIX_CIRC_MDS_12_COL: [i64; 12];
14 const MATRIX_CIRC_MDS_16_COL: [i64; 16];
15 const MATRIX_CIRC_MDS_24_COL: [i64; 24];
16 const MATRIX_CIRC_MDS_32_COL: [i64; 32];
17 const MATRIX_CIRC_MDS_64_COL: [i64; 64];
18}
19
20#[derive(Clone, Debug, Default)]
21pub struct MdsMatrixMontyField31<MU: MDSUtils> {
22 _phantom: PhantomData<MU>,
23}
24
25/// Instantiate convolution for "small" RHS vectors over a 31-bit MONTY_FIELD.
26///
27/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) <
28/// 2^24 (roughly), though in practice the sum will be less than 2^9.
29struct SmallConvolveMontyField31;
30
31impl<FP: MontyParameters> Convolve<MontyField31<FP>, i64, i64, i64> for SmallConvolveMontyField31 {
32 /// Return the lift of a Monty31 element, satisfying 0 <=
33 /// input.value < P < 2^31. Note that Monty31 elements are
34 /// represented in Monty form.
35 #[inline(always)]
36 fn read(input: MontyField31<FP>) -> i64 {
37 input.value as i64
38 }
39
40 /// For a convolution of size N, |x| < N * 2^31 and (as per the
41 /// assumption above), |y| < 2^24. So the product is at most N * 2^55
42 /// which will not overflow for N <= 16.
43 ///
44 /// Note that the LHS element is in Monty form, while the RHS
45 /// element is a "plain integer". This informs the implementation
46 /// of `reduce()` below.
47 #[inline(always)]
48 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
49 dot_product(u, v)
50 }
51
52 /// The assumptions above mean z < N^2 * 2^55, which is at most
53 /// 2^63 when N <= 16.
54 ///
55 /// Because the LHS elements were in Monty form and the RHS
56 /// elements were plain integers, reduction is simply the usual
57 /// reduction modulo P, rather than "Monty reduction".
58 ///
59 /// NB: Even though intermediate values could be negative, the
60 /// output must be non-negative since the inputs were
61 /// non-negative.
62 #[inline(always)]
63 fn reduce(z: i64) -> MontyField31<FP> {
64 debug_assert!(z >= 0);
65
66 MontyField31::new_monty((z as u64 % FP::PRIME as u64) as u32)
67 }
68}
69
70/// Given |x| < 2^80 compute x' such that:
71/// |x'| < 2**50
72/// x' = x mod p
73/// x' = x mod 2^10
74/// See Thm 1 (Below function) for a proof that this function is correct.
75#[inline(always)]
76fn barrett_red_monty31<BP: BarrettParameters>(input: i128) -> i64 {
77 // input = input_low + beta*input_high
78 // So input_high < 2**63 and fits in an i64.
79 let input_high = (input >> BP::N) as i64; // input_high < input / beta < 2**{80 - N}
80
81 // I, input_high are i64's so this multiplication can't overflow.
82 let quot = (((input_high as i128) * (BP::PSEUDO_INV as i128)) >> BP::N) as i64;
83
84 // Replace quot by a close value which is divisible by 2^10.
85 let quot_2adic = quot & BP::MASK;
86
87 // quot_2adic, P are i64's so this can't overflow.
88 // sub is by construction divisible by both P and 2^10.
89 let sub = (quot_2adic as i128) * BP::PRIME_I128;
90
91 (input - sub) as i64
92}
93
94// Theorem 1:
95// Given |x| < 2^80, barrett_red(x) computes an x' such that:
96// x' = x mod p
97// x' = x mod 2^10
98// |x'| < 2**50.
99///////////////////////////////////////////////////////////////////////////////////////
100// PROOF:
101// By construction P, 2**10 | sub and so we immediately see that
102// x' = x mod p
103// x' = x mod 2^10.
104//
105// It remains to prove that |x'| < 2**50.
106//
107// We start by introducing some simple inequalities and relations between our variables:
108//
109// First consider the relationship between bit-shift and division.
110// It's easy to check that for all x:
111// 1: (x >> N) <= x / 2**N <= 1 + (x >> N)
112//
113// Similarly, as our mask just 0's the last 10 bits,
114// 2: x + 1 - 2^10 <= x & mask <= x
115//
116// Now if x, y are positive integers then
117// (x / y) - 1 <= x // y <= x / y
118// Where // denotes integer division.
119//
120// From this last inequality we immediately derive:
121// 3: (2**{2N} / P) - 1 <= I <= (2**{2N} / P)
122// 3a: 2**{2N} - P <= PI
123//
124// Finally, note that by definition:
125// input = input_high*(2**N) + input_low
126// Hence a simple rearrangement gets us
127// 4: input_high*(2**N) = input - input_low
128//
129//
130// We now need to split into cases depending on the sign of input.
131// Note that if x = 0 then x' = 0 so that case is trivial.
132///////////////////////////////////////////////////////////////////////////
133// CASE 1: input > 0
134//
135// If input > 0 then:
136// sub = Q*P = ((((input >> N) * I) >> N) & mask) * P <= P * (input / 2**{N}) * (2**{2N} / P) / 2**{N} = input
137// So input - sub >= 0.
138//
139// We need to improve our bound on Q. Observe that:
140// Q = (((input_high * I) >> N) & mask)
141// --(2) => Q + (2^10 - 1) >= (input_high * I) >> N)
142// --(1) => Q + 2^10 >= (I*x_high)/(2**N)
143// => (2**N)*Q + 2^10*(2**N) >= I*x_high
144//
145// Hence we find that:
146// (2**N)*Q*P + 2^10*(2**N)*P >= input_high*I*P
147// --(3a) >= input_high*2**{2N} - P*input_high
148// --(4) >= (2**N)*input - (2**N)*input_low - (2**N)*input_high (Assuming P < 2**N)
149//
150// Dividing by 2**N we get
151// Q*P + 2^{10}*P >= input - input_low - input_high
152// which rearranges to
153// x' = input - Q*P <= 2^{10}*P + input_low + input_high
154//
155// Picking N = 40 we see that 2^{10}*P, input_low, input_high are all bounded by 2**40
156// Hence x' < 2**42 < 2**50 as desired.
157//
158//
159//
160///////////////////////////////////////////////////////////////////////////
161// CASE 2: input < 0
162//
163// This case will be similar but all our inequalities will change slightly as negatives complicate things.
164// First observe that:
165// (input >> N) * I >= (input >> N) * 2**(2N) / P
166// >= (1 + (input / 2**N)) * 2**(2N) / P
167// >= (2**N + input) * 2**N / P
168//
169// Thus:
170// Q = ((input >> N) * I) >> N >= ((2**N + input) * 2**N / P) >> N
171// >= ((2**N + input) / P) - 1
172//
173// And so sub = Q*P >= 2**N - P + input.
174// Hence input - sub < 2**N - P.
175//
176// Thus if input - sub > 0 then |input - sub| < 2**50.
177// Thus we are left with bounding -(input - sub) = (sub - input).
178// Again we will proceed by improving our bound on Q.
179//
180// Q = (((input_high * I) >> N) & mask)
181// --(2) => Q <= (input_high * I) >> N) <= (I*x_high)/(2**N)
182// --(1) => Q <= (I*x_high)/(2**N)
183// => (2**N)*Q <= I*x_high
184//
185// Hence we find that:
186// (2**N)*Q*P <= input_high*I*P
187// --(3a) <= input_high*2**{2N} - P*input_high
188// --(4) <= (2**N)*input - (2**N)*input_low - (2**N)*input_high (Assuming P < 2**N)
189//
190// Dividing by 2**N we get
191// Q*P <= input - input_low - input_high
192// which rearranges to
193// -x' = -input + Q*P <= -input_high - input_low < 2**50
194//
195// This completes the proof.
196
197/// Instantiate convolution for "large" RHS vectors over BabyBear.
198///
199/// Here "large" means the elements can be as big as the field
200/// characteristic, and the size N of the RHS is <= 64.
201#[derive(Debug, Clone, Default)]
202struct LargeConvolveMontyField31;
203
204impl<FP> Convolve<MontyField31<FP>, i64, i64, i64> for LargeConvolveMontyField31
205where
206 FP: BarrettParameters,
207{
208 /// Return the lift of a MontyField31 element, satisfying
209 /// 0 <= input.value < P < 2^31.
210 /// Note that MontyField31 elements are represented in Monty form.
211 #[inline(always)]
212 fn read(input: MontyField31<FP>) -> i64 {
213 input.value as i64
214 }
215
216 #[inline(always)]
217 fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
218 // For a convolution of size N, |x|, |y| < N * 2^31, so the
219 // product could be as much as N^2 * 2^62. This will overflow an
220 // i64, so we first widen to i128. Note that N^2 * 2^62 < 2^80
221 // for N <= 64, as required by `barrett_red_monty31()`.
222
223 let mut dp = 0i128;
224 for i in 0..N {
225 dp += u[i] as i128 * v[i] as i128;
226 }
227 barrett_red_monty31::<FP>(dp)
228 }
229
230 #[inline(always)]
231 fn reduce(z: i64) -> MontyField31<FP> {
232 // After the barrett reduction method, the output z of parity
233 // dot satisfies |z| < 2^50 (See Thm 1 above).
234 //
235 // In the recombining steps, conv_n maps (wo, w1) ->
236 // ((wo + w1)/2, (wo + w1)/2) which has no effect on the maximal
237 // size. (Indeed, it makes sizes almost strictly smaller).
238 //
239 // On the other hand, negacyclic_conv_n (ignoring the re-index)
240 // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1).
241 // Hence if the input is <= K, the output is <= 3K.
242 //
243 // Thus the values appearing at the end are bounded by 3^n 2^50
244 // where n is the maximal number of negacyclic_conv
245 // recombination steps. When N = 64, we need to recombine for
246 // signed_conv_32, signed_conv_16, signed_conv_8 so the
247 // overall bound will be 3^3 2^50 < 32 * 2^50 < 2^55.
248 debug_assert!(z > -(1i64 << 55));
249 debug_assert!(z < (1i64 << 55));
250
251 // Note we do NOT move it into MONTY form. We assume it is already
252 // in this form.
253 let red = (z % (FP::PRIME as i64)) as u32;
254
255 // If z >= 0: 0 <= red < P is the correct value and P + red will
256 // not overflow.
257 // If z < 0: -P < red < 0 and the value we want is P + red.
258 // On bits, + acts identically for i32 and u32. Hence we can use
259 // u32's and just check for overflow.
260
261 let (corr, over) = red.overflowing_add(FP::PRIME);
262 let value = if over { corr } else { red };
263 MontyField31::new_monty(value)
264 }
265}
266
267impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 8]>
268 for MdsMatrixMontyField31<MU>
269{
270 fn permute(&self, input: [MontyField31<FP>; 8]) -> [MontyField31<FP>; 8] {
271 SmallConvolveMontyField31::apply(
272 input,
273 MU::MATRIX_CIRC_MDS_8_COL,
274 <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv8,
275 )
276 }
277
278 fn permute_mut(&self, input: &mut [MontyField31<FP>; 8]) {
279 *input = self.permute(*input);
280 }
281}
282impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 8>
283 for MdsMatrixMontyField31<MU>
284{
285}
286
287impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 12]>
288 for MdsMatrixMontyField31<MU>
289{
290 fn permute(&self, input: [MontyField31<FP>; 12]) -> [MontyField31<FP>; 12] {
291 SmallConvolveMontyField31::apply(
292 input,
293 MU::MATRIX_CIRC_MDS_12_COL,
294 <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv12,
295 )
296 }
297
298 fn permute_mut(&self, input: &mut [MontyField31<FP>; 12]) {
299 *input = self.permute(*input);
300 }
301}
302impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 12>
303 for MdsMatrixMontyField31<MU>
304{
305}
306
307impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 16]>
308 for MdsMatrixMontyField31<MU>
309{
310 fn permute(&self, input: [MontyField31<FP>; 16]) -> [MontyField31<FP>; 16] {
311 SmallConvolveMontyField31::apply(
312 input,
313 MU::MATRIX_CIRC_MDS_16_COL,
314 <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv16,
315 )
316 }
317
318 fn permute_mut(&self, input: &mut [MontyField31<FP>; 16]) {
319 *input = self.permute(*input);
320 }
321}
322impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 16>
323 for MdsMatrixMontyField31<MU>
324{
325}
326
327impl<FP, MU: MDSUtils> Permutation<[MontyField31<FP>; 24]> for MdsMatrixMontyField31<MU>
328where
329 FP: BarrettParameters,
330{
331 fn permute(&self, input: [MontyField31<FP>; 24]) -> [MontyField31<FP>; 24] {
332 LargeConvolveMontyField31::apply(
333 input,
334 MU::MATRIX_CIRC_MDS_24_COL,
335 <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv24,
336 )
337 }
338
339 fn permute_mut(&self, input: &mut [MontyField31<FP>; 24]) {
340 *input = self.permute(*input);
341 }
342}
343impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 24>
344 for MdsMatrixMontyField31<MU>
345{
346}
347
348impl<FP: BarrettParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 32]>
349 for MdsMatrixMontyField31<MU>
350{
351 fn permute(&self, input: [MontyField31<FP>; 32]) -> [MontyField31<FP>; 32] {
352 LargeConvolveMontyField31::apply(
353 input,
354 MU::MATRIX_CIRC_MDS_32_COL,
355 <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv32,
356 )
357 }
358
359 fn permute_mut(&self, input: &mut [MontyField31<FP>; 32]) {
360 *input = self.permute(*input);
361 }
362}
363impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 32>
364 for MdsMatrixMontyField31<MU>
365{
366}
367
368impl<FP: BarrettParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 64]>
369 for MdsMatrixMontyField31<MU>
370{
371 fn permute(&self, input: [MontyField31<FP>; 64]) -> [MontyField31<FP>; 64] {
372 LargeConvolveMontyField31::apply(
373 input,
374 MU::MATRIX_CIRC_MDS_64_COL,
375 <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv64,
376 )
377 }
378
379 fn permute_mut(&self, input: &mut [MontyField31<FP>; 64]) {
380 *input = self.permute(*input);
381 }
382}
383impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 64>
384 for MdsMatrixMontyField31<MU>
385{
386}