openvm_mod_circuit_builder/
core_chip.rs

1use std::{
2    marker::PhantomData,
3    mem::{align_of, size_of},
4    sync::Arc,
5};
6
7use itertools::Itertools;
8use num_bigint::BigUint;
9use num_traits::Zero;
10use openvm_circuit::{
11    arch::*,
12    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
13};
14use openvm_circuit_primitives::{
15    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerChip},
16    SubAir, TraceSubRowGenerator,
17};
18use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP};
19use openvm_stark_backend::{
20    interaction::InteractionBuilder,
21    p3_air::BaseAir,
22    p3_field::{Field, FieldAlgebra, PrimeField32},
23    rap::BaseAirWithPublicValues,
24};
25use openvm_stark_sdk::p3_baby_bear::BabyBear;
26
27use crate::{
28    builder::{FieldExpr, FieldExprCols},
29    utils::biguint_to_limbs_vec,
30};
31
32#[derive(Clone)]
33pub struct FieldExpressionCoreAir {
34    pub expr: FieldExpr,
35
36    /// The global opcode offset.
37    pub offset: usize,
38
39    /// All the opcode indices (including setup) supported by this Air.
40    /// The last one must be the setup opcode if it's a chip needs setup.
41    pub local_opcode_idx: Vec<usize>,
42    /// Opcode flag idx (indices from builder.new_flag()) for all except setup opcode. Empty if
43    /// single op chip.
44    pub opcode_flag_idx: Vec<usize>,
45    // Example 1: 1-op chip EcAdd that needs setup
46    //   local_opcode_idx = [0, 2], where 0 is EcAdd, 2 is setup
47    //   opcode_flag_idx = [], not needed for single op chip.
48    // Example 2: 1-op chip EvaluateLine that doesn't need setup
49    //   local_opcode_idx = [2], the id within PairingOpcodeEnum
50    //   opcode_flag_idx = [], not needed
51    // Example 3: 2-op chip MulDiv that needs setup
52    //   local_opcode_idx = [2, 3, 4], where 2 is Mul, 3 is Div, 4 is setup
53    //   opcode_flag_idx = [0, 1], where 0 is mul_flag, 1 is div_flag, in the builder
54    // We don't support 2-op chip that doesn't need setup right now.
55}
56
57impl FieldExpressionCoreAir {
58    pub fn new(
59        expr: FieldExpr,
60        offset: usize,
61        local_opcode_idx: Vec<usize>,
62        opcode_flag_idx: Vec<usize>,
63    ) -> Self {
64        let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
65            // single op chip that needs setup, so there is only one default flag, must be 0.
66            vec![0]
67        } else {
68            // multi ops chip or no-setup chip, use as is.
69            opcode_flag_idx
70        };
71        assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
72        Self {
73            expr,
74            offset,
75            local_opcode_idx,
76            opcode_flag_idx,
77        }
78    }
79
80    pub fn num_inputs(&self) -> usize {
81        self.expr.builder.num_input
82    }
83
84    pub fn num_vars(&self) -> usize {
85        self.expr.builder.num_variables
86    }
87
88    pub fn num_flags(&self) -> usize {
89        self.expr.builder.num_flags
90    }
91
92    pub fn output_indices(&self) -> &[usize] {
93        &self.expr.builder.output_indices
94    }
95}
96
97impl<F: Field> BaseAir<F> for FieldExpressionCoreAir {
98    fn width(&self) -> usize {
99        BaseAir::<F>::width(&self.expr)
100    }
101}
102
103impl<F: Field> BaseAirWithPublicValues<F> for FieldExpressionCoreAir {}
104
105impl<AB: InteractionBuilder, I> VmCoreAir<AB, I> for FieldExpressionCoreAir
106where
107    I: VmAdapterInterface<AB::Expr>,
108    AdapterAirContext<AB::Expr, I>:
109        From<AdapterAirContext<AB::Expr, DynAdapterInterface<AB::Expr>>>,
110{
111    fn eval(
112        &self,
113        builder: &mut AB,
114        local: &[AB::Var],
115        _from_pc: AB::Var,
116    ) -> AdapterAirContext<AB::Expr, I> {
117        assert_eq!(local.len(), BaseAir::<AB::F>::width(&self.expr));
118        self.expr.eval(builder, local);
119        let FieldExprCols {
120            is_valid,
121            inputs,
122            vars,
123            flags,
124            ..
125        } = self.expr.load_vars(local);
126        assert_eq!(inputs.len(), self.num_inputs());
127        assert_eq!(vars.len(), self.num_vars());
128        assert_eq!(flags.len(), self.num_flags());
129        let reads: Vec<AB::Expr> = inputs.concat().iter().map(|x| (*x).into()).collect();
130        let writes: Vec<AB::Expr> = self
131            .output_indices()
132            .iter()
133            .flat_map(|&i| vars[i].clone())
134            .map(Into::into)
135            .collect();
136
137        let opcode_flags_except_last = self.opcode_flag_idx.iter().map(|&i| flags[i]).collect_vec();
138        let last_opcode_flag = is_valid
139            - opcode_flags_except_last
140                .iter()
141                .map(|&v| v.into())
142                .sum::<AB::Expr>();
143        builder.assert_bool(last_opcode_flag.clone());
144        let opcode_flags = opcode_flags_except_last
145            .into_iter()
146            .map(Into::into)
147            .chain(Some(last_opcode_flag));
148        let expected_opcode = opcode_flags
149            .zip(self.local_opcode_idx.iter().map(|&i| i + self.offset))
150            .map(|(flag, global_idx)| flag * AB::Expr::from_canonical_usize(global_idx))
151            .sum();
152
153        let instruction = MinimalInstruction {
154            is_valid: is_valid.into(),
155            opcode: expected_opcode,
156        };
157
158        let ctx: AdapterAirContext<_, DynAdapterInterface<_>> = AdapterAirContext {
159            to_pc: None,
160            reads: reads.into(),
161            writes: writes.into(),
162            instruction: instruction.into(),
163        };
164        ctx.into()
165    }
166
167    fn start_offset(&self) -> usize {
168        self.offset
169    }
170}
171
172pub struct FieldExpressionMetadata<F, A> {
173    pub total_input_limbs: usize, // num_inputs * limbs_per_input
174    _phantom: PhantomData<(F, A)>,
175}
176
177impl<F, A> Clone for FieldExpressionMetadata<F, A> {
178    fn clone(&self) -> Self {
179        Self {
180            total_input_limbs: self.total_input_limbs,
181            _phantom: PhantomData,
182        }
183    }
184}
185
186impl<F, A> Default for FieldExpressionMetadata<F, A> {
187    fn default() -> Self {
188        Self {
189            total_input_limbs: 0,
190            _phantom: PhantomData,
191        }
192    }
193}
194
195impl<F, A> FieldExpressionMetadata<F, A> {
196    pub fn new(total_input_limbs: usize) -> Self {
197        Self {
198            total_input_limbs,
199            _phantom: PhantomData,
200        }
201    }
202}
203
204impl<F, A> AdapterCoreMetadata for FieldExpressionMetadata<F, A>
205where
206    A: AdapterTraceExecutor<F>,
207{
208    #[inline(always)]
209    fn get_adapter_width() -> usize {
210        A::WIDTH * size_of::<F>()
211    }
212}
213
214pub type FieldExpressionRecordLayout<F, A> = AdapterCoreLayout<FieldExpressionMetadata<F, A>>;
215
216pub struct FieldExpressionCoreRecordMut<'a> {
217    pub opcode: &'a mut u8,
218    pub input_limbs: &'a mut [u8],
219}
220
221impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressionRecordLayout<F, A>>
222    for [u8]
223{
224    fn custom_borrow(
225        &'a mut self,
226        layout: FieldExpressionRecordLayout<F, A>,
227    ) -> FieldExpressionCoreRecordMut<'a> {
228        // SAFETY: The buffer length is the width of the trace which should be at least 1
229        let (opcode_buf, input_limbs_buff) = unsafe { self.split_at_mut_unchecked(1) };
230
231        // SAFETY: opcode_buf has exactly 1 element from split_at_mut_unchecked(1)
232        let opcode_buf = unsafe { opcode_buf.get_unchecked_mut(0) };
233
234        FieldExpressionCoreRecordMut {
235            opcode: opcode_buf,
236            input_limbs: &mut input_limbs_buff[..layout.metadata.total_input_limbs],
237        }
238    }
239
240    unsafe fn extract_layout(&self) -> FieldExpressionRecordLayout<F, A> {
241        panic!("Should get the Layout information from FieldExpressionExecutor");
242    }
243}
244
245impl<F, A> SizedRecord<FieldExpressionRecordLayout<F, A>> for FieldExpressionCoreRecordMut<'_> {
246    fn size(layout: &FieldExpressionRecordLayout<F, A>) -> usize {
247        layout.metadata.total_input_limbs + 1
248    }
249
250    fn alignment(_layout: &FieldExpressionRecordLayout<F, A>) -> usize {
251        align_of::<u8>()
252    }
253}
254
255impl<'a> FieldExpressionCoreRecordMut<'a> {
256    // This method is only used in testing
257    pub fn new_from_execution_data(
258        buffer: &'a mut [u8],
259        inputs: &[BigUint],
260        limbs_per_input: usize,
261    ) -> Self {
262        let record_info = FieldExpressionMetadata::<(), ()>::new(inputs.len() * limbs_per_input);
263
264        let record: Self = buffer.custom_borrow(FieldExpressionRecordLayout {
265            metadata: record_info,
266        });
267        record
268    }
269
270    #[inline(always)]
271    pub fn fill_from_execution_data(&mut self, opcode: u8, data: &[u8]) {
272        // Rust will assert that length of `data` and `self.input_limbs` are the same
273        // That is `data.len() == num_inputs * limbs_per_input`
274        *self.opcode = opcode;
275        self.input_limbs.copy_from_slice(data);
276    }
277}
278
279#[derive(Clone)]
280pub struct FieldExpressionExecutor<A> {
281    adapter: A,
282    pub expr: FieldExpr,
283    pub offset: usize,
284    pub local_opcode_idx: Vec<usize>,
285    pub opcode_flag_idx: Vec<usize>,
286    pub name: String,
287}
288
289impl<A> FieldExpressionExecutor<A> {
290    #[allow(clippy::too_many_arguments)]
291    pub fn new(
292        adapter: A,
293        expr: FieldExpr,
294        offset: usize,
295        local_opcode_idx: Vec<usize>,
296        opcode_flag_idx: Vec<usize>,
297        name: &str,
298    ) -> Self {
299        let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
300            // single op chip that needs setup, so there is only one default flag, must be 0.
301            vec![0]
302        } else {
303            // multi ops chip or no-setup chip, use as is.
304            opcode_flag_idx
305        };
306        assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
307        tracing::debug!(
308            "FieldExpressionCoreExecutor: opcode={name}, main_width={}",
309            BaseAir::<BabyBear>::width(&expr)
310        );
311        Self {
312            adapter,
313            expr,
314            offset,
315            local_opcode_idx,
316            opcode_flag_idx,
317            name: name.to_string(),
318        }
319    }
320
321    pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
322        FieldExpressionRecordLayout {
323            metadata: FieldExpressionMetadata::new(
324                self.expr.builder.num_input * self.expr.canonical_num_limbs(),
325            ),
326        }
327    }
328}
329
330pub struct FieldExpressionFiller<A> {
331    adapter: A,
332    pub expr: FieldExpr,
333    pub local_opcode_idx: Vec<usize>,
334    pub opcode_flag_idx: Vec<usize>,
335    pub range_checker: SharedVariableRangeCheckerChip,
336    pub should_finalize: bool,
337}
338
339impl<A> FieldExpressionFiller<A> {
340    #[allow(clippy::too_many_arguments)]
341    pub fn new(
342        adapter: A,
343        expr: FieldExpr,
344        local_opcode_idx: Vec<usize>,
345        opcode_flag_idx: Vec<usize>,
346        range_checker: SharedVariableRangeCheckerChip,
347        should_finalize: bool,
348    ) -> Self {
349        let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() {
350            // single op chip that needs setup, so there is only one default flag, must be 0.
351            vec![0]
352        } else {
353            // multi ops chip or no-setup chip, use as is.
354            opcode_flag_idx
355        };
356        assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1);
357        Self {
358            adapter,
359            expr,
360            local_opcode_idx,
361            opcode_flag_idx,
362            range_checker,
363            should_finalize,
364        }
365    }
366    pub fn num_inputs(&self) -> usize {
367        self.expr.builder.num_input
368    }
369
370    pub fn num_flags(&self) -> usize {
371        self.expr.builder.num_flags
372    }
373
374    pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
375        FieldExpressionRecordLayout {
376            metadata: FieldExpressionMetadata::new(
377                self.num_inputs() * self.expr.canonical_num_limbs(),
378            ),
379        }
380    }
381}
382
383impl<F, A, RA> PreflightExecutor<F, RA> for FieldExpressionExecutor<A>
384where
385    F: PrimeField32,
386    A: 'static
387        + AdapterTraceExecutor<F, ReadData: Into<DynArray<u8>>, WriteData: From<DynArray<u8>>>,
388    for<'buf> RA: RecordArena<
389        'buf,
390        FieldExpressionRecordLayout<F, A>,
391        (A::RecordMut<'buf>, FieldExpressionCoreRecordMut<'buf>),
392    >,
393{
394    fn execute(
395        &self,
396        state: VmStateMut<F, TracingMemory, RA>,
397        instruction: &Instruction<F>,
398    ) -> Result<(), ExecutionError> {
399        let (mut adapter_record, mut core_record) = state.ctx.alloc(self.get_record_layout());
400
401        A::start(*state.pc, state.memory, &mut adapter_record);
402
403        let data: DynArray<_> = self
404            .adapter
405            .read(state.memory, instruction, &mut adapter_record)
406            .into();
407
408        core_record.fill_from_execution_data(
409            instruction.opcode.local_opcode_idx(self.offset) as u8,
410            &data.0,
411        );
412
413        let (writes, _, _) = run_field_expression(
414            &self.expr,
415            &self.local_opcode_idx,
416            &self.opcode_flag_idx,
417            core_record.input_limbs,
418            *core_record.opcode as usize,
419        );
420
421        self.adapter.write(
422            state.memory,
423            instruction,
424            writes.into(),
425            &mut adapter_record,
426        );
427
428        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
429        Ok(())
430    }
431
432    fn get_opcode_name(&self, _opcode: usize) -> String {
433        self.name.clone()
434    }
435}
436
437impl<F, A> TraceFiller<F> for FieldExpressionFiller<A>
438where
439    F: PrimeField32 + Send + Sync + Clone,
440    A: 'static + AdapterTraceFiller<F>,
441{
442    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
443        // Get the core record from the row slice
444        // SAFETY: Caller guarantees that row_slice has width A::WIDTH + core width
445        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
446
447        self.adapter.fill_trace_row(mem_helper, adapter_row);
448
449        // SAFETY:
450        // - caller ensures `core_row` contains a valid record representation that was previously
451        //   written by the executor
452        // - core_row slice is transmuted to FieldExpressionCoreRecordMut using the specified
453        //   layout, which satisfies CustomBorrow requirements for safe access.
454        let record: FieldExpressionCoreRecordMut =
455            unsafe { get_record_from_slice(&mut core_row, self.get_record_layout::<F>()) };
456
457        let (_, inputs, flags) = run_field_expression(
458            &self.expr,
459            &self.local_opcode_idx,
460            &self.opcode_flag_idx,
461            record.input_limbs,
462            *record.opcode as usize,
463        );
464
465        let range_checker = self.range_checker.as_ref();
466        self.expr
467            .generate_subrow((range_checker, inputs, flags), core_row);
468    }
469
470    fn fill_dummy_trace_row(&self, row_slice: &mut [F]) {
471        if !self.should_finalize {
472            return;
473        }
474
475        let inputs: Vec<BigUint> = vec![BigUint::zero(); self.num_inputs()];
476        let flags: Vec<bool> = vec![false; self.num_flags()];
477        let core_row = &mut row_slice[A::WIDTH..];
478        // We **do not** want this trace row to update the range checker
479        // so we must create a temporary range checker
480        let tmp_range_checker = Arc::new(VariableRangeCheckerChip::new(self.range_checker.bus()));
481        self.expr
482            .generate_subrow((&tmp_range_checker, inputs, flags), core_row);
483        core_row[0] = F::ZERO; // is_valid = 0
484    }
485}
486
487fn run_field_expression(
488    expr: &FieldExpr,
489    local_opcode_flags: &[usize],
490    opcode_flag_idx: &[usize],
491    data: &[u8],
492    local_opcode_idx: usize,
493) -> (DynArray<u8>, Vec<BigUint>, Vec<bool>) {
494    let field_element_limbs = expr.canonical_num_limbs();
495    assert_eq!(data.len(), expr.builder.num_input * field_element_limbs);
496
497    let mut inputs = Vec::with_capacity(expr.builder.num_input);
498    for i in 0..expr.builder.num_input {
499        let start = i * field_element_limbs;
500        let end = start + field_element_limbs;
501        let limb_slice = &data[start..end];
502        let input = BigUint::from_bytes_le(limb_slice);
503        inputs.push(input);
504    }
505
506    let mut flags = vec![];
507    if expr.needs_setup() {
508        flags = vec![false; expr.builder.num_flags];
509
510        // Find which opcode this is in our local_opcode_idx list
511        if let Some(opcode_position) = local_opcode_flags
512            .iter()
513            .position(|&idx| idx == local_opcode_idx)
514        {
515            // If this is NOT the last opcode (setup), set the corresponding flag
516            if opcode_position < opcode_flag_idx.len() {
517                let flag_idx = opcode_flag_idx[opcode_position];
518                flags[flag_idx] = true;
519            }
520            // If opcode_position == step.opcode_flag_idx.len(), it's the setup operation
521            // and all flags should remain false (which they already are)
522        }
523    }
524
525    let vars = expr.execute(inputs.clone(), flags.clone());
526    assert_eq!(vars.len(), expr.builder.num_variables);
527
528    let outputs: Vec<BigUint> = expr
529        .builder
530        .output_indices
531        .iter()
532        .map(|&i| vars[i].clone())
533        .collect();
534    let writes: DynArray<_> = outputs
535        .iter()
536        .map(|x| biguint_to_limbs_vec(x, field_element_limbs))
537        .concat()
538        .into_iter()
539        .collect::<Vec<_>>()
540        .into();
541
542    (writes, inputs, flags)
543}
544
545#[inline(always)]
546pub fn run_field_expression_precomputed<const NEEDS_SETUP: bool>(
547    expr: &FieldExpr,
548    flag_idx: usize,
549    data: &[u8],
550) -> DynArray<u8> {
551    let field_element_limbs = expr.canonical_num_limbs();
552    assert_eq!(data.len(), expr.num_inputs() * field_element_limbs);
553
554    let mut inputs = Vec::with_capacity(expr.num_inputs());
555    for i in 0..expr.num_inputs() {
556        let start = i * expr.canonical_num_limbs();
557        let end = start + expr.canonical_num_limbs();
558        let limb_slice = &data[start..end];
559        let input = BigUint::from_bytes_le(limb_slice);
560        inputs.push(input);
561    }
562
563    let flags = if NEEDS_SETUP {
564        let mut flags = vec![false; expr.num_flags()];
565        if flag_idx < expr.num_flags() {
566            flags[flag_idx] = true;
567        }
568        flags
569    } else {
570        vec![]
571    };
572
573    let vars = expr.execute(inputs, flags);
574    assert_eq!(vars.len(), expr.num_vars());
575
576    let outputs: Vec<BigUint> = expr
577        .output_indices()
578        .iter()
579        .map(|&i| vars[i].clone())
580        .collect();
581
582    outputs
583        .iter()
584        .map(|x| biguint_to_limbs_vec(x, field_element_limbs))
585        .concat()
586        .into_iter()
587        .collect::<Vec<_>>()
588        .into()
589}