openvm_rv32im_circuit/divrem/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use num_integer::Integer;
8use openvm_circuit::{
9    arch::*,
10    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
11};
12use openvm_circuit_primitives::{
13    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
14    range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip},
15    utils::{not, select},
16    AlignedBytesBorrow,
17};
18use openvm_circuit_primitives_derive::AlignedBorrow;
19use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
20use openvm_rv32im_transpiler::DivRemOpcode;
21use openvm_stark_backend::{
22    interaction::InteractionBuilder,
23    p3_air::{AirBuilder, BaseAir},
24    p3_field::{Field, FieldAlgebra, PrimeField32},
25    rap::BaseAirWithPublicValues,
26};
27use strum::IntoEnumIterator;
28
29#[repr(C)]
30#[derive(AlignedBorrow)]
31pub struct DivRemCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
32    // b = c * q + r for some 0 <= |r| < |c| and sign(r) = sign(b) or r = 0.
33    pub b: [T; NUM_LIMBS],
34    pub c: [T; NUM_LIMBS],
35    pub q: [T; NUM_LIMBS],
36    pub r: [T; NUM_LIMBS],
37
38    // Flags to indicate special cases.
39    pub zero_divisor: T,
40    pub r_zero: T,
41
42    // Sign of b and c respectively, while q_sign = b_sign ^ c_sign if q is non-zero
43    // and is 0 otherwise. sign_xor = b_sign ^ c_sign always.
44    pub b_sign: T,
45    pub c_sign: T,
46    pub q_sign: T,
47    pub sign_xor: T,
48
49    // Auxiliary columns to constrain that zero_divisor = 1 if and only if c = 0.
50    pub c_sum_inv: T,
51    // Auxiliary columns to constrain that r_zero = 1 if and only if r = 0 and zero_divisor = 0.
52    pub r_sum_inv: T,
53
54    // Auxiliary columns to constrain that 0 <= |r| < |c|. When sign_xor == 1 we have
55    // r_prime = -r, and when sign_xor == 0 we have r_prime = r. Each r_inv[i] is the
56    // field inverse of r_prime[i] - 2^LIMB_BITS, ensures each r_prime[i] is in range.
57    pub r_prime: [T; NUM_LIMBS],
58    pub r_inv: [T; NUM_LIMBS],
59    pub lt_marker: [T; NUM_LIMBS],
60    pub lt_diff: T,
61
62    // Opcode flags
63    pub opcode_div_flag: T,
64    pub opcode_divu_flag: T,
65    pub opcode_rem_flag: T,
66    pub opcode_remu_flag: T,
67}
68
69#[derive(Copy, Clone, Debug, derive_new::new)]
70pub struct DivRemCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
71    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
72    pub range_tuple_bus: RangeTupleCheckerBus<2>,
73    offset: usize,
74}
75
76impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
77    for DivRemCoreAir<NUM_LIMBS, LIMB_BITS>
78{
79    fn width(&self) -> usize {
80        DivRemCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
81    }
82}
83impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
84    for DivRemCoreAir<NUM_LIMBS, LIMB_BITS>
85{
86}
87
88impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
89    for DivRemCoreAir<NUM_LIMBS, LIMB_BITS>
90where
91    AB: InteractionBuilder,
92    I: VmAdapterInterface<AB::Expr>,
93    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
94    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
95    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
96{
97    fn eval(
98        &self,
99        builder: &mut AB,
100        local_core: &[AB::Var],
101        _from_pc: AB::Var,
102    ) -> AdapterAirContext<AB::Expr, I> {
103        let cols: &DivRemCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
104        let flags = [
105            cols.opcode_div_flag,
106            cols.opcode_divu_flag,
107            cols.opcode_rem_flag,
108            cols.opcode_remu_flag,
109        ];
110
111        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
112            builder.assert_bool(flag);
113            acc + flag.into()
114        });
115        builder.assert_bool(is_valid.clone());
116
117        let b = &cols.b;
118        let c = &cols.c;
119        let q = &cols.q;
120        let r = &cols.r;
121
122        // Constrain that b = (c * q + r) % 2^{NUM_LIMBS * LIMB_BITS} and range checkeach element in
123        // q.
124        let b_ext = cols.b_sign * AB::F::from_canonical_u32((1 << LIMB_BITS) - 1);
125        let c_ext = cols.c_sign * AB::F::from_canonical_u32((1 << LIMB_BITS) - 1);
126        let carry_divide = AB::F::from_canonical_u32(1 << LIMB_BITS).inverse();
127        let mut carry: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
128
129        for i in 0..NUM_LIMBS {
130            let expected_limb = if i == 0 {
131                AB::Expr::ZERO
132            } else {
133                carry[i - 1].clone()
134            } + (0..=i).fold(r[i].into(), |ac, k| ac + (c[k] * q[i - k]));
135            carry[i] = (expected_limb - b[i]) * carry_divide;
136        }
137
138        for (q, carry) in q.iter().zip(carry.iter()) {
139            self.range_tuple_bus
140                .send(vec![(*q).into(), carry.clone()])
141                .eval(builder, is_valid.clone());
142        }
143
144        // Constrain that the upper limbs of b = c * q + r are all equal to b_ext and
145        // range check each element in r.
146        let q_ext = cols.q_sign * AB::F::from_canonical_u32((1 << LIMB_BITS) - 1);
147        let mut carry_ext: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
148
149        for j in 0..NUM_LIMBS {
150            let expected_limb = if j == 0 {
151                carry[NUM_LIMBS - 1].clone()
152            } else {
153                carry_ext[j - 1].clone()
154            } + ((j + 1)..NUM_LIMBS)
155                .fold(AB::Expr::ZERO, |acc, k| acc + (c[k] * q[NUM_LIMBS + j - k]))
156                + (0..(j + 1)).fold(AB::Expr::ZERO, |acc, k| {
157                    acc + (c[k] * q_ext.clone()) + (q[k] * c_ext.clone())
158                })
159                + (AB::Expr::ONE - cols.r_zero) * b_ext.clone();
160            // Technically there are ways to constrain that c * q is in range without
161            // using a range checker, but because we already have to range check each
162            // limb of r it requires no additional columns to also range check each
163            // carry_ext.
164            //
165            // Note that the sign of r is not equal to the sign of b only when r = 0.
166            // Flag column r_zero tracks this special case.
167            carry_ext[j] = (expected_limb - b_ext.clone()) * carry_divide;
168        }
169
170        for (r, carry) in r.iter().zip(carry_ext.iter()) {
171            self.range_tuple_bus
172                .send(vec![(*r).into(), carry.clone()])
173                .eval(builder, is_valid.clone());
174        }
175
176        // Handle special cases. We can have either at most one of a zero divisor,
177        // or a 0 remainder. Signed overflow falls under the latter.
178        let special_case = cols.zero_divisor + cols.r_zero;
179        builder.assert_bool(special_case.clone());
180
181        // Constrain that zero_divisor = 1 if and only if c = 0.
182        builder.assert_bool(cols.zero_divisor);
183        let mut when_zero_divisor = builder.when(cols.zero_divisor);
184        for i in 0..NUM_LIMBS {
185            when_zero_divisor.assert_zero(c[i]);
186            when_zero_divisor.assert_eq(q[i], AB::F::from_canonical_u32((1 << LIMB_BITS) - 1));
187        }
188        // c_sum is guaranteed to be non-zero if c is non-zero since we assume
189        // each limb of c to be within [0, 2^LIMB_BITS) already.
190        // To constrain that if c = 0 then zero_divisor = 1, we check that if zero_divisor = 0
191        // and is_valid = 1 then c_sum is non-zero using c_sum_inv.
192        let c_sum = c.iter().fold(AB::Expr::ZERO, |acc, c| acc + *c);
193        let valid_and_not_zero_divisor = is_valid.clone() - cols.zero_divisor;
194        builder.assert_bool(valid_and_not_zero_divisor.clone());
195        builder
196            .when(valid_and_not_zero_divisor)
197            .assert_one(c_sum * cols.c_sum_inv);
198
199        // Constrain that r_zero = 1 if and only if r = 0 and zero_divisor = 0.
200        builder.assert_bool(cols.r_zero);
201        r.iter()
202            .for_each(|r_i| builder.when(cols.r_zero).assert_zero(*r_i));
203        // To constrain that if r = 0 and zero_divisor = 0 then r_zero = 1, we check that
204        // if special_case = 0 and is_valid = 1 then r_sum is non-zero (using r_sum_inv).
205        let r_sum = r.iter().fold(AB::Expr::ZERO, |acc, r| acc + *r);
206        let valid_and_not_special_case = is_valid.clone() - special_case.clone();
207        builder.assert_bool(valid_and_not_special_case.clone());
208        builder
209            .when(valid_and_not_special_case)
210            .assert_one(r_sum * cols.r_sum_inv);
211
212        // Constrain the correctness of b_sign and c_sign. Note that we do not need to
213        // check that the sign of r is b_sign since we cannot have r_prime < c (or c < r_prime
214        // if c is negative) if this is not the case.
215        let signed = cols.opcode_div_flag + cols.opcode_rem_flag;
216
217        builder.assert_bool(cols.b_sign);
218        builder.assert_bool(cols.c_sign);
219        builder
220            .when(not::<AB::Expr>(signed.clone()))
221            .assert_zero(cols.b_sign);
222        builder
223            .when(not::<AB::Expr>(signed.clone()))
224            .assert_zero(cols.c_sign);
225        builder.assert_eq(
226            cols.b_sign + cols.c_sign - AB::Expr::from_canonical_u32(2) * cols.b_sign * cols.c_sign,
227            cols.sign_xor,
228        );
229
230        // To constrain the correctness of q_sign we make sure if q is non-zero then
231        // q_sign = b_sign ^ c_sign, and if q is zero then q_sign = 0.
232        // Note:
233        // - q_sum is guaranteed to be non-zero if q is non-zero since we've range checked each
234        // limb of q to be within [0, 2^LIMB_BITS) already.
235        // - If q is zero and q_ext satisfies the constraint
236        // sign_extend(b) = sign_extend(c) * sign_extend(q) + sign_extend(r), then q_sign must be 0.
237        // Thus, we do not need additional constraints in case q is zero.
238        let nonzero_q = q.iter().fold(AB::Expr::ZERO, |acc, q| acc + *q);
239        builder.assert_bool(cols.q_sign);
240        builder
241            .when(nonzero_q)
242            .when(not(cols.zero_divisor))
243            .assert_eq(cols.q_sign, cols.sign_xor);
244        builder
245            .when_ne(cols.q_sign, cols.sign_xor)
246            .when(not(cols.zero_divisor))
247            .assert_zero(cols.q_sign);
248
249        // Check that the signs of b and c are correct.
250        let sign_mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
251        self.bitwise_lookup_bus
252            .send_range(
253                AB::Expr::from_canonical_u32(2) * (b[NUM_LIMBS - 1] - cols.b_sign * sign_mask),
254                AB::Expr::from_canonical_u32(2) * (c[NUM_LIMBS - 1] - cols.c_sign * sign_mask),
255            )
256            .eval(builder, signed.clone());
257
258        // Constrain that 0 <= |r| < |c| by checking that r_prime < c (unsigned LT). By
259        // definition, the sign of r must be b_sign. If c is negative then we want
260        // to constrain c < r_prime. If c is positive, then we want to constrain r_prime < c.
261        //
262        // Because we already constrain that r and q are correct for special cases,
263        // we skip the range check when special_case = 1.
264        let r_p = &cols.r_prime;
265        let mut carry_lt: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
266
267        for i in 0..NUM_LIMBS {
268            // When the signs of r (i.e. b) and c are the same, r_prime = r.
269            builder.when(not(cols.sign_xor)).assert_eq(r[i], r_p[i]);
270
271            // When the signs of r and c are different, r_prime = -r. To constrain this, we
272            // first ensure each r[i] + r_prime[i] + carry[i - 1] is in {0, 2^LIMB_BITS}, and
273            // that when the sum is 0 then r_prime[i] = 0 as well. Passing both constraints
274            // implies that 0 <= r_prime[i] <= 2^LIMB_BITS, and in order to ensure r_prime[i] !=
275            // 2^LIMB_BITS we check that r_prime[i] - 2^LIMB_BITS has an inverse in F.
276            let last_carry = if i > 0 {
277                carry_lt[i - 1].clone()
278            } else {
279                AB::Expr::ZERO
280            };
281            carry_lt[i] = (last_carry.clone() + r[i] + r_p[i]) * carry_divide;
282            builder.when(cols.sign_xor).assert_zero(
283                (carry_lt[i].clone() - last_carry) * (carry_lt[i].clone() - AB::Expr::ONE),
284            );
285            builder
286                .when(cols.sign_xor)
287                .assert_one((r_p[i] - AB::F::from_canonical_u32(1 << LIMB_BITS)) * cols.r_inv[i]);
288            builder
289                .when(cols.sign_xor)
290                .when(not::<AB::Expr>(carry_lt[i].clone()))
291                .assert_zero(r_p[i]);
292        }
293
294        let marker = &cols.lt_marker;
295        let mut prefix_sum = special_case.clone();
296
297        for i in (0..NUM_LIMBS).rev() {
298            let diff = r_p[i] * (AB::Expr::from_canonical_u8(2) * cols.c_sign - AB::Expr::ONE)
299                + c[i] * (AB::Expr::ONE - AB::Expr::from_canonical_u8(2) * cols.c_sign);
300            prefix_sum += marker[i].into();
301            builder.assert_bool(marker[i]);
302            builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
303            builder.when(marker[i]).assert_eq(cols.lt_diff, diff);
304        }
305        // - If r_prime != c, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index
306        //   where diff != 0. Constrains that diff == lt_diff where lt_diff is non-zero.
307        // - If r_prime == c, then prefix_sum = 0. Here, prefix_sum cannot be 1 because all diff are
308        //   zero, making diff == lt_diff fails.
309
310        builder.when(is_valid.clone()).assert_one(prefix_sum);
311        // Range check to ensure lt_diff is non-zero.
312        self.bitwise_lookup_bus
313            .send_range(cols.lt_diff - AB::Expr::ONE, AB::F::ZERO)
314            .eval(builder, is_valid.clone() - special_case);
315
316        // Generate expected opcode and output a to pass to the adapter.
317        let expected_opcode = flags.iter().zip(DivRemOpcode::iter()).fold(
318            AB::Expr::ZERO,
319            |acc, (flag, local_opcode)| {
320                acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
321            },
322        ) + AB::Expr::from_canonical_usize(self.offset);
323
324        let is_div = cols.opcode_div_flag + cols.opcode_divu_flag;
325        let a = array::from_fn(|i| select(is_div.clone(), q[i], r[i]));
326
327        AdapterAirContext {
328            to_pc: None,
329            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
330            writes: [a.map(Into::into)].into(),
331            instruction: MinimalInstruction {
332                is_valid,
333                opcode: expected_opcode,
334            }
335            .into(),
336        }
337    }
338
339    fn start_offset(&self) -> usize {
340        self.offset
341    }
342}
343
344#[derive(Debug, Eq, PartialEq)]
345#[repr(u8)]
346pub(super) enum DivRemCoreSpecialCase {
347    None,
348    ZeroDivisor,
349    SignedOverflow,
350}
351
352#[repr(C)]
353#[derive(AlignedBytesBorrow, Debug)]
354pub struct DivRemCoreRecord<const NUM_LIMBS: usize> {
355    pub b: [u8; NUM_LIMBS],
356    pub c: [u8; NUM_LIMBS],
357    pub local_opcode: u8,
358}
359
360#[derive(Clone, Copy, derive_new::new)]
361pub struct DivRemExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
362    adapter: A,
363    pub offset: usize,
364}
365
366pub struct DivRemFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
367    adapter: A,
368    pub offset: usize,
369    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
370    pub range_tuple_chip: SharedRangeTupleCheckerChip<2>,
371}
372
373impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> DivRemFiller<A, NUM_LIMBS, LIMB_BITS> {
374    pub fn new(
375        adapter: A,
376        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
377        range_tuple_chip: SharedRangeTupleCheckerChip<2>,
378        offset: usize,
379    ) -> Self {
380        // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i
381        // < 2 * NUM_LIMBS. a[i] must have LIMB_BITS bits and carry[i] is the sum of i + 1
382        // bytes (with LIMB_BITS bits). BitwiseOperationLookup is used to sign check bytes.
383        debug_assert!(
384            range_tuple_chip.sizes()[0] == 1 << LIMB_BITS,
385            "First element of RangeTupleChecker must have size {}",
386            1 << LIMB_BITS
387        );
388        debug_assert!(
389            range_tuple_chip.sizes()[1] >= (1 << LIMB_BITS) * 2 * NUM_LIMBS as u32,
390            "Second element of RangeTupleChecker must have size of at least {}",
391            (1 << LIMB_BITS) * 2 * NUM_LIMBS as u32
392        );
393
394        Self {
395            adapter,
396            offset,
397            bitwise_lookup_chip,
398            range_tuple_chip,
399        }
400    }
401}
402
403impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
404    for DivRemExecutor<A, NUM_LIMBS, LIMB_BITS>
405where
406    F: PrimeField32,
407    A: 'static
408        + AdapterTraceExecutor<
409            F,
410            ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
411            WriteData: From<[[u8; NUM_LIMBS]; 1]>,
412        >,
413    for<'buf> RA: RecordArena<
414        'buf,
415        EmptyAdapterCoreLayout<F, A>,
416        (A::RecordMut<'buf>, &'buf mut DivRemCoreRecord<NUM_LIMBS>),
417    >,
418{
419    fn get_opcode_name(&self, opcode: usize) -> String {
420        format!("{:?}", DivRemOpcode::from_usize(opcode - self.offset))
421    }
422
423    fn execute(
424        &self,
425        state: VmStateMut<F, TracingMemory, RA>,
426        instruction: &Instruction<F>,
427    ) -> Result<(), ExecutionError> {
428        let Instruction { opcode, .. } = instruction;
429
430        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
431
432        A::start(*state.pc, state.memory, &mut adapter_record);
433
434        core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8;
435
436        let is_signed = core_record.local_opcode == DivRemOpcode::DIV as u8
437            || core_record.local_opcode == DivRemOpcode::REM as u8;
438        let is_div = core_record.local_opcode == DivRemOpcode::DIV as u8
439            || core_record.local_opcode == DivRemOpcode::DIVU as u8;
440
441        [core_record.b, core_record.c] = self
442            .adapter
443            .read(state.memory, instruction, &mut adapter_record)
444            .into();
445
446        let b = core_record.b.map(u32::from);
447        let c = core_record.c.map(u32::from);
448        let (q, r, _, _, _, _) = run_divrem::<NUM_LIMBS, LIMB_BITS>(is_signed, &b, &c);
449
450        let rd = if is_div {
451            q.map(|x| x as u8)
452        } else {
453            r.map(|x| x as u8)
454        };
455
456        self.adapter
457            .write(state.memory, instruction, [rd].into(), &mut adapter_record);
458
459        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
460
461        Ok(())
462    }
463}
464
465impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
466    for DivRemFiller<A, NUM_LIMBS, LIMB_BITS>
467where
468    F: PrimeField32,
469    A: 'static + AdapterTraceFiller<F>,
470{
471    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
472        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
473        // DivRemCoreCols::width() elements
474        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
475        self.adapter.fill_trace_row(mem_helper, adapter_row);
476        // SAFETY: core_row contains a valid DivRemCoreRecord written by the executor
477        // during trace generation
478        let record: &DivRemCoreRecord<NUM_LIMBS> =
479            unsafe { get_record_from_slice(&mut core_row, ()) };
480        let core_row: &mut DivRemCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
481
482        let opcode = DivRemOpcode::from_usize(record.local_opcode as usize);
483        let is_signed = opcode == DivRemOpcode::DIV || opcode == DivRemOpcode::REM;
484
485        let (q, r, b_sign, c_sign, q_sign, case) = run_divrem::<NUM_LIMBS, LIMB_BITS>(
486            is_signed,
487            &record.b.map(u32::from),
488            &record.c.map(u32::from),
489        );
490
491        let carries = run_mul_carries::<NUM_LIMBS, LIMB_BITS>(
492            is_signed,
493            &record.c.map(u32::from),
494            &q,
495            &r,
496            q_sign,
497        );
498        for i in 0..NUM_LIMBS {
499            self.range_tuple_chip.add_count(&[q[i], carries[i]]);
500            self.range_tuple_chip
501                .add_count(&[r[i], carries[i + NUM_LIMBS]]);
502        }
503
504        let sign_xor = b_sign ^ c_sign;
505        let r_prime = if sign_xor {
506            negate::<NUM_LIMBS, LIMB_BITS>(&r)
507        } else {
508            r
509        };
510        let r_zero = r.iter().all(|&v| v == 0) && case != DivRemCoreSpecialCase::ZeroDivisor;
511
512        if is_signed {
513            let b_sign_mask = if b_sign { 1 << (LIMB_BITS - 1) } else { 0 };
514            let c_sign_mask = if c_sign { 1 << (LIMB_BITS - 1) } else { 0 };
515            self.bitwise_lookup_chip.request_range(
516                (record.b[NUM_LIMBS - 1] as u32 - b_sign_mask) << 1,
517                (record.c[NUM_LIMBS - 1] as u32 - c_sign_mask) << 1,
518            );
519        }
520
521        // Write in a reverse order
522        core_row.opcode_remu_flag = F::from_bool(opcode == DivRemOpcode::REMU);
523        core_row.opcode_rem_flag = F::from_bool(opcode == DivRemOpcode::REM);
524        core_row.opcode_divu_flag = F::from_bool(opcode == DivRemOpcode::DIVU);
525        core_row.opcode_div_flag = F::from_bool(opcode == DivRemOpcode::DIV);
526
527        core_row.lt_diff = F::ZERO;
528        core_row.lt_marker = [F::ZERO; NUM_LIMBS];
529        if case == DivRemCoreSpecialCase::None && !r_zero {
530            let idx = run_sltu_diff_idx(&record.c.map(u32::from), &r_prime, c_sign);
531            let val = if c_sign {
532                r_prime[idx] - record.c[idx] as u32
533            } else {
534                record.c[idx] as u32 - r_prime[idx]
535            };
536            self.bitwise_lookup_chip.request_range(val - 1, 0);
537            core_row.lt_diff = F::from_canonical_u32(val);
538            core_row.lt_marker[idx] = F::ONE;
539        }
540
541        let r_prime_f = r_prime.map(F::from_canonical_u32);
542        core_row.r_inv = r_prime_f.map(|r| (r - F::from_canonical_u32(256)).inverse());
543        core_row.r_prime = r_prime_f;
544
545        let r_sum_f = r
546            .iter()
547            .fold(F::ZERO, |acc, r| acc + F::from_canonical_u32(*r));
548        core_row.r_sum_inv = r_sum_f.try_inverse().unwrap_or(F::ZERO);
549
550        let c_sum_f = F::from_canonical_u32(record.c.iter().fold(0, |acc, c| acc + *c as u32));
551        core_row.c_sum_inv = c_sum_f.try_inverse().unwrap_or(F::ZERO);
552
553        core_row.sign_xor = F::from_bool(sign_xor);
554        core_row.q_sign = F::from_bool(q_sign);
555        core_row.c_sign = F::from_bool(c_sign);
556        core_row.b_sign = F::from_bool(b_sign);
557
558        core_row.r_zero = F::from_bool(r_zero);
559        core_row.zero_divisor = F::from_bool(case == DivRemCoreSpecialCase::ZeroDivisor);
560
561        core_row.r = r.map(F::from_canonical_u32);
562        core_row.q = q.map(F::from_canonical_u32);
563        core_row.c = record.c.map(F::from_canonical_u8);
564        core_row.b = record.b.map(F::from_canonical_u8);
565    }
566}
567
568// Returns (quotient, remainder, x_sign, y_sign, q_sign, case) where case = 0 for normal, 1
569// for zero divisor, and 2 for signed overflow
570#[inline(always)]
571pub(super) fn run_divrem<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
572    signed: bool,
573    x: &[u32; NUM_LIMBS],
574    y: &[u32; NUM_LIMBS],
575) -> (
576    [u32; NUM_LIMBS],
577    [u32; NUM_LIMBS],
578    bool,
579    bool,
580    bool,
581    DivRemCoreSpecialCase,
582) {
583    let x_sign = signed && (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1);
584    let y_sign = signed && (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1);
585    let max_limb = (1 << LIMB_BITS) - 1;
586
587    let zero_divisor = y.iter().all(|val| *val == 0);
588    let overflow = x[NUM_LIMBS - 1] == 1 << (LIMB_BITS - 1)
589        && x[..(NUM_LIMBS - 1)].iter().all(|val| *val == 0)
590        && y.iter().all(|val| *val == max_limb)
591        && x_sign
592        && y_sign;
593
594    if zero_divisor {
595        return (
596            [max_limb; NUM_LIMBS],
597            *x,
598            x_sign,
599            y_sign,
600            signed,
601            DivRemCoreSpecialCase::ZeroDivisor,
602        );
603    } else if overflow {
604        return (
605            *x,
606            [0; NUM_LIMBS],
607            x_sign,
608            y_sign,
609            false,
610            DivRemCoreSpecialCase::SignedOverflow,
611        );
612    }
613
614    let x_abs = if x_sign {
615        negate::<NUM_LIMBS, LIMB_BITS>(x)
616    } else {
617        *x
618    };
619    let y_abs = if y_sign {
620        negate::<NUM_LIMBS, LIMB_BITS>(y)
621    } else {
622        *y
623    };
624
625    let x_big = limbs_to_biguint::<NUM_LIMBS, LIMB_BITS>(&x_abs);
626    let y_big = limbs_to_biguint::<NUM_LIMBS, LIMB_BITS>(&y_abs);
627    let q_big = x_big.clone() / y_big.clone();
628    let r_big = x_big.clone() % y_big.clone();
629
630    let q = if x_sign ^ y_sign {
631        negate::<NUM_LIMBS, LIMB_BITS>(&biguint_to_limbs::<NUM_LIMBS, LIMB_BITS>(&q_big))
632    } else {
633        biguint_to_limbs::<NUM_LIMBS, LIMB_BITS>(&q_big)
634    };
635    let q_sign = signed && (q[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1);
636
637    // In C |q * y| <= |x|, which means if x is negative then r <= 0 and vice versa.
638    let r = if x_sign {
639        negate::<NUM_LIMBS, LIMB_BITS>(&biguint_to_limbs::<NUM_LIMBS, LIMB_BITS>(&r_big))
640    } else {
641        biguint_to_limbs::<NUM_LIMBS, LIMB_BITS>(&r_big)
642    };
643
644    (q, r, x_sign, y_sign, q_sign, DivRemCoreSpecialCase::None)
645}
646
647#[inline(always)]
648pub(super) fn run_sltu_diff_idx<const NUM_LIMBS: usize>(
649    x: &[u32; NUM_LIMBS],
650    y: &[u32; NUM_LIMBS],
651    cmp: bool,
652) -> usize {
653    for i in (0..NUM_LIMBS).rev() {
654        if x[i] != y[i] {
655            assert!((x[i] < y[i]) == cmp);
656            return i;
657        }
658    }
659    assert!(!cmp);
660    NUM_LIMBS
661}
662
663// returns carries of d * q + r
664#[inline(always)]
665pub(super) fn run_mul_carries<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
666    signed: bool,
667    d: &[u32; NUM_LIMBS],
668    q: &[u32; NUM_LIMBS],
669    r: &[u32; NUM_LIMBS],
670    q_sign: bool,
671) -> Vec<u32> {
672    let mut carry = vec![0u32; 2 * NUM_LIMBS];
673    for i in 0..NUM_LIMBS {
674        let mut val = r[i] + if i > 0 { carry[i - 1] } else { 0 };
675        for j in 0..=i {
676            val += d[j] * q[i - j];
677        }
678        carry[i] = val >> LIMB_BITS;
679    }
680
681    let q_ext = if q_sign && signed {
682        (1 << LIMB_BITS) - 1
683    } else {
684        0
685    };
686    let d_ext =
687        (d[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) * if signed { (1 << LIMB_BITS) - 1 } else { 0 };
688    let r_ext =
689        (r[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) * if signed { (1 << LIMB_BITS) - 1 } else { 0 };
690    let mut d_prefix = 0;
691    let mut q_prefix = 0;
692
693    for i in 0..NUM_LIMBS {
694        d_prefix += d[i];
695        q_prefix += q[i];
696        let mut val = carry[NUM_LIMBS + i - 1] + d_prefix * q_ext + q_prefix * d_ext + r_ext;
697        for j in (i + 1)..NUM_LIMBS {
698            val += d[j] * q[NUM_LIMBS + i - j];
699        }
700        carry[NUM_LIMBS + i] = val >> LIMB_BITS;
701    }
702    carry
703}
704
705#[inline(always)]
706fn limbs_to_biguint<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
707    x: &[u32; NUM_LIMBS],
708) -> BigUint {
709    let base = BigUint::new(vec![1 << LIMB_BITS]);
710    let mut res = BigUint::new(vec![0]);
711    for val in x.iter().rev() {
712        res *= base.clone();
713        res += BigUint::new(vec![*val]);
714    }
715    res
716}
717
718#[inline(always)]
719fn biguint_to_limbs<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
720    x: &BigUint,
721) -> [u32; NUM_LIMBS] {
722    let mut res = [0; NUM_LIMBS];
723    let mut x = x.clone();
724    let base = BigUint::from(1u32 << LIMB_BITS);
725    for limb in res.iter_mut() {
726        let (quot, rem) = x.div_rem(&base);
727        *limb = rem.iter_u32_digits().next().unwrap_or(0);
728        x = quot;
729    }
730    debug_assert_eq!(x, BigUint::from(0u32));
731    res
732}
733
734#[inline(always)]
735fn negate<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
736    x: &[u32; NUM_LIMBS],
737) -> [u32; NUM_LIMBS] {
738    let mut carry = 1;
739    array::from_fn(|i| {
740        let val = (1 << LIMB_BITS) + carry - 1 - x[i];
741        carry = val >> LIMB_BITS;
742        val % (1 << LIMB_BITS)
743    })
744}