p3_mds/
karatsuba_convolution.rs

1//! Calculate the convolution of two vectors using a Karatsuba-style
2//! decomposition and the CRT.
3//!
4//! This is not a new idea, but we did have the pleasure of
5//! reinventing it independently. Some references:
6//! - https://cr.yp.to/lineartime/multapps-20080515.pdf
7//! - https://2π.com/23/convolution/
8//!
9//! Given a vector v \in F^N, let v(x) \in F[X] denote the polynomial
10//! v_0 + v_1 x + ... + v_{N - 1} x^{N - 1}.  Then w is equal to the
11//! convolution v * u if and only if w(x) = v(x)u(x) mod x^N - 1.
12//! Additionally, define the negacyclic convolution by w(x) = v(x)u(x)
13//! mod x^N + 1.  Using the Chinese remainder theorem we can compute
14//! w(x) as
15//!     w(x) = 1/2 (w_0(x) + w_1(x)) + x^{N/2}/2 (w_0(x) - w_1(x))
16//! where
17//!     w_0 = v(x)u(x) mod x^{N/2} - 1
18//!     w_1 = v(x)u(x) mod x^{N/2} + 1
19//!
20//! To compute w_0 and w_1 we first compute
21//!                  v_0(x) = v(x) mod x^{N/2} - 1
22//!                  v_1(x) = v(x) mod x^{N/2} + 1
23//!                  u_0(x) = u(x) mod x^{N/2} - 1
24//!                  u_1(x) = u(x) mod x^{N/2} + 1
25//!
26//! Now w_0 is the convolution of v_0 and u_0 which we can compute
27//! recursively.  For w_1 we compute the negacyclic convolution
28//! v_1(x)u_1(x) mod x^{N/2} + 1 using Karatsuba.
29//!
30//! There are 2 possible approaches to applying Karatsuba which mirror
31//! the DIT vs DIF approaches to FFT's, the left/right decomposition
32//! or the even/odd decomposition. The latter seems to have fewer
33//! operations and so it is the one implemented below, though it does
34//! require a bit more data manipulation. It works as follows:
35//!
36//! Define the even v_e and odd v_o parts so that v(x) = (v_e(x^2) + x v_o(x^2)).
37//! Then v(x)u(x)
38//!    = (v_e(x^2)u_e(x^2) + x^2 v_o(x^2)u_o(x^2))
39//!      + x ((v_e(x^2) + v_o(x^2))(u_e(x^2) + u_o(x^2))
40//!            - (v_e(x^2)u_e(x^2) + v_o(x^2)u_o(x^2)))
41//! This reduces the problem to 3 negacyclic convolutions of size N/2 which
42//! can be computed recursively.
43//!
44//! Of course, for small sizes we just explicitly write out the O(n^2)
45//! approach.
46
47use core::ops::{Add, AddAssign, Neg, ShrAssign, Sub, SubAssign};
48
49/// This trait collects the operations needed by `Convolve` below.
50///
51/// TODO: Think of a better name for this.
52pub trait RngElt:
53    Add<Output = Self>
54    + AddAssign
55    + Copy
56    + Default
57    + Neg<Output = Self>
58    + ShrAssign<u32>
59    + Sub<Output = Self>
60    + SubAssign
61{
62}
63
64impl RngElt for i64 {}
65impl RngElt for i128 {}
66
67/// Template function to perform convolution of vectors.
68///
69/// Roughly speaking, for a convolution of size `N`, it should be
70/// possible to add `N` elements of type `T` without overflowing, and
71/// similarly for `U`. Then multiplication via `Self::mul` should
72/// produce an element of type `V` which will not overflow after about
73/// `N` additions (this is an over-estimate).
74///
75/// For example usage, see `{mersenne-31,baby-bear,goldilocks}/src/mds.rs`.
76///
77/// NB: In practice, one of the parameters to the convolution will be
78/// constant (the MDS matrix). After inspecting Godbolt output, it
79/// seems that the compiler does indeed generate single constants as
80/// inputs to the multiplication, rather than doing all that
81/// arithmetic on the constant values every time. Note however that,
82/// for MDS matrices with large entries (N >= 24), these compile-time
83/// generated constants will be about N times bigger than they need to
84/// be in principle, which could be a potential avenue for some minor
85/// improvements.
86///
87/// NB: If primitive multiplications are still the bottleneck, a
88/// further possibility would be to find an MDS matrix some of whose
89/// entries are powers of 2. Then the multiplication can be replaced
90/// with a shift, which on most architectures has better throughput
91/// and latency, and is issued on different ports (1*p06) to
92/// multiplication (1*p1).
93pub trait Convolve<F, T: RngElt, U: RngElt, V: RngElt> {
94    /// Given an input element, retrieve the corresponding internal
95    /// element that will be used in calculations.
96    fn read(input: F) -> T;
97
98    /// Given input vectors `lhs` and `rhs`, calculate their dot
99    /// product. The result can be reduced with respect to the modulus
100    /// (of `F`), but it must have the same lower 10 bits as the dot
101    /// product if all inputs are considered integers. See
102    /// `monty-31/src/mds.rs::barrett_red_monty31()` for an example
103    /// of how this can be implemented in practice.
104    fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> V;
105
106    /// Convert an internal element of type `V` back into an external
107    /// element.
108    fn reduce(z: V) -> F;
109
110    /// Convolve `lhs` and `rhs`.
111    ///
112    /// The parameter `conv` should be the function in this trait that
113    /// corresponds to length `N`.
114    #[inline(always)]
115    fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [V])>(
116        lhs: [F; N],
117        rhs: [U; N],
118        conv: C,
119    ) -> [F; N] {
120        let lhs = lhs.map(Self::read);
121        let mut output = [V::default(); N];
122        conv(lhs, rhs, &mut output);
123        output.map(Self::reduce)
124    }
125
126    #[inline(always)]
127    fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
128        output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
129        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
130        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
131    }
132
133    #[inline(always)]
134    fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
135        output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
136        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
137        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
138    }
139
140    #[inline(always)]
141    fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
142        // NB: This is just explicitly implementing
143        // conv_n_recursive::<4, 2, _, _>(lhs, rhs, output, Self::conv2, Self::negacyclic_conv2)
144        let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
145        let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
146        let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
147        let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
148
149        output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
150        output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
151        output[2] = Self::parity_dot(u_p, v_p);
152        output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
153
154        output[0] += output[2];
155        output[1] += output[3];
156
157        output[0] >>= 1;
158        output[1] >>= 1;
159
160        output[2] -= output[0];
161        output[3] -= output[1];
162    }
163
164    #[inline(always)]
165    fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
166        output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
167        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
168        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
169        output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
170    }
171
172    #[inline(always)]
173    fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
174        conv_n_recursive::<6, 3, T, U, V, _, _>(
175            lhs,
176            rhs,
177            output,
178            Self::conv3,
179            Self::negacyclic_conv3,
180        )
181    }
182
183    #[inline(always)]
184    fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
185        negacyclic_conv_n_recursive::<6, 3, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv3)
186    }
187
188    #[inline(always)]
189    fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
190        conv_n_recursive::<8, 4, T, U, V, _, _>(
191            lhs,
192            rhs,
193            output,
194            Self::conv4,
195            Self::negacyclic_conv4,
196        )
197    }
198
199    #[inline(always)]
200    fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
201        negacyclic_conv_n_recursive::<8, 4, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv4)
202    }
203
204    #[inline(always)]
205    fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
206        conv_n_recursive::<12, 6, T, U, V, _, _>(
207            lhs,
208            rhs,
209            output,
210            Self::conv6,
211            Self::negacyclic_conv6,
212        )
213    }
214
215    #[inline(always)]
216    fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
217        negacyclic_conv_n_recursive::<12, 6, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv6)
218    }
219
220    #[inline(always)]
221    fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
222        conv_n_recursive::<16, 8, T, U, V, _, _>(
223            lhs,
224            rhs,
225            output,
226            Self::conv8,
227            Self::negacyclic_conv8,
228        )
229    }
230
231    #[inline(always)]
232    fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
233        negacyclic_conv_n_recursive::<16, 8, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv8)
234    }
235
236    #[inline(always)]
237    fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [V]) {
238        conv_n_recursive::<24, 12, T, U, V, _, _>(
239            lhs,
240            rhs,
241            output,
242            Self::conv12,
243            Self::negacyclic_conv12,
244        )
245    }
246
247    #[inline(always)]
248    fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
249        conv_n_recursive::<32, 16, T, U, V, _, _>(
250            lhs,
251            rhs,
252            output,
253            Self::conv16,
254            Self::negacyclic_conv16,
255        )
256    }
257
258    #[inline(always)]
259    fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
260        negacyclic_conv_n_recursive::<32, 16, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv16)
261    }
262
263    #[inline(always)]
264    fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [V]) {
265        conv_n_recursive::<64, 32, T, U, V, _, _>(
266            lhs,
267            rhs,
268            output,
269            Self::conv32,
270            Self::negacyclic_conv32,
271        )
272    }
273}
274
275/// Compute output(x) = lhs(x)rhs(x) mod x^N - 1.
276/// Do this recursively using a convolution and negacyclic convolution of size HALF_N = N/2.
277#[inline(always)]
278fn conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, C, NC>(
279    lhs: [T; N],
280    rhs: [U; N],
281    output: &mut [V],
282    inner_conv: C,
283    inner_negacyclic_conv: NC,
284) where
285    T: RngElt,
286    U: RngElt,
287    V: RngElt,
288    C: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
289    NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
290{
291    debug_assert_eq!(2 * HALF_N, N);
292    // NB: The compiler is smart enough not to initialise these arrays.
293    let mut lhs_pos = [T::default(); HALF_N]; // lhs_pos = lhs(x) mod x^{N/2} - 1
294    let mut lhs_neg = [T::default(); HALF_N]; // lhs_neg = lhs(x) mod x^{N/2} + 1
295    let mut rhs_pos = [U::default(); HALF_N]; // rhs_pos = rhs(x) mod x^{N/2} - 1
296    let mut rhs_neg = [U::default(); HALF_N]; // rhs_neg = rhs(x) mod x^{N/2} + 1
297
298    for i in 0..HALF_N {
299        let s = lhs[i];
300        let t = lhs[i + HALF_N];
301
302        lhs_pos[i] = s + t;
303        lhs_neg[i] = s - t;
304
305        let s = rhs[i];
306        let t = rhs[i + HALF_N];
307
308        rhs_pos[i] = s + t;
309        rhs_neg[i] = s - t;
310    }
311
312    let (left, right) = output.split_at_mut(HALF_N);
313
314    // left = w1 = lhs(x)rhs(x) mod x^{N/2} + 1
315    inner_negacyclic_conv(lhs_neg, rhs_neg, left);
316
317    // right = w0 = lhs(x)rhs(x) mod x^{N/2} - 1
318    inner_conv(lhs_pos, rhs_pos, right);
319
320    for i in 0..HALF_N {
321        left[i] += right[i]; // w_0 + w_1
322        left[i] >>= 1; // (w_0 + w_1)/2
323        right[i] -= left[i]; // (w_0 - w_1)/2
324    }
325}
326
327/// Compute output(x) = lhs(x)rhs(x) mod x^N + 1.
328/// Do this recursively using three negacyclic convolutions of size HALF_N = N/2.
329#[inline(always)]
330fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, NC>(
331    lhs: [T; N],
332    rhs: [U; N],
333    output: &mut [V],
334    inner_negacyclic_conv: NC,
335) where
336    T: RngElt,
337    U: RngElt,
338    V: RngElt,
339    NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
340{
341    debug_assert_eq!(2 * HALF_N, N);
342    // NB: The compiler is smart enough not to initialise these arrays.
343    let mut lhs_even = [T::default(); HALF_N];
344    let mut lhs_odd = [T::default(); HALF_N];
345    let mut lhs_sum = [T::default(); HALF_N];
346    let mut rhs_even = [U::default(); HALF_N];
347    let mut rhs_odd = [U::default(); HALF_N];
348    let mut rhs_sum = [U::default(); HALF_N];
349
350    for i in 0..HALF_N {
351        let s = lhs[2 * i];
352        let t = lhs[2 * i + 1];
353        lhs_even[i] = s;
354        lhs_odd[i] = t;
355        lhs_sum[i] = s + t;
356
357        let s = rhs[2 * i];
358        let t = rhs[2 * i + 1];
359        rhs_even[i] = s;
360        rhs_odd[i] = t;
361        rhs_sum[i] = s + t;
362    }
363
364    let mut even_s_conv = [V::default(); HALF_N];
365    let (left, right) = output.split_at_mut(HALF_N);
366
367    // Recursively compute the size N/2 negacyclic convolutions of
368    // the even parts, odd parts, and sums.
369    inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
370    inner_negacyclic_conv(lhs_odd, rhs_odd, left);
371    inner_negacyclic_conv(lhs_sum, rhs_sum, right);
372
373    // Adjust so that the correct values are in right and
374    // even_s_conv respectively:
375    right[0] -= even_s_conv[0] + left[0];
376    even_s_conv[0] -= left[HALF_N - 1];
377
378    for i in 1..HALF_N {
379        right[i] -= even_s_conv[i] + left[i];
380        even_s_conv[i] += left[i - 1];
381    }
382
383    // Interleave even_s_conv and right in the output:
384    for i in 0..HALF_N {
385        output[2 * i] = even_s_conv[i];
386        output[2 * i + 1] = output[i + HALF_N];
387    }
388}