1use std::{
2 array::{self, from_fn},
3 borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::arch::{
9 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
10 VmCoreAir, VmCoreChip,
11};
12use openvm_circuit_primitives::{
13 bigint::utils::big_uint_to_limbs,
14 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
15 is_equal_array::{IsEqArrayIo, IsEqArraySubAir},
16 SubAir, TraceSubRowGenerator,
17};
18use openvm_circuit_primitives_derive::AlignedBorrow;
19use openvm_instructions::{instruction::Instruction, LocalOpcode};
20use openvm_stark_backend::{
21 interaction::InteractionBuilder,
22 p3_air::{AirBuilder, BaseAir},
23 p3_field::{Field, FieldAlgebra, PrimeField32},
24 rap::BaseAirWithPublicValues,
25};
26use serde::{Deserialize, Serialize};
27use serde_big_array::BigArray;
28#[repr(C)]
33#[derive(AlignedBorrow)]
34pub struct ModularIsEqualCoreCols<T, const READ_LIMBS: usize> {
35 pub is_valid: T,
36 pub is_setup: T,
37 pub b: [T; READ_LIMBS],
38 pub c: [T; READ_LIMBS],
39 pub cmp_result: T,
40
41 pub eq_marker: [T; READ_LIMBS],
43
44 pub lt_marker: [T; READ_LIMBS],
56 pub b_lt_diff: T,
57 pub c_lt_diff: T,
58 pub c_lt_mark: T,
59}
60
61#[derive(Clone, Debug)]
62pub struct ModularIsEqualCoreAir<
63 const READ_LIMBS: usize,
64 const WRITE_LIMBS: usize,
65 const LIMB_BITS: usize,
66> {
67 pub bus: BitwiseOperationLookupBus,
68 pub subair: IsEqArraySubAir<READ_LIMBS>,
69 pub modulus_limbs: [u32; READ_LIMBS],
70 pub offset: usize,
71}
72
73impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
74 ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
75{
76 pub fn new(modulus: BigUint, bus: BitwiseOperationLookupBus, offset: usize) -> Self {
77 let mod_vec = big_uint_to_limbs(&modulus, LIMB_BITS);
78 assert!(mod_vec.len() <= READ_LIMBS);
79 let modulus_limbs = array::from_fn(|i| {
80 if i < mod_vec.len() {
81 mod_vec[i] as u32
82 } else {
83 0
84 }
85 });
86 Self {
87 bus,
88 subair: IsEqArraySubAir::<READ_LIMBS>,
89 modulus_limbs,
90 offset,
91 }
92 }
93}
94
95impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
96 for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
97{
98 fn width(&self) -> usize {
99 ModularIsEqualCoreCols::<F, READ_LIMBS>::width()
100 }
101}
102impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
103 BaseAirWithPublicValues<F> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
104{
105}
106
107impl<AB, I, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
108 VmCoreAir<AB, I> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
109where
110 AB: InteractionBuilder,
111 I: VmAdapterInterface<AB::Expr>,
112 I::Reads: From<[[AB::Expr; READ_LIMBS]; 2]>,
113 I::Writes: From<[[AB::Expr; WRITE_LIMBS]; 1]>,
114 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
115{
116 fn eval(
117 &self,
118 builder: &mut AB,
119 local_core: &[AB::Var],
120 _from_pc: AB::Var,
121 ) -> AdapterAirContext<AB::Expr, I> {
122 let cols: &ModularIsEqualCoreCols<_, READ_LIMBS> = local_core.borrow();
123
124 builder.assert_bool(cols.is_valid);
125 builder.assert_bool(cols.is_setup);
126 builder.when(cols.is_setup).assert_one(cols.is_valid);
127 builder.assert_bool(cols.cmp_result);
128
129 let eq_subair_io = IsEqArrayIo {
131 x: cols.b.map(Into::into),
132 y: cols.c.map(Into::into),
133 out: cols.cmp_result.into(),
134 condition: cols.is_valid - cols.is_setup,
135 };
136 self.subair.eval(builder, (eq_subair_io, cols.eq_marker));
137
138 let lt_marker_sum = cols
149 .lt_marker
150 .iter()
151 .fold(AB::Expr::ZERO, |acc, x| acc + *x);
152 let lt_marker_one_check_sum = cols
153 .lt_marker
154 .iter()
155 .fold(AB::Expr::ZERO, |acc, x| acc + (*x) * (*x - AB::F::ONE));
156
157 builder
159 .when(cols.is_valid - cols.is_setup)
160 .assert_bool(cols.c_lt_mark - AB::F::ONE);
161
162 builder
164 .when(cols.is_valid - cols.is_setup)
165 .when_ne(cols.c_lt_mark, AB::F::from_canonical_u8(2))
166 .assert_one(lt_marker_sum.clone());
167
168 builder
170 .when(cols.is_valid - cols.is_setup)
171 .when_ne(cols.c_lt_mark, AB::F::ONE)
172 .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(3));
173
174 builder.when_ne(cols.c_lt_mark, AB::F::ONE).assert_eq(
177 lt_marker_one_check_sum,
178 cols.is_valid * AB::F::from_canonical_u8(2),
179 );
180
181 builder
187 .when(cols.is_setup)
188 .assert_eq(cols.c_lt_mark, AB::F::from_canonical_u8(2));
189 builder
190 .when(cols.is_setup)
191 .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(2));
192
193 let modulus = self.modulus_limbs.map(AB::F::from_canonical_u32);
195 let mut prefix_sum = AB::Expr::ZERO;
196
197 for i in (0..READ_LIMBS).rev() {
198 prefix_sum += cols.lt_marker[i].into();
199 builder.assert_zero(
200 cols.lt_marker[i]
201 * (cols.lt_marker[i] - AB::F::ONE)
202 * (cols.lt_marker[i] - cols.c_lt_mark),
203 );
204
205 builder
216 .when_ne(prefix_sum.clone(), AB::F::ONE)
217 .when_ne(prefix_sum.clone(), lt_marker_sum.clone() - cols.is_setup)
218 .assert_eq(cols.b[i], modulus[i]);
219 builder
222 .when_ne(cols.lt_marker[i], AB::F::ZERO)
223 .when_ne(cols.lt_marker[i], AB::F::from_canonical_u8(2))
224 .assert_eq(AB::Expr::from(modulus[i]) - cols.b[i], cols.b_lt_diff);
225
226 builder
232 .when_ne(prefix_sum.clone(), cols.c_lt_mark)
233 .when_ne(prefix_sum.clone(), lt_marker_sum.clone())
234 .assert_eq(cols.c[i], modulus[i]);
235 builder
239 .when_ne(cols.lt_marker[i], AB::F::ZERO)
240 .when_ne(
241 cols.lt_marker[i],
242 AB::Expr::from_canonical_u8(3) - cols.c_lt_mark,
243 )
244 .assert_eq(AB::Expr::from(modulus[i]) - cols.c[i], cols.c_lt_diff);
245 }
246
247 self.bus
249 .send_range(
250 cols.b_lt_diff - AB::Expr::ONE,
251 cols.c_lt_diff - AB::Expr::ONE,
252 )
253 .eval(builder, cols.is_valid - cols.is_setup);
254
255 let expected_opcode = AB::Expr::from_canonical_usize(self.offset)
256 + cols.is_setup
257 * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize)
258 + (AB::Expr::ONE - cols.is_setup)
259 * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::IS_EQ as usize);
260 let mut a: [AB::Expr; WRITE_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
261 a[0] = cols.cmp_result.into();
262
263 AdapterAirContext {
264 to_pc: None,
265 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
266 writes: [a].into(),
267 instruction: MinimalInstruction {
268 is_valid: cols.is_valid.into(),
269 opcode: expected_opcode,
270 }
271 .into(),
272 }
273 }
274
275 fn start_offset(&self) -> usize {
276 self.offset
277 }
278}
279
280#[repr(C)]
281#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
282pub struct ModularIsEqualCoreRecord<T, const READ_LIMBS: usize> {
283 #[serde(with = "BigArray")]
284 pub b: [T; READ_LIMBS],
285 #[serde(with = "BigArray")]
286 pub c: [T; READ_LIMBS],
287 pub cmp_result: T,
288 #[serde(with = "BigArray")]
289 pub eq_marker: [T; READ_LIMBS],
290 pub b_diff_idx: usize,
291 pub c_diff_idx: usize,
292 pub is_setup: bool,
293}
294
295pub struct ModularIsEqualCoreChip<
296 const READ_LIMBS: usize,
297 const WRITE_LIMBS: usize,
298 const LIMB_BITS: usize,
299> {
300 pub air: ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>,
301 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
302}
303
304impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
305 ModularIsEqualCoreChip<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
306{
307 pub fn new(
308 modulus: BigUint,
309 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
310 offset: usize,
311 ) -> Self {
312 Self {
313 air: ModularIsEqualCoreAir::new(modulus, bitwise_lookup_chip.bus(), offset),
314 bitwise_lookup_chip,
315 }
316 }
317}
318
319impl<
320 F: PrimeField32,
321 I: VmAdapterInterface<F>,
322 const READ_LIMBS: usize,
323 const WRITE_LIMBS: usize,
324 const LIMB_BITS: usize,
325 > VmCoreChip<F, I> for ModularIsEqualCoreChip<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
326where
327 I::Reads: Into<[[F; READ_LIMBS]; 2]>,
328 I::Writes: From<[[F; WRITE_LIMBS]; 1]>,
329{
330 type Record = ModularIsEqualCoreRecord<F, READ_LIMBS>;
331 type Air = ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>;
332
333 #[allow(clippy::type_complexity)]
334 fn execute_instruction(
335 &self,
336 instruction: &Instruction<F>,
337 _from_pc: u32,
338 reads: I::Reads,
339 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
340 let data: [[F; READ_LIMBS]; 2] = reads.into();
341 let b = data[0].map(|x| x.as_canonical_u32());
342 let c = data[1].map(|y| y.as_canonical_u32());
343 let (b_cmp, b_diff_idx) = run_unsigned_less_than::<READ_LIMBS>(&b, &self.air.modulus_limbs);
344 let (c_cmp, c_diff_idx) = run_unsigned_less_than::<READ_LIMBS>(&c, &self.air.modulus_limbs);
345 let is_setup = instruction.opcode.local_opcode_idx(self.air.offset)
346 == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize;
347
348 if !is_setup {
349 assert!(b_cmp, "{:?} >= {:?}", b, self.air.modulus_limbs);
350 }
351 assert!(c_cmp, "{:?} >= {:?}", c, self.air.modulus_limbs);
352 if !is_setup {
353 self.bitwise_lookup_chip.request_range(
354 self.air.modulus_limbs[b_diff_idx] - b[b_diff_idx] - 1,
355 self.air.modulus_limbs[c_diff_idx] - c[c_diff_idx] - 1,
356 );
357 }
358
359 let mut eq_marker = [F::ZERO; READ_LIMBS];
360 let mut cmp_result = F::ZERO;
361 self.air
362 .subair
363 .generate_subrow((&data[0], &data[1]), (&mut eq_marker, &mut cmp_result));
364
365 let mut writes = [F::ZERO; WRITE_LIMBS];
366 writes[0] = cmp_result;
367
368 let output = AdapterRuntimeContext::without_pc([writes]);
369 let record = ModularIsEqualCoreRecord {
370 is_setup,
371 b: data[0],
372 c: data[1],
373 cmp_result,
374 eq_marker,
375 b_diff_idx,
376 c_diff_idx,
377 };
378
379 Ok((output, record))
380 }
381
382 fn get_opcode_name(&self, opcode: usize) -> String {
383 format!(
384 "{:?}",
385 Rv32ModularArithmeticOpcode::from_usize(opcode - self.air.offset)
386 )
387 }
388
389 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
390 let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut();
391 row_slice.is_valid = F::ONE;
392 row_slice.is_setup = F::from_bool(record.is_setup);
393 row_slice.b = record.b;
394 row_slice.c = record.c;
395 row_slice.cmp_result = record.cmp_result;
396
397 row_slice.eq_marker = record.eq_marker;
398
399 if !record.is_setup {
400 row_slice.b_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.b_diff_idx])
401 - record.b[record.b_diff_idx];
402 }
403 row_slice.c_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.c_diff_idx])
404 - record.c[record.c_diff_idx];
405 row_slice.c_lt_mark = if record.b_diff_idx == record.c_diff_idx {
406 F::ONE
407 } else {
408 F::from_canonical_u8(2)
409 };
410 row_slice.lt_marker = from_fn(|i| {
411 if i == record.b_diff_idx {
412 F::ONE
413 } else if i == record.c_diff_idx {
414 row_slice.c_lt_mark
415 } else {
416 F::ZERO
417 }
418 });
419 }
420
421 fn air(&self) -> &Self::Air {
422 &self.air
423 }
424}
425
426pub(super) fn run_unsigned_less_than<const NUM_LIMBS: usize>(
428 x: &[u32; NUM_LIMBS],
429 y: &[u32; NUM_LIMBS],
430) -> (bool, usize) {
431 for i in (0..NUM_LIMBS).rev() {
432 if x[i] != y[i] {
433 return (x[i] < y[i], i);
434 }
435 }
436 (false, NUM_LIMBS)
437}