1use std::{
2 marker::PhantomData,
3 mem::{align_of, size_of},
4 sync::Arc,
5};
6
7use itertools::Itertools;
8use num_bigint::BigUint;
9use num_traits::Zero;
10use openvm_circuit::{
11 arch::*,
12 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
13};
14use openvm_circuit_primitives::{
15 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerChip},
16 SubAir, TraceSubRowGenerator,
17};
18use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP};
19use openvm_stark_backend::{
20 interaction::InteractionBuilder,
21 p3_air::BaseAir,
22 p3_field::{Field, FieldAlgebra, PrimeField32},
23 rap::BaseAirWithPublicValues,
24};
25use openvm_stark_sdk::p3_baby_bear::BabyBear;
26
27use crate::{
28 builder::{FieldExpr, FieldExprCols},
29 utils::biguint_to_limbs_vec,
30};
31
32#[derive(Clone)]
33pub struct FieldExpressionCoreAir {
34 pub expr: FieldExpr,
35
36 pub offset: usize,
38
39 pub local_opcode_idx: Vec<usize>,
42 pub opcode_flag_idx: Vec<usize>,
45 }
56
57impl FieldExpressionCoreAir {
58 pub fn new(
59 expr: FieldExpr,
60 offset: usize,
61 local_opcode_idx: Vec<usize>,
62 opcode_flag_idx: Vec<usize>,
63 ) -> Self {
64 let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
65 vec![0]
67 } else {
68 opcode_flag_idx
70 };
71 assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
72 Self {
73 expr,
74 offset,
75 local_opcode_idx,
76 opcode_flag_idx,
77 }
78 }
79
80 pub fn num_inputs(&self) -> usize {
81 self.expr.builder.num_input
82 }
83
84 pub fn num_vars(&self) -> usize {
85 self.expr.builder.num_variables
86 }
87
88 pub fn num_flags(&self) -> usize {
89 self.expr.builder.num_flags
90 }
91
92 pub fn output_indices(&self) -> &[usize] {
93 &self.expr.builder.output_indices
94 }
95}
96
97impl<F: Field> BaseAir<F> for FieldExpressionCoreAir {
98 fn width(&self) -> usize {
99 BaseAir::<F>::width(&self.expr)
100 }
101}
102
103impl<F: Field> BaseAirWithPublicValues<F> for FieldExpressionCoreAir {}
104
105impl<AB: InteractionBuilder, I> VmCoreAir<AB, I> for FieldExpressionCoreAir
106where
107 I: VmAdapterInterface<AB::Expr>,
108 AdapterAirContext<AB::Expr, I>:
109 From<AdapterAirContext<AB::Expr, DynAdapterInterface<AB::Expr>>>,
110{
111 fn eval(
112 &self,
113 builder: &mut AB,
114 local: &[AB::Var],
115 _from_pc: AB::Var,
116 ) -> AdapterAirContext<AB::Expr, I> {
117 assert_eq!(local.len(), BaseAir::<AB::F>::width(&self.expr));
118 self.expr.eval(builder, local);
119 let FieldExprCols {
120 is_valid,
121 inputs,
122 vars,
123 flags,
124 ..
125 } = self.expr.load_vars(local);
126 assert_eq!(inputs.len(), self.num_inputs());
127 assert_eq!(vars.len(), self.num_vars());
128 assert_eq!(flags.len(), self.num_flags());
129 let reads: Vec<AB::Expr> = inputs.concat().iter().map(|x| (*x).into()).collect();
130 let writes: Vec<AB::Expr> = self
131 .output_indices()
132 .iter()
133 .flat_map(|&i| vars[i].clone())
134 .map(Into::into)
135 .collect();
136
137 let opcode_flags_except_last = self.opcode_flag_idx.iter().map(|&i| flags[i]).collect_vec();
138 let last_opcode_flag = is_valid
139 - opcode_flags_except_last
140 .iter()
141 .map(|&v| v.into())
142 .sum::<AB::Expr>();
143 builder.assert_bool(last_opcode_flag.clone());
144 let opcode_flags = opcode_flags_except_last
145 .into_iter()
146 .map(Into::into)
147 .chain(Some(last_opcode_flag));
148 let expected_opcode = opcode_flags
149 .zip(self.local_opcode_idx.iter().map(|&i| i + self.offset))
150 .map(|(flag, global_idx)| flag * AB::Expr::from_canonical_usize(global_idx))
151 .sum();
152
153 let instruction = MinimalInstruction {
154 is_valid: is_valid.into(),
155 opcode: expected_opcode,
156 };
157
158 let ctx: AdapterAirContext<_, DynAdapterInterface<_>> = AdapterAirContext {
159 to_pc: None,
160 reads: reads.into(),
161 writes: writes.into(),
162 instruction: instruction.into(),
163 };
164 ctx.into()
165 }
166
167 fn start_offset(&self) -> usize {
168 self.offset
169 }
170}
171
172pub struct FieldExpressionMetadata<F, A> {
173 pub total_input_limbs: usize, _phantom: PhantomData<(F, A)>,
175}
176
177impl<F, A> Clone for FieldExpressionMetadata<F, A> {
178 fn clone(&self) -> Self {
179 Self {
180 total_input_limbs: self.total_input_limbs,
181 _phantom: PhantomData,
182 }
183 }
184}
185
186impl<F, A> Default for FieldExpressionMetadata<F, A> {
187 fn default() -> Self {
188 Self {
189 total_input_limbs: 0,
190 _phantom: PhantomData,
191 }
192 }
193}
194
195impl<F, A> FieldExpressionMetadata<F, A> {
196 pub fn new(total_input_limbs: usize) -> Self {
197 Self {
198 total_input_limbs,
199 _phantom: PhantomData,
200 }
201 }
202}
203
204impl<F, A> AdapterCoreMetadata for FieldExpressionMetadata<F, A>
205where
206 A: AdapterTraceExecutor<F>,
207{
208 #[inline(always)]
209 fn get_adapter_width() -> usize {
210 A::WIDTH * size_of::<F>()
211 }
212}
213
214pub type FieldExpressionRecordLayout<F, A> = AdapterCoreLayout<FieldExpressionMetadata<F, A>>;
215
216pub struct FieldExpressionCoreRecordMut<'a> {
217 pub opcode: &'a mut u8,
218 pub input_limbs: &'a mut [u8],
219}
220
221impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressionRecordLayout<F, A>>
222 for [u8]
223{
224 fn custom_borrow(
225 &'a mut self,
226 layout: FieldExpressionRecordLayout<F, A>,
227 ) -> FieldExpressionCoreRecordMut<'a> {
228 let (opcode_buf, input_limbs_buff) = unsafe { self.split_at_mut_unchecked(1) };
230
231 let opcode_buf = unsafe { opcode_buf.get_unchecked_mut(0) };
233
234 FieldExpressionCoreRecordMut {
235 opcode: opcode_buf,
236 input_limbs: &mut input_limbs_buff[..layout.metadata.total_input_limbs],
237 }
238 }
239
240 unsafe fn extract_layout(&self) -> FieldExpressionRecordLayout<F, A> {
241 panic!("Should get the Layout information from FieldExpressionExecutor");
242 }
243}
244
245impl<F, A> SizedRecord<FieldExpressionRecordLayout<F, A>> for FieldExpressionCoreRecordMut<'_> {
246 fn size(layout: &FieldExpressionRecordLayout<F, A>) -> usize {
247 layout.metadata.total_input_limbs + 1
248 }
249
250 fn alignment(_layout: &FieldExpressionRecordLayout<F, A>) -> usize {
251 align_of::<u8>()
252 }
253}
254
255impl<'a> FieldExpressionCoreRecordMut<'a> {
256 pub fn new_from_execution_data(
258 buffer: &'a mut [u8],
259 inputs: &[BigUint],
260 limbs_per_input: usize,
261 ) -> Self {
262 let record_info = FieldExpressionMetadata::<(), ()>::new(inputs.len() * limbs_per_input);
263
264 let record: Self = buffer.custom_borrow(FieldExpressionRecordLayout {
265 metadata: record_info,
266 });
267 record
268 }
269
270 #[inline(always)]
271 pub fn fill_from_execution_data(&mut self, opcode: u8, data: &[u8]) {
272 *self.opcode = opcode;
275 self.input_limbs.copy_from_slice(data);
276 }
277}
278
279#[derive(Clone)]
280pub struct FieldExpressionExecutor<A> {
281 adapter: A,
282 pub expr: FieldExpr,
283 pub offset: usize,
284 pub local_opcode_idx: Vec<usize>,
285 pub opcode_flag_idx: Vec<usize>,
286 pub name: String,
287}
288
289impl<A> FieldExpressionExecutor<A> {
290 #[allow(clippy::too_many_arguments)]
291 pub fn new(
292 adapter: A,
293 expr: FieldExpr,
294 offset: usize,
295 local_opcode_idx: Vec<usize>,
296 opcode_flag_idx: Vec<usize>,
297 name: &str,
298 ) -> Self {
299 let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
300 vec![0]
302 } else {
303 opcode_flag_idx
305 };
306 assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
307 tracing::debug!(
308 "FieldExpressionCoreExecutor: opcode={name}, main_width={}",
309 BaseAir::<BabyBear>::width(&expr)
310 );
311 Self {
312 adapter,
313 expr,
314 offset,
315 local_opcode_idx,
316 opcode_flag_idx,
317 name: name.to_string(),
318 }
319 }
320
321 pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
322 FieldExpressionRecordLayout {
323 metadata: FieldExpressionMetadata::new(
324 self.expr.builder.num_input * self.expr.canonical_num_limbs(),
325 ),
326 }
327 }
328}
329
330pub struct FieldExpressionFiller<A> {
331 adapter: A,
332 pub expr: FieldExpr,
333 pub local_opcode_idx: Vec<usize>,
334 pub opcode_flag_idx: Vec<usize>,
335 pub range_checker: SharedVariableRangeCheckerChip,
336 pub should_finalize: bool,
337}
338
339impl<A> FieldExpressionFiller<A> {
340 #[allow(clippy::too_many_arguments)]
341 pub fn new(
342 adapter: A,
343 expr: FieldExpr,
344 local_opcode_idx: Vec<usize>,
345 opcode_flag_idx: Vec<usize>,
346 range_checker: SharedVariableRangeCheckerChip,
347 should_finalize: bool,
348 ) -> Self {
349 let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
350 vec![0]
352 } else {
353 opcode_flag_idx
355 };
356 assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
357 Self {
358 adapter,
359 expr,
360 local_opcode_idx,
361 opcode_flag_idx,
362 range_checker,
363 should_finalize,
364 }
365 }
366 pub fn num_inputs(&self) -> usize {
367 self.expr.builder.num_input
368 }
369
370 pub fn num_flags(&self) -> usize {
371 self.expr.builder.num_flags
372 }
373
374 pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
375 FieldExpressionRecordLayout {
376 metadata: FieldExpressionMetadata::new(
377 self.num_inputs() * self.expr.canonical_num_limbs(),
378 ),
379 }
380 }
381}
382
383impl<F, A, RA> PreflightExecutor<F, RA> for FieldExpressionExecutor<A>
384where
385 F: PrimeField32,
386 A: 'static
387 + AdapterTraceExecutor<F, ReadData: Into<DynArray<u8>>, WriteData: From<DynArray<u8>>>,
388 for<'buf> RA: RecordArena<
389 'buf,
390 FieldExpressionRecordLayout<F, A>,
391 (A::RecordMut<'buf>, FieldExpressionCoreRecordMut<'buf>),
392 >,
393{
394 fn execute(
395 &self,
396 state: VmStateMut<F, TracingMemory, RA>,
397 instruction: &Instruction<F>,
398 ) -> Result<(), ExecutionError> {
399 let (mut adapter_record, mut core_record) = state.ctx.alloc(self.get_record_layout());
400
401 A::start(*state.pc, state.memory, &mut adapter_record);
402
403 let data: DynArray<_> = self
404 .adapter
405 .read(state.memory, instruction, &mut adapter_record)
406 .into();
407
408 core_record.fill_from_execution_data(
409 instruction.opcode.local_opcode_idx(self.offset) as u8,
410 &data.0,
411 );
412
413 let (writes, _, _) = run_field_expression(
414 &self.expr,
415 &self.local_opcode_idx,
416 &self.opcode_flag_idx,
417 core_record.input_limbs,
418 *core_record.opcode as usize,
419 );
420
421 self.adapter.write(
422 state.memory,
423 instruction,
424 writes.into(),
425 &mut adapter_record,
426 );
427
428 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
429 Ok(())
430 }
431
432 fn get_opcode_name(&self, _opcode: usize) -> String {
433 self.name.clone()
434 }
435}
436
437impl<F, A> TraceFiller<F> for FieldExpressionFiller<A>
438where
439 F: PrimeField32 + Send + Sync + Clone,
440 A: 'static + AdapterTraceFiller<F>,
441{
442 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
443 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
446
447 self.adapter.fill_trace_row(mem_helper, adapter_row);
448
449 let record: FieldExpressionCoreRecordMut =
455 unsafe { get_record_from_slice(&mut core_row, self.get_record_layout::<F>()) };
456
457 let (_, inputs, flags) = run_field_expression(
458 &self.expr,
459 &self.local_opcode_idx,
460 &self.opcode_flag_idx,
461 record.input_limbs,
462 *record.opcode as usize,
463 );
464
465 let range_checker = self.range_checker.as_ref();
466 self.expr
467 .generate_subrow((range_checker, inputs, flags), core_row);
468 }
469
470 fn fill_dummy_trace_row(&self, row_slice: &mut [F]) {
471 if !self.should_finalize {
472 return;
473 }
474
475 let inputs: Vec<BigUint> = vec![BigUint::zero(); self.num_inputs()];
476 let flags: Vec<bool> = vec![false; self.num_flags()];
477 let core_row = &mut row_slice[A::WIDTH..];
478 let tmp_range_checker = Arc::new(VariableRangeCheckerChip::new(self.range_checker.bus()));
481 self.expr
482 .generate_subrow((&tmp_range_checker, inputs, flags), core_row);
483 core_row[0] = F::ZERO; }
485}
486
487fn run_field_expression(
488 expr: &FieldExpr,
489 local_opcode_flags: &[usize],
490 opcode_flag_idx: &[usize],
491 data: &[u8],
492 local_opcode_idx: usize,
493) -> (DynArray<u8>, Vec<BigUint>, Vec<bool>) {
494 let field_element_limbs = expr.canonical_num_limbs();
495 assert_eq!(data.len(), expr.builder.num_input * field_element_limbs);
496
497 let mut inputs = Vec::with_capacity(expr.builder.num_input);
498 for i in 0..expr.builder.num_input {
499 let start = i * field_element_limbs;
500 let end = start + field_element_limbs;
501 let limb_slice = &data[start..end];
502 let input = BigUint::from_bytes_le(limb_slice);
503 inputs.push(input);
504 }
505
506 let mut flags = vec![];
507 if expr.needs_setup() {
508 flags = vec![false; expr.builder.num_flags];
509
510 if let Some(opcode_position) = local_opcode_flags
512 .iter()
513 .position(|&idx| idx == local_opcode_idx)
514 {
515 if opcode_position < opcode_flag_idx.len() {
517 let flag_idx = opcode_flag_idx[opcode_position];
518 flags[flag_idx] = true;
519 }
520 }
523 }
524
525 let vars = expr.execute(inputs.clone(), flags.clone());
526 assert_eq!(vars.len(), expr.builder.num_variables);
527
528 let outputs: Vec<BigUint> = expr
529 .builder
530 .output_indices
531 .iter()
532 .map(|&i| vars[i].clone())
533 .collect();
534 let writes: DynArray<_> = outputs
535 .iter()
536 .map(|x| biguint_to_limbs_vec(x, field_element_limbs))
537 .concat()
538 .into_iter()
539 .collect::<Vec<_>>()
540 .into();
541
542 (writes, inputs, flags)
543}
544
545#[inline(always)]
546pub fn run_field_expression_precomputed<const NEEDS_SETUP: bool>(
547 expr: &FieldExpr,
548 flag_idx: usize,
549 data: &[u8],
550) -> DynArray<u8> {
551 let field_element_limbs = expr.canonical_num_limbs();
552 assert_eq!(data.len(), expr.num_inputs() * field_element_limbs);
553
554 let mut inputs = Vec::with_capacity(expr.num_inputs());
555 for i in 0..expr.num_inputs() {
556 let start = i * expr.canonical_num_limbs();
557 let end = start + expr.canonical_num_limbs();
558 let limb_slice = &data[start..end];
559 let input = BigUint::from_bytes_le(limb_slice);
560 inputs.push(input);
561 }
562
563 let flags = if NEEDS_SETUP {
564 let mut flags = vec![false; expr.num_flags()];
565 if flag_idx < expr.num_flags() {
566 flags[flag_idx] = true;
567 }
568 flags
569 } else {
570 vec![]
571 };
572
573 let vars = expr.execute(inputs, flags);
574 assert_eq!(vars.len(), expr.num_vars());
575
576 let outputs: Vec<BigUint> = expr
577 .output_indices()
578 .iter()
579 .map(|&i| vars[i].clone())
580 .collect();
581
582 outputs
583 .iter()
584 .map(|x| biguint_to_limbs_vec(x, field_element_limbs))
585 .concat()
586 .into_iter()
587 .collect::<Vec<_>>()
588 .into()
589}