openvm_rv32im_circuit/branch_eq/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface,
8    VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::utils::not;
11use openvm_circuit_primitives_derive::AlignedBorrow;
12use openvm_instructions::{instruction::Instruction, LocalOpcode};
13use openvm_rv32im_transpiler::BranchEqualOpcode;
14use openvm_stark_backend::{
15    interaction::InteractionBuilder,
16    p3_air::{AirBuilder, BaseAir},
17    p3_field::{Field, FieldAlgebra, PrimeField32},
18    rap::BaseAirWithPublicValues,
19};
20use serde::{Deserialize, Serialize};
21use serde_big_array::BigArray;
22use strum::IntoEnumIterator;
23
24#[repr(C)]
25#[derive(AlignedBorrow)]
26pub struct BranchEqualCoreCols<T, const NUM_LIMBS: usize> {
27    pub a: [T; NUM_LIMBS],
28    pub b: [T; NUM_LIMBS],
29
30    // Boolean result of a op b. Should branch if and only if cmp_result = 1.
31    pub cmp_result: T,
32    pub imm: T,
33
34    pub opcode_beq_flag: T,
35    pub opcode_bne_flag: T,
36
37    pub diff_inv_marker: [T; NUM_LIMBS],
38}
39
40#[derive(Copy, Clone, Debug)]
41pub struct BranchEqualCoreAir<const NUM_LIMBS: usize> {
42    offset: usize,
43    pc_step: u32,
44}
45
46impl<F: Field, const NUM_LIMBS: usize> BaseAir<F> for BranchEqualCoreAir<NUM_LIMBS> {
47    fn width(&self) -> usize {
48        BranchEqualCoreCols::<F, NUM_LIMBS>::width()
49    }
50}
51impl<F: Field, const NUM_LIMBS: usize> BaseAirWithPublicValues<F>
52    for BranchEqualCoreAir<NUM_LIMBS>
53{
54}
55
56impl<AB, I, const NUM_LIMBS: usize> VmCoreAir<AB, I> for BranchEqualCoreAir<NUM_LIMBS>
57where
58    AB: InteractionBuilder,
59    I: VmAdapterInterface<AB::Expr>,
60    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
61    I::Writes: Default,
62    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
63{
64    fn eval(
65        &self,
66        builder: &mut AB,
67        local: &[AB::Var],
68        from_pc: AB::Var,
69    ) -> AdapterAirContext<AB::Expr, I> {
70        let cols: &BranchEqualCoreCols<_, NUM_LIMBS> = local.borrow();
71        let flags = [cols.opcode_beq_flag, cols.opcode_bne_flag];
72
73        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
74            builder.assert_bool(flag);
75            acc + flag.into()
76        });
77        builder.assert_bool(is_valid.clone());
78        builder.assert_bool(cols.cmp_result);
79
80        let a = &cols.a;
81        let b = &cols.b;
82        let inv_marker = &cols.diff_inv_marker;
83
84        // 1 if cmp_result indicates a and b are equal, 0 otherwise
85        let cmp_eq =
86            cols.cmp_result * cols.opcode_beq_flag + not(cols.cmp_result) * cols.opcode_bne_flag;
87        let mut sum = cmp_eq.clone();
88
89        // For BEQ, inv_marker is used to check equality of a and b:
90        // - If a == b, all inv_marker values must be 0 (sum = 0)
91        // - If a != b, inv_marker contains 0s for all positions except ONE position i where a[i] != b[i]
92        // - At this position, inv_marker[i] contains the multiplicative inverse of (a[i] - b[i])
93        // - This ensures inv_marker[i] * (a[i] - b[i]) = 1, making the sum = 1
94        // Note: There might be multiple valid inv_marker if a != b.
95        // But as long as the trace can provide at least one, that’s sufficient to prove a != b.
96        //
97        // Note:
98        // - If cmp_eq == 0, then it is impossible to have sum != 0 if a == b.
99        // - If cmp_eq == 1, then it is impossible for a[i] - b[i] == 0 to pass for all i if a != b.
100        for i in 0..NUM_LIMBS {
101            sum += (a[i] - b[i]) * inv_marker[i];
102            builder.assert_zero(cmp_eq.clone() * (a[i] - b[i]));
103        }
104        builder.when(is_valid.clone()).assert_one(sum);
105
106        let expected_opcode = flags
107            .iter()
108            .zip(BranchEqualOpcode::iter())
109            .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
110                acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
111            })
112            + AB::Expr::from_canonical_usize(self.offset);
113
114        let to_pc = from_pc
115            + cols.cmp_result * cols.imm
116            + not(cols.cmp_result) * AB::Expr::from_canonical_u32(self.pc_step);
117
118        AdapterAirContext {
119            to_pc: Some(to_pc),
120            reads: [cols.a.map(Into::into), cols.b.map(Into::into)].into(),
121            writes: Default::default(),
122            instruction: ImmInstruction {
123                is_valid,
124                opcode: expected_opcode,
125                immediate: cols.imm.into(),
126            }
127            .into(),
128        }
129    }
130
131    fn start_offset(&self) -> usize {
132        self.offset
133    }
134}
135
136#[repr(C)]
137#[derive(Clone, Debug, Serialize, Deserialize)]
138pub struct BranchEqualCoreRecord<T, const NUM_LIMBS: usize> {
139    #[serde(with = "BigArray")]
140    pub a: [T; NUM_LIMBS],
141    #[serde(with = "BigArray")]
142    pub b: [T; NUM_LIMBS],
143    pub cmp_result: T,
144    pub imm: T,
145    pub diff_inv_val: T,
146    pub diff_idx: usize,
147    pub opcode: BranchEqualOpcode,
148}
149
150#[derive(Debug)]
151pub struct BranchEqualCoreChip<const NUM_LIMBS: usize> {
152    pub air: BranchEqualCoreAir<NUM_LIMBS>,
153}
154
155impl<const NUM_LIMBS: usize> BranchEqualCoreChip<NUM_LIMBS> {
156    pub fn new(offset: usize, pc_step: u32) -> Self {
157        Self {
158            air: BranchEqualCoreAir { offset, pc_step },
159        }
160    }
161}
162
163impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize> VmCoreChip<F, I>
164    for BranchEqualCoreChip<NUM_LIMBS>
165where
166    I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
167    I::Writes: Default,
168{
169    type Record = BranchEqualCoreRecord<F, NUM_LIMBS>;
170    type Air = BranchEqualCoreAir<NUM_LIMBS>;
171
172    #[allow(clippy::type_complexity)]
173    fn execute_instruction(
174        &self,
175        instruction: &Instruction<F>,
176        from_pc: u32,
177        reads: I::Reads,
178    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
179        let Instruction { opcode, c: imm, .. } = *instruction;
180        let branch_eq_opcode =
181            BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
182
183        let data: [[F; NUM_LIMBS]; 2] = reads.into();
184        let x = data[0].map(|x| x.as_canonical_u32());
185        let y = data[1].map(|y| y.as_canonical_u32());
186        let (cmp_result, diff_idx, diff_inv_val) = run_eq::<F, NUM_LIMBS>(branch_eq_opcode, &x, &y);
187
188        let output = AdapterRuntimeContext {
189            to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()),
190            writes: Default::default(),
191        };
192        let record = BranchEqualCoreRecord {
193            opcode: branch_eq_opcode,
194            a: data[0],
195            b: data[1],
196            cmp_result: F::from_bool(cmp_result),
197            imm,
198            diff_idx,
199            diff_inv_val,
200        };
201
202        Ok((output, record))
203    }
204
205    fn get_opcode_name(&self, opcode: usize) -> String {
206        format!(
207            "{:?}",
208            BranchEqualOpcode::from_usize(opcode - self.air.offset)
209        )
210    }
211
212    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
213        let row_slice: &mut BranchEqualCoreCols<_, NUM_LIMBS> = row_slice.borrow_mut();
214        row_slice.a = record.a;
215        row_slice.b = record.b;
216        row_slice.cmp_result = record.cmp_result;
217        row_slice.imm = record.imm;
218        row_slice.opcode_beq_flag = F::from_bool(record.opcode == BranchEqualOpcode::BEQ);
219        row_slice.opcode_bne_flag = F::from_bool(record.opcode == BranchEqualOpcode::BNE);
220        row_slice.diff_inv_marker = array::from_fn(|i| {
221            if i == record.diff_idx {
222                record.diff_inv_val
223            } else {
224                F::ZERO
225            }
226        });
227    }
228
229    fn air(&self) -> &Self::Air {
230        &self.air
231    }
232}
233
234// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx])
235pub(super) fn run_eq<F: PrimeField32, const NUM_LIMBS: usize>(
236    local_opcode: BranchEqualOpcode,
237    x: &[u32; NUM_LIMBS],
238    y: &[u32; NUM_LIMBS],
239) -> (bool, usize, F) {
240    for i in 0..NUM_LIMBS {
241        if x[i] != y[i] {
242            return (
243                local_opcode == BranchEqualOpcode::BNE,
244                i,
245                (F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])).inverse(),
246            );
247        }
248    }
249    (local_opcode == BranchEqualOpcode::BEQ, 0, F::ZERO)
250}