1use core::ops::{Add, AddAssign, Neg, ShrAssign, Sub, SubAssign};
48
49pub trait RngElt:
53 Add<Output = Self>
54 + AddAssign
55 + Copy
56 + Default
57 + Neg<Output = Self>
58 + ShrAssign<u32>
59 + Sub<Output = Self>
60 + SubAssign
61{
62}
63
64impl RngElt for i64 {}
65impl RngElt for i128 {}
66
67pub trait Convolve<F, T: RngElt, U: RngElt, V: RngElt> {
94 fn read(input: F) -> T;
97
98 fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> V;
105
106 fn reduce(z: V) -> F;
109
110 #[inline(always)]
115 fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [V])>(
116 lhs: [F; N],
117 rhs: [U; N],
118 conv: C,
119 ) -> [F; N] {
120 let lhs = lhs.map(Self::read);
121 let mut output = [V::default(); N];
122 conv(lhs, rhs, &mut output);
123 output.map(Self::reduce)
124 }
125
126 #[inline(always)]
127 fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
128 output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
129 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
130 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
131 }
132
133 #[inline(always)]
134 fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [V]) {
135 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
136 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
137 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
138 }
139
140 #[inline(always)]
141 fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
142 let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
145 let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
146 let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
147 let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
148
149 output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
150 output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
151 output[2] = Self::parity_dot(u_p, v_p);
152 output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
153
154 output[0] += output[2];
155 output[1] += output[3];
156
157 output[0] >>= 1;
158 output[1] >>= 1;
159
160 output[2] -= output[0];
161 output[3] -= output[1];
162 }
163
164 #[inline(always)]
165 fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [V]) {
166 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
167 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
168 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
169 output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
170 }
171
172 #[inline(always)]
173 fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
174 conv_n_recursive::<6, 3, T, U, V, _, _>(
175 lhs,
176 rhs,
177 output,
178 Self::conv3,
179 Self::negacyclic_conv3,
180 )
181 }
182
183 #[inline(always)]
184 fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [V]) {
185 negacyclic_conv_n_recursive::<6, 3, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv3)
186 }
187
188 #[inline(always)]
189 fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
190 conv_n_recursive::<8, 4, T, U, V, _, _>(
191 lhs,
192 rhs,
193 output,
194 Self::conv4,
195 Self::negacyclic_conv4,
196 )
197 }
198
199 #[inline(always)]
200 fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [V]) {
201 negacyclic_conv_n_recursive::<8, 4, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv4)
202 }
203
204 #[inline(always)]
205 fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
206 conv_n_recursive::<12, 6, T, U, V, _, _>(
207 lhs,
208 rhs,
209 output,
210 Self::conv6,
211 Self::negacyclic_conv6,
212 )
213 }
214
215 #[inline(always)]
216 fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [V]) {
217 negacyclic_conv_n_recursive::<12, 6, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv6)
218 }
219
220 #[inline(always)]
221 fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
222 conv_n_recursive::<16, 8, T, U, V, _, _>(
223 lhs,
224 rhs,
225 output,
226 Self::conv8,
227 Self::negacyclic_conv8,
228 )
229 }
230
231 #[inline(always)]
232 fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [V]) {
233 negacyclic_conv_n_recursive::<16, 8, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv8)
234 }
235
236 #[inline(always)]
237 fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [V]) {
238 conv_n_recursive::<24, 12, T, U, V, _, _>(
239 lhs,
240 rhs,
241 output,
242 Self::conv12,
243 Self::negacyclic_conv12,
244 )
245 }
246
247 #[inline(always)]
248 fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
249 conv_n_recursive::<32, 16, T, U, V, _, _>(
250 lhs,
251 rhs,
252 output,
253 Self::conv16,
254 Self::negacyclic_conv16,
255 )
256 }
257
258 #[inline(always)]
259 fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [V]) {
260 negacyclic_conv_n_recursive::<32, 16, T, U, V, _>(lhs, rhs, output, Self::negacyclic_conv16)
261 }
262
263 #[inline(always)]
264 fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [V]) {
265 conv_n_recursive::<64, 32, T, U, V, _, _>(
266 lhs,
267 rhs,
268 output,
269 Self::conv32,
270 Self::negacyclic_conv32,
271 )
272 }
273}
274
275#[inline(always)]
278fn conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, C, NC>(
279 lhs: [T; N],
280 rhs: [U; N],
281 output: &mut [V],
282 inner_conv: C,
283 inner_negacyclic_conv: NC,
284) where
285 T: RngElt,
286 U: RngElt,
287 V: RngElt,
288 C: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
289 NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
290{
291 debug_assert_eq!(2 * HALF_N, N);
292 let mut lhs_pos = [T::default(); HALF_N]; let mut lhs_neg = [T::default(); HALF_N]; let mut rhs_pos = [U::default(); HALF_N]; let mut rhs_neg = [U::default(); HALF_N]; for i in 0..HALF_N {
299 let s = lhs[i];
300 let t = lhs[i + HALF_N];
301
302 lhs_pos[i] = s + t;
303 lhs_neg[i] = s - t;
304
305 let s = rhs[i];
306 let t = rhs[i + HALF_N];
307
308 rhs_pos[i] = s + t;
309 rhs_neg[i] = s - t;
310 }
311
312 let (left, right) = output.split_at_mut(HALF_N);
313
314 inner_negacyclic_conv(lhs_neg, rhs_neg, left);
316
317 inner_conv(lhs_pos, rhs_pos, right);
319
320 for i in 0..HALF_N {
321 left[i] += right[i]; left[i] >>= 1; right[i] -= left[i]; }
325}
326
327#[inline(always)]
330fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, T, U, V, NC>(
331 lhs: [T; N],
332 rhs: [U; N],
333 output: &mut [V],
334 inner_negacyclic_conv: NC,
335) where
336 T: RngElt,
337 U: RngElt,
338 V: RngElt,
339 NC: Fn([T; HALF_N], [U; HALF_N], &mut [V]),
340{
341 debug_assert_eq!(2 * HALF_N, N);
342 let mut lhs_even = [T::default(); HALF_N];
344 let mut lhs_odd = [T::default(); HALF_N];
345 let mut lhs_sum = [T::default(); HALF_N];
346 let mut rhs_even = [U::default(); HALF_N];
347 let mut rhs_odd = [U::default(); HALF_N];
348 let mut rhs_sum = [U::default(); HALF_N];
349
350 for i in 0..HALF_N {
351 let s = lhs[2 * i];
352 let t = lhs[2 * i + 1];
353 lhs_even[i] = s;
354 lhs_odd[i] = t;
355 lhs_sum[i] = s + t;
356
357 let s = rhs[2 * i];
358 let t = rhs[2 * i + 1];
359 rhs_even[i] = s;
360 rhs_odd[i] = t;
361 rhs_sum[i] = s + t;
362 }
363
364 let mut even_s_conv = [V::default(); HALF_N];
365 let (left, right) = output.split_at_mut(HALF_N);
366
367 inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
370 inner_negacyclic_conv(lhs_odd, rhs_odd, left);
371 inner_negacyclic_conv(lhs_sum, rhs_sum, right);
372
373 right[0] -= even_s_conv[0] + left[0];
376 even_s_conv[0] -= left[HALF_N - 1];
377
378 for i in 1..HALF_N {
379 right[i] -= even_s_conv[i] + left[i];
380 even_s_conv[i] += left[i - 1];
381 }
382
383 for i in 0..HALF_N {
385 output[2 * i] = even_s_conv[i];
386 output[2 * i + 1] = output[i + HALF_N];
387 }
388}