openvm_rv32im_circuit/branch_eq/
core.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::*,
5    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
6};
7use openvm_circuit_primitives::utils::not;
8use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow};
9use openvm_instructions::{instruction::Instruction, LocalOpcode};
10use openvm_rv32im_transpiler::BranchEqualOpcode;
11use openvm_stark_backend::{
12    interaction::InteractionBuilder,
13    p3_air::{AirBuilder, BaseAir},
14    p3_field::{Field, FieldAlgebra, PrimeField32},
15    rap::BaseAirWithPublicValues,
16};
17use strum::IntoEnumIterator;
18
19#[repr(C)]
20#[derive(AlignedBorrow)]
21pub struct BranchEqualCoreCols<T, const NUM_LIMBS: usize> {
22    pub a: [T; NUM_LIMBS],
23    pub b: [T; NUM_LIMBS],
24
25    // Boolean result of a op b. Should branch if and only if cmp_result = 1.
26    pub cmp_result: T,
27    pub imm: T,
28
29    pub opcode_beq_flag: T,
30    pub opcode_bne_flag: T,
31
32    pub diff_inv_marker: [T; NUM_LIMBS],
33}
34
35#[derive(Copy, Clone, Debug, derive_new::new)]
36pub struct BranchEqualCoreAir<const NUM_LIMBS: usize> {
37    offset: usize,
38    pc_step: u32,
39}
40
41impl<F: Field, const NUM_LIMBS: usize> BaseAir<F> for BranchEqualCoreAir<NUM_LIMBS> {
42    fn width(&self) -> usize {
43        BranchEqualCoreCols::<F, NUM_LIMBS>::width()
44    }
45}
46impl<F: Field, const NUM_LIMBS: usize> BaseAirWithPublicValues<F>
47    for BranchEqualCoreAir<NUM_LIMBS>
48{
49}
50
51impl<AB, I, const NUM_LIMBS: usize> VmCoreAir<AB, I> for BranchEqualCoreAir<NUM_LIMBS>
52where
53    AB: InteractionBuilder,
54    I: VmAdapterInterface<AB::Expr>,
55    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
56    I::Writes: Default,
57    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
58{
59    fn eval(
60        &self,
61        builder: &mut AB,
62        local: &[AB::Var],
63        from_pc: AB::Var,
64    ) -> AdapterAirContext<AB::Expr, I> {
65        let cols: &BranchEqualCoreCols<_, NUM_LIMBS> = local.borrow();
66        let flags = [cols.opcode_beq_flag, cols.opcode_bne_flag];
67
68        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
69            builder.assert_bool(flag);
70            acc + flag.into()
71        });
72        builder.assert_bool(is_valid.clone());
73        builder.assert_bool(cols.cmp_result);
74
75        let a = &cols.a;
76        let b = &cols.b;
77        let inv_marker = &cols.diff_inv_marker;
78
79        // 1 if cmp_result indicates a and b are equal, 0 otherwise
80        let cmp_eq =
81            cols.cmp_result * cols.opcode_beq_flag + not(cols.cmp_result) * cols.opcode_bne_flag;
82        let mut sum = cmp_eq.clone();
83
84        // For BEQ, inv_marker is used to check equality of a and b:
85        // - If a == b, all inv_marker values must be 0 (sum = 0)
86        // - If a != b, inv_marker contains 0s for all positions except ONE position i where a[i] !=
87        //   b[i]
88        // - At this position, inv_marker[i] contains the multiplicative inverse of (a[i] - b[i])
89        // - This ensures inv_marker[i] * (a[i] - b[i]) = 1, making the sum = 1
90        // Note: There might be multiple valid inv_marker if a != b.
91        // But as long as the trace can provide at least one, that’s sufficient to prove a != b.
92        //
93        // Note:
94        // - If cmp_eq == 0, then it is impossible to have sum != 0 if a == b.
95        // - If cmp_eq == 1, then it is impossible for a[i] - b[i] == 0 to pass for all i if a != b.
96        for i in 0..NUM_LIMBS {
97            sum += (a[i] - b[i]) * inv_marker[i];
98            builder.assert_zero(cmp_eq.clone() * (a[i] - b[i]));
99        }
100        builder.when(is_valid.clone()).assert_one(sum);
101
102        let expected_opcode = flags
103            .iter()
104            .zip(BranchEqualOpcode::iter())
105            .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
106                acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
107            })
108            + AB::Expr::from_canonical_usize(self.offset);
109
110        let to_pc = from_pc
111            + cols.cmp_result * cols.imm
112            + not(cols.cmp_result) * AB::Expr::from_canonical_u32(self.pc_step);
113
114        AdapterAirContext {
115            to_pc: Some(to_pc),
116            reads: [cols.a.map(Into::into), cols.b.map(Into::into)].into(),
117            writes: Default::default(),
118            instruction: ImmInstruction {
119                is_valid,
120                opcode: expected_opcode,
121                immediate: cols.imm.into(),
122            }
123            .into(),
124        }
125    }
126
127    fn start_offset(&self) -> usize {
128        self.offset
129    }
130}
131
132#[repr(C)]
133#[derive(AlignedBytesBorrow, Debug)]
134pub struct BranchEqualCoreRecord<const NUM_LIMBS: usize> {
135    pub a: [u8; NUM_LIMBS],
136    pub b: [u8; NUM_LIMBS],
137    pub imm: u32,
138    pub local_opcode: u8,
139}
140
141#[derive(Clone, Copy, derive_new::new)]
142pub struct BranchEqualExecutor<A, const NUM_LIMBS: usize> {
143    adapter: A,
144    pub offset: usize,
145    pub pc_step: u32,
146}
147
148#[derive(Clone, Copy, derive_new::new)]
149pub struct BranchEqualFiller<A, const NUM_LIMBS: usize> {
150    adapter: A,
151    pub offset: usize,
152    pub pc_step: u32,
153}
154
155impl<F, A, RA, const NUM_LIMBS: usize> PreflightExecutor<F, RA>
156    for BranchEqualExecutor<A, NUM_LIMBS>
157where
158    F: PrimeField32,
159    A: 'static + AdapterTraceExecutor<F, ReadData: Into<[[u8; NUM_LIMBS]; 2]>, WriteData = ()>,
160    for<'buf> RA: RecordArena<
161        'buf,
162        EmptyAdapterCoreLayout<F, A>,
163        (
164            A::RecordMut<'buf>,
165            &'buf mut BranchEqualCoreRecord<NUM_LIMBS>,
166        ),
167    >,
168{
169    fn get_opcode_name(&self, opcode: usize) -> String {
170        format!("{:?}", BranchEqualOpcode::from_usize(opcode - self.offset))
171    }
172
173    fn execute(
174        &self,
175        state: VmStateMut<F, TracingMemory, RA>,
176        instruction: &Instruction<F>,
177    ) -> Result<(), ExecutionError> {
178        let &Instruction { opcode, c: imm, .. } = instruction;
179
180        let branch_eq_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset));
181
182        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
183
184        A::start(*state.pc, state.memory, &mut adapter_record);
185
186        let [rs1, rs2] = self
187            .adapter
188            .read(state.memory, instruction, &mut adapter_record)
189            .into();
190
191        core_record.a = rs1;
192        core_record.b = rs2;
193        core_record.imm = imm.as_canonical_u32();
194        core_record.local_opcode = branch_eq_opcode as u8;
195
196        if fast_run_eq(branch_eq_opcode, &rs1, &rs2) {
197            *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32();
198        } else {
199            *state.pc = state.pc.wrapping_add(self.pc_step);
200        }
201
202        Ok(())
203    }
204}
205
206impl<F, A, const NUM_LIMBS: usize> TraceFiller<F> for BranchEqualFiller<A, NUM_LIMBS>
207where
208    F: PrimeField32,
209    A: 'static + AdapterTraceFiller<F>,
210{
211    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
212        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
213        // BranchEqualCoreCols::width() elements
214        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
215        self.adapter.fill_trace_row(mem_helper, adapter_row);
216        // SAFETY: core_row contains a valid BranchEqualCoreRecord written by the executor
217        // during trace generation
218        let record: &BranchEqualCoreRecord<NUM_LIMBS> =
219            unsafe { get_record_from_slice(&mut core_row, ()) };
220        let core_row: &mut BranchEqualCoreCols<F, NUM_LIMBS> = core_row.borrow_mut();
221
222        let (cmp_result, diff_idx, diff_inv_val) = run_eq::<F, NUM_LIMBS>(
223            record.local_opcode == BranchEqualOpcode::BEQ as u8,
224            &record.a,
225            &record.b,
226        );
227        core_row.diff_inv_marker = [F::ZERO; NUM_LIMBS];
228        core_row.diff_inv_marker[diff_idx] = diff_inv_val;
229
230        core_row.opcode_bne_flag =
231            F::from_bool(record.local_opcode == BranchEqualOpcode::BNE as u8);
232        core_row.opcode_beq_flag =
233            F::from_bool(record.local_opcode == BranchEqualOpcode::BEQ as u8);
234
235        core_row.imm = F::from_canonical_u32(record.imm);
236        core_row.cmp_result = F::from_bool(cmp_result);
237
238        core_row.b = record.b.map(F::from_canonical_u8);
239        core_row.a = record.a.map(F::from_canonical_u8);
240    }
241}
242
243// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx])
244#[inline(always)]
245pub(super) fn fast_run_eq<const NUM_LIMBS: usize>(
246    local_opcode: BranchEqualOpcode,
247    x: &[u8; NUM_LIMBS],
248    y: &[u8; NUM_LIMBS],
249) -> bool {
250    match local_opcode {
251        BranchEqualOpcode::BEQ => x == y,
252        BranchEqualOpcode::BNE => x != y,
253    }
254}
255
256// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx])
257#[inline(always)]
258pub(super) fn run_eq<F, const NUM_LIMBS: usize>(
259    is_beq: bool,
260    x: &[u8; NUM_LIMBS],
261    y: &[u8; NUM_LIMBS],
262) -> (bool, usize, F)
263where
264    F: PrimeField32,
265{
266    for i in 0..NUM_LIMBS {
267        if x[i] != y[i] {
268            return (
269                !is_beq,
270                i,
271                (F::from_canonical_u8(x[i]) - F::from_canonical_u8(y[i])).inverse(),
272            );
273        }
274    }
275    (is_beq, 0, F::ZERO)
276}