1use std::{
2 marker::PhantomData,
3 mem::{align_of, size_of},
4 sync::Arc,
5};
6
7use itertools::Itertools;
8use num_bigint::BigUint;
9use num_traits::Zero;
10use openvm_circuit::{
11 arch::*,
12 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
13};
14use openvm_circuit_primitives::{
15 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerChip},
16 SubAir, TraceSubRowGenerator,
17};
18use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP};
19use openvm_stark_backend::{
20 interaction::InteractionBuilder,
21 p3_air::BaseAir,
22 p3_field::{Field, FieldAlgebra, PrimeField32},
23 rap::BaseAirWithPublicValues,
24};
25use openvm_stark_sdk::p3_baby_bear::BabyBear;
26
27use crate::{
28 builder::{FieldExpr, FieldExprCols},
29 utils::biguint_to_limbs_vec,
30};
31
32#[derive(Clone)]
33pub struct FieldExpressionCoreAir {
34 pub expr: FieldExpr,
35
36 pub offset: usize,
38
39 pub local_opcode_idx: Vec<usize>,
42 pub opcode_flag_idx: Vec<usize>,
45 }
56
57impl FieldExpressionCoreAir {
58 pub fn new(
59 expr: FieldExpr,
60 offset: usize,
61 local_opcode_idx: Vec<usize>,
62 opcode_flag_idx: Vec<usize>,
63 ) -> Self {
64 let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
65 vec![0]
67 } else {
68 opcode_flag_idx
70 };
71 assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
72 Self {
73 expr,
74 offset,
75 local_opcode_idx,
76 opcode_flag_idx,
77 }
78 }
79
80 pub fn num_inputs(&self) -> usize {
81 self.expr.builder.num_input
82 }
83
84 pub fn num_vars(&self) -> usize {
85 self.expr.builder.num_variables
86 }
87
88 pub fn num_flags(&self) -> usize {
89 self.expr.builder.num_flags
90 }
91
92 pub fn output_indices(&self) -> &[usize] {
93 &self.expr.builder.output_indices
94 }
95}
96
97impl<F: Field> BaseAir<F> for FieldExpressionCoreAir {
98 fn width(&self) -> usize {
99 BaseAir::<F>::width(&self.expr)
100 }
101}
102
103impl<F: Field> BaseAirWithPublicValues<F> for FieldExpressionCoreAir {}
104
105impl<AB: InteractionBuilder, I> VmCoreAir<AB, I> for FieldExpressionCoreAir
106where
107 I: VmAdapterInterface<AB::Expr>,
108 AdapterAirContext<AB::Expr, I>:
109 From<AdapterAirContext<AB::Expr, DynAdapterInterface<AB::Expr>>>,
110{
111 fn eval(
112 &self,
113 builder: &mut AB,
114 local: &[AB::Var],
115 _from_pc: AB::Var,
116 ) -> AdapterAirContext<AB::Expr, I> {
117 assert_eq!(local.len(), BaseAir::<AB::F>::width(&self.expr));
118 self.expr.eval(builder, local);
119 let FieldExprCols {
120 is_valid,
121 inputs,
122 vars,
123 flags,
124 ..
125 } = self.expr.load_vars(local);
126 assert_eq!(inputs.len(), self.num_inputs());
127 assert_eq!(vars.len(), self.num_vars());
128 assert_eq!(flags.len(), self.num_flags());
129 let reads: Vec<AB::Expr> = inputs.concat().iter().map(|x| (*x).into()).collect();
130 let writes: Vec<AB::Expr> = self
131 .output_indices()
132 .iter()
133 .flat_map(|&i| vars[i].clone())
134 .map(Into::into)
135 .collect();
136
137 let opcode_flags_except_last = self.opcode_flag_idx.iter().map(|&i| flags[i]).collect_vec();
138 let last_opcode_flag = is_valid
139 - opcode_flags_except_last
140 .iter()
141 .map(|&v| v.into())
142 .sum::<AB::Expr>();
143 builder.assert_bool(last_opcode_flag.clone());
144 let opcode_flags = opcode_flags_except_last
145 .into_iter()
146 .map(Into::into)
147 .chain(Some(last_opcode_flag));
148 let expected_opcode = opcode_flags
149 .zip(self.local_opcode_idx.iter().map(|&i| i + self.offset))
150 .map(|(flag, global_idx)| flag * AB::Expr::from_canonical_usize(global_idx))
151 .sum();
152
153 let instruction = MinimalInstruction {
154 is_valid: is_valid.into(),
155 opcode: expected_opcode,
156 };
157
158 let ctx: AdapterAirContext<_, DynAdapterInterface<_>> = AdapterAirContext {
159 to_pc: None,
160 reads: reads.into(),
161 writes: writes.into(),
162 instruction: instruction.into(),
163 };
164 ctx.into()
165 }
166
167 fn start_offset(&self) -> usize {
168 self.offset
169 }
170}
171
172pub struct FieldExpressionMetadata<F, A> {
173 pub total_input_limbs: usize, _phantom: PhantomData<(F, A)>,
175}
176
177impl<F, A> Clone for FieldExpressionMetadata<F, A> {
178 fn clone(&self) -> Self {
179 Self {
180 total_input_limbs: self.total_input_limbs,
181 _phantom: PhantomData,
182 }
183 }
184}
185
186impl<F, A> Default for FieldExpressionMetadata<F, A> {
187 fn default() -> Self {
188 Self {
189 total_input_limbs: 0,
190 _phantom: PhantomData,
191 }
192 }
193}
194
195impl<F, A> FieldExpressionMetadata<F, A> {
196 pub fn new(total_input_limbs: usize) -> Self {
197 Self {
198 total_input_limbs,
199 _phantom: PhantomData,
200 }
201 }
202}
203
204impl<F, A> AdapterCoreMetadata for FieldExpressionMetadata<F, A>
205where
206 A: AdapterTraceExecutor<F>,
207{
208 #[inline(always)]
209 fn get_adapter_width() -> usize {
210 A::WIDTH * size_of::<F>()
211 }
212}
213
214pub type FieldExpressionRecordLayout<F, A> = AdapterCoreLayout<FieldExpressionMetadata<F, A>>;
215
216pub struct FieldExpressionCoreRecordMut<'a> {
217 pub opcode: &'a mut u8,
218 pub input_limbs: &'a mut [u8],
219}
220
221impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressionRecordLayout<F, A>>
222 for [u8]
223{
224 fn custom_borrow(
225 &'a mut self,
226 layout: FieldExpressionRecordLayout<F, A>,
227 ) -> FieldExpressionCoreRecordMut<'a> {
228 let (opcode_buf, input_limbs_buff) = unsafe { self.split_at_mut_unchecked(1) };
230
231 let opcode_buf = unsafe { opcode_buf.get_unchecked_mut(0) };
233
234 FieldExpressionCoreRecordMut {
235 opcode: opcode_buf,
236 input_limbs: &mut input_limbs_buff[..layout.metadata.total_input_limbs],
237 }
238 }
239
240 unsafe fn extract_layout(&self) -> FieldExpressionRecordLayout<F, A> {
241 panic!("Should get the Layout information from FieldExpressionExecutor");
242 }
243}
244
245impl<F, A> SizedRecord<FieldExpressionRecordLayout<F, A>> for FieldExpressionCoreRecordMut<'_> {
246 fn size(layout: &FieldExpressionRecordLayout<F, A>) -> usize {
247 layout.metadata.total_input_limbs + 1
248 }
249
250 fn alignment(_layout: &FieldExpressionRecordLayout<F, A>) -> usize {
251 align_of::<u8>()
252 }
253}
254
255impl<'a> FieldExpressionCoreRecordMut<'a> {
256 pub fn new_from_execution_data(
258 buffer: &'a mut [u8],
259 inputs: &[BigUint],
260 limbs_per_input: usize,
261 ) -> Self {
262 let record_info = FieldExpressionMetadata::<(), ()>::new(inputs.len() * limbs_per_input);
263
264 let record: Self = buffer.custom_borrow(FieldExpressionRecordLayout {
265 metadata: record_info,
266 });
267 record
268 }
269
270 #[inline(always)]
271 pub fn fill_from_execution_data(&mut self, opcode: u8, data: &[u8]) {
272 *self.opcode = opcode;
275 self.input_limbs.copy_from_slice(data);
276 }
277}
278
279#[derive(Clone)]
280pub struct FieldExpressionExecutor<A> {
281 adapter: A,
282 pub expr: FieldExpr,
283 pub offset: usize,
284 pub local_opcode_idx: Vec<usize>,
285 pub opcode_flag_idx: Vec<usize>,
286 pub name: String,
287}
288
289impl<A> FieldExpressionExecutor<A> {
290 #[allow(clippy::too_many_arguments)]
291 pub fn new(
292 adapter: A,
293 expr: FieldExpr,
294 offset: usize,
295 local_opcode_idx: Vec<usize>,
296 opcode_flag_idx: Vec<usize>,
297 name: &str,
298 ) -> Self {
299 let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
300 vec![0]
302 } else {
303 opcode_flag_idx
305 };
306 assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
307 tracing::debug!(
308 "FieldExpressionCoreExecutor: opcode={name}, main_width={}",
309 BaseAir::<BabyBear>::width(&expr)
310 );
311 Self {
312 adapter,
313 expr,
314 offset,
315 local_opcode_idx,
316 opcode_flag_idx,
317 name: name.to_string(),
318 }
319 }
320
321 pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
322 FieldExpressionRecordLayout {
323 metadata: FieldExpressionMetadata::new(
324 self.expr.builder.num_input * self.expr.canonical_num_limbs(),
325 ),
326 }
327 }
328}
329
330pub struct FieldExpressionFiller<A> {
331 adapter: A,
332 pub expr: FieldExpr,
333 pub local_opcode_idx: Vec<usize>,
334 pub opcode_flag_idx: Vec<usize>,
335 pub range_checker: SharedVariableRangeCheckerChip,
336 pub should_finalize: bool,
337}
338
339impl<A> FieldExpressionFiller<A> {
340 #[allow(clippy::too_many_arguments)]
341 pub fn new(
342 adapter: A,
343 expr: FieldExpr,
344 local_opcode_idx: Vec<usize>,
345 opcode_flag_idx: Vec<usize>,
346 range_checker: SharedVariableRangeCheckerChip,
347 should_finalize: bool,
348 ) -> Self {
349 let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
350 vec![0]
352 } else {
353 opcode_flag_idx
355 };
356 assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
357 Self {
358 adapter,
359 expr,
360 local_opcode_idx,
361 opcode_flag_idx,
362 range_checker,
363 should_finalize,
364 }
365 }
366 pub fn num_inputs(&self) -> usize {
367 self.expr.builder.num_input
368 }
369
370 pub fn num_flags(&self) -> usize {
371 self.expr.builder.num_flags
372 }
373
374 pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
375 FieldExpressionRecordLayout {
376 metadata: FieldExpressionMetadata::new(
377 self.num_inputs() * self.expr.canonical_num_limbs(),
378 ),
379 }
380 }
381}
382
383impl<F, A, RA> PreflightExecutor<F, RA> for FieldExpressionExecutor<A>
384where
385 F: PrimeField32,
386 A: 'static
387 + AdapterTraceExecutor<F, ReadData: Into<DynArray<u8>>, WriteData: From<DynArray<u8>>>,
388 for<'buf> RA: RecordArena<
389 'buf,
390 FieldExpressionRecordLayout<F, A>,
391 (A::RecordMut<'buf>, FieldExpressionCoreRecordMut<'buf>),
392 >,
393{
394 fn execute(
395 &self,
396 state: VmStateMut<F, TracingMemory, RA>,
397 instruction: &Instruction<F>,
398 ) -> Result<(), ExecutionError> {
399 let (mut adapter_record, mut core_record) = state.ctx.alloc(self.get_record_layout());
400
401 A::start(*state.pc, state.memory, &mut adapter_record);
402
403 let data: DynArray<_> = self
404 .adapter
405 .read(state.memory, instruction, &mut adapter_record)
406 .into();
407
408 core_record.fill_from_execution_data(
409 instruction.opcode.local_opcode_idx(self.offset) as u8,
410 &data.0,
411 );
412
413 let (writes, _, _) = run_field_expression(
414 &self.expr,
415 &self.local_opcode_idx,
416 &self.opcode_flag_idx,
417 core_record.input_limbs,
418 *core_record.opcode as usize,
419 );
420
421 self.adapter.write(
422 state.memory,
423 instruction,
424 writes.into(),
425 &mut adapter_record,
426 );
427
428 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
429 Ok(())
430 }
431
432 fn get_opcode_name(&self, _opcode: usize) -> String {
433 self.name.clone()
434 }
435}
436
437#[cfg(feature = "aot")]
438impl<F: PrimeField32, A> AotExecutor<F> for FieldExpressionExecutor<A> {}
439
440impl<F, A> TraceFiller<F> for FieldExpressionFiller<A>
441where
442 F: PrimeField32 + Send + Sync + Clone,
443 A: 'static + AdapterTraceFiller<F>,
444{
445 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
446 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
449
450 self.adapter.fill_trace_row(mem_helper, adapter_row);
451
452 let record: FieldExpressionCoreRecordMut =
458 unsafe { get_record_from_slice(&mut core_row, self.get_record_layout::<F>()) };
459
460 let (_, inputs, flags) = run_field_expression(
461 &self.expr,
462 &self.local_opcode_idx,
463 &self.opcode_flag_idx,
464 record.input_limbs,
465 *record.opcode as usize,
466 );
467
468 let range_checker = self.range_checker.as_ref();
469 self.expr
470 .generate_subrow((range_checker, inputs, flags), core_row);
471 }
472
473 fn fill_dummy_trace_row(&self, row_slice: &mut [F]) {
474 if !self.should_finalize {
475 return;
476 }
477
478 let inputs: Vec<BigUint> = vec![BigUint::zero(); self.num_inputs()];
479 let flags: Vec<bool> = vec![false; self.num_flags()];
480 let core_row = &mut row_slice[A::WIDTH..];
481 let tmp_range_checker = Arc::new(VariableRangeCheckerChip::new(self.range_checker.bus()));
484 self.expr
485 .generate_subrow((&tmp_range_checker, inputs, flags), core_row);
486 core_row[0] = F::ZERO; }
488}
489
490fn run_field_expression(
491 expr: &FieldExpr,
492 local_opcode_flags: &[usize],
493 opcode_flag_idx: &[usize],
494 data: &[u8],
495 local_opcode_idx: usize,
496) -> (DynArray<u8>, Vec<BigUint>, Vec<bool>) {
497 let field_element_limbs = expr.canonical_num_limbs();
498 assert_eq!(data.len(), expr.builder.num_input * field_element_limbs);
499
500 let mut inputs = Vec::with_capacity(expr.builder.num_input);
501 for i in 0..expr.builder.num_input {
502 let start = i * field_element_limbs;
503 let end = start + field_element_limbs;
504 let limb_slice = &data[start..end];
505 let input = BigUint::from_bytes_le(limb_slice);
506 inputs.push(input);
507 }
508
509 let mut flags = vec![];
510 if expr.needs_setup() {
511 flags = vec![false; expr.builder.num_flags];
512
513 if let Some(opcode_position) = local_opcode_flags
515 .iter()
516 .position(|&idx| idx == local_opcode_idx)
517 {
518 if opcode_position < opcode_flag_idx.len() {
520 let flag_idx = opcode_flag_idx[opcode_position];
521 flags[flag_idx] = true;
522 }
523 }
526 }
527
528 let vars = expr.execute(inputs.clone(), flags.clone());
529 assert_eq!(vars.len(), expr.builder.num_variables);
530
531 let outputs: Vec<BigUint> = expr
532 .builder
533 .output_indices
534 .iter()
535 .map(|&i| vars[i].clone())
536 .collect();
537 let writes: DynArray<_> = outputs
538 .iter()
539 .map(|x| biguint_to_limbs_vec(x, field_element_limbs))
540 .concat()
541 .into_iter()
542 .collect::<Vec<_>>()
543 .into();
544
545 (writes, inputs, flags)
546}
547
548#[inline(always)]
549pub fn run_field_expression_precomputed<const NEEDS_SETUP: bool>(
550 expr: &FieldExpr,
551 flag_idx: usize,
552 data: &[u8],
553) -> DynArray<u8> {
554 let field_element_limbs = expr.canonical_num_limbs();
555 assert_eq!(data.len(), expr.num_inputs() * field_element_limbs);
556
557 let mut inputs = Vec::with_capacity(expr.num_inputs());
558 for i in 0..expr.num_inputs() {
559 let start = i * expr.canonical_num_limbs();
560 let end = start + expr.canonical_num_limbs();
561 let limb_slice = &data[start..end];
562 let input = BigUint::from_bytes_le(limb_slice);
563 inputs.push(input);
564 }
565
566 let flags = if NEEDS_SETUP {
567 let mut flags = vec![false; expr.num_flags()];
568 if flag_idx < expr.num_flags() {
569 flags[flag_idx] = true;
570 }
571 flags
572 } else {
573 vec![]
574 };
575
576 let vars = expr.execute(inputs, flags);
577 assert_eq!(vars.len(), expr.num_vars());
578
579 let outputs: Vec<BigUint> = expr
580 .output_indices()
581 .iter()
582 .map(|&i| vars[i].clone())
583 .collect();
584
585 outputs
586 .iter()
587 .map(|x| biguint_to_limbs_vec(x, field_element_limbs))
588 .concat()
589 .into_iter()
590 .collect::<Vec<_>>()
591 .into()
592}