p3_air/virtual_column.rs
1//! Utilities for describing logical columns over a `(preprocessed_row, main_row)` pair.
2//!
3//! [`VirtualPairCol`] lets a gadget talk about a derived value without allocating a dedicated
4//! trace column. Instead, define the value once as an affine combination of existing
5//! preprocessed and main columns, then reuse that description anywhere the row layout matches.
6//!
7//! Typical uses include:
8//! - combining fixed columns and witness columns into one gadget input,
9//! - defining lookup/table values as packed or shifted views over existing columns,
10//! - reusing the same gadget wiring across multiple traces that share a column layout.
11//!
12//! # Example: combine fixed and witness columns
13//!
14//! ```
15//! use p3_air::{PairCol, VirtualPairCol};
16//! use p3_baby_bear::BabyBear;
17//! use p3_field::PrimeCharacteristicRing;
18//!
19//! type F = BabyBear;
20//!
21//! let gadget_input = VirtualPairCol::new(
22//! vec![
23//! (PairCol::Main(0), F::ONE),
24//! (PairCol::Main(1), F::from_u8(2)),
25//! (PairCol::Preprocessed(0), F::ONE),
26//! ],
27//! F::from_u8(7),
28//! );
29//!
30//! let preprocessed = [F::from_u8(5)];
31//! let main = [F::from_u8(11), F::from_u8(13)];
32//!
33//! assert_eq!(gadget_input.apply::<F, F>(&preprocessed, &main), F::from_u8(49));
34//! ```
35//!
36//! # Example: reuse one logical lookup key across traces
37//!
38//! ```
39//! use p3_air::{PairCol, VirtualPairCol};
40//! use p3_baby_bear::BabyBear;
41//! use p3_field::PrimeCharacteristicRing;
42//!
43//! type F = BabyBear;
44//!
45//! // Pack two witness limbs, and include a fixed table offset.
46//! let lookup_key = VirtualPairCol::new(
47//! vec![
48//! (PairCol::Main(0), F::ONE),
49//! (PairCol::Main(1), F::from_u8(16)),
50//! (PairCol::Preprocessed(0), F::from_u8(64)),
51//! ],
52//! F::ZERO,
53//! );
54//!
55//! let table_layout = [F::from_u8(1)];
56//! let trace_a = [F::from_u8(3), F::from_u8(2)];
57//! let trace_b = [F::from_u8(5), F::from_u8(1)];
58//!
59//! assert_eq!(lookup_key.apply::<F, F>(&table_layout, &trace_a), F::from_u8(99));
60//! assert_eq!(lookup_key.apply::<F, F>(&table_layout, &trace_b), F::from_u8(85));
61//! ```
62
63use alloc::vec;
64use alloc::vec::Vec;
65use core::ops::Mul;
66
67use p3_field::{Field, PrimeCharacteristicRing};
68
69/// An affine linear combination of columns in a PAIR (Preprocessed AIR).
70///
71/// This structure represents the column `V` with entries `V[j] = Σ(w_i * V_i[j]) + c` where:
72/// - `w_i` are the column weights
73/// - `V_i` are the columns (either preprocessed or main trace columns)
74/// - `c` is a constant term
75///
76/// A `VirtualPairCol` does not allocate any new trace storage. It only records how to read a
77/// logical value from the existing preprocessed and main rows, which makes it convenient for
78/// reusable gadget inputs, packed lookup values, and cross-trace layouts that share the same
79/// column wiring.
80#[derive(Clone, Debug)]
81pub struct VirtualPairCol<F: Field> {
82 /// Linear combination coefficients: pairs of (column, weight).
83 column_weights: Vec<(PairCol, F)>,
84 /// Constant term added to the linear combination.
85 constant: F,
86}
87
88/// A reference to a column in a PAIR (Preprocessed AIR).
89#[derive(Clone, Copy, Debug)]
90pub enum PairCol {
91 /// A preprocessed (fixed) column at the specified index.
92 ///
93 /// These columns contain values that are determined during the setup phase
94 /// and remain constant across all proof generations.
95 Preprocessed(usize),
96 /// A main trace column at the specified index.
97 ///
98 /// These columns contain witness values that vary between different executions
99 /// and are filled during trace generation.
100 Main(usize),
101}
102
103impl PairCol {
104 /// Retrieves the value corresponding to the appropriate column.
105 ///
106 /// # Arguments
107 /// * `preprocessed` - Slice containing preprocessed row
108 /// * `main` - Slice containing main trace row
109 ///
110 /// # Panics
111 /// Panics if the column index is out of bounds for the respective rows.
112 pub const fn get<T: Copy>(&self, preprocessed: &[T], main: &[T]) -> T {
113 match self {
114 Self::Preprocessed(i) => preprocessed[*i],
115 Self::Main(i) => main[*i],
116 }
117 }
118}
119
120impl<F: Field> VirtualPairCol<F> {
121 /// Creates a new virtual column with the specified column weights and constant term.
122 ///
123 /// # Arguments
124 /// * `column_weights` - Vector of (column, weight) pairs defining the linear combination
125 /// * `constant` - Constant term to add to the linear combination
126 pub const fn new(column_weights: Vec<(PairCol, F)>, constant: F) -> Self {
127 Self {
128 column_weights,
129 constant,
130 }
131 }
132
133 /// Creates a virtual column as a linear combination of preprocessed columns.
134 ///
135 /// # Arguments
136 /// * `column_weights` - Vector of (column_index, weight) pairs for preprocessed columns
137 /// * `constant` - Constant term to add to the combination
138 pub fn new_preprocessed(column_weights: Vec<(usize, F)>, constant: F) -> Self {
139 Self::new(
140 column_weights
141 .into_iter()
142 .map(|(i, w)| (PairCol::Preprocessed(i), w))
143 .collect(),
144 constant,
145 )
146 }
147
148 /// Creates a virtual column as a linear combination of main trace columns.
149 ///
150 /// # Arguments
151 /// * `column_weights` - Vector of (column_index, weight) pairs for main trace columns
152 /// * `constant` - Constant term to add to the combination
153 pub fn new_main(column_weights: Vec<(usize, F)>, constant: F) -> Self {
154 Self::new(
155 column_weights
156 .into_iter()
157 .map(|(i, w)| (PairCol::Main(i), w))
158 .collect(),
159 constant,
160 )
161 }
162
163 /// A virtual column that always evaluates to the field element `1`.
164 pub const ONE: Self = Self::constant(F::ONE);
165
166 /// Create a virtual column whose value on every row is equal and constant.
167 ///
168 /// # Arguments
169 /// * `x` - The constant field element.
170 #[must_use]
171 pub const fn constant(x: F) -> Self {
172 Self {
173 column_weights: vec![],
174 constant: x,
175 }
176 }
177
178 /// Creates a virtual column equal to a provided column.
179 ///
180 /// # Arguments
181 /// * `column` - The column to represent.
182 #[must_use]
183 pub fn single(column: PairCol) -> Self {
184 Self {
185 column_weights: vec![(column, F::ONE)],
186 constant: F::ZERO,
187 }
188 }
189
190 /// Creates a virtual column equal to a preprocessed column.
191 ///
192 /// # Arguments
193 /// * `column` - Index of the preprocessed column
194 #[must_use]
195 pub fn single_preprocessed(column: usize) -> Self {
196 Self::single(PairCol::Preprocessed(column))
197 }
198
199 /// Creates a virtual column equal to a main trace column.
200 ///
201 /// # Arguments
202 /// * `column` - Index of the main trace column
203 #[must_use]
204 pub fn single_main(column: usize) -> Self {
205 Self::single(PairCol::Main(column))
206 }
207
208 /// Create a virtual column which is the sum of main trace columns.
209 ///
210 /// # Arguments
211 /// * `columns` - Vector of main trace column indices to sum
212 #[must_use]
213 pub fn sum_main(columns: Vec<usize>) -> Self {
214 let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
215 Self::new_main(column_weights, F::ZERO)
216 }
217
218 /// Create a virtual column which is the sum of preprocessed columns.
219 ///
220 /// # Arguments
221 /// * `columns` - Vector of preprocessed column indices to sum
222 #[must_use]
223 pub fn sum_preprocessed(columns: Vec<usize>) -> Self {
224 let column_weights = columns.into_iter().map(|col| (col, F::ONE)).collect();
225 Self::new_preprocessed(column_weights, F::ZERO)
226 }
227
228 /// Create a virtual column which is the difference between two preprocessed columns.
229 ///
230 /// # Arguments
231 /// * `a_col` - Index of the minuend preprocessed column.
232 /// * `b_col` - Index of the subtrahend preprocessed column.
233 #[must_use]
234 pub fn diff_preprocessed(a_col: usize, b_col: usize) -> Self {
235 Self::new_preprocessed(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
236 }
237
238 /// Create a virtual column which is the difference between two main trace columns.
239 ///
240 /// # Arguments
241 /// * `a_col` - Index of the minuend main trace column.
242 /// * `b_col` - Index of the subtrahend main trace column.
243 #[must_use]
244 pub fn diff_main(a_col: usize, b_col: usize) -> Self {
245 Self::new_main(vec![(a_col, F::ONE), (b_col, F::NEG_ONE)], F::ZERO)
246 }
247
248 /// Evaluates the virtual column at a given row by applying the affine linear combination to a pair of preprocessed and main trace rows.
249 ///
250 /// This computes `Σ(w_i * column_values[i]) + constant`
251 ///
252 /// # Arguments
253 /// * `preprocessed` - Row of preprocessed values.
254 /// * `main` - Row of main trace values.
255 pub fn apply<Expr, Var>(&self, preprocessed: &[Var], main: &[Var]) -> Expr
256 where
257 F: Into<Expr>,
258 Expr: PrimeCharacteristicRing + Mul<F, Output = Expr>,
259 Var: Into<Expr> + Copy,
260 {
261 self.column_weights
262 .iter()
263 .fold(self.constant.into(), |acc, &(col, w)| {
264 acc + col.get(preprocessed, main).into() * w
265 })
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use p3_baby_bear::BabyBear;
272
273 use super::*;
274
275 type F = BabyBear;
276
277 #[test]
278 fn test_pair_col_get_main_and_preprocessed() {
279 let pre = [F::from_u8(10), F::from_u8(20)];
280 let main = [F::from_u8(30), F::from_u8(40)];
281
282 // Preprocessed(1) should return 20
283 assert_eq!(PairCol::Preprocessed(1).get(&pre, &main), F::from_u8(20));
284
285 // Main(0) should return 30
286 assert_eq!(PairCol::Main(0).get(&pre, &main), F::from_u8(30));
287 }
288
289 #[test]
290 fn test_constant_only_virtual_pair_col() {
291 let col = VirtualPairCol::<F>::constant(F::from_u8(7));
292
293 // Apply to any input: result should always be the constant
294 let pre = [F::ONE];
295 let main = [F::ONE];
296 let result = col.apply::<F, F>(&pre, &main);
297
298 assert_eq!(result, F::from_u8(7));
299 }
300
301 #[test]
302 fn test_single_main_column() {
303 let col = VirtualPairCol::<F>::single_main(1); // column index 1
304
305 let main = [F::from_u8(9), F::from_u8(5)];
306 let pre = [F::ZERO]; // ignored
307
308 let result = col.apply::<F, F>(&pre, &main);
309
310 // Since we used single_main(1), this should equal main[1] = 5
311 assert_eq!(result, F::from_u8(5));
312 }
313
314 #[test]
315 fn test_single_preprocessed_column() {
316 let col = VirtualPairCol::<F>::single_preprocessed(0);
317
318 let pre = [F::from_u8(12)];
319 let main = [];
320
321 let result = col.apply::<F, F>(&pre, &main);
322
323 assert_eq!(result, F::from_u8(12));
324 }
325
326 #[test]
327 fn test_sum_main_columns() {
328 // This adds up main[0] + main[2]
329 let col = VirtualPairCol::<F>::sum_main(vec![0, 2]);
330
331 let main = [
332 F::TWO,
333 F::from_u8(99), // ignored
334 F::from_u8(5),
335 ];
336 let pre = [];
337
338 let result = col.apply::<F, F>(&pre, &main);
339
340 assert_eq!(result, F::from_u8(2) + F::from_u8(5));
341 }
342
343 #[test]
344 fn test_sum_preprocessed_columns() {
345 let col = VirtualPairCol::<F>::sum_preprocessed(vec![1, 2]);
346
347 let pre = [
348 F::from_u8(3), // ignored
349 F::from_u8(4),
350 F::from_u8(6),
351 ];
352 let main = [];
353
354 let result = col.apply::<F, F>(&pre, &main);
355
356 assert_eq!(result, F::from_u8(4) + F::from_u8(6));
357 }
358
359 #[test]
360 fn test_diff_main_columns() {
361 // Computes main[2] - main[0]
362 let col = VirtualPairCol::<F>::diff_main(2, 0);
363
364 let main = [
365 F::from_u8(7),
366 F::ZERO, // ignored
367 F::from_u8(10),
368 ];
369 let pre = [];
370
371 let result = col.apply::<F, F>(&pre, &main);
372
373 assert_eq!(result, F::from_u8(10) - F::from_u8(7));
374 }
375
376 #[test]
377 fn test_diff_preprocessed_columns() {
378 // Computes pre[1] - pre[0]
379 let col = VirtualPairCol::<F>::diff_preprocessed(1, 0);
380
381 let pre = [F::from_u8(4), F::from_u8(15)];
382 let main = [];
383
384 let result = col.apply::<F, F>(&pre, &main);
385
386 assert_eq!(result, F::from_u8(15) - F::from_u8(4));
387 }
388
389 #[test]
390 fn test_combination_with_constant_and_weights() {
391 // Computes: 3 * main[1] + 2 * pre[0] + constant (5)
392 let col = VirtualPairCol {
393 column_weights: vec![
394 (PairCol::Main(1), F::from_u8(3)),
395 (PairCol::Preprocessed(0), F::TWO),
396 ],
397 constant: F::from_u8(5),
398 };
399
400 let main = [F::ZERO, F::from_u8(4)];
401 let pre = [F::from_u8(6)];
402
403 let result = col.apply::<F, F>(&pre, &main);
404
405 // result = 3*4 + 2*6 + 5
406 assert_eq!(result, F::from_u8(29));
407 }
408
409 #[test]
410 fn test_virtual_pair_col_can_pack_lookup_value() {
411 let col = VirtualPairCol::new(
412 vec![
413 (PairCol::Main(0), F::ONE),
414 (PairCol::Main(1), F::from_u8(16)),
415 (PairCol::Preprocessed(0), F::from_u8(64)),
416 ],
417 F::ZERO,
418 );
419
420 let pre = [F::from_u8(1)];
421 let main = [F::from_u8(3), F::from_u8(2)];
422
423 let result = col.apply::<F, F>(&pre, &main);
424
425 assert_eq!(result, F::from_u8(99));
426 }
427
428 #[test]
429 fn test_virtual_pair_col_reuses_layout_across_traces() {
430 let col = VirtualPairCol::new(
431 vec![
432 (PairCol::Preprocessed(0), F::ONE),
433 (PairCol::Main(0), F::ONE),
434 (PairCol::Main(1), F::NEG_ONE),
435 ],
436 F::from_u8(3),
437 );
438
439 let pre_a = [F::from_u8(10)];
440 let main_a = [F::from_u8(8), F::from_u8(2)];
441 let pre_b = [F::from_u8(7)];
442 let main_b = [F::from_u8(4), F::from_u8(9)];
443
444 assert_eq!(col.apply::<F, F>(&pre_a, &main_a), F::from_u8(19));
445 assert_eq!(col.apply::<F, F>(&pre_b, &main_b), F::from_u8(5));
446 }
447
448 #[test]
449 fn test_virtual_pair_col_one_is_identity() {
450 // VirtualPairCol::ONE should always evaluate to 1 regardless of input
451 let col = VirtualPairCol::<F>::ONE;
452 let pre = [F::from_u8(99)];
453 let main = [F::from_u8(42)];
454
455 let result = col.apply::<F, F>(&pre, &main);
456
457 assert_eq!(result, F::ONE);
458 }
459}