p3_air/
virtual_column.rs

1//! Utilities for describing logical columns over a `(preprocessed_row, main_row)` pair.
2//!
3//! [`VirtualPairCol`] lets a gadget talk about a derived value without allocating a dedicated
4//! trace column. Instead, define the value once as an affine combination of existing
5//! preprocessed and main columns, then reuse that description anywhere the row layout matches.
6//!
7//! Typical uses include:
8//! - combining fixed columns and witness columns into one gadget input,
9//! - defining lookup/table values as packed or shifted views over existing columns,
10//! - reusing the same gadget wiring across multiple traces that share a column layout.
11//!
12//! # Example: combine fixed and witness columns
13//!
14//! ```
15//! use p3_air::{PairCol, VirtualPairCol};
16//! use p3_baby_bear::BabyBear;
17//! use p3_field::PrimeCharacteristicRing;
18//!
19//! type F = BabyBear;
20//!
21//! let gadget_input = VirtualPairCol::new(
22//!     vec![
23//!         (PairCol::Main(0), F::ONE),
24//!         (PairCol::Main(1), F::from_u8(2)),
25//!         (PairCol::Preprocessed(0), F::ONE),
26//!     ],
27//!     F::from_u8(7),
28//! );
29//!
30//! let preprocessed = [F::from_u8(5)];
31//! let main = [F::from_u8(11), F::from_u8(13)];
32//!
33//! assert_eq!(gadget_input.apply::<F, F>(&preprocessed, &main), F::from_u8(49));
34//! ```
35//!
36//! # Example: reuse one logical lookup key across traces
37//!
38//! ```
39//! use p3_air::{PairCol, VirtualPairCol};
40//! use p3_baby_bear::BabyBear;
41//! use p3_field::PrimeCharacteristicRing;
42//!
43//! type F = BabyBear;
44//!
45//! // Pack two witness limbs, and include a fixed table offset.
46//! let lookup_key = VirtualPairCol::new(
47//!     vec![
48//!         (PairCol::Main(0), F::ONE),
49//!         (PairCol::Main(1), F::from_u8(16)),
50//!         (PairCol::Preprocessed(0), F::from_u8(64)),
51//!     ],
52//!     F::ZERO,
53//! );
54//!
55//! let table_layout = [F::from_u8(1)];
56//! let trace_a = [F::from_u8(3), F::from_u8(2)];
57//! let trace_b = [F::from_u8(5), F::from_u8(1)];
58//!
59//! assert_eq!(lookup_key.apply::<F, F>(&table_layout, &trace_a), F::from_u8(99));
60//! assert_eq!(lookup_key.apply::<F, F>(&table_layout, &trace_b), F::from_u8(85));
61//! ```
62
63use alloc::vec;
64use alloc::vec::Vec;
65use core::ops::Mul;
66
67use p3_field::{Field, PrimeCharacteristicRing};
68
69/// An affine linear combination of columns in a PAIR (Preprocessed AIR).
70///
71/// This structure represents the column `V` with entries `V[j] = Σ(w_i * V_i[j]) + c` where:
72/// - `w_i` are the column weights
73/// - `V_i` are the columns (either preprocessed or main trace columns)
74/// - `c` is a constant term
75///
76/// A `VirtualPairCol` does not allocate any new trace storage. It only records how to read a
77/// logical value from the existing preprocessed and main rows, which makes it convenient for
78/// reusable gadget inputs, packed lookup values, and cross-trace layouts that share the same
79/// column wiring.
80#[derive(Clone, Debug)]
81pub struct VirtualPairCol<F: Field> {
82    /// Linear combination coefficients: pairs of (column, weight).
83    column_weights: Vec<(PairCol, F)>,
84    /// Constant term added to the linear combination.
85    constant: F,
86}
87
88/// A reference to a column in a PAIR (Preprocessed AIR).
89#[derive(Clone, Copy, Debug)]
90pub enum PairCol {
91    /// A preprocessed (fixed) column at the specified index.
92    ///
93    /// These columns contain values that are determined during the setup phase
94    /// and remain constant across all proof generations.
95    Preprocessed(usize),
96    /// A main trace column at the specified index.
97    ///
98    /// These columns contain witness values that vary between different executions
99    /// and are filled during trace generation.
100    Main(usize),
101}
102
103impl PairCol {
104    /// Retrieves the value corresponding to the appropriate column.
105    ///
106    /// # Arguments
107    /// * `preprocessed` - Slice containing preprocessed row
108    /// * `main` - Slice containing main trace row
109    ///
110    /// # Panics
111    /// Panics if the column index is out of bounds for the respective rows.
112    pub const fn get<T: Copy>(&self, preprocessed: &[T], main: &[T]) -> T {
113        match self {
114            Self::Preprocessed(i) => preprocessed[*i],
115            Self::Main(i) => main[*i],
116        }
117    }
118}
119
120impl<F: Field> VirtualPairCol<F> {
121    /// Creates a new virtual column with the specified column weights and constant term.
122    ///
123    /// # Arguments
124    /// * `column_weights` - Vector of (column, weight) pairs defining the linear combination
125    /// * `constant` - Constant term to add to the linear combination
126    pub const fn new(column_weights: Vec<(PairCol, F)>, constant: F) -> Self {
127        Self {
128            column_weights,
129            constant,
130        }
131    }
132
133    /// Creates a virtual column as a linear combination of preprocessed columns.
134    ///
135    /// # Arguments
136    /// * `column_weights` - Vector of (column_index, weight) pairs for preprocessed columns
137    /// * `constant` - Constant term to add to the combination
138    pub fn new_preprocessed(column_weights: Vec<(usize, F)>, constant: F) -> Self {
139        Self::new(
140            column_weights
141                .into_iter()
142                .map(|(i, w)| (PairCol::Preprocessed(i), w))
143                .collect(),
144            constant,
145        )
146    }
147
148    /// Creates a virtual column as a linear combination of main trace columns.
149    ///
150    /// # Arguments
151    /// * `column_weights` - Vector of (column_index, weight) pairs for main trace columns
152    /// * `constant` - Constant term to add to the combination
153    pub fn new_main(column_weights: Vec<(usize, F)>, constant: F) -> Self {
154        Self::new(
155            column_weights
156                .into_iter()
157                .map(|(i, w)| (PairCol::Main(i), w))
158                .collect(),
159            constant,
160        )
161    }
162
163    /// A virtual column that always evaluates to the field element `1`.
164    pub const ONE: Self = Self::constant(F::ONE);
165
166    /// Create a virtual column whose value on every row is equal and constant.
167    ///
168    /// # Arguments
169    /// * `x` - The constant field element.
170    #[must_use]
171    pub const fn constant(x: F) -> Self {
172        Self {
173            column_weights: vec![],
174            constant: x,
175        }
176    }
177
178    /// Creates a virtual column equal to a provided column.
179    ///
180    /// # Arguments
181    /// * `column` - The column to represent.
182    #[must_use]
183    pub fn single(column: PairCol) -> Self {
184        Self {
185            column_weights: vec![(column, F::ONE)],
186            constant: F::ZERO,
187        }
188    }
189
190    /// Creates a virtual column equal to a preprocessed column.
191    ///
192    /// # Arguments
193    /// * `column` - Index of the preprocessed column
194    #[must_use]
195    pub fn single_preprocessed(column: usize) -> Self {
196        Self::single(PairCol::Preprocessed(column))
197    }
198
199    /// Creates a virtual column equal to a main trace column.
200    ///
201    /// # Arguments
202    /// * `column` - Index of the main trace column
203    #[must_use]
204    pub fn single_main(column: usize) -> Self {
205        Self::single(PairCol::Main(column))
206    }
207
208    /// Create a virtual column which is the sum of main trace columns.
209    ///
210    /// # Arguments
211    /// * `columns` - Vector of main trace column indices to sum
212    #[must_use]
213    pub fn sum_main(columns: Vec<usize>) -> Self {
214        let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
215        Self::new_main(column_weights, F::ZERO)
216    }
217
218    /// Create a virtual column which is the sum of preprocessed columns.
219    ///
220    /// # Arguments
221    /// * `columns` - Vector of preprocessed column indices to sum
222    #[must_use]
223    pub fn sum_preprocessed(columns: Vec<usize>) -> Self {
224        let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
225        Self::new_preprocessed(column_weights, F::ZERO)
226    }
227
228    /// Create a virtual column which is the difference between two preprocessed columns.
229    ///
230    /// # Arguments
231    /// * `a_col` - Index of the minuend preprocessed column.
232    /// * `b_col` - Index of the subtrahend preprocessed column.
233    #[must_use]
234    pub fn diff_preprocessed(a_col: usize, b_col: usize) -> Self {
235        Self::new_preprocessed(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
236    }
237
238    /// Create a virtual column which is the difference between two main trace columns.
239    ///
240    /// # Arguments
241    /// * `a_col` - Index of the minuend main trace column.
242    /// * `b_col` - Index of the subtrahend main trace column.
243    #[must_use]
244    pub fn diff_main(a_col: usize, b_col: usize) -> Self {
245        Self::new_main(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
246    }
247
248    /// Evaluates the virtual column at a given row by applying the affine linear combination to a pair of preprocessed and main trace rows.
249    ///
250    /// This computes `Σ(w_i * column_values[i]) + constant`
251    ///
252    /// # Arguments
253    /// * `preprocessed` - Row of preprocessed values.
254    /// * `main` - Row of main trace values.
255    pub fn apply<Expr, Var>(&self, preprocessed: &[Var], main: &[Var]) -> Expr
256    where
257        F: Into<Expr>,
258        Expr: PrimeCharacteristicRing + Mul<F, Output = Expr>,
259        Var: Into<Expr> + Copy,
260    {
261        self.column_weights
262            .iter()
263            .fold(self.constant.into(), |acc, &(col, w)| {
264                acc + col.get(preprocessed, main).into() * w
265            })
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use p3_baby_bear::BabyBear;
272
273    use super::*;
274
275    type F = BabyBear;
276
277    #[test]
278    fn test_pair_col_get_main_and_preprocessed() {
279        let pre = [F::from_u8(10), F::from_u8(20)];
280        let main = [F::from_u8(30), F::from_u8(40)];
281
282        // Preprocessed(1) should return 20
283        assert_eq!(PairCol::Preprocessed(1).get(&pre, &main), F::from_u8(20));
284
285        // Main(0) should return 30
286        assert_eq!(PairCol::Main(0).get(&pre, &main), F::from_u8(30));
287    }
288
289    #[test]
290    fn test_constant_only_virtual_pair_col() {
291        let col = VirtualPairCol::<F>::constant(F::from_u8(7));
292
293        // Apply to any input: result should always be the constant
294        let pre = [F::ONE];
295        let main = [F::ONE];
296        let result = col.apply::<F, F>(&pre, &main);
297
298        assert_eq!(result, F::from_u8(7));
299    }
300
301    #[test]
302    fn test_single_main_column() {
303        let col = VirtualPairCol::<F>::single_main(1); // column index 1
304
305        let main = [F::from_u8(9), F::from_u8(5)];
306        let pre = [F::ZERO]; // ignored
307
308        let result = col.apply::<F, F>(&pre, &main);
309
310        // Since we used single_main(1), this should equal main[1] = 5
311        assert_eq!(result, F::from_u8(5));
312    }
313
314    #[test]
315    fn test_single_preprocessed_column() {
316        let col = VirtualPairCol::<F>::single_preprocessed(0);
317
318        let pre = [F::from_u8(12)];
319        let main = [];
320
321        let result = col.apply::<F, F>(&pre, &main);
322
323        assert_eq!(result, F::from_u8(12));
324    }
325
326    #[test]
327    fn test_sum_main_columns() {
328        // This adds up main[0] + main[2]
329        let col = VirtualPairCol::<F>::sum_main(vec![0, 2]);
330
331        let main = [
332            F::TWO,
333            F::from_u8(99), // ignored
334            F::from_u8(5),
335        ];
336        let pre = [];
337
338        let result = col.apply::<F, F>(&pre, &main);
339
340        assert_eq!(result, F::from_u8(2) + F::from_u8(5));
341    }
342
343    #[test]
344    fn test_sum_preprocessed_columns() {
345        let col = VirtualPairCol::<F>::sum_preprocessed(vec![1, 2]);
346
347        let pre = [
348            F::from_u8(3), // ignored
349            F::from_u8(4),
350            F::from_u8(6),
351        ];
352        let main = [];
353
354        let result = col.apply::<F, F>(&pre, &main);
355
356        assert_eq!(result, F::from_u8(4) + F::from_u8(6));
357    }
358
359    #[test]
360    fn test_diff_main_columns() {
361        // Computes main[2] - main[0]
362        let col = VirtualPairCol::<F>::diff_main(2, 0);
363
364        let main = [
365            F::from_u8(7),
366            F::ZERO, // ignored
367            F::from_u8(10),
368        ];
369        let pre = [];
370
371        let result = col.apply::<F, F>(&pre, &main);
372
373        assert_eq!(result, F::from_u8(10) - F::from_u8(7));
374    }
375
376    #[test]
377    fn test_diff_preprocessed_columns() {
378        // Computes pre[1] - pre[0]
379        let col = VirtualPairCol::<F>::diff_preprocessed(1, 0);
380
381        let pre = [F::from_u8(4), F::from_u8(15)];
382        let main = [];
383
384        let result = col.apply::<F, F>(&pre, &main);
385
386        assert_eq!(result, F::from_u8(15) - F::from_u8(4));
387    }
388
389    #[test]
390    fn test_combination_with_constant_and_weights() {
391        // Computes: 3 * main[1] + 2 * pre[0] + constant (5)
392        let col = VirtualPairCol {
393            column_weights: vec![
394                (PairCol::Main(1), F::from_u8(3)),
395                (PairCol::Preprocessed(0), F::TWO),
396            ],
397            constant: F::from_u8(5),
398        };
399
400        let main = [F::ZERO, F::from_u8(4)];
401        let pre = [F::from_u8(6)];
402
403        let result = col.apply::<F, F>(&pre, &main);
404
405        // result = 3*4 + 2*6 + 5
406        assert_eq!(result, F::from_u8(29));
407    }
408
409    #[test]
410    fn test_virtual_pair_col_can_pack_lookup_value() {
411        let col = VirtualPairCol::new(
412            vec![
413                (PairCol::Main(0), F::ONE),
414                (PairCol::Main(1), F::from_u8(16)),
415                (PairCol::Preprocessed(0), F::from_u8(64)),
416            ],
417            F::ZERO,
418        );
419
420        let pre = [F::from_u8(1)];
421        let main = [F::from_u8(3), F::from_u8(2)];
422
423        let result = col.apply::<F, F>(&pre, &main);
424
425        assert_eq!(result, F::from_u8(99));
426    }
427
428    #[test]
429    fn test_virtual_pair_col_reuses_layout_across_traces() {
430        let col = VirtualPairCol::new(
431            vec![
432                (PairCol::Preprocessed(0), F::ONE),
433                (PairCol::Main(0), F::ONE),
434                (PairCol::Main(1), F::NEG_ONE),
435            ],
436            F::from_u8(3),
437        );
438
439        let pre_a = [F::from_u8(10)];
440        let main_a = [F::from_u8(8), F::from_u8(2)];
441        let pre_b = [F::from_u8(7)];
442        let main_b = [F::from_u8(4), F::from_u8(9)];
443
444        assert_eq!(col.apply::<F, F>(&pre_a, &main_a), F::from_u8(19));
445        assert_eq!(col.apply::<F, F>(&pre_b, &main_b), F::from_u8(5));
446    }
447
448    #[test]
449    fn test_virtual_pair_col_one_is_identity() {
450        // VirtualPairCol::ONE should always evaluate to 1 regardless of input
451        let col = VirtualPairCol::<F>::ONE;
452        let pre = [F::from_u8(99)];
453        let main = [F::from_u8(42)];
454
455        let result = col.apply::<F, F>(&pre, &main);
456
457        assert_eq!(result, F::ONE);
458    }
459}