openvm_stark_sdk/dummy_airs/interaction/
dummy_interaction_air.rs1use std::{iter, sync::Arc};
8
9use derivative::Derivative;
10use itertools::izip;
11use openvm_stark_backend::{
12 air_builders::PartitionedAirBuilder,
13 config::{StarkGenericConfig, Val},
14 interaction::{BusIndex, InteractionBuilder},
15 p3_air::{Air, BaseAir},
16 p3_field::{Field, FieldAlgebra},
17 p3_matrix::{dense::RowMajorMatrix, Matrix},
18 prover::{
19 cpu::{CpuBackend, CpuDevice},
20 hal::TraceCommitter,
21 types::{AirProvingContext, CommittedTraceData},
22 },
23 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
24 AirRef, Chip, ChipUsageGetter,
25};
26
27pub struct DummyInteractionCols;
28impl DummyInteractionCols {
29 pub fn count_col() -> usize {
30 0
31 }
32 pub fn field_col(field_idx: usize) -> usize {
33 field_idx + 1
34 }
35}
36
37#[derive(Clone, Copy)]
38pub struct DummyInteractionAir {
39 field_width: usize,
40 pub is_send: bool,
42 bus_index: BusIndex,
43 pub count_weight: u32,
44 pub partition: bool,
46}
47
48impl DummyInteractionAir {
49 pub fn new(field_width: usize, is_send: bool, bus_index: BusIndex) -> Self {
50 Self {
51 field_width,
52 is_send,
53 bus_index,
54 count_weight: 0,
55 partition: false,
56 }
57 }
58
59 pub fn partition(self) -> Self {
60 Self {
61 partition: true,
62 ..self
63 }
64 }
65
66 pub fn field_width(&self) -> usize {
67 self.field_width
68 }
69}
70
71impl<F: Field> BaseAirWithPublicValues<F> for DummyInteractionAir {}
72impl<F: Field> PartitionedBaseAir<F> for DummyInteractionAir {
73 fn cached_main_widths(&self) -> Vec<usize> {
74 if self.partition {
75 vec![self.field_width]
76 } else {
77 vec![]
78 }
79 }
80 fn common_main_width(&self) -> usize {
81 if self.partition {
82 1
83 } else {
84 1 + self.field_width
85 }
86 }
87}
88impl<F: Field> BaseAir<F> for DummyInteractionAir {
89 fn width(&self) -> usize {
90 1 + self.field_width
91 }
92
93 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
94 None
95 }
96}
97
98impl<AB: InteractionBuilder + PartitionedAirBuilder> Air<AB> for DummyInteractionAir {
99 fn eval(&self, builder: &mut AB) {
100 let (fields, count) = if self.partition {
101 let local_0 = builder.common_main().row_slice(0);
102 let local_1 = builder.cached_mains()[0].row_slice(0);
103 let count = local_0[0];
104 let fields = local_1.to_vec();
105 (fields, count)
106 } else {
107 let main = builder.main();
108 let local = main.row_slice(0);
109 let count = local[DummyInteractionCols::count_col()];
110 let fields: Vec<_> = (0..self.field_width)
111 .map(|i| local[DummyInteractionCols::field_col(i)])
112 .collect();
113 (fields, count)
114 };
115 if self.is_send {
116 builder.push_interaction(self.bus_index, fields, count, self.count_weight);
117 } else {
118 builder.push_interaction(
119 self.bus_index,
120 fields,
121 AB::Expr::NEG_ONE * count,
122 self.count_weight,
123 );
124 }
125 }
126}
127
128#[derive(Derivative)]
131#[derivative(Clone(bound = ""))]
132pub struct DummyInteractionChip<SC: StarkGenericConfig> {
133 device: Option<CpuDevice<SC>>,
134 data: Option<DummyInteractionData>,
135 pub air: DummyInteractionAir,
136}
137
138#[derive(Debug, Clone)]
139pub struct DummyInteractionData {
140 pub count: Vec<u32>,
141 pub fields: Vec<Vec<u32>>,
142}
143
144impl<SC: StarkGenericConfig> DummyInteractionChip<SC>
145where
146 Val<SC>: FieldAlgebra,
147{
148 pub fn new_without_partition(field_width: usize, is_send: bool, bus_index: BusIndex) -> Self {
149 let air = DummyInteractionAir::new(field_width, is_send, bus_index);
150 Self {
151 device: None,
152 data: None,
153 air,
154 }
155 }
156 pub fn new_with_partition(
157 device: CpuDevice<SC>,
158 field_width: usize,
159 is_send: bool,
160 bus_index: BusIndex,
161 ) -> Self {
162 let air = DummyInteractionAir::new(field_width, is_send, bus_index).partition();
163 Self {
164 device: Some(device),
165 data: None,
166 air,
167 }
168 }
169 pub fn load_data(&mut self, data: DummyInteractionData) {
170 let DummyInteractionData { count, fields } = &data;
171 let h = count.len();
172 assert_eq!(fields.len(), h);
173 let w = fields[0].len();
174 assert_eq!(self.air.field_width, w);
175 assert!(fields.iter().all(|r| r.len() == w));
176 self.data = Some(data);
177 }
178 pub fn air(&self) -> AirRef<SC> {
179 Arc::new(self.air)
180 }
181
182 fn generate_traces_with_partition(
183 &self,
184 data: DummyInteractionData,
185 ) -> AirProvingContext<CpuBackend<SC>> {
186 let DummyInteractionData {
187 mut count,
188 mut fields,
189 } = data;
190 let h = count.len();
191 assert_eq!(fields.len(), h);
192 let w = fields[0].len();
193 assert_eq!(self.air.field_width, w);
194 assert!(fields.iter().all(|r| r.len() == w));
195 let h = h.next_power_of_two();
196 count.resize(h, 0);
197 fields.resize(h, vec![0; w]);
198 let common_main_val: Vec<_> = count
199 .into_iter()
200 .map(Val::<SC>::from_canonical_u32)
201 .collect();
202 let cached_trace_val: Vec<_> = fields
203 .into_iter()
204 .flatten()
205 .map(Val::<SC>::from_canonical_u32)
206 .collect();
207 let cached_trace = Arc::new(RowMajorMatrix::new(cached_trace_val, w));
208 let (commit, data) = self
209 .device
210 .as_ref()
211 .unwrap()
212 .commit(std::slice::from_ref(&cached_trace));
213
214 AirProvingContext {
215 cached_mains: vec![CommittedTraceData {
216 commitment: commit,
217 data,
218 trace: cached_trace,
219 }],
220 common_main: Some(Arc::new(RowMajorMatrix::new(common_main_val, 1))),
221 public_values: vec![],
222 }
223 }
224
225 fn generate_traces_without_partition(
226 &self,
227 data: DummyInteractionData,
228 ) -> RowMajorMatrix<Val<SC>> {
229 let DummyInteractionData { count, fields } = data;
230 let h = count.len();
231 assert_eq!(fields.len(), h);
232 let w = fields[0].len();
233 assert_eq!(self.air.field_width, w);
234 assert!(fields.iter().all(|r| r.len() == w));
235 let common_main_val: Vec<_> = izip!(count, fields)
236 .flat_map(|(count, fields)| iter::once(count).chain(fields))
237 .chain(iter::repeat(0))
238 .take((w + 1) * h.next_power_of_two())
239 .map(Val::<SC>::from_canonical_u32)
240 .collect();
241 RowMajorMatrix::new(common_main_val, w + 1)
242 }
243}
244
245impl<SC: StarkGenericConfig> Chip<(), CpuBackend<SC>> for DummyInteractionChip<SC> {
246 fn generate_proving_ctx(&self, _: ()) -> AirProvingContext<CpuBackend<SC>> {
247 assert!(self.data.is_some());
248 let data = self.data.clone().unwrap();
249 if self.air.partition {
250 self.generate_traces_with_partition(data)
251 } else {
252 let trace = self.generate_traces_without_partition(data);
253 AirProvingContext::simple_no_pis(Arc::new(trace))
254 }
255 }
256}
257
258impl<SC: StarkGenericConfig> ChipUsageGetter for DummyInteractionChip<SC> {
259 fn air_name(&self) -> String {
260 "DummyInteractionAir".to_string()
261 }
262 fn current_trace_height(&self) -> usize {
263 if let Some(data) = &self.data {
264 data.count.len()
265 } else {
266 0
267 }
268 }
269
270 fn trace_width(&self) -> usize {
271 self.air.field_width + 1
272 }
273}