p3_monty_31/
mds.rs

1use core::marker::PhantomData;
2
3use p3_mds::karatsuba_convolution::Convolve;
4use p3_mds::util::dot_product;
5use p3_mds::MdsPermutation;
6use p3_symmetric::Permutation;
7
8use crate::{BarrettParameters, MontyField31, MontyParameters};
9
10/// A collection of circulant MDS matrices saved using their left most column.
11pub trait MDSUtils: Clone + Sync {
12    const MATRIX_CIRC_MDS_8_COL: [i64; 8];
13    const MATRIX_CIRC_MDS_12_COL: [i64; 12];
14    const MATRIX_CIRC_MDS_16_COL: [i64; 16];
15    const MATRIX_CIRC_MDS_24_COL: [i64; 24];
16    const MATRIX_CIRC_MDS_32_COL: [i64; 32];
17    const MATRIX_CIRC_MDS_64_COL: [i64; 64];
18}
19
20#[derive(Clone, Debug, Default)]
21pub struct MdsMatrixMontyField31<MU: MDSUtils> {
22    _phantom: PhantomData<MU>,
23}
24
25/// Instantiate convolution for "small" RHS vectors over a 31-bit MONTY_FIELD.
26///
27/// Here "small" means N = len(rhs) <= 16 and sum(r for r in rhs) <
28/// 2^24 (roughly), though in practice the sum will be less than 2^9.
29struct SmallConvolveMontyField31;
30
31impl<FP: MontyParameters> Convolve<MontyField31<FP>, i64, i64, i64> for SmallConvolveMontyField31 {
32    /// Return the lift of a Monty31 element, satisfying 0 <=
33    /// input.value < P < 2^31. Note that Monty31 elements are
34    /// represented in Monty form.
35    #[inline(always)]
36    fn read(input: MontyField31<FP>) -> i64 {
37        input.value as i64
38    }
39
40    /// For a convolution of size N, |x| < N * 2^31 and (as per the
41    /// assumption above), |y| < 2^24. So the product is at most N * 2^55
42    /// which will not overflow for N <= 16.
43    ///
44    /// Note that the LHS element is in Monty form, while the RHS
45    /// element is a "plain integer". This informs the implementation
46    /// of `reduce()` below.
47    #[inline(always)]
48    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
49        dot_product(u, v)
50    }
51
52    /// The assumptions above mean z < N^2 * 2^55, which is at most
53    /// 2^63 when N <= 16.
54    ///
55    /// Because the LHS elements were in Monty form and the RHS
56    /// elements were plain integers, reduction is simply the usual
57    /// reduction modulo P, rather than "Monty reduction".
58    ///
59    /// NB: Even though intermediate values could be negative, the
60    /// output must be non-negative since the inputs were
61    /// non-negative.
62    #[inline(always)]
63    fn reduce(z: i64) -> MontyField31<FP> {
64        debug_assert!(z >= 0);
65
66        MontyField31::new_monty((z as u64 % FP::PRIME as u64) as u32)
67    }
68}
69
70/// Given |x| < 2^80 compute x' such that:
71/// |x'| < 2**50
72/// x' = x mod p
73/// x' = x mod 2^10
74/// See Thm 1 (Below function) for a proof that this function is correct.
75#[inline(always)]
76fn barrett_red_monty31<BP: BarrettParameters>(input: i128) -> i64 {
77    // input = input_low + beta*input_high
78    // So input_high < 2**63 and fits in an i64.
79    let input_high = (input >> BP::N) as i64; // input_high < input / beta < 2**{80 - N}
80
81    // I, input_high are i64's so this multiplication can't overflow.
82    let quot = (((input_high as i128) * (BP::PSEUDO_INV as i128)) >> BP::N) as i64;
83
84    // Replace quot by a close value which is divisible by 2^10.
85    let quot_2adic = quot & BP::MASK;
86
87    // quot_2adic, P are i64's so this can't overflow.
88    // sub is by construction divisible by both P and 2^10.
89    let sub = (quot_2adic as i128) * BP::PRIME_I128;
90
91    (input - sub) as i64
92}
93
94// Theorem 1:
95// Given |x| < 2^80, barrett_red(x) computes an x' such that:
96//       x' = x mod p
97//       x' = x mod 2^10
98//       |x'| < 2**50.
99///////////////////////////////////////////////////////////////////////////////////////
100// PROOF:
101// By construction P, 2**10 | sub and so we immediately see that
102// x' = x mod p
103// x' = x mod 2^10.
104//
105// It remains to prove that |x'| < 2**50.
106//
107// We start by introducing some simple inequalities and relations between our variables:
108//
109// First consider the relationship between bit-shift and division.
110// It's easy to check that for all x:
111// 1: (x >> N) <= x / 2**N <= 1 + (x >> N)
112//
113// Similarly, as our mask just 0's the last 10 bits,
114// 2: x + 1 - 2^10 <= x & mask <= x
115//
116// Now if x, y are positive integers then
117// (x / y) - 1 <= x // y <= x / y
118// Where // denotes integer division.
119//
120// From this last inequality we immediately derive:
121// 3: (2**{2N} / P) - 1 <= I <= (2**{2N} / P)
122// 3a: 2**{2N} - P <= PI
123//
124// Finally, note that by definition:
125// input = input_high*(2**N) + input_low
126// Hence a simple rearrangement gets us
127// 4: input_high*(2**N) = input - input_low
128//
129//
130// We now need to split into cases depending on the sign of input.
131// Note that if x = 0 then x' = 0 so that case is trivial.
132///////////////////////////////////////////////////////////////////////////
133// CASE 1: input > 0
134//
135// If input > 0 then:
136// sub = Q*P = ((((input >> N) * I) >> N) & mask) * P <= P * (input / 2**{N}) * (2**{2N} / P) / 2**{N} = input
137// So input - sub >= 0.
138//
139// We need to improve our bound on Q. Observe that:
140// Q = (((input_high * I) >> N) & mask)
141// --(2)   => Q + (2^10 - 1) >= (input_high * I) >> N)
142// --(1)   => Q + 2^10 >= (I*x_high)/(2**N)
143//         => (2**N)*Q + 2^10*(2**N) >= I*x_high
144//
145// Hence we find that:
146// (2**N)*Q*P + 2^10*(2**N)*P >= input_high*I*P
147// --(3a)                     >= input_high*2**{2N} - P*input_high
148// --(4)                      >= (2**N)*input - (2**N)*input_low - (2**N)*input_high   (Assuming P < 2**N)
149//
150// Dividing by 2**N we get
151// Q*P + 2^{10}*P >= input - input_low - input_high
152// which rearranges to
153// x' = input - Q*P <= 2^{10}*P + input_low + input_high
154//
155// Picking N = 40 we see that 2^{10}*P, input_low, input_high are all bounded by 2**40
156// Hence x' < 2**42 < 2**50 as desired.
157//
158//
159//
160///////////////////////////////////////////////////////////////////////////
161// CASE 2: input < 0
162//
163// This case will be similar but all our inequalities will change slightly as negatives complicate things.
164// First observe that:
165// (input >> N) * I   >= (input >> N) * 2**(2N) / P
166//                    >= (1 + (input / 2**N)) * 2**(2N) / P
167//                    >= (2**N + input) * 2**N / P
168//
169// Thus:
170// Q = ((input >> N) * I) >> N >= ((2**N + input) * 2**N / P) >> N
171//                             >= ((2**N + input) / P) - 1
172//
173// And so sub = Q*P >= 2**N - P + input.
174// Hence input - sub < 2**N - P.
175//
176// Thus if input - sub > 0 then |input - sub| < 2**50.
177// Thus we are left with bounding -(input - sub) = (sub - input).
178// Again we will proceed by improving our bound on Q.
179//
180// Q = (((input_high * I) >> N) & mask)
181// --(2)   => Q <= (input_high * I) >> N) <= (I*x_high)/(2**N)
182// --(1)   => Q <= (I*x_high)/(2**N)
183//         => (2**N)*Q <= I*x_high
184//
185// Hence we find that:
186// (2**N)*Q*P <= input_high*I*P
187// --(3a)     <= input_high*2**{2N} - P*input_high
188// --(4)      <= (2**N)*input - (2**N)*input_low - (2**N)*input_high   (Assuming P < 2**N)
189//
190// Dividing by 2**N we get
191// Q*P <= input - input_low - input_high
192// which rearranges to
193// -x' = -input + Q*P <= -input_high - input_low < 2**50
194//
195// This completes the proof.
196
197/// Instantiate convolution for "large" RHS vectors over BabyBear.
198///
199/// Here "large" means the elements can be as big as the field
200/// characteristic, and the size N of the RHS is <= 64.
201#[derive(Debug, Clone, Default)]
202struct LargeConvolveMontyField31;
203
204impl<FP> Convolve<MontyField31<FP>, i64, i64, i64> for LargeConvolveMontyField31
205where
206    FP: BarrettParameters,
207{
208    /// Return the lift of a MontyField31 element, satisfying
209    /// 0 <= input.value < P < 2^31.
210    /// Note that MontyField31 elements are represented in Monty form.
211    #[inline(always)]
212    fn read(input: MontyField31<FP>) -> i64 {
213        input.value as i64
214    }
215
216    #[inline(always)]
217    fn parity_dot<const N: usize>(u: [i64; N], v: [i64; N]) -> i64 {
218        // For a convolution of size N, |x|, |y| < N * 2^31, so the
219        // product could be as much as N^2 * 2^62. This will overflow an
220        // i64, so we first widen to i128. Note that N^2 * 2^62 < 2^80
221        // for N <= 64, as required by `barrett_red_monty31()`.
222
223        let mut dp = 0i128;
224        for i in 0..N {
225            dp += u[i] as i128 * v[i] as i128;
226        }
227        barrett_red_monty31::<FP>(dp)
228    }
229
230    #[inline(always)]
231    fn reduce(z: i64) -> MontyField31<FP> {
232        // After the barrett reduction method, the output z of parity
233        // dot satisfies |z| < 2^50 (See Thm 1 above).
234        //
235        // In the recombining steps, conv_n maps (wo, w1) ->
236        // ((wo + w1)/2, (wo + w1)/2) which has no effect on the maximal
237        // size. (Indeed, it makes sizes almost strictly smaller).
238        //
239        // On the other hand, negacyclic_conv_n (ignoring the re-index)
240        // recombines as: (w0, w1, w2) -> (w0 + w1, w2 - w0 - w1).
241        // Hence if the input is <= K, the output is <= 3K.
242        //
243        // Thus the values appearing at the end are bounded by 3^n 2^50
244        // where n is the maximal number of negacyclic_conv
245        // recombination steps. When N = 64, we need to recombine for
246        // signed_conv_32, signed_conv_16, signed_conv_8 so the
247        // overall bound will be 3^3 2^50 < 32 * 2^50 < 2^55.
248        debug_assert!(z > -(1i64 << 55));
249        debug_assert!(z < (1i64 << 55));
250
251        // Note we do NOT move it into MONTY form. We assume it is already
252        // in this form.
253        let red = (z % (FP::PRIME as i64)) as u32;
254
255        // If z >= 0: 0 <= red < P is the correct value and P + red will
256        // not overflow.
257        // If z < 0: -P < red < 0 and the value we want is P + red.
258        // On bits, + acts identically for i32 and u32. Hence we can use
259        // u32's and just check for overflow.
260
261        let (corr, over) = red.overflowing_add(FP::PRIME);
262        let value = if over { corr } else { red };
263        MontyField31::new_monty(value)
264    }
265}
266
267impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 8]>
268    for MdsMatrixMontyField31<MU>
269{
270    fn permute(&self, input: [MontyField31<FP>; 8]) -> [MontyField31<FP>; 8] {
271        SmallConvolveMontyField31::apply(
272            input,
273            MU::MATRIX_CIRC_MDS_8_COL,
274            <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv8,
275        )
276    }
277
278    fn permute_mut(&self, input: &mut [MontyField31<FP>; 8]) {
279        *input = self.permute(*input);
280    }
281}
282impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 8>
283    for MdsMatrixMontyField31<MU>
284{
285}
286
287impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 12]>
288    for MdsMatrixMontyField31<MU>
289{
290    fn permute(&self, input: [MontyField31<FP>; 12]) -> [MontyField31<FP>; 12] {
291        SmallConvolveMontyField31::apply(
292            input,
293            MU::MATRIX_CIRC_MDS_12_COL,
294            <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv12,
295        )
296    }
297
298    fn permute_mut(&self, input: &mut [MontyField31<FP>; 12]) {
299        *input = self.permute(*input);
300    }
301}
302impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 12>
303    for MdsMatrixMontyField31<MU>
304{
305}
306
307impl<FP: MontyParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 16]>
308    for MdsMatrixMontyField31<MU>
309{
310    fn permute(&self, input: [MontyField31<FP>; 16]) -> [MontyField31<FP>; 16] {
311        SmallConvolveMontyField31::apply(
312            input,
313            MU::MATRIX_CIRC_MDS_16_COL,
314            <SmallConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv16,
315        )
316    }
317
318    fn permute_mut(&self, input: &mut [MontyField31<FP>; 16]) {
319        *input = self.permute(*input);
320    }
321}
322impl<FP: MontyParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 16>
323    for MdsMatrixMontyField31<MU>
324{
325}
326
327impl<FP, MU: MDSUtils> Permutation<[MontyField31<FP>; 24]> for MdsMatrixMontyField31<MU>
328where
329    FP: BarrettParameters,
330{
331    fn permute(&self, input: [MontyField31<FP>; 24]) -> [MontyField31<FP>; 24] {
332        LargeConvolveMontyField31::apply(
333            input,
334            MU::MATRIX_CIRC_MDS_24_COL,
335            <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv24,
336        )
337    }
338
339    fn permute_mut(&self, input: &mut [MontyField31<FP>; 24]) {
340        *input = self.permute(*input);
341    }
342}
343impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 24>
344    for MdsMatrixMontyField31<MU>
345{
346}
347
348impl<FP: BarrettParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 32]>
349    for MdsMatrixMontyField31<MU>
350{
351    fn permute(&self, input: [MontyField31<FP>; 32]) -> [MontyField31<FP>; 32] {
352        LargeConvolveMontyField31::apply(
353            input,
354            MU::MATRIX_CIRC_MDS_32_COL,
355            <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv32,
356        )
357    }
358
359    fn permute_mut(&self, input: &mut [MontyField31<FP>; 32]) {
360        *input = self.permute(*input);
361    }
362}
363impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 32>
364    for MdsMatrixMontyField31<MU>
365{
366}
367
368impl<FP: BarrettParameters, MU: MDSUtils> Permutation<[MontyField31<FP>; 64]>
369    for MdsMatrixMontyField31<MU>
370{
371    fn permute(&self, input: [MontyField31<FP>; 64]) -> [MontyField31<FP>; 64] {
372        LargeConvolveMontyField31::apply(
373            input,
374            MU::MATRIX_CIRC_MDS_64_COL,
375            <LargeConvolveMontyField31 as Convolve<MontyField31<FP>, i64, i64, i64>>::conv64,
376        )
377    }
378
379    fn permute_mut(&self, input: &mut [MontyField31<FP>; 64]) {
380        *input = self.permute(*input);
381    }
382}
383impl<FP: BarrettParameters, MU: MDSUtils> MdsPermutation<MontyField31<FP>, 64>
384    for MdsMatrixMontyField31<MU>
385{
386}