1use std::{
2 array::{self, from_fn},
3 borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::arch::{
9 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
10 VmCoreAir, VmCoreChip,
11};
12use openvm_circuit_primitives::{
13 bigint::utils::big_uint_to_limbs,
14 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
15 is_equal_array::{IsEqArrayIo, IsEqArraySubAir},
16 SubAir, TraceSubRowGenerator,
17};
18use openvm_circuit_primitives_derive::AlignedBorrow;
19use openvm_instructions::{instruction::Instruction, LocalOpcode};
20use openvm_stark_backend::{
21 interaction::InteractionBuilder,
22 p3_air::{AirBuilder, BaseAir},
23 p3_field::{Field, FieldAlgebra, PrimeField32},
24 rap::BaseAirWithPublicValues,
25};
26use serde::{Deserialize, Serialize};
27use serde_big_array::BigArray;
28#[repr(C)]
33#[derive(AlignedBorrow)]
34pub struct ModularIsEqualCoreCols<T, const READ_LIMBS: usize> {
35 pub is_valid: T,
36 pub is_setup: T,
37 pub b: [T; READ_LIMBS],
38 pub c: [T; READ_LIMBS],
39 pub cmp_result: T,
40
41 pub eq_marker: [T; READ_LIMBS],
43
44 pub lt_marker: [T; READ_LIMBS],
56 pub b_lt_diff: T,
57 pub c_lt_diff: T,
58 pub c_lt_mark: T,
59}
60
61#[derive(Clone, Debug)]
62pub struct ModularIsEqualCoreAir<
63 const READ_LIMBS: usize,
64 const WRITE_LIMBS: usize,
65 const LIMB_BITS: usize,
66> {
67 pub bus: BitwiseOperationLookupBus,
68 pub subair: IsEqArraySubAir<READ_LIMBS>,
69 pub modulus_limbs: [u32; READ_LIMBS],
70 pub offset: usize,
71}
72
73impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
74 ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
75{
76 pub fn new(modulus: BigUint, bus: BitwiseOperationLookupBus, offset: usize) -> Self {
77 let mod_vec = big_uint_to_limbs(&modulus, LIMB_BITS);
78 assert!(mod_vec.len() <= READ_LIMBS);
79 let modulus_limbs = array::from_fn(|i| {
80 if i < mod_vec.len() {
81 mod_vec[i] as u32
82 } else {
83 0
84 }
85 });
86 Self {
87 bus,
88 subair: IsEqArraySubAir::<READ_LIMBS>,
89 modulus_limbs,
90 offset,
91 }
92 }
93}
94
95impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
96 for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
97{
98 fn width(&self) -> usize {
99 ModularIsEqualCoreCols::<F, READ_LIMBS>::width()
100 }
101}
102impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
103 BaseAirWithPublicValues<F> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
104{
105}
106
107impl<AB, I, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
108 VmCoreAir<AB, I> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
109where
110 AB: InteractionBuilder,
111 I: VmAdapterInterface<AB::Expr>,
112 I::Reads: From<[[AB::Expr; READ_LIMBS]; 2]>,
113 I::Writes: From<[[AB::Expr; WRITE_LIMBS]; 1]>,
114 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
115{
116 fn eval(
117 &self,
118 builder: &mut AB,
119 local_core: &[AB::Var],
120 _from_pc: AB::Var,
121 ) -> AdapterAirContext<AB::Expr, I> {
122 let cols: &ModularIsEqualCoreCols<_, READ_LIMBS> = local_core.borrow();
123
124 builder.assert_bool(cols.is_valid);
125 builder.assert_bool(cols.is_setup);
126 builder.when(cols.is_setup).assert_one(cols.is_valid);
127 builder.assert_bool(cols.cmp_result);
128
129 let eq_subair_io = IsEqArrayIo {
131 x: cols.b.map(Into::into),
132 y: cols.c.map(Into::into),
133 out: cols.cmp_result.into(),
134 condition: cols.is_valid - cols.is_setup,
135 };
136 self.subair.eval(builder, (eq_subair_io, cols.eq_marker));
137
138 let lt_marker_sum = cols
149 .lt_marker
150 .iter()
151 .fold(AB::Expr::ZERO, |acc, x| acc + *x);
152 let lt_marker_one_check_sum = cols
153 .lt_marker
154 .iter()
155 .fold(AB::Expr::ZERO, |acc, x| acc + (*x) * (*x - AB::F::ONE));
156
157 builder
159 .when(cols.is_valid - cols.is_setup)
160 .assert_bool(cols.c_lt_mark - AB::F::ONE);
161
162 builder
164 .when(cols.is_valid - cols.is_setup)
165 .when_ne(cols.c_lt_mark, AB::F::from_canonical_u8(2))
166 .assert_one(lt_marker_sum.clone());
167
168 builder
170 .when(cols.is_valid - cols.is_setup)
171 .when_ne(cols.c_lt_mark, AB::F::ONE)
172 .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(3));
173
174 builder.when_ne(cols.c_lt_mark, AB::F::ONE).assert_eq(
177 lt_marker_one_check_sum,
178 cols.is_valid * AB::F::from_canonical_u8(2),
179 );
180
181 builder
187 .when(cols.is_setup)
188 .assert_eq(cols.c_lt_mark, AB::F::from_canonical_u8(2));
189 builder
190 .when(cols.is_setup)
191 .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(2));
192
193 let modulus = self.modulus_limbs.map(AB::F::from_canonical_u32);
195 let mut prefix_sum = AB::Expr::ZERO;
196
197 for i in (0..READ_LIMBS).rev() {
198 prefix_sum += cols.lt_marker[i].into();
199 builder.assert_zero(
200 cols.lt_marker[i]
201 * (cols.lt_marker[i] - AB::F::ONE)
202 * (cols.lt_marker[i] - cols.c_lt_mark),
203 );
204
205 builder
213 .when_ne(prefix_sum.clone(), AB::F::ONE)
214 .when_ne(prefix_sum.clone(), lt_marker_sum.clone() - cols.is_setup)
215 .assert_eq(cols.b[i], modulus[i]);
216 builder
218 .when_ne(cols.lt_marker[i], AB::F::ZERO)
219 .when_ne(cols.lt_marker[i], AB::F::from_canonical_u8(2))
220 .assert_eq(AB::Expr::from(modulus[i]) - cols.b[i], cols.b_lt_diff);
221
222 builder
227 .when_ne(prefix_sum.clone(), cols.c_lt_mark)
228 .when_ne(prefix_sum.clone(), lt_marker_sum.clone())
229 .assert_eq(cols.c[i], modulus[i]);
230 builder
233 .when_ne(cols.lt_marker[i], AB::F::ZERO)
234 .when_ne(
235 cols.lt_marker[i],
236 AB::Expr::from_canonical_u8(3) - cols.c_lt_mark,
237 )
238 .assert_eq(AB::Expr::from(modulus[i]) - cols.c[i], cols.c_lt_diff);
239 }
240
241 self.bus
243 .send_range(
244 cols.b_lt_diff - AB::Expr::ONE,
245 cols.c_lt_diff - AB::Expr::ONE,
246 )
247 .eval(builder, cols.is_valid - cols.is_setup);
248
249 let expected_opcode = AB::Expr::from_canonical_usize(self.offset)
250 + cols.is_setup
251 * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize)
252 + (AB::Expr::ONE - cols.is_setup)
253 * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::IS_EQ as usize);
254 let mut a: [AB::Expr; WRITE_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
255 a[0] = cols.cmp_result.into();
256
257 AdapterAirContext {
258 to_pc: None,
259 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
260 writes: [a].into(),
261 instruction: MinimalInstruction {
262 is_valid: cols.is_valid.into(),
263 opcode: expected_opcode,
264 }
265 .into(),
266 }
267 }
268
269 fn start_offset(&self) -> usize {
270 self.offset
271 }
272}
273
274#[repr(C)]
275#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
276pub struct ModularIsEqualCoreRecord<T, const READ_LIMBS: usize> {
277 #[serde(with = "BigArray")]
278 pub b: [T; READ_LIMBS],
279 #[serde(with = "BigArray")]
280 pub c: [T; READ_LIMBS],
281 pub cmp_result: T,
282 #[serde(with = "BigArray")]
283 pub eq_marker: [T; READ_LIMBS],
284 pub b_diff_idx: usize,
285 pub c_diff_idx: usize,
286 pub is_setup: bool,
287}
288
289pub struct ModularIsEqualCoreChip<
290 const READ_LIMBS: usize,
291 const WRITE_LIMBS: usize,
292 const LIMB_BITS: usize,
293> {
294 pub air: ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>,
295 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
296}
297
298impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
299 ModularIsEqualCoreChip<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
300{
301 pub fn new(
302 modulus: BigUint,
303 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
304 offset: usize,
305 ) -> Self {
306 Self {
307 air: ModularIsEqualCoreAir::new(modulus, bitwise_lookup_chip.bus(), offset),
308 bitwise_lookup_chip,
309 }
310 }
311}
312
313impl<
314 F: PrimeField32,
315 I: VmAdapterInterface<F>,
316 const READ_LIMBS: usize,
317 const WRITE_LIMBS: usize,
318 const LIMB_BITS: usize,
319 > VmCoreChip<F, I> for ModularIsEqualCoreChip<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
320where
321 I::Reads: Into<[[F; READ_LIMBS]; 2]>,
322 I::Writes: From<[[F; WRITE_LIMBS]; 1]>,
323{
324 type Record = ModularIsEqualCoreRecord<F, READ_LIMBS>;
325 type Air = ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>;
326
327 #[allow(clippy::type_complexity)]
328 fn execute_instruction(
329 &self,
330 instruction: &Instruction<F>,
331 _from_pc: u32,
332 reads: I::Reads,
333 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
334 let data: [[F; READ_LIMBS]; 2] = reads.into();
335 let b = data[0].map(|x| x.as_canonical_u32());
336 let c = data[1].map(|y| y.as_canonical_u32());
337 let (b_cmp, b_diff_idx) = run_unsigned_less_than::<READ_LIMBS>(&b, &self.air.modulus_limbs);
338 let (c_cmp, c_diff_idx) = run_unsigned_less_than::<READ_LIMBS>(&c, &self.air.modulus_limbs);
339 let is_setup = instruction.opcode.local_opcode_idx(self.air.offset)
340 == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize;
341
342 if !is_setup {
343 assert!(b_cmp, "{:?} >= {:?}", b, self.air.modulus_limbs);
344 }
345 assert!(c_cmp, "{:?} >= {:?}", c, self.air.modulus_limbs);
346 if !is_setup {
347 self.bitwise_lookup_chip.request_range(
348 self.air.modulus_limbs[b_diff_idx] - b[b_diff_idx] - 1,
349 self.air.modulus_limbs[c_diff_idx] - c[c_diff_idx] - 1,
350 );
351 }
352
353 let mut eq_marker = [F::ZERO; READ_LIMBS];
354 let mut cmp_result = F::ZERO;
355 self.air
356 .subair
357 .generate_subrow((&data[0], &data[1]), (&mut eq_marker, &mut cmp_result));
358
359 let mut writes = [F::ZERO; WRITE_LIMBS];
360 writes[0] = cmp_result;
361
362 let output = AdapterRuntimeContext::without_pc([writes]);
363 let record = ModularIsEqualCoreRecord {
364 is_setup,
365 b: data[0],
366 c: data[1],
367 cmp_result,
368 eq_marker,
369 b_diff_idx,
370 c_diff_idx,
371 };
372
373 Ok((output, record))
374 }
375
376 fn get_opcode_name(&self, opcode: usize) -> String {
377 format!(
378 "{:?}",
379 Rv32ModularArithmeticOpcode::from_usize(opcode - self.air.offset)
380 )
381 }
382
383 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
384 let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut();
385 row_slice.is_valid = F::ONE;
386 row_slice.is_setup = F::from_bool(record.is_setup);
387 row_slice.b = record.b;
388 row_slice.c = record.c;
389 row_slice.cmp_result = record.cmp_result;
390
391 row_slice.eq_marker = record.eq_marker;
392
393 if !record.is_setup {
394 row_slice.b_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.b_diff_idx])
395 - record.b[record.b_diff_idx];
396 }
397 row_slice.c_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.c_diff_idx])
398 - record.c[record.c_diff_idx];
399 row_slice.c_lt_mark = if record.b_diff_idx == record.c_diff_idx {
400 F::ONE
401 } else {
402 F::from_canonical_u8(2)
403 };
404 row_slice.lt_marker = from_fn(|i| {
405 if i == record.b_diff_idx {
406 F::ONE
407 } else if i == record.c_diff_idx {
408 row_slice.c_lt_mark
409 } else {
410 F::ZERO
411 }
412 });
413 }
414
415 fn air(&self) -> &Self::Air {
416 &self.air
417 }
418}
419
420pub(super) fn run_unsigned_less_than<const NUM_LIMBS: usize>(
422 x: &[u32; NUM_LIMBS],
423 y: &[u32; NUM_LIMBS],
424) -> (bool, usize) {
425 for i in (0..NUM_LIMBS).rev() {
426 if x[i] != y[i] {
427 return (x[i] < y[i], i);
428 }
429 }
430 (false, NUM_LIMBS)
431}