openvm_mod_circuit_builder/
core_chip.rs
1use itertools::Itertools;
2use num_bigint::BigUint;
3use num_traits::Zero;
4use openvm_circuit::arch::{
5 AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, DynArray, MinimalInstruction,
6 Result, VmAdapterInterface, VmCoreAir, VmCoreChip,
7};
8use openvm_circuit_primitives::{
9 var_range::SharedVariableRangeCheckerChip, SubAir, TraceSubRowGenerator,
10};
11use openvm_instructions::instruction::Instruction;
12use openvm_stark_backend::{
13 interaction::InteractionBuilder,
14 p3_air::BaseAir,
15 p3_field::{Field, FieldAlgebra, PrimeField32},
16 p3_matrix::{dense::RowMajorMatrix, Matrix},
17 rap::BaseAirWithPublicValues,
18};
19use openvm_stark_sdk::p3_baby_bear::BabyBear;
20use serde::{Deserialize, Serialize};
21use serde_with::{serde_as, DisplayFromStr};
22
23use crate::{
24 utils::{biguint_to_limbs_vec, limbs_to_biguint},
25 FieldExpr, FieldExprCols,
26};
27
28#[derive(Clone)]
29pub struct FieldExpressionCoreAir {
30 pub expr: FieldExpr,
31
32 pub offset: usize,
34
35 pub local_opcode_idx: Vec<usize>,
38 pub opcode_flag_idx: Vec<usize>,
40 }
51
52impl FieldExpressionCoreAir {
53 pub fn new(
54 expr: FieldExpr,
55 offset: usize,
56 local_opcode_idx: Vec<usize>,
57 opcode_flag_idx: Vec<usize>,
58 ) -> Self {
59 let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
60 vec![0]
62 } else {
63 opcode_flag_idx
65 };
66 assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
67 Self {
68 expr,
69 offset,
70 local_opcode_idx,
71 opcode_flag_idx,
72 }
73 }
74
75 pub fn num_inputs(&self) -> usize {
76 self.expr.builder.num_input
77 }
78
79 pub fn num_vars(&self) -> usize {
80 self.expr.builder.num_variables
81 }
82
83 pub fn num_flags(&self) -> usize {
84 self.expr.builder.num_flags
85 }
86
87 pub fn output_indices(&self) -> &[usize] {
88 &self.expr.builder.output_indices
89 }
90}
91
92impl<F: Field> BaseAir<F> for FieldExpressionCoreAir {
93 fn width(&self) -> usize {
94 BaseAir::<F>::width(&self.expr)
95 }
96}
97
98impl<F: Field> BaseAirWithPublicValues<F> for FieldExpressionCoreAir {}
99
100impl<AB: InteractionBuilder, I> VmCoreAir<AB, I> for FieldExpressionCoreAir
101where
102 I: VmAdapterInterface<AB::Expr>,
103 AdapterAirContext<AB::Expr, I>:
104 From<AdapterAirContext<AB::Expr, DynAdapterInterface<AB::Expr>>>,
105{
106 fn eval(
107 &self,
108 builder: &mut AB,
109 local: &[AB::Var],
110 _from_pc: AB::Var,
111 ) -> AdapterAirContext<AB::Expr, I> {
112 assert_eq!(local.len(), BaseAir::<AB::F>::width(&self.expr));
113 self.expr.eval(builder, local);
114 let FieldExprCols {
115 is_valid,
116 inputs,
117 vars,
118 flags,
119 ..
120 } = self.expr.load_vars(local);
121 assert_eq!(inputs.len(), self.num_inputs());
122 assert_eq!(vars.len(), self.num_vars());
123 assert_eq!(flags.len(), self.num_flags());
124 let reads: Vec<AB::Expr> = inputs.concat().iter().map(|x| (*x).into()).collect();
125 let writes: Vec<AB::Expr> = self
126 .output_indices()
127 .iter()
128 .flat_map(|&i| vars[i].clone())
129 .map(Into::into)
130 .collect();
131
132 let opcode_flags_except_last = self.opcode_flag_idx.iter().map(|&i| flags[i]).collect_vec();
133 let last_opcode_flag = is_valid
134 - opcode_flags_except_last
135 .iter()
136 .map(|&v| v.into())
137 .sum::<AB::Expr>();
138 builder.assert_bool(last_opcode_flag.clone());
139 let opcode_flags = opcode_flags_except_last
140 .into_iter()
141 .map(Into::into)
142 .chain(Some(last_opcode_flag));
143 let expected_opcode = opcode_flags
144 .zip(self.local_opcode_idx.iter().map(|&i| i + self.offset))
145 .map(|(flag, global_idx)| flag * AB::Expr::from_canonical_usize(global_idx))
146 .sum();
147
148 let instruction = MinimalInstruction {
149 is_valid: is_valid.into(),
150 opcode: expected_opcode,
151 };
152
153 let ctx: AdapterAirContext<_, DynAdapterInterface<_>> = AdapterAirContext {
154 to_pc: None,
155 reads: reads.into(),
156 writes: writes.into(),
157 instruction: instruction.into(),
158 };
159 ctx.into()
160 }
161
162 fn start_offset(&self) -> usize {
163 self.offset
164 }
165}
166
167#[serde_as]
168#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)]
169pub struct FieldExpressionRecord {
170 #[serde_as(as = "Vec<DisplayFromStr>")]
171 pub inputs: Vec<BigUint>,
172 pub flags: Vec<bool>,
173}
174
175pub struct FieldExpressionCoreChip {
176 pub air: FieldExpressionCoreAir,
177 pub range_checker: SharedVariableRangeCheckerChip,
178
179 pub name: String,
180
181 pub should_finalize: bool,
183}
184
185impl FieldExpressionCoreChip {
186 pub fn new(
187 expr: FieldExpr,
188 offset: usize,
189 local_opcode_idx: Vec<usize>,
190 opcode_flag_idx: Vec<usize>,
191 range_checker: SharedVariableRangeCheckerChip,
192 name: &str,
193 should_finalize: bool,
194 ) -> Self {
195 let air = FieldExpressionCoreAir::new(expr, offset, local_opcode_idx, opcode_flag_idx);
196 tracing::info!(
197 "FieldExpressionCoreChip: opcode={name}, main_width={}",
198 BaseAir::<BabyBear>::width(&air)
199 );
200 Self {
201 air,
202 range_checker,
203 name: name.to_string(),
204 should_finalize,
205 }
206 }
207
208 pub fn expr(&self) -> &FieldExpr {
209 &self.air.expr
210 }
211}
212
213impl<F: PrimeField32, I> VmCoreChip<F, I> for FieldExpressionCoreChip
214where
215 I: VmAdapterInterface<F>,
216 I::Reads: Into<DynArray<F>>,
217 AdapterRuntimeContext<F, I>: From<AdapterRuntimeContext<F, DynAdapterInterface<F>>>,
218{
219 type Record = FieldExpressionRecord;
220 type Air = FieldExpressionCoreAir;
221
222 fn execute_instruction(
223 &self,
224 instruction: &Instruction<F>,
225 _from_pc: u32,
226 reads: I::Reads,
227 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
228 let field_element_limbs = self.air.expr.canonical_num_limbs();
229 let limb_bits = self.air.expr.canonical_limb_bits();
230 let data: DynArray<_> = reads.into();
231 let data = data.0;
232 assert_eq!(data.len(), self.air.num_inputs() * field_element_limbs);
233 let data_u32: Vec<u32> = data.iter().map(|x| x.as_canonical_u32()).collect();
234
235 let mut inputs = vec![];
236 for i in 0..self.air.num_inputs() {
237 let start = i * field_element_limbs;
238 let end = start + field_element_limbs;
239 let limb_slice = &data_u32[start..end];
240 let input = limbs_to_biguint(limb_slice, limb_bits);
241 inputs.push(input);
242 }
243
244 let Instruction { opcode, .. } = instruction;
245 let local_opcode_idx = opcode.local_opcode_idx(self.air.offset);
246 let mut flags = vec![];
247
248 if self.expr().needs_setup() {
251 flags = vec![false; self.air.num_flags()];
252 self.air
253 .opcode_flag_idx
254 .iter()
255 .enumerate()
256 .for_each(|(i, &flag_idx)| {
257 flags[flag_idx] = local_opcode_idx == self.air.local_opcode_idx[i]
258 });
259 }
260
261 let vars = self.air.expr.execute(inputs.clone(), flags.clone());
262 assert_eq!(vars.len(), self.air.num_vars());
263
264 let outputs: Vec<BigUint> = self
265 .air
266 .output_indices()
267 .iter()
268 .map(|&i| vars[i].clone())
269 .collect();
270 let writes: Vec<F> = outputs
271 .iter()
272 .map(|x| biguint_to_limbs_vec(x.clone(), limb_bits, field_element_limbs))
273 .concat()
274 .into_iter()
275 .map(|x| F::from_canonical_u32(x))
276 .collect();
277
278 let ctx = AdapterRuntimeContext::<_, DynAdapterInterface<_>>::without_pc(writes);
279 Ok((ctx.into(), FieldExpressionRecord { inputs, flags }))
280 }
281
282 fn get_opcode_name(&self, _opcode: usize) -> String {
283 self.name.clone()
284 }
285
286 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
287 self.air.expr.generate_subrow(
288 (self.range_checker.as_ref(), record.inputs, record.flags),
289 row_slice,
290 );
291 }
292
293 fn air(&self) -> &Self::Air {
294 &self.air
295 }
296
297 fn finalize(&self, trace: &mut RowMajorMatrix<F>, num_records: usize) {
298 if !self.should_finalize || num_records == 0 {
299 return;
300 }
301
302 let core_width = <Self::Air as BaseAir<F>>::width(&self.air);
303 let adapter_width = trace.width() - core_width;
304 let dummy_row = self.generate_dummy_trace_row(adapter_width, core_width);
305 for row in trace.rows_mut().skip(num_records) {
306 row.copy_from_slice(&dummy_row);
307 }
308 }
309}
310
311impl FieldExpressionCoreChip {
312 fn generate_dummy_trace_row<F: PrimeField32>(
315 &self,
316 adapter_width: usize,
317 core_width: usize,
318 ) -> Vec<F> {
319 let record = FieldExpressionRecord {
320 inputs: vec![BigUint::zero(); self.air.num_inputs()],
321 flags: vec![false; self.air.num_flags()],
322 };
323 let mut row = vec![F::ZERO; adapter_width + core_width];
324 let core_row = &mut row[adapter_width..];
325 let tmp_range_checker = SharedVariableRangeCheckerChip::new(self.range_checker.bus());
328 self.air.expr.generate_subrow(
329 (tmp_range_checker.as_ref(), record.inputs, record.flags),
330 core_row,
331 );
332 core_row[0] = F::ZERO; row
334 }
335}