openvm_rv32im_circuit/branch_eq/
core.rs1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4 arch::*,
5 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
6};
7use openvm_circuit_primitives::utils::not;
8use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow};
9use openvm_instructions::{instruction::Instruction, LocalOpcode};
10use openvm_rv32im_transpiler::BranchEqualOpcode;
11use openvm_stark_backend::{
12 interaction::InteractionBuilder,
13 p3_air::{AirBuilder, BaseAir},
14 p3_field::{Field, FieldAlgebra, PrimeField32},
15 rap::BaseAirWithPublicValues,
16};
17use strum::IntoEnumIterator;
18
19#[repr(C)]
20#[derive(AlignedBorrow)]
21pub struct BranchEqualCoreCols<T, const NUM_LIMBS: usize> {
22 pub a: [T; NUM_LIMBS],
23 pub b: [T; NUM_LIMBS],
24
25 pub cmp_result: T,
27 pub imm: T,
28
29 pub opcode_beq_flag: T,
30 pub opcode_bne_flag: T,
31
32 pub diff_inv_marker: [T; NUM_LIMBS],
33}
34
35#[derive(Copy, Clone, Debug, derive_new::new)]
36pub struct BranchEqualCoreAir<const NUM_LIMBS: usize> {
37 offset: usize,
38 pc_step: u32,
39}
40
41impl<F: Field, const NUM_LIMBS: usize> BaseAir<F> for BranchEqualCoreAir<NUM_LIMBS> {
42 fn width(&self) -> usize {
43 BranchEqualCoreCols::<F, NUM_LIMBS>::width()
44 }
45}
46impl<F: Field, const NUM_LIMBS: usize> BaseAirWithPublicValues<F>
47 for BranchEqualCoreAir<NUM_LIMBS>
48{
49}
50
51impl<AB, I, const NUM_LIMBS: usize> VmCoreAir<AB, I> for BranchEqualCoreAir<NUM_LIMBS>
52where
53 AB: InteractionBuilder,
54 I: VmAdapterInterface<AB::Expr>,
55 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
56 I::Writes: Default,
57 I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
58{
59 fn eval(
60 &self,
61 builder: &mut AB,
62 local: &[AB::Var],
63 from_pc: AB::Var,
64 ) -> AdapterAirContext<AB::Expr, I> {
65 let cols: &BranchEqualCoreCols<_, NUM_LIMBS> = local.borrow();
66 let flags = [cols.opcode_beq_flag, cols.opcode_bne_flag];
67
68 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
69 builder.assert_bool(flag);
70 acc + flag.into()
71 });
72 builder.assert_bool(is_valid.clone());
73 builder.assert_bool(cols.cmp_result);
74
75 let a = &cols.a;
76 let b = &cols.b;
77 let inv_marker = &cols.diff_inv_marker;
78
79 let cmp_eq =
81 cols.cmp_result * cols.opcode_beq_flag + not(cols.cmp_result) * cols.opcode_bne_flag;
82 let mut sum = cmp_eq.clone();
83
84 for i in 0..NUM_LIMBS {
97 sum += (a[i] - b[i]) * inv_marker[i];
98 builder.assert_zero(cmp_eq.clone() * (a[i] - b[i]));
99 }
100 builder.when(is_valid.clone()).assert_one(sum);
101
102 let expected_opcode = flags
103 .iter()
104 .zip(BranchEqualOpcode::iter())
105 .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
106 acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
107 })
108 + AB::Expr::from_canonical_usize(self.offset);
109
110 let to_pc = from_pc
111 + cols.cmp_result * cols.imm
112 + not(cols.cmp_result) * AB::Expr::from_canonical_u32(self.pc_step);
113
114 AdapterAirContext {
115 to_pc: Some(to_pc),
116 reads: [cols.a.map(Into::into), cols.b.map(Into::into)].into(),
117 writes: Default::default(),
118 instruction: ImmInstruction {
119 is_valid,
120 opcode: expected_opcode,
121 immediate: cols.imm.into(),
122 }
123 .into(),
124 }
125 }
126
127 fn start_offset(&self) -> usize {
128 self.offset
129 }
130}
131
132#[repr(C)]
133#[derive(AlignedBytesBorrow, Debug)]
134pub struct BranchEqualCoreRecord<const NUM_LIMBS: usize> {
135 pub a: [u8; NUM_LIMBS],
136 pub b: [u8; NUM_LIMBS],
137 pub imm: u32,
138 pub local_opcode: u8,
139}
140
141#[derive(Clone, Copy, derive_new::new)]
142pub struct BranchEqualExecutor<A, const NUM_LIMBS: usize> {
143 adapter: A,
144 pub offset: usize,
145 pub pc_step: u32,
146}
147
148#[derive(Clone, Copy, derive_new::new)]
149pub struct BranchEqualFiller<A, const NUM_LIMBS: usize> {
150 adapter: A,
151 pub offset: usize,
152 pub pc_step: u32,
153}
154
155impl<F, A, RA, const NUM_LIMBS: usize> PreflightExecutor<F, RA>
156 for BranchEqualExecutor<A, NUM_LIMBS>
157where
158 F: PrimeField32,
159 A: 'static + AdapterTraceExecutor<F, ReadData: Into<[[u8; NUM_LIMBS]; 2]>, WriteData = ()>,
160 for<'buf> RA: RecordArena<
161 'buf,
162 EmptyAdapterCoreLayout<F, A>,
163 (
164 A::RecordMut<'buf>,
165 &'buf mut BranchEqualCoreRecord<NUM_LIMBS>,
166 ),
167 >,
168{
169 fn get_opcode_name(&self, opcode: usize) -> String {
170 format!("{:?}", BranchEqualOpcode::from_usize(opcode - self.offset))
171 }
172
173 fn execute(
174 &self,
175 state: VmStateMut<F, TracingMemory, RA>,
176 instruction: &Instruction<F>,
177 ) -> Result<(), ExecutionError> {
178 let &Instruction { opcode, c: imm, .. } = instruction;
179
180 let branch_eq_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset));
181
182 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
183
184 A::start(*state.pc, state.memory, &mut adapter_record);
185
186 let [rs1, rs2] = self
187 .adapter
188 .read(state.memory, instruction, &mut adapter_record)
189 .into();
190
191 core_record.a = rs1;
192 core_record.b = rs2;
193 core_record.imm = imm.as_canonical_u32();
194 core_record.local_opcode = branch_eq_opcode as u8;
195
196 if fast_run_eq(branch_eq_opcode, &rs1, &rs2) {
197 *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32();
198 } else {
199 *state.pc = state.pc.wrapping_add(self.pc_step);
200 }
201
202 Ok(())
203 }
204}
205
206impl<F, A, const NUM_LIMBS: usize> TraceFiller<F> for BranchEqualFiller<A, NUM_LIMBS>
207where
208 F: PrimeField32,
209 A: 'static + AdapterTraceFiller<F>,
210{
211 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
212 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
215 self.adapter.fill_trace_row(mem_helper, adapter_row);
216 let record: &BranchEqualCoreRecord<NUM_LIMBS> =
219 unsafe { get_record_from_slice(&mut core_row, ()) };
220 let core_row: &mut BranchEqualCoreCols<F, NUM_LIMBS> = core_row.borrow_mut();
221
222 let (cmp_result, diff_idx, diff_inv_val) = run_eq::<F, NUM_LIMBS>(
223 record.local_opcode == BranchEqualOpcode::BEQ as u8,
224 &record.a,
225 &record.b,
226 );
227 core_row.diff_inv_marker = [F::ZERO; NUM_LIMBS];
228 core_row.diff_inv_marker[diff_idx] = diff_inv_val;
229
230 core_row.opcode_bne_flag =
231 F::from_bool(record.local_opcode == BranchEqualOpcode::BNE as u8);
232 core_row.opcode_beq_flag =
233 F::from_bool(record.local_opcode == BranchEqualOpcode::BEQ as u8);
234
235 core_row.imm = F::from_canonical_u32(record.imm);
236 core_row.cmp_result = F::from_bool(cmp_result);
237
238 core_row.b = record.b.map(F::from_canonical_u8);
239 core_row.a = record.a.map(F::from_canonical_u8);
240 }
241}
242
243#[inline(always)]
245pub(super) fn fast_run_eq<const NUM_LIMBS: usize>(
246 local_opcode: BranchEqualOpcode,
247 x: &[u8; NUM_LIMBS],
248 y: &[u8; NUM_LIMBS],
249) -> bool {
250 match local_opcode {
251 BranchEqualOpcode::BEQ => x == y,
252 BranchEqualOpcode::BNE => x != y,
253 }
254}
255
256#[inline(always)]
258pub(super) fn run_eq<F, const NUM_LIMBS: usize>(
259 is_beq: bool,
260 x: &[u8; NUM_LIMBS],
261 y: &[u8; NUM_LIMBS],
262) -> (bool, usize, F)
263where
264 F: PrimeField32,
265{
266 for i in 0..NUM_LIMBS {
267 if x[i] != y[i] {
268 return (
269 !is_beq,
270 i,
271 (F::from_canonical_u8(x[i]) - F::from_canonical_u8(y[i])).inverse(),
272 );
273 }
274 }
275 (is_beq, 0, F::ZERO)
276}