openvm_rv32im_circuit/branch_eq/
core.rs
1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface,
8 VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::utils::not;
11use openvm_circuit_primitives_derive::AlignedBorrow;
12use openvm_instructions::{instruction::Instruction, LocalOpcode};
13use openvm_rv32im_transpiler::BranchEqualOpcode;
14use openvm_stark_backend::{
15 interaction::InteractionBuilder,
16 p3_air::{AirBuilder, BaseAir},
17 p3_field::{Field, FieldAlgebra, PrimeField32},
18 rap::BaseAirWithPublicValues,
19};
20use serde::{Deserialize, Serialize};
21use serde_big_array::BigArray;
22use strum::IntoEnumIterator;
23
24#[repr(C)]
25#[derive(AlignedBorrow)]
26pub struct BranchEqualCoreCols<T, const NUM_LIMBS: usize> {
27 pub a: [T; NUM_LIMBS],
28 pub b: [T; NUM_LIMBS],
29
30 pub cmp_result: T,
32 pub imm: T,
33
34 pub opcode_beq_flag: T,
35 pub opcode_bne_flag: T,
36
37 pub diff_inv_marker: [T; NUM_LIMBS],
38}
39
40#[derive(Copy, Clone, Debug)]
41pub struct BranchEqualCoreAir<const NUM_LIMBS: usize> {
42 offset: usize,
43 pc_step: u32,
44}
45
46impl<F: Field, const NUM_LIMBS: usize> BaseAir<F> for BranchEqualCoreAir<NUM_LIMBS> {
47 fn width(&self) -> usize {
48 BranchEqualCoreCols::<F, NUM_LIMBS>::width()
49 }
50}
51impl<F: Field, const NUM_LIMBS: usize> BaseAirWithPublicValues<F>
52 for BranchEqualCoreAir<NUM_LIMBS>
53{
54}
55
56impl<AB, I, const NUM_LIMBS: usize> VmCoreAir<AB, I> for BranchEqualCoreAir<NUM_LIMBS>
57where
58 AB: InteractionBuilder,
59 I: VmAdapterInterface<AB::Expr>,
60 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
61 I::Writes: Default,
62 I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
63{
64 fn eval(
65 &self,
66 builder: &mut AB,
67 local: &[AB::Var],
68 from_pc: AB::Var,
69 ) -> AdapterAirContext<AB::Expr, I> {
70 let cols: &BranchEqualCoreCols<_, NUM_LIMBS> = local.borrow();
71 let flags = [cols.opcode_beq_flag, cols.opcode_bne_flag];
72
73 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
74 builder.assert_bool(flag);
75 acc + flag.into()
76 });
77 builder.assert_bool(is_valid.clone());
78 builder.assert_bool(cols.cmp_result);
79
80 let a = &cols.a;
81 let b = &cols.b;
82 let inv_marker = &cols.diff_inv_marker;
83
84 let cmp_eq =
86 cols.cmp_result * cols.opcode_beq_flag + not(cols.cmp_result) * cols.opcode_bne_flag;
87 let mut sum = cmp_eq.clone();
88
89 for i in 0..NUM_LIMBS {
101 sum += (a[i] - b[i]) * inv_marker[i];
102 builder.assert_zero(cmp_eq.clone() * (a[i] - b[i]));
103 }
104 builder.when(is_valid.clone()).assert_one(sum);
105
106 let expected_opcode = flags
107 .iter()
108 .zip(BranchEqualOpcode::iter())
109 .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
110 acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
111 })
112 + AB::Expr::from_canonical_usize(self.offset);
113
114 let to_pc = from_pc
115 + cols.cmp_result * cols.imm
116 + not(cols.cmp_result) * AB::Expr::from_canonical_u32(self.pc_step);
117
118 AdapterAirContext {
119 to_pc: Some(to_pc),
120 reads: [cols.a.map(Into::into), cols.b.map(Into::into)].into(),
121 writes: Default::default(),
122 instruction: ImmInstruction {
123 is_valid,
124 opcode: expected_opcode,
125 immediate: cols.imm.into(),
126 }
127 .into(),
128 }
129 }
130
131 fn start_offset(&self) -> usize {
132 self.offset
133 }
134}
135
136#[repr(C)]
137#[derive(Clone, Debug, Serialize, Deserialize)]
138pub struct BranchEqualCoreRecord<T, const NUM_LIMBS: usize> {
139 #[serde(with = "BigArray")]
140 pub a: [T; NUM_LIMBS],
141 #[serde(with = "BigArray")]
142 pub b: [T; NUM_LIMBS],
143 pub cmp_result: T,
144 pub imm: T,
145 pub diff_inv_val: T,
146 pub diff_idx: usize,
147 pub opcode: BranchEqualOpcode,
148}
149
150#[derive(Debug)]
151pub struct BranchEqualCoreChip<const NUM_LIMBS: usize> {
152 pub air: BranchEqualCoreAir<NUM_LIMBS>,
153}
154
155impl<const NUM_LIMBS: usize> BranchEqualCoreChip<NUM_LIMBS> {
156 pub fn new(offset: usize, pc_step: u32) -> Self {
157 Self {
158 air: BranchEqualCoreAir { offset, pc_step },
159 }
160 }
161}
162
163impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize> VmCoreChip<F, I>
164 for BranchEqualCoreChip<NUM_LIMBS>
165where
166 I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
167 I::Writes: Default,
168{
169 type Record = BranchEqualCoreRecord<F, NUM_LIMBS>;
170 type Air = BranchEqualCoreAir<NUM_LIMBS>;
171
172 #[allow(clippy::type_complexity)]
173 fn execute_instruction(
174 &self,
175 instruction: &Instruction<F>,
176 from_pc: u32,
177 reads: I::Reads,
178 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
179 let Instruction { opcode, c: imm, .. } = *instruction;
180 let branch_eq_opcode =
181 BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
182
183 let data: [[F; NUM_LIMBS]; 2] = reads.into();
184 let x = data[0].map(|x| x.as_canonical_u32());
185 let y = data[1].map(|y| y.as_canonical_u32());
186 let (cmp_result, diff_idx, diff_inv_val) = run_eq::<F, NUM_LIMBS>(branch_eq_opcode, &x, &y);
187
188 let output = AdapterRuntimeContext {
189 to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()),
190 writes: Default::default(),
191 };
192 let record = BranchEqualCoreRecord {
193 opcode: branch_eq_opcode,
194 a: data[0],
195 b: data[1],
196 cmp_result: F::from_bool(cmp_result),
197 imm,
198 diff_idx,
199 diff_inv_val,
200 };
201
202 Ok((output, record))
203 }
204
205 fn get_opcode_name(&self, opcode: usize) -> String {
206 format!(
207 "{:?}",
208 BranchEqualOpcode::from_usize(opcode - self.air.offset)
209 )
210 }
211
212 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
213 let row_slice: &mut BranchEqualCoreCols<_, NUM_LIMBS> = row_slice.borrow_mut();
214 row_slice.a = record.a;
215 row_slice.b = record.b;
216 row_slice.cmp_result = record.cmp_result;
217 row_slice.imm = record.imm;
218 row_slice.opcode_beq_flag = F::from_bool(record.opcode == BranchEqualOpcode::BEQ);
219 row_slice.opcode_bne_flag = F::from_bool(record.opcode == BranchEqualOpcode::BNE);
220 row_slice.diff_inv_marker = array::from_fn(|i| {
221 if i == record.diff_idx {
222 record.diff_inv_val
223 } else {
224 F::ZERO
225 }
226 });
227 }
228
229 fn air(&self) -> &Self::Air {
230 &self.air
231 }
232}
233
234pub(super) fn run_eq<F: PrimeField32, const NUM_LIMBS: usize>(
236 local_opcode: BranchEqualOpcode,
237 x: &[u32; NUM_LIMBS],
238 y: &[u32; NUM_LIMBS],
239) -> (bool, usize, F) {
240 for i in 0..NUM_LIMBS {
241 if x[i] != y[i] {
242 return (
243 local_opcode == BranchEqualOpcode::BNE,
244 i,
245 (F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])).inverse(),
246 );
247 }
248 }
249 (local_opcode == BranchEqualOpcode::BEQ, 0, F::ZERO)
250}