p3_air/
virtual_column.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::ops::Mul;
4
5use p3_field::{Field, PrimeCharacteristicRing};
6
7/// An affine linear combination of columns in a PAIR (Preprocessed AIR).
8///
9/// This structure represents the column `V` with entries `V[j] = Σ(w_i * V_i[j]) + c` where:
10/// - `w_i` are the column weights
11/// - `V_i` are the columns (either preprocessed or main trace columns)
12/// - `c` is a constant term
13#[derive(Clone, Debug)]
14pub struct VirtualPairCol<F: Field> {
15    /// Linear combination coefficients: pairs of (column, weight).
16    column_weights: Vec<(PairCol, F)>,
17    /// Constant term added to the linear combination.
18    constant: F,
19}
20
21/// A reference to a column in a PAIR (Preprocessed AIR).
22#[derive(Clone, Copy, Debug)]
23pub enum PairCol {
24    /// A preprocessed (fixed) column at the specified index.
25    ///
26    /// These columns contain values that are determined during the setup phase
27    /// and remain constant across all proof generations.
28    Preprocessed(usize),
29    /// A main trace column at the specified index.
30    ///
31    /// These columns contain witness values that vary between different executions
32    /// and are filled during trace generation.
33    Main(usize),
34}
35
36impl PairCol {
37    /// Retrieves the value corresponding to the appropriate column.
38    ///
39    /// # Arguments
40    /// * `preprocessed` - Slice containing preprocessed row
41    /// * `main` - Slice containing main trace row
42    ///
43    /// # Panics
44    /// Panics if the column index is out of bounds for the respective rows.
45    pub const fn get<T: Copy>(&self, preprocessed: &[T], main: &[T]) -> T {
46        match self {
47            Self::Preprocessed(i) => preprocessed[*i],
48            Self::Main(i) => main[*i],
49        }
50    }
51}
52
53impl<F: Field> VirtualPairCol<F> {
54    /// Creates a new virtual column with the specified column weights and constant term.
55    ///
56    /// # Arguments
57    /// * `column_weights` - Vector of (column, weight) pairs defining the linear combination
58    /// * `constant` - Constant term to add to the linear combination
59    pub const fn new(column_weights: Vec<(PairCol, F)>, constant: F) -> Self {
60        Self {
61            column_weights,
62            constant,
63        }
64    }
65
66    /// Creates a virtual column as a linear combination of preprocessed columns.
67    ///
68    /// # Arguments
69    /// * `column_weights` - Vector of (column_index, weight) pairs for preprocessed columns
70    /// * `constant` - Constant term to add to the combination
71    pub fn new_preprocessed(column_weights: Vec<(usize, F)>, constant: F) -> Self {
72        Self::new(
73            column_weights
74                .into_iter()
75                .map(|(i, w)| (PairCol::Preprocessed(i), w))
76                .collect(),
77            constant,
78        )
79    }
80
81    /// Creates a virtual column as a linear combination of main trace columns.
82    ///
83    /// # Arguments
84    /// * `column_weights` - Vector of (column_index, weight) pairs for main trace columns
85    /// * `constant` - Constant term to add to the combination
86    pub fn new_main(column_weights: Vec<(usize, F)>, constant: F) -> Self {
87        Self::new(
88            column_weights
89                .into_iter()
90                .map(|(i, w)| (PairCol::Main(i), w))
91                .collect(),
92            constant,
93        )
94    }
95
96    /// A virtual column that always evaluates to the field element `1`.
97    pub const ONE: Self = Self::constant(F::ONE);
98
99    /// Create a virtual column whose value on every row is equal and constant.
100    ///
101    /// # Arguments
102    /// * `x` - The constant field element.
103    #[must_use]
104    pub const fn constant(x: F) -> Self {
105        Self {
106            column_weights: vec![],
107            constant: x,
108        }
109    }
110
111    /// Creates a virtual column equal to a provided column.
112    ///
113    /// # Arguments
114    /// * `column` - The column to represent.
115    #[must_use]
116    pub fn single(column: PairCol) -> Self {
117        Self {
118            column_weights: vec![(column, F::ONE)],
119            constant: F::ZERO,
120        }
121    }
122
123    /// Creates a virtual column equal to a preprocessed column.
124    ///
125    /// # Arguments
126    /// * `column` - Index of the preprocessed column
127    #[must_use]
128    pub fn single_preprocessed(column: usize) -> Self {
129        Self::single(PairCol::Preprocessed(column))
130    }
131
132    /// Creates a virtual column equal to a main trace column.
133    ///
134    /// # Arguments
135    /// * `column` - Index of the main trace column
136    #[must_use]
137    pub fn single_main(column: usize) -> Self {
138        Self::single(PairCol::Main(column))
139    }
140
141    /// Create a virtual column which is the sum of main trace columns.
142    ///
143    /// # Arguments
144    /// * `columns` - Vector of main trace column indices to sum
145    #[must_use]
146    pub fn sum_main(columns: Vec<usize>) -> Self {
147        let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
148        Self::new_main(column_weights, F::ZERO)
149    }
150
151    /// Create a virtual column which is the sum of preprocessed columns.
152    ///
153    /// # Arguments
154    /// * `columns` - Vector of preprocessed column indices to sum
155    #[must_use]
156    pub fn sum_preprocessed(columns: Vec<usize>) -> Self {
157        let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
158        Self::new_preprocessed(column_weights, F::ZERO)
159    }
160
161    /// Create a virtual column which is the difference between two preprocessed columns.
162    ///
163    /// # Arguments
164    /// * `a_col` - Index of the minuend preprocessed column.
165    /// * `b_col` - Index of the subtrahend preprocessed column.
166    #[must_use]
167    pub fn diff_preprocessed(a_col: usize, b_col: usize) -> Self {
168        Self::new_preprocessed(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
169    }
170
171    /// Create a virtual column which is the difference between two main trace columns.
172    ///
173    /// # Arguments
174    /// * `a_col` - Index of the minuend main trace column.
175    /// * `b_col` - Index of the subtrahend main trace column.
176    #[must_use]
177    pub fn diff_main(a_col: usize, b_col: usize) -> Self {
178        Self::new_main(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
179    }
180
181    /// Evaluates the virtual column at a given row by applying the affine linear combination to a pair of preprocessed and main trace rows.
182    ///
183    /// This computes `Σ(w_i * column_values[i]) + constant`
184    ///
185    /// # Arguments
186    /// * `preprocessed` - Row of preprocessed values.
187    /// * `main` - Row of main trace values.
188    pub fn apply<Expr, Var>(&self, preprocessed: &[Var], main: &[Var]) -> Expr
189    where
190        F: Into<Expr>,
191        Expr: PrimeCharacteristicRing + Mul<F, Output = Expr>,
192        Var: Into<Expr> + Copy,
193    {
194        self.column_weights
195            .iter()
196            .fold(self.constant.into(), |acc, &(col, w)| {
197                acc + col.get(preprocessed, main).into() * w
198            })
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use p3_baby_bear::BabyBear;
205
206    use super::*;
207
208    type F = BabyBear;
209
210    #[test]
211    fn test_pair_col_get_main_and_preprocessed() {
212        let pre = [F::from_u8(10), F::from_u8(20)];
213        let main = [F::from_u8(30), F::from_u8(40)];
214
215        // Preprocessed(1) should return 20
216        assert_eq!(PairCol::Preprocessed(1).get(&pre, &main), F::from_u8(20));
217
218        // Main(0) should return 30
219        assert_eq!(PairCol::Main(0).get(&pre, &main), F::from_u8(30));
220    }
221
222    #[test]
223    fn test_constant_only_virtual_pair_col() {
224        let col = VirtualPairCol::<F>::constant(F::from_u8(7));
225
226        // Apply to any input: result should always be the constant
227        let pre = [F::ONE];
228        let main = [F::ONE];
229        let result = col.apply::<F, F>(&pre, &main);
230
231        assert_eq!(result, F::from_u8(7));
232    }
233
234    #[test]
235    fn test_single_main_column() {
236        let col = VirtualPairCol::<F>::single_main(1); // column index 1
237
238        let main = [F::from_u8(9), F::from_u8(5)];
239        let pre = [F::ZERO]; // ignored
240
241        let result = col.apply::<F, F>(&pre, &main);
242
243        // Since we used single_main(1), this should equal main[1] = 5
244        assert_eq!(result, F::from_u8(5));
245    }
246
247    #[test]
248    fn test_single_preprocessed_column() {
249        let col = VirtualPairCol::<F>::single_preprocessed(0);
250
251        let pre = [F::from_u8(12)];
252        let main = [];
253
254        let result = col.apply::<F, F>(&pre, &main);
255
256        assert_eq!(result, F::from_u8(12));
257    }
258
259    #[test]
260    fn test_sum_main_columns() {
261        // This adds up main[0] + main[2]
262        let col = VirtualPairCol::<F>::sum_main(vec![0, 2]);
263
264        let main = [
265            F::TWO,
266            F::from_u8(99), // ignored
267            F::from_u8(5),
268        ];
269        let pre = [];
270
271        let result = col.apply::<F, F>(&pre, &main);
272
273        assert_eq!(result, F::from_u8(2) + F::from_u8(5));
274    }
275
276    #[test]
277    fn test_sum_preprocessed_columns() {
278        let col = VirtualPairCol::<F>::sum_preprocessed(vec![1, 2]);
279
280        let pre = [
281            F::from_u8(3), // ignored
282            F::from_u8(4),
283            F::from_u8(6),
284        ];
285        let main = [];
286
287        let result = col.apply::<F, F>(&pre, &main);
288
289        assert_eq!(result, F::from_u8(4) + F::from_u8(6));
290    }
291
292    #[test]
293    fn test_diff_main_columns() {
294        // Computes main[2] - main[0]
295        let col = VirtualPairCol::<F>::diff_main(2, 0);
296
297        let main = [
298            F::from_u8(7),
299            F::ZERO, // ignored
300            F::from_u8(10),
301        ];
302        let pre = [];
303
304        let result = col.apply::<F, F>(&pre, &main);
305
306        assert_eq!(result, F::from_u8(10) - F::from_u8(7));
307    }
308
309    #[test]
310    fn test_diff_preprocessed_columns() {
311        // Computes pre[1] - pre[0]
312        let col = VirtualPairCol::<F>::diff_preprocessed(1, 0);
313
314        let pre = [F::from_u8(4), F::from_u8(15)];
315        let main = [];
316
317        let result = col.apply::<F, F>(&pre, &main);
318
319        assert_eq!(result, F::from_u8(15) - F::from_u8(4));
320    }
321
322    #[test]
323    fn test_combination_with_constant_and_weights() {
324        // Computes: 3 * main[1] + 2 * pre[0] + constant (5)
325        let col = VirtualPairCol {
326            column_weights: vec![
327                (PairCol::Main(1), F::from_u8(3)),
328                (PairCol::Preprocessed(0), F::TWO),
329            ],
330            constant: F::from_u8(5),
331        };
332
333        let main = [F::ZERO, F::from_u8(4)];
334        let pre = [F::from_u8(6)];
335
336        let result = col.apply::<F, F>(&pre, &main);
337
338        // result = 3*4 + 2*6 + 5
339        assert_eq!(result, F::from_u8(29));
340    }
341
342    #[test]
343    fn test_virtual_pair_col_one_is_identity() {
344        // VirtualPairCol::ONE should always evaluate to 1 regardless of input
345        let col = VirtualPairCol::<F>::ONE;
346        let pre = [F::from_u8(99)];
347        let main = [F::from_u8(42)];
348
349        let result = col.apply::<F, F>(&pre, &main);
350
351        assert_eq!(result, F::ONE);
352    }
353}