openvm_stark_sdk/dummy_airs/interaction/
dummy_interaction_air.rs
1use std::{iter, sync::Arc};
8
9use derivative::Derivative;
10use itertools::izip;
11use openvm_stark_backend::{
12 air_builders::PartitionedAirBuilder,
13 config::{StarkGenericConfig, Val},
14 interaction::{BusIndex, InteractionBuilder},
15 p3_air::{Air, BaseAir},
16 p3_field::{Field, FieldAlgebra},
17 p3_matrix::{dense::RowMajorMatrix, Matrix},
18 prover::{
19 cpu::CpuDevice,
20 hal::TraceCommitter,
21 types::{AirProofInput, AirProofRawInput, CommittedTraceData},
22 },
23 rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir},
24 Chip, ChipUsageGetter,
25};
26
27pub struct DummyInteractionCols;
28impl DummyInteractionCols {
29 pub fn count_col() -> usize {
30 0
31 }
32 pub fn field_col(field_idx: usize) -> usize {
33 field_idx + 1
34 }
35}
36
37#[derive(Clone, Copy)]
38pub struct DummyInteractionAir {
39 field_width: usize,
40 pub is_send: bool,
42 bus_index: BusIndex,
43 pub count_weight: u32,
44 pub partition: bool,
46}
47
48impl DummyInteractionAir {
49 pub fn new(field_width: usize, is_send: bool, bus_index: BusIndex) -> Self {
50 Self {
51 field_width,
52 is_send,
53 bus_index,
54 count_weight: 0,
55 partition: false,
56 }
57 }
58
59 pub fn partition(self) -> Self {
60 Self {
61 partition: true,
62 ..self
63 }
64 }
65
66 pub fn field_width(&self) -> usize {
67 self.field_width
68 }
69}
70
71impl<F: Field> BaseAirWithPublicValues<F> for DummyInteractionAir {}
72impl<F: Field> PartitionedBaseAir<F> for DummyInteractionAir {
73 fn cached_main_widths(&self) -> Vec<usize> {
74 if self.partition {
75 vec![self.field_width]
76 } else {
77 vec![]
78 }
79 }
80 fn common_main_width(&self) -> usize {
81 if self.partition {
82 1
83 } else {
84 1 + self.field_width
85 }
86 }
87}
88impl<F: Field> BaseAir<F> for DummyInteractionAir {
89 fn width(&self) -> usize {
90 1 + self.field_width
91 }
92
93 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
94 None
95 }
96}
97
98impl<AB: InteractionBuilder + PartitionedAirBuilder> Air<AB> for DummyInteractionAir {
99 fn eval(&self, builder: &mut AB) {
100 let (fields, count) = if self.partition {
101 let local_0 = builder.common_main().row_slice(0);
102 let local_1 = builder.cached_mains()[0].row_slice(0);
103 let count = local_0[0];
104 let fields = local_1.to_vec();
105 (fields, count)
106 } else {
107 let main = builder.main();
108 let local = main.row_slice(0);
109 let count = local[DummyInteractionCols::count_col()];
110 let fields: Vec<_> = (0..self.field_width)
111 .map(|i| local[DummyInteractionCols::field_col(i)])
112 .collect();
113 (fields, count)
114 };
115 if self.is_send {
116 builder.push_interaction(self.bus_index, fields, count, self.count_weight);
117 } else {
118 builder.push_interaction(
119 self.bus_index,
120 fields,
121 AB::Expr::NEG_ONE * count,
122 self.count_weight,
123 );
124 }
125 }
126}
127
128#[derive(Derivative)]
131#[derivative(Clone(bound = ""))]
132pub struct DummyInteractionChip<'a, SC: StarkGenericConfig> {
133 device: Option<CpuDevice<'a, SC>>,
134 data: Option<DummyInteractionData>,
136 pub air: DummyInteractionAir,
137}
138
139#[derive(Debug, Clone)]
140pub struct DummyInteractionData {
141 pub count: Vec<u32>,
142 pub fields: Vec<Vec<u32>>,
143}
144
145impl<'a, SC: StarkGenericConfig> DummyInteractionChip<'a, SC>
146where
147 Val<SC>: FieldAlgebra,
148{
149 pub fn new_without_partition(field_width: usize, is_send: bool, bus_index: BusIndex) -> Self {
150 let air = DummyInteractionAir::new(field_width, is_send, bus_index);
151 Self {
152 device: None,
153 data: None,
154 air,
155 }
156 }
157 pub fn new_with_partition(
158 config: &'a SC,
159 field_width: usize,
160 is_send: bool,
161 bus_index: BusIndex,
162 ) -> Self {
163 let air = DummyInteractionAir::new(field_width, is_send, bus_index).partition();
164 Self {
165 device: Some(CpuDevice::new(config)),
166 data: None,
167 air,
168 }
169 }
170 pub fn load_data(&mut self, data: DummyInteractionData) {
171 let DummyInteractionData { count, fields } = &data;
172 let h = count.len();
173 assert_eq!(fields.len(), h);
174 let w = fields[0].len();
175 assert_eq!(self.air.field_width, w);
176 assert!(fields.iter().all(|r| r.len() == w));
177 self.data = Some(data);
178 }
179
180 #[allow(clippy::type_complexity)]
181 fn generate_traces_with_partition(
182 &self,
183 data: DummyInteractionData,
184 ) -> (RowMajorMatrix<Val<SC>>, CommittedTraceData<SC>) {
185 let DummyInteractionData {
186 mut count,
187 mut fields,
188 } = data;
189 let h = count.len();
190 assert_eq!(fields.len(), h);
191 let w = fields[0].len();
192 assert_eq!(self.air.field_width, w);
193 assert!(fields.iter().all(|r| r.len() == w));
194 let h = h.next_power_of_two();
195 count.resize(h, 0);
196 fields.resize(h, vec![0; w]);
197 let common_main_val: Vec<_> = count
198 .into_iter()
199 .map(Val::<SC>::from_canonical_u32)
200 .collect();
201 let cached_trace_val: Vec<_> = fields
202 .into_iter()
203 .flatten()
204 .map(Val::<SC>::from_canonical_u32)
205 .collect();
206 let cached_trace = Arc::new(RowMajorMatrix::new(cached_trace_val, w));
207 let (commit, data) = self
208 .device
209 .as_ref()
210 .unwrap()
211 .commit(&[cached_trace.clone()]);
212 (
213 RowMajorMatrix::new(common_main_val, 1),
214 CommittedTraceData {
215 trace: cached_trace,
216 commitment: commit,
217 pcs_data: data.data,
218 },
219 )
220 }
221
222 fn generate_traces_without_partition(
223 &self,
224 data: DummyInteractionData,
225 ) -> RowMajorMatrix<Val<SC>> {
226 let DummyInteractionData { count, fields } = data;
227 let h = count.len();
228 assert_eq!(fields.len(), h);
229 let w = fields[0].len();
230 assert_eq!(self.air.field_width, w);
231 assert!(fields.iter().all(|r| r.len() == w));
232 let common_main_val: Vec<_> = izip!(count, fields)
233 .flat_map(|(count, fields)| iter::once(count).chain(fields))
234 .chain(iter::repeat(0))
235 .take((w + 1) * h.next_power_of_two())
236 .map(Val::<SC>::from_canonical_u32)
237 .collect();
238 RowMajorMatrix::new(common_main_val, w + 1)
239 }
240}
241
242impl<SC: StarkGenericConfig> Chip<SC> for DummyInteractionChip<'_, SC> {
243 fn air(&self) -> Arc<dyn AnyRap<SC>> {
244 Arc::new(self.air)
245 }
246
247 fn generate_air_proof_input(self) -> AirProofInput<SC> {
248 assert!(self.data.is_some());
249 let data = self.data.clone().unwrap();
250 if self.device.is_some() {
251 let (common_main, cached) = self.generate_traces_with_partition(data);
252 AirProofInput {
253 cached_mains_pdata: vec![(cached.commitment, cached.pcs_data)],
254 raw: AirProofRawInput {
255 cached_mains: vec![cached.trace],
256 common_main: Some(common_main),
257 public_values: vec![],
258 },
259 }
260 } else {
261 let common_main = self.generate_traces_without_partition(data);
262 AirProofInput {
263 cached_mains_pdata: vec![],
264 raw: AirProofRawInput {
265 cached_mains: vec![],
266 common_main: Some(common_main),
267 public_values: vec![],
268 },
269 }
270 }
271 }
272}
273
274impl<SC: StarkGenericConfig> ChipUsageGetter for DummyInteractionChip<'_, SC> {
275 fn air_name(&self) -> String {
276 "DummyInteractionAir".to_string()
277 }
278 fn current_trace_height(&self) -> usize {
279 if let Some(data) = &self.data {
280 data.count.len()
281 } else {
282 0
283 }
284 }
285
286 fn trace_width(&self) -> usize {
287 self.air.field_width + 1
288 }
289}