openvm_stark_sdk/dummy_airs/interaction/
dummy_interaction_air.rs

1//! Air with columns
2//! | count | fields[..] |
3//!
4//! Chip will either send or receive the fields with multiplicity count.
5//! The main Air has no constraints, the only constraints are specified by the Chip trait
6
7use std::{iter, sync::Arc};
8
9use derivative::Derivative;
10use itertools::izip;
11use openvm_stark_backend::{
12    air_builders::PartitionedAirBuilder,
13    config::{StarkGenericConfig, Val},
14    interaction::{BusIndex, InteractionBuilder},
15    p3_air::{Air, BaseAir},
16    p3_field::{Field, FieldAlgebra},
17    p3_matrix::{dense::RowMajorMatrix, Matrix},
18    prover::{
19        cpu::{CpuBackend, CpuDevice},
20        hal::TraceCommitter,
21        types::{AirProvingContext, CommittedTraceData},
22    },
23    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
24    AirRef, Chip, ChipUsageGetter,
25};
26
27pub struct DummyInteractionCols;
28impl DummyInteractionCols {
29    pub fn count_col() -> usize {
30        0
31    }
32    pub fn field_col(field_idx: usize) -> usize {
33        field_idx + 1
34    }
35}
36
37#[derive(Clone, Copy)]
38pub struct DummyInteractionAir {
39    field_width: usize,
40    /// Send if true. Receive if false.
41    pub is_send: bool,
42    bus_index: BusIndex,
43    pub count_weight: u32,
44    /// If true, then | count | and | fields[..] | are in separate main trace partitions.
45    pub partition: bool,
46}
47
48impl DummyInteractionAir {
49    pub fn new(field_width: usize, is_send: bool, bus_index: BusIndex) -> Self {
50        Self {
51            field_width,
52            is_send,
53            bus_index,
54            count_weight: 0,
55            partition: false,
56        }
57    }
58
59    pub fn partition(self) -> Self {
60        Self {
61            partition: true,
62            ..self
63        }
64    }
65
66    pub fn field_width(&self) -> usize {
67        self.field_width
68    }
69}
70
71impl<F: Field> BaseAirWithPublicValues<F> for DummyInteractionAir {}
72impl<F: Field> PartitionedBaseAir<F> for DummyInteractionAir {
73    fn cached_main_widths(&self) -> Vec<usize> {
74        if self.partition {
75            vec![self.field_width]
76        } else {
77            vec![]
78        }
79    }
80    fn common_main_width(&self) -> usize {
81        if self.partition {
82            1
83        } else {
84            1 + self.field_width
85        }
86    }
87}
88impl<F: Field> BaseAir<F> for DummyInteractionAir {
89    fn width(&self) -> usize {
90        1 + self.field_width
91    }
92
93    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
94        None
95    }
96}
97
98impl<AB: InteractionBuilder + PartitionedAirBuilder> Air<AB> for DummyInteractionAir {
99    fn eval(&self, builder: &mut AB) {
100        let (fields, count) = if self.partition {
101            let local_0 = builder.common_main().row_slice(0);
102            let local_1 = builder.cached_mains()[0].row_slice(0);
103            let count = local_0[0];
104            let fields = local_1.to_vec();
105            (fields, count)
106        } else {
107            let main = builder.main();
108            let local = main.row_slice(0);
109            let count = local[DummyInteractionCols::count_col()];
110            let fields: Vec<_> = (0..self.field_width)
111                .map(|i| local[DummyInteractionCols::field_col(i)])
112                .collect();
113            (fields, count)
114        };
115        if self.is_send {
116            builder.push_interaction(self.bus_index, fields, count, self.count_weight);
117        } else {
118            builder.push_interaction(
119                self.bus_index,
120                fields,
121                AB::Expr::NEG_ONE * count,
122                self.count_weight,
123            );
124        }
125    }
126}
127
128/// Note: in principle, committing cached trace is out of scope of a chip. But this chip is for
129/// usually testing, so we support it for convenience.
130#[derive(Derivative)]
131#[derivative(Clone(bound = ""))]
132pub struct DummyInteractionChip<SC: StarkGenericConfig> {
133    device: Option<CpuDevice<SC>>,
134    data: Option<DummyInteractionData>,
135    pub air: DummyInteractionAir,
136}
137
138#[derive(Debug, Clone)]
139pub struct DummyInteractionData {
140    pub count: Vec<u32>,
141    pub fields: Vec<Vec<u32>>,
142}
143
144impl<SC: StarkGenericConfig> DummyInteractionChip<SC>
145where
146    Val<SC>: FieldAlgebra,
147{
148    pub fn new_without_partition(field_width: usize, is_send: bool, bus_index: BusIndex) -> Self {
149        let air = DummyInteractionAir::new(field_width, is_send, bus_index);
150        Self {
151            device: None,
152            data: None,
153            air,
154        }
155    }
156    pub fn new_with_partition(
157        device: CpuDevice<SC>,
158        field_width: usize,
159        is_send: bool,
160        bus_index: BusIndex,
161    ) -> Self {
162        let air = DummyInteractionAir::new(field_width, is_send, bus_index).partition();
163        Self {
164            device: Some(device),
165            data: None,
166            air,
167        }
168    }
169    pub fn load_data(&mut self, data: DummyInteractionData) {
170        let DummyInteractionData { count, fields } = &data;
171        let h = count.len();
172        assert_eq!(fields.len(), h);
173        let w = fields[0].len();
174        assert_eq!(self.air.field_width, w);
175        assert!(fields.iter().all(|r| r.len() == w));
176        self.data = Some(data);
177    }
178    pub fn air(&self) -> AirRef<SC> {
179        Arc::new(self.air)
180    }
181
182    fn generate_traces_with_partition(
183        &self,
184        data: DummyInteractionData,
185    ) -> AirProvingContext<CpuBackend<SC>> {
186        let DummyInteractionData {
187            mut count,
188            mut fields,
189        } = data;
190        let h = count.len();
191        assert_eq!(fields.len(), h);
192        let w = fields[0].len();
193        assert_eq!(self.air.field_width, w);
194        assert!(fields.iter().all(|r| r.len() == w));
195        let h = h.next_power_of_two();
196        count.resize(h, 0);
197        fields.resize(h, vec![0; w]);
198        let common_main_val: Vec<_> = count
199            .into_iter()
200            .map(Val::<SC>::from_canonical_u32)
201            .collect();
202        let cached_trace_val: Vec<_> = fields
203            .into_iter()
204            .flatten()
205            .map(Val::<SC>::from_canonical_u32)
206            .collect();
207        let cached_trace = Arc::new(RowMajorMatrix::new(cached_trace_val, w));
208        let (commit, data) = self
209            .device
210            .as_ref()
211            .unwrap()
212            .commit(std::slice::from_ref(&cached_trace));
213
214        AirProvingContext {
215            cached_mains: vec![CommittedTraceData {
216                commitment: commit,
217                data,
218                trace: cached_trace,
219            }],
220            common_main: Some(Arc::new(RowMajorMatrix::new(common_main_val, 1))),
221            public_values: vec![],
222        }
223    }
224
225    fn generate_traces_without_partition(
226        &self,
227        data: DummyInteractionData,
228    ) -> RowMajorMatrix<Val<SC>> {
229        let DummyInteractionData { count, fields } = data;
230        let h = count.len();
231        assert_eq!(fields.len(), h);
232        let w = fields[0].len();
233        assert_eq!(self.air.field_width, w);
234        assert!(fields.iter().all(|r| r.len() == w));
235        let common_main_val: Vec<_> = izip!(count, fields)
236            .flat_map(|(count, fields)| iter::once(count).chain(fields))
237            .chain(iter::repeat(0))
238            .take((w + 1) * h.next_power_of_two())
239            .map(Val::<SC>::from_canonical_u32)
240            .collect();
241        RowMajorMatrix::new(common_main_val, w + 1)
242    }
243}
244
245impl<SC: StarkGenericConfig> Chip<(), CpuBackend<SC>> for DummyInteractionChip<SC> {
246    fn generate_proving_ctx(&self, _: ()) -> AirProvingContext<CpuBackend<SC>> {
247        assert!(self.data.is_some());
248        let data = self.data.clone().unwrap();
249        if self.air.partition {
250            self.generate_traces_with_partition(data)
251        } else {
252            let trace = self.generate_traces_without_partition(data);
253            AirProvingContext::simple_no_pis(Arc::new(trace))
254        }
255    }
256}
257
258impl<SC: StarkGenericConfig> ChipUsageGetter for DummyInteractionChip<SC> {
259    fn air_name(&self) -> String {
260        "DummyInteractionAir".to_string()
261    }
262    fn current_trace_height(&self) -> usize {
263        if let Some(data) = &self.data {
264            data.count.len()
265        } else {
266            0
267        }
268    }
269
270    fn trace_width(&self) -> usize {
271        self.air.field_width + 1
272    }
273}