openvm_stark_sdk/dummy_airs/interaction/
dummy_interaction_air.rs

1//! Air with columns
2//! | count | fields[..] |
3//!
4//! Chip will either send or receive the fields with multiplicity count.
5//! The main Air has no constraints, the only constraints are specified by the Chip trait
6
7use std::{iter, sync::Arc};
8
9use derivative::Derivative;
10use itertools::izip;
11use openvm_stark_backend::{
12    air_builders::PartitionedAirBuilder,
13    config::{StarkGenericConfig, Val},
14    interaction::{BusIndex, InteractionBuilder},
15    p3_air::{Air, BaseAir},
16    p3_field::{Field, FieldAlgebra},
17    p3_matrix::{dense::RowMajorMatrix, Matrix},
18    prover::{
19        cpu::CpuDevice,
20        hal::TraceCommitter,
21        types::{AirProofInput, AirProofRawInput, CommittedTraceData},
22    },
23    rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir},
24    Chip, ChipUsageGetter,
25};
26
27pub struct DummyInteractionCols;
28impl DummyInteractionCols {
29    pub fn count_col() -> usize {
30        0
31    }
32    pub fn field_col(field_idx: usize) -> usize {
33        field_idx + 1
34    }
35}
36
37#[derive(Clone, Copy)]
38pub struct DummyInteractionAir {
39    field_width: usize,
40    /// Send if true. Receive if false.
41    pub is_send: bool,
42    bus_index: BusIndex,
43    pub count_weight: u32,
44    /// If true, then | count | and | fields[..] | are in separate main trace partitions.
45    pub partition: bool,
46}
47
48impl DummyInteractionAir {
49    pub fn new(field_width: usize, is_send: bool, bus_index: BusIndex) -> Self {
50        Self {
51            field_width,
52            is_send,
53            bus_index,
54            count_weight: 0,
55            partition: false,
56        }
57    }
58
59    pub fn partition(self) -> Self {
60        Self {
61            partition: true,
62            ..self
63        }
64    }
65
66    pub fn field_width(&self) -> usize {
67        self.field_width
68    }
69}
70
71impl<F: Field> BaseAirWithPublicValues<F> for DummyInteractionAir {}
72impl<F: Field> PartitionedBaseAir<F> for DummyInteractionAir {
73    fn cached_main_widths(&self) -> Vec<usize> {
74        if self.partition {
75            vec![self.field_width]
76        } else {
77            vec![]
78        }
79    }
80    fn common_main_width(&self) -> usize {
81        if self.partition {
82            1
83        } else {
84            1 + self.field_width
85        }
86    }
87}
88impl<F: Field> BaseAir<F> for DummyInteractionAir {
89    fn width(&self) -> usize {
90        1 + self.field_width
91    }
92
93    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
94        None
95    }
96}
97
98impl<AB: InteractionBuilder + PartitionedAirBuilder> Air<AB> for DummyInteractionAir {
99    fn eval(&self, builder: &mut AB) {
100        let (fields, count) = if self.partition {
101            let local_0 = builder.common_main().row_slice(0);
102            let local_1 = builder.cached_mains()[0].row_slice(0);
103            let count = local_0[0];
104            let fields = local_1.to_vec();
105            (fields, count)
106        } else {
107            let main = builder.main();
108            let local = main.row_slice(0);
109            let count = local[DummyInteractionCols::count_col()];
110            let fields: Vec<_> = (0..self.field_width)
111                .map(|i| local[DummyInteractionCols::field_col(i)])
112                .collect();
113            (fields, count)
114        };
115        if self.is_send {
116            builder.push_interaction(self.bus_index, fields, count, self.count_weight);
117        } else {
118            builder.push_interaction(
119                self.bus_index,
120                fields,
121                AB::Expr::NEG_ONE * count,
122                self.count_weight,
123            );
124        }
125    }
126}
127
128/// Note: in principle, committing cached trace is out of scope of a chip. But this chip is for
129/// usually testing, so we support it for convenience.
130#[derive(Derivative)]
131#[derivative(Clone(bound = ""))]
132pub struct DummyInteractionChip<'a, SC: StarkGenericConfig> {
133    device: Option<CpuDevice<'a, SC>>,
134    // common_main: Option<RowMajorMatrix<Val<SC>>>,
135    data: Option<DummyInteractionData>,
136    pub air: DummyInteractionAir,
137}
138
139#[derive(Debug, Clone)]
140pub struct DummyInteractionData {
141    pub count: Vec<u32>,
142    pub fields: Vec<Vec<u32>>,
143}
144
145impl<'a, SC: StarkGenericConfig> DummyInteractionChip<'a, SC>
146where
147    Val<SC>: FieldAlgebra,
148{
149    pub fn new_without_partition(field_width: usize, is_send: bool, bus_index: BusIndex) -> Self {
150        let air = DummyInteractionAir::new(field_width, is_send, bus_index);
151        Self {
152            device: None,
153            data: None,
154            air,
155        }
156    }
157    pub fn new_with_partition(
158        config: &'a SC,
159        field_width: usize,
160        is_send: bool,
161        bus_index: BusIndex,
162    ) -> Self {
163        let air = DummyInteractionAir::new(field_width, is_send, bus_index).partition();
164        Self {
165            device: Some(CpuDevice::new(config)),
166            data: None,
167            air,
168        }
169    }
170    pub fn load_data(&mut self, data: DummyInteractionData) {
171        let DummyInteractionData { count, fields } = &data;
172        let h = count.len();
173        assert_eq!(fields.len(), h);
174        let w = fields[0].len();
175        assert_eq!(self.air.field_width, w);
176        assert!(fields.iter().all(|r| r.len() == w));
177        self.data = Some(data);
178    }
179
180    #[allow(clippy::type_complexity)]
181    fn generate_traces_with_partition(
182        &self,
183        data: DummyInteractionData,
184    ) -> (RowMajorMatrix<Val<SC>>, CommittedTraceData<SC>) {
185        let DummyInteractionData {
186            mut count,
187            mut fields,
188        } = data;
189        let h = count.len();
190        assert_eq!(fields.len(), h);
191        let w = fields[0].len();
192        assert_eq!(self.air.field_width, w);
193        assert!(fields.iter().all(|r| r.len() == w));
194        let h = h.next_power_of_two();
195        count.resize(h, 0);
196        fields.resize(h, vec![0; w]);
197        let common_main_val: Vec<_> = count
198            .into_iter()
199            .map(Val::<SC>::from_canonical_u32)
200            .collect();
201        let cached_trace_val: Vec<_> = fields
202            .into_iter()
203            .flatten()
204            .map(Val::<SC>::from_canonical_u32)
205            .collect();
206        let cached_trace = Arc::new(RowMajorMatrix::new(cached_trace_val, w));
207        let (commit, data) = self
208            .device
209            .as_ref()
210            .unwrap()
211            .commit(&[cached_trace.clone()]);
212        (
213            RowMajorMatrix::new(common_main_val, 1),
214            CommittedTraceData {
215                trace: cached_trace,
216                commitment: commit,
217                pcs_data: data.data,
218            },
219        )
220    }
221
222    fn generate_traces_without_partition(
223        &self,
224        data: DummyInteractionData,
225    ) -> RowMajorMatrix<Val<SC>> {
226        let DummyInteractionData { count, fields } = data;
227        let h = count.len();
228        assert_eq!(fields.len(), h);
229        let w = fields[0].len();
230        assert_eq!(self.air.field_width, w);
231        assert!(fields.iter().all(|r| r.len() == w));
232        let common_main_val: Vec<_> = izip!(count, fields)
233            .flat_map(|(count, fields)| iter::once(count).chain(fields))
234            .chain(iter::repeat(0))
235            .take((w + 1) * h.next_power_of_two())
236            .map(Val::<SC>::from_canonical_u32)
237            .collect();
238        RowMajorMatrix::new(common_main_val, w + 1)
239    }
240}
241
242impl<SC: StarkGenericConfig> Chip<SC> for DummyInteractionChip<'_, SC> {
243    fn air(&self) -> Arc<dyn AnyRap<SC>> {
244        Arc::new(self.air)
245    }
246
247    fn generate_air_proof_input(self) -> AirProofInput<SC> {
248        assert!(self.data.is_some());
249        let data = self.data.clone().unwrap();
250        if self.device.is_some() {
251            let (common_main, cached) = self.generate_traces_with_partition(data);
252            AirProofInput {
253                cached_mains_pdata: vec![(cached.commitment, cached.pcs_data)],
254                raw: AirProofRawInput {
255                    cached_mains: vec![cached.trace],
256                    common_main: Some(common_main),
257                    public_values: vec![],
258                },
259            }
260        } else {
261            let common_main = self.generate_traces_without_partition(data);
262            AirProofInput {
263                cached_mains_pdata: vec![],
264                raw: AirProofRawInput {
265                    cached_mains: vec![],
266                    common_main: Some(common_main),
267                    public_values: vec![],
268                },
269            }
270        }
271    }
272}
273
274impl<SC: StarkGenericConfig> ChipUsageGetter for DummyInteractionChip<'_, SC> {
275    fn air_name(&self) -> String {
276        "DummyInteractionAir".to_string()
277    }
278    fn current_trace_height(&self) -> usize {
279        if let Some(data) = &self.data {
280            data.count.len()
281        } else {
282            0
283        }
284    }
285
286    fn trace_width(&self) -> usize {
287        self.air.field_width + 1
288    }
289}