1use alloc::vec;
2use alloc::vec::Vec;
3use core::ops::Mul;
4
5use p3_field::{Field, PrimeCharacteristicRing};
6
7#[derive(Clone, Debug)]
14pub struct VirtualPairCol<F: Field> {
15 column_weights: Vec<(PairCol, F)>,
17 constant: F,
19}
20
21#[derive(Clone, Copy, Debug)]
23pub enum PairCol {
24 Preprocessed(usize),
29 Main(usize),
34}
35
36impl PairCol {
37 pub const fn get<T: Copy>(&self, preprocessed: &[T], main: &[T]) -> T {
46 match self {
47 Self::Preprocessed(i) => preprocessed[*i],
48 Self::Main(i) => main[*i],
49 }
50 }
51}
52
53impl<F: Field> VirtualPairCol<F> {
54 pub const fn new(column_weights: Vec<(PairCol, F)>, constant: F) -> Self {
60 Self {
61 column_weights,
62 constant,
63 }
64 }
65
66 pub fn new_preprocessed(column_weights: Vec<(usize, F)>, constant: F) -> Self {
72 Self::new(
73 column_weights
74 .into_iter()
75 .map(|(i, w)| (PairCol::Preprocessed(i), w))
76 .collect(),
77 constant,
78 )
79 }
80
81 pub fn new_main(column_weights: Vec<(usize, F)>, constant: F) -> Self {
87 Self::new(
88 column_weights
89 .into_iter()
90 .map(|(i, w)| (PairCol::Main(i), w))
91 .collect(),
92 constant,
93 )
94 }
95
96 pub const ONE: Self = Self::constant(F::ONE);
98
99 #[must_use]
104 pub const fn constant(x: F) -> Self {
105 Self {
106 column_weights: vec![],
107 constant: x,
108 }
109 }
110
111 #[must_use]
116 pub fn single(column: PairCol) -> Self {
117 Self {
118 column_weights: vec![(column, F::ONE)],
119 constant: F::ZERO,
120 }
121 }
122
123 #[must_use]
128 pub fn single_preprocessed(column: usize) -> Self {
129 Self::single(PairCol::Preprocessed(column))
130 }
131
132 #[must_use]
137 pub fn single_main(column: usize) -> Self {
138 Self::single(PairCol::Main(column))
139 }
140
141 #[must_use]
146 pub fn sum_main(columns: Vec<usize>) -> Self {
147 let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
148 Self::new_main(column_weights, F::ZERO)
149 }
150
151 #[must_use]
156 pub fn sum_preprocessed(columns: Vec<usize>) -> Self {
157 let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
158 Self::new_preprocessed(column_weights, F::ZERO)
159 }
160
161 #[must_use]
167 pub fn diff_preprocessed(a_col: usize, b_col: usize) -> Self {
168 Self::new_preprocessed(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
169 }
170
171 #[must_use]
177 pub fn diff_main(a_col: usize, b_col: usize) -> Self {
178 Self::new_main(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
179 }
180
181 pub fn apply<Expr, Var>(&self, preprocessed: &[Var], main: &[Var]) -> Expr
189 where
190 F: Into<Expr>,
191 Expr: PrimeCharacteristicRing + Mul<F, Output = Expr>,
192 Var: Into<Expr> + Copy,
193 {
194 self.column_weights
195 .iter()
196 .fold(self.constant.into(), |acc, &(col, w)| {
197 acc + col.get(preprocessed, main).into() * w
198 })
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use p3_baby_bear::BabyBear;
205
206 use super::*;
207
208 type F = BabyBear;
209
210 #[test]
211 fn test_pair_col_get_main_and_preprocessed() {
212 let pre = [F::from_u8(10), F::from_u8(20)];
213 let main = [F::from_u8(30), F::from_u8(40)];
214
215 assert_eq!(PairCol::Preprocessed(1).get(&pre, &main), F::from_u8(20));
217
218 assert_eq!(PairCol::Main(0).get(&pre, &main), F::from_u8(30));
220 }
221
222 #[test]
223 fn test_constant_only_virtual_pair_col() {
224 let col = VirtualPairCol::<F>::constant(F::from_u8(7));
225
226 let pre = [F::ONE];
228 let main = [F::ONE];
229 let result = col.apply::<F, F>(&pre, &main);
230
231 assert_eq!(result, F::from_u8(7));
232 }
233
234 #[test]
235 fn test_single_main_column() {
236 let col = VirtualPairCol::<F>::single_main(1); let main = [F::from_u8(9), F::from_u8(5)];
239 let pre = [F::ZERO]; let result = col.apply::<F, F>(&pre, &main);
242
243 assert_eq!(result, F::from_u8(5));
245 }
246
247 #[test]
248 fn test_single_preprocessed_column() {
249 let col = VirtualPairCol::<F>::single_preprocessed(0);
250
251 let pre = [F::from_u8(12)];
252 let main = [];
253
254 let result = col.apply::<F, F>(&pre, &main);
255
256 assert_eq!(result, F::from_u8(12));
257 }
258
259 #[test]
260 fn test_sum_main_columns() {
261 let col = VirtualPairCol::<F>::sum_main(vec![0, 2]);
263
264 let main = [
265 F::TWO,
266 F::from_u8(99), F::from_u8(5),
268 ];
269 let pre = [];
270
271 let result = col.apply::<F, F>(&pre, &main);
272
273 assert_eq!(result, F::from_u8(2) + F::from_u8(5));
274 }
275
276 #[test]
277 fn test_sum_preprocessed_columns() {
278 let col = VirtualPairCol::<F>::sum_preprocessed(vec![1, 2]);
279
280 let pre = [
281 F::from_u8(3), F::from_u8(4),
283 F::from_u8(6),
284 ];
285 let main = [];
286
287 let result = col.apply::<F, F>(&pre, &main);
288
289 assert_eq!(result, F::from_u8(4) + F::from_u8(6));
290 }
291
292 #[test]
293 fn test_diff_main_columns() {
294 let col = VirtualPairCol::<F>::diff_main(2, 0);
296
297 let main = [
298 F::from_u8(7),
299 F::ZERO, F::from_u8(10),
301 ];
302 let pre = [];
303
304 let result = col.apply::<F, F>(&pre, &main);
305
306 assert_eq!(result, F::from_u8(10) - F::from_u8(7));
307 }
308
309 #[test]
310 fn test_diff_preprocessed_columns() {
311 let col = VirtualPairCol::<F>::diff_preprocessed(1, 0);
313
314 let pre = [F::from_u8(4), F::from_u8(15)];
315 let main = [];
316
317 let result = col.apply::<F, F>(&pre, &main);
318
319 assert_eq!(result, F::from_u8(15) - F::from_u8(4));
320 }
321
322 #[test]
323 fn test_combination_with_constant_and_weights() {
324 let col = VirtualPairCol {
326 column_weights: vec![
327 (PairCol::Main(1), F::from_u8(3)),
328 (PairCol::Preprocessed(0), F::TWO),
329 ],
330 constant: F::from_u8(5),
331 };
332
333 let main = [F::ZERO, F::from_u8(4)];
334 let pre = [F::from_u8(6)];
335
336 let result = col.apply::<F, F>(&pre, &main);
337
338 assert_eq!(result, F::from_u8(29));
340 }
341
342 #[test]
343 fn test_virtual_pair_col_one_is_identity() {
344 let col = VirtualPairCol::<F>::ONE;
346 let pre = [F::from_u8(99)];
347 let main = [F::from_u8(42)];
348
349 let result = col.apply::<F, F>(&pre, &main);
350
351 assert_eq!(result, F::ONE);
352 }
353}