1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use num_integer::Integer;
8use openvm_circuit::arch::{
9 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
10 VmCoreAir, VmCoreChip,
11};
12use openvm_circuit_primitives::{
13 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
14 range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip},
15 utils::{not, select},
16};
17use openvm_circuit_primitives_derive::AlignedBorrow;
18use openvm_instructions::{instruction::Instruction, LocalOpcode};
19use openvm_rv32im_transpiler::DivRemOpcode;
20use openvm_stark_backend::{
21 interaction::InteractionBuilder,
22 p3_air::{AirBuilder, BaseAir},
23 p3_field::{Field, FieldAlgebra, PrimeField32},
24 rap::BaseAirWithPublicValues,
25};
26use serde::{de::DeserializeOwned, Deserialize, Serialize};
27use serde_big_array::BigArray;
28use strum::IntoEnumIterator;
29
30#[repr(C)]
31#[derive(AlignedBorrow)]
32pub struct DivRemCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
33 pub b: [T; NUM_LIMBS],
35 pub c: [T; NUM_LIMBS],
36 pub q: [T; NUM_LIMBS],
37 pub r: [T; NUM_LIMBS],
38
39 pub zero_divisor: T,
41 pub r_zero: T,
42
43 pub b_sign: T,
46 pub c_sign: T,
47 pub q_sign: T,
48 pub sign_xor: T,
49
50 pub c_sum_inv: T,
52 pub r_sum_inv: T,
54
55 pub r_prime: [T; NUM_LIMBS],
59 pub r_inv: [T; NUM_LIMBS],
60 pub lt_marker: [T; NUM_LIMBS],
61 pub lt_diff: T,
62
63 pub opcode_div_flag: T,
65 pub opcode_divu_flag: T,
66 pub opcode_rem_flag: T,
67 pub opcode_remu_flag: T,
68}
69
70#[derive(Copy, Clone, Debug)]
71pub struct DivRemCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
72 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
73 pub range_tuple_bus: RangeTupleCheckerBus<2>,
74 offset: usize,
75}
76
77impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
78 for DivRemCoreAir<NUM_LIMBS, LIMB_BITS>
79{
80 fn width(&self) -> usize {
81 DivRemCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
82 }
83}
84impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
85 for DivRemCoreAir<NUM_LIMBS, LIMB_BITS>
86{
87}
88
89impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
90 for DivRemCoreAir<NUM_LIMBS, LIMB_BITS>
91where
92 AB: InteractionBuilder,
93 I: VmAdapterInterface<AB::Expr>,
94 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
95 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
96 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
97{
98 fn eval(
99 &self,
100 builder: &mut AB,
101 local_core: &[AB::Var],
102 _from_pc: AB::Var,
103 ) -> AdapterAirContext<AB::Expr, I> {
104 let cols: &DivRemCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
105 let flags = [
106 cols.opcode_div_flag,
107 cols.opcode_divu_flag,
108 cols.opcode_rem_flag,
109 cols.opcode_remu_flag,
110 ];
111
112 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
113 builder.assert_bool(flag);
114 acc + flag.into()
115 });
116 builder.assert_bool(is_valid.clone());
117
118 let b = &cols.b;
119 let c = &cols.c;
120 let q = &cols.q;
121 let r = &cols.r;
122
123 let b_ext = cols.b_sign * AB::F::from_canonical_u32((1 << LIMB_BITS) - 1);
125 let c_ext = cols.c_sign * AB::F::from_canonical_u32((1 << LIMB_BITS) - 1);
126 let carry_divide = AB::F::from_canonical_u32(1 << LIMB_BITS).inverse();
127 let mut carry: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
128
129 for i in 0..NUM_LIMBS {
130 let expected_limb = if i == 0 {
131 AB::Expr::ZERO
132 } else {
133 carry[i - 1].clone()
134 } + (0..=i).fold(r[i].into(), |ac, k| ac + (c[k] * q[i - k]));
135 carry[i] = (expected_limb - b[i]) * carry_divide;
136 }
137
138 for (q, carry) in q.iter().zip(carry.iter()) {
139 self.range_tuple_bus
140 .send(vec![(*q).into(), carry.clone()])
141 .eval(builder, is_valid.clone());
142 }
143
144 let q_ext = cols.q_sign * AB::F::from_canonical_u32((1 << LIMB_BITS) - 1);
147 let mut carry_ext: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
148
149 for j in 0..NUM_LIMBS {
150 let expected_limb = if j == 0 {
151 carry[NUM_LIMBS - 1].clone()
152 } else {
153 carry_ext[j - 1].clone()
154 } + ((j + 1)..NUM_LIMBS)
155 .fold(AB::Expr::ZERO, |acc, k| acc + (c[k] * q[NUM_LIMBS + j - k]))
156 + (0..(j + 1)).fold(AB::Expr::ZERO, |acc, k| {
157 acc + (c[k] * q_ext.clone()) + (q[k] * c_ext.clone())
158 })
159 + (AB::Expr::ONE - cols.r_zero) * b_ext.clone();
160 carry_ext[j] = (expected_limb - b_ext.clone()) * carry_divide;
168 }
169
170 for (r, carry) in r.iter().zip(carry_ext.iter()) {
171 self.range_tuple_bus
172 .send(vec![(*r).into(), carry.clone()])
173 .eval(builder, is_valid.clone());
174 }
175
176 let special_case = cols.zero_divisor + cols.r_zero;
179 builder.assert_bool(special_case.clone());
180
181 builder.assert_bool(cols.zero_divisor);
183 let mut when_zero_divisor = builder.when(cols.zero_divisor);
184 for i in 0..NUM_LIMBS {
185 when_zero_divisor.assert_zero(c[i]);
186 when_zero_divisor.assert_eq(q[i], AB::F::from_canonical_u32((1 << LIMB_BITS) - 1));
187 }
188 let c_sum = c.iter().fold(AB::Expr::ZERO, |acc, c| acc + *c);
193 let valid_and_not_zero_divisor = is_valid.clone() - cols.zero_divisor;
194 builder.assert_bool(valid_and_not_zero_divisor.clone());
195 builder
196 .when(valid_and_not_zero_divisor)
197 .assert_one(c_sum * cols.c_sum_inv);
198
199 builder.assert_bool(cols.r_zero);
201 r.iter()
202 .for_each(|r_i| builder.when(cols.r_zero).assert_zero(*r_i));
203 let r_sum = r.iter().fold(AB::Expr::ZERO, |acc, r| acc + *r);
206 let valid_and_not_special_case = is_valid.clone() - special_case.clone();
207 builder.assert_bool(valid_and_not_special_case.clone());
208 builder
209 .when(valid_and_not_special_case)
210 .assert_one(r_sum * cols.r_sum_inv);
211
212 let signed = cols.opcode_div_flag + cols.opcode_rem_flag;
216
217 builder.assert_bool(cols.b_sign);
218 builder.assert_bool(cols.c_sign);
219 builder
220 .when(not::<AB::Expr>(signed.clone()))
221 .assert_zero(cols.b_sign);
222 builder
223 .when(not::<AB::Expr>(signed.clone()))
224 .assert_zero(cols.c_sign);
225 builder.assert_eq(
226 cols.b_sign + cols.c_sign - AB::Expr::from_canonical_u32(2) * cols.b_sign * cols.c_sign,
227 cols.sign_xor,
228 );
229
230 let nonzero_q = q.iter().fold(AB::Expr::ZERO, |acc, q| acc + *q);
239 builder.assert_bool(cols.q_sign);
240 builder
241 .when(nonzero_q)
242 .when(not(cols.zero_divisor))
243 .assert_eq(cols.q_sign, cols.sign_xor);
244 builder
245 .when_ne(cols.q_sign, cols.sign_xor)
246 .when(not(cols.zero_divisor))
247 .assert_zero(cols.q_sign);
248
249 let sign_mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
251 self.bitwise_lookup_bus
252 .send_range(
253 AB::Expr::from_canonical_u32(2) * (b[NUM_LIMBS - 1] - cols.b_sign * sign_mask),
254 AB::Expr::from_canonical_u32(2) * (c[NUM_LIMBS - 1] - cols.c_sign * sign_mask),
255 )
256 .eval(builder, signed.clone());
257
258 let r_p = &cols.r_prime;
265 let mut carry_lt: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
266
267 for i in 0..NUM_LIMBS {
268 builder.when(not(cols.sign_xor)).assert_eq(r[i], r_p[i]);
270
271 let last_carry = if i > 0 {
277 carry_lt[i - 1].clone()
278 } else {
279 AB::Expr::ZERO
280 };
281 carry_lt[i] = (last_carry.clone() + r[i] + r_p[i]) * carry_divide;
282 builder.when(cols.sign_xor).assert_zero(
283 (carry_lt[i].clone() - last_carry) * (carry_lt[i].clone() - AB::Expr::ONE),
284 );
285 builder
286 .when(cols.sign_xor)
287 .assert_one((r_p[i] - AB::F::from_canonical_u32(1 << LIMB_BITS)) * cols.r_inv[i]);
288 builder
289 .when(cols.sign_xor)
290 .when(not::<AB::Expr>(carry_lt[i].clone()))
291 .assert_zero(r_p[i]);
292 }
293
294 let marker = &cols.lt_marker;
295 let mut prefix_sum = special_case.clone();
296
297 for i in (0..NUM_LIMBS).rev() {
298 let diff = r_p[i] * (AB::Expr::from_canonical_u8(2) * cols.c_sign - AB::Expr::ONE)
299 + c[i] * (AB::Expr::ONE - AB::Expr::from_canonical_u8(2) * cols.c_sign);
300 prefix_sum += marker[i].into();
301 builder.assert_bool(marker[i]);
302 builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
303 builder.when(marker[i]).assert_eq(cols.lt_diff, diff);
304 }
305 builder.when(is_valid.clone()).assert_one(prefix_sum);
311 self.bitwise_lookup_bus
313 .send_range(cols.lt_diff - AB::Expr::ONE, AB::F::ZERO)
314 .eval(builder, is_valid.clone() - special_case);
315
316 let expected_opcode = flags.iter().zip(DivRemOpcode::iter()).fold(
318 AB::Expr::ZERO,
319 |acc, (flag, local_opcode)| {
320 acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
321 },
322 ) + AB::Expr::from_canonical_usize(self.offset);
323
324 let is_div = cols.opcode_div_flag + cols.opcode_divu_flag;
325 let a = array::from_fn(|i| select(is_div.clone(), q[i], r[i]));
326
327 AdapterAirContext {
328 to_pc: None,
329 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
330 writes: [a.map(Into::into)].into(),
331 instruction: MinimalInstruction {
332 is_valid,
333 opcode: expected_opcode,
334 }
335 .into(),
336 }
337 }
338
339 fn start_offset(&self) -> usize {
340 self.offset
341 }
342}
343
344pub struct DivRemCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
345 pub air: DivRemCoreAir<NUM_LIMBS, LIMB_BITS>,
346 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
347 pub range_tuple_chip: SharedRangeTupleCheckerChip<2>,
348}
349
350impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> DivRemCoreChip<NUM_LIMBS, LIMB_BITS> {
351 pub fn new(
352 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
353 range_tuple_chip: SharedRangeTupleCheckerChip<2>,
354 offset: usize,
355 ) -> Self {
356 debug_assert!(
360 range_tuple_chip.sizes()[0] == 1 << LIMB_BITS,
361 "First element of RangeTupleChecker must have size {}",
362 1 << LIMB_BITS
363 );
364 debug_assert!(
365 range_tuple_chip.sizes()[1] >= (1 << LIMB_BITS) * 2 * NUM_LIMBS as u32,
366 "Second element of RangeTupleChecker must have size of at least {}",
367 (1 << LIMB_BITS) * 2 * NUM_LIMBS as u32
368 );
369
370 Self {
371 air: DivRemCoreAir {
372 bitwise_lookup_bus: bitwise_lookup_chip.bus(),
373 range_tuple_bus: *range_tuple_chip.bus(),
374 offset,
375 },
376 bitwise_lookup_chip,
377 range_tuple_chip,
378 }
379 }
380}
381
382#[repr(C)]
383#[derive(Clone, Debug, Serialize, Deserialize)]
384#[serde(bound = "T: Serialize + DeserializeOwned")]
385pub struct DivRemCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
386 #[serde(with = "BigArray")]
387 pub b: [T; NUM_LIMBS],
388 #[serde(with = "BigArray")]
389 pub c: [T; NUM_LIMBS],
390 #[serde(with = "BigArray")]
391 pub q: [T; NUM_LIMBS],
392 #[serde(with = "BigArray")]
393 pub r: [T; NUM_LIMBS],
394 pub zero_divisor: T,
395 pub r_zero: T,
396 pub b_sign: T,
397 pub c_sign: T,
398 pub q_sign: T,
399 pub sign_xor: T,
400 pub c_sum_inv: T,
401 pub r_sum_inv: T,
402 #[serde(with = "BigArray")]
403 pub r_prime: [T; NUM_LIMBS],
404 #[serde(with = "BigArray")]
405 pub r_inv: [T; NUM_LIMBS],
406 pub lt_diff_val: T,
407 pub lt_diff_idx: usize,
408 pub opcode: DivRemOpcode,
409}
410
411#[derive(Debug, Eq, PartialEq)]
412#[repr(u8)]
413pub(super) enum DivRemCoreSpecialCase {
414 None,
415 ZeroDivisor,
416 SignedOverflow,
417}
418
419impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
420 VmCoreChip<F, I> for DivRemCoreChip<NUM_LIMBS, LIMB_BITS>
421where
422 I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
423 I::Writes: From<[[F; NUM_LIMBS]; 1]>,
424{
425 type Record = DivRemCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
426 type Air = DivRemCoreAir<NUM_LIMBS, LIMB_BITS>;
427
428 #[allow(clippy::type_complexity)]
429 fn execute_instruction(
430 &self,
431 instruction: &Instruction<F>,
432 _from_pc: u32,
433 reads: I::Reads,
434 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
435 let Instruction { opcode, .. } = instruction;
436 let divrem_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
437
438 let is_div = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::DIVU;
439 let is_signed = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::REM;
440
441 let data: [[F; NUM_LIMBS]; 2] = reads.into();
442 let b = data[0].map(|x| x.as_canonical_u32());
443 let c = data[1].map(|y| y.as_canonical_u32());
444 let (q, r, b_sign, c_sign, q_sign, case) =
445 run_divrem::<NUM_LIMBS, LIMB_BITS>(is_signed, &b, &c);
446
447 let carries = run_mul_carries::<NUM_LIMBS, LIMB_BITS>(is_signed, &c, &q, &r, q_sign);
448 for i in 0..NUM_LIMBS {
449 self.range_tuple_chip.add_count(&[q[i], carries[i]]);
450 self.range_tuple_chip
451 .add_count(&[r[i], carries[i + NUM_LIMBS]]);
452 }
453
454 let sign_xor = b_sign ^ c_sign;
455 let r_prime = if sign_xor {
456 negate::<NUM_LIMBS, LIMB_BITS>(&r)
457 } else {
458 r
459 };
460 let r_zero = r.iter().all(|&v| v == 0) && case != DivRemCoreSpecialCase::ZeroDivisor;
461
462 if is_signed {
463 let b_sign_mask = if b_sign { 1 << (LIMB_BITS - 1) } else { 0 };
464 let c_sign_mask = if c_sign { 1 << (LIMB_BITS - 1) } else { 0 };
465 self.bitwise_lookup_chip.request_range(
466 (b[NUM_LIMBS - 1] - b_sign_mask) << 1,
467 (c[NUM_LIMBS - 1] - c_sign_mask) << 1,
468 );
469 }
470
471 let c_sum_f = data[1].iter().fold(F::ZERO, |acc, c| acc + *c);
472 let c_sum_inv_f = c_sum_f.try_inverse().unwrap_or(F::ZERO);
473
474 let r_sum_f = r
475 .iter()
476 .fold(F::ZERO, |acc, r| acc + F::from_canonical_u32(*r));
477 let r_sum_inv_f = r_sum_f.try_inverse().unwrap_or(F::ZERO);
478
479 let (lt_diff_idx, lt_diff_val) = if case == DivRemCoreSpecialCase::None && !r_zero {
480 let idx = run_sltu_diff_idx(&c, &r_prime, c_sign);
481 let val = if c_sign {
482 r_prime[idx] - c[idx]
483 } else {
484 c[idx] - r_prime[idx]
485 };
486 self.bitwise_lookup_chip.request_range(val - 1, 0);
487 (idx, val)
488 } else {
489 (NUM_LIMBS, 0)
490 };
491
492 let r_prime_f = r_prime.map(F::from_canonical_u32);
493 let output = AdapterRuntimeContext::without_pc([
494 (if is_div { &q } else { &r }).map(F::from_canonical_u32)
495 ]);
496 let record = DivRemCoreRecord {
497 opcode: divrem_opcode,
498 b: data[0],
499 c: data[1],
500 q: q.map(F::from_canonical_u32),
501 r: r.map(F::from_canonical_u32),
502 zero_divisor: F::from_bool(case == DivRemCoreSpecialCase::ZeroDivisor),
503 r_zero: F::from_bool(r_zero),
504 b_sign: F::from_bool(b_sign),
505 c_sign: F::from_bool(c_sign),
506 q_sign: F::from_bool(q_sign),
507 sign_xor: F::from_bool(sign_xor),
508 c_sum_inv: c_sum_inv_f,
509 r_sum_inv: r_sum_inv_f,
510 r_prime: r_prime_f,
511 r_inv: r_prime_f.map(|r| (r - F::from_canonical_u32(256)).inverse()),
512 lt_diff_val: F::from_canonical_u32(lt_diff_val),
513 lt_diff_idx,
514 };
515
516 Ok((output, record))
517 }
518
519 fn get_opcode_name(&self, opcode: usize) -> String {
520 format!("{:?}", DivRemOpcode::from_usize(opcode - self.air.offset))
521 }
522
523 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
524 let row_slice: &mut DivRemCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
525 row_slice.b = record.b;
526 row_slice.c = record.c;
527 row_slice.q = record.q;
528 row_slice.r = record.r;
529 row_slice.zero_divisor = record.zero_divisor;
530 row_slice.r_zero = record.r_zero;
531 row_slice.b_sign = record.b_sign;
532 row_slice.c_sign = record.c_sign;
533 row_slice.q_sign = record.q_sign;
534 row_slice.sign_xor = record.sign_xor;
535 row_slice.c_sum_inv = record.c_sum_inv;
536 row_slice.r_sum_inv = record.r_sum_inv;
537 row_slice.r_prime = record.r_prime;
538 row_slice.r_inv = record.r_inv;
539 row_slice.lt_marker = array::from_fn(|i| F::from_bool(i == record.lt_diff_idx));
540 row_slice.lt_diff = record.lt_diff_val;
541 row_slice.opcode_div_flag = F::from_bool(record.opcode == DivRemOpcode::DIV);
542 row_slice.opcode_divu_flag = F::from_bool(record.opcode == DivRemOpcode::DIVU);
543 row_slice.opcode_rem_flag = F::from_bool(record.opcode == DivRemOpcode::REM);
544 row_slice.opcode_remu_flag = F::from_bool(record.opcode == DivRemOpcode::REMU);
545 }
546
547 fn air(&self) -> &Self::Air {
548 &self.air
549 }
550}
551
552pub(super) fn run_divrem<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
555 signed: bool,
556 x: &[u32; NUM_LIMBS],
557 y: &[u32; NUM_LIMBS],
558) -> (
559 [u32; NUM_LIMBS],
560 [u32; NUM_LIMBS],
561 bool,
562 bool,
563 bool,
564 DivRemCoreSpecialCase,
565) {
566 let x_sign = signed && (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1);
567 let y_sign = signed && (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1);
568 let max_limb = (1 << LIMB_BITS) - 1;
569
570 let zero_divisor = y.iter().all(|val| *val == 0);
571 let overflow = x[NUM_LIMBS - 1] == 1 << (LIMB_BITS - 1)
572 && x[..(NUM_LIMBS - 1)].iter().all(|val| *val == 0)
573 && y.iter().all(|val| *val == max_limb)
574 && x_sign
575 && y_sign;
576
577 if zero_divisor {
578 return (
579 [max_limb; NUM_LIMBS],
580 *x,
581 x_sign,
582 y_sign,
583 signed,
584 DivRemCoreSpecialCase::ZeroDivisor,
585 );
586 } else if overflow {
587 return (
588 *x,
589 [0; NUM_LIMBS],
590 x_sign,
591 y_sign,
592 false,
593 DivRemCoreSpecialCase::SignedOverflow,
594 );
595 }
596
597 let x_abs = if x_sign {
598 negate::<NUM_LIMBS, LIMB_BITS>(x)
599 } else {
600 *x
601 };
602 let y_abs = if y_sign {
603 negate::<NUM_LIMBS, LIMB_BITS>(y)
604 } else {
605 *y
606 };
607
608 let x_big = limbs_to_biguint::<NUM_LIMBS, LIMB_BITS>(&x_abs);
609 let y_big = limbs_to_biguint::<NUM_LIMBS, LIMB_BITS>(&y_abs);
610 let q_big = x_big.clone() / y_big.clone();
611 let r_big = x_big.clone() % y_big.clone();
612
613 let q = if x_sign ^ y_sign {
614 negate::<NUM_LIMBS, LIMB_BITS>(&biguint_to_limbs::<NUM_LIMBS, LIMB_BITS>(&q_big))
615 } else {
616 biguint_to_limbs::<NUM_LIMBS, LIMB_BITS>(&q_big)
617 };
618 let q_sign = signed && (q[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1);
619
620 let r = if x_sign {
622 negate::<NUM_LIMBS, LIMB_BITS>(&biguint_to_limbs::<NUM_LIMBS, LIMB_BITS>(&r_big))
623 } else {
624 biguint_to_limbs::<NUM_LIMBS, LIMB_BITS>(&r_big)
625 };
626
627 (q, r, x_sign, y_sign, q_sign, DivRemCoreSpecialCase::None)
628}
629
630pub(super) fn run_sltu_diff_idx<const NUM_LIMBS: usize>(
631 x: &[u32; NUM_LIMBS],
632 y: &[u32; NUM_LIMBS],
633 cmp: bool,
634) -> usize {
635 for i in (0..NUM_LIMBS).rev() {
636 if x[i] != y[i] {
637 assert!((x[i] < y[i]) == cmp);
638 return i;
639 }
640 }
641 assert!(!cmp);
642 NUM_LIMBS
643}
644
645pub(super) fn run_mul_carries<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
647 signed: bool,
648 d: &[u32; NUM_LIMBS],
649 q: &[u32; NUM_LIMBS],
650 r: &[u32; NUM_LIMBS],
651 q_sign: bool,
652) -> Vec<u32> {
653 let mut carry = vec![0u32; 2 * NUM_LIMBS];
654 for i in 0..NUM_LIMBS {
655 let mut val = r[i] + if i > 0 { carry[i - 1] } else { 0 };
656 for j in 0..=i {
657 val += d[j] * q[i - j];
658 }
659 carry[i] = val >> LIMB_BITS;
660 }
661
662 let q_ext = if q_sign && signed {
663 (1 << LIMB_BITS) - 1
664 } else {
665 0
666 };
667 let d_ext =
668 (d[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) * if signed { (1 << LIMB_BITS) - 1 } else { 0 };
669 let r_ext =
670 (r[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) * if signed { (1 << LIMB_BITS) - 1 } else { 0 };
671 let mut d_prefix = 0;
672 let mut q_prefix = 0;
673
674 for i in 0..NUM_LIMBS {
675 d_prefix += d[i];
676 q_prefix += q[i];
677 let mut val = carry[NUM_LIMBS + i - 1] + d_prefix * q_ext + q_prefix * d_ext + r_ext;
678 for j in (i + 1)..NUM_LIMBS {
679 val += d[j] * q[NUM_LIMBS + i - j];
680 }
681 carry[NUM_LIMBS + i] = val >> LIMB_BITS;
682 }
683 carry
684}
685
686fn limbs_to_biguint<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
687 x: &[u32; NUM_LIMBS],
688) -> BigUint {
689 let base = BigUint::new(vec![1 << LIMB_BITS]);
690 let mut res = BigUint::new(vec![0]);
691 for val in x.iter().rev() {
692 res *= base.clone();
693 res += BigUint::new(vec![*val]);
694 }
695 res
696}
697
698fn biguint_to_limbs<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
699 x: &BigUint,
700) -> [u32; NUM_LIMBS] {
701 let mut res = [0; NUM_LIMBS];
702 let mut x = x.clone();
703 let base = BigUint::from(1u32 << LIMB_BITS);
704 for limb in res.iter_mut() {
705 let (quot, rem) = x.div_rem(&base);
706 *limb = rem.iter_u32_digits().next().unwrap_or(0);
707 x = quot;
708 }
709 debug_assert_eq!(x, BigUint::from(0u32));
710 res
711}
712
713fn negate<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
714 x: &[u32; NUM_LIMBS],
715) -> [u32; NUM_LIMBS] {
716 let mut carry = 1;
717 array::from_fn(|i| {
718 let val = (1 << LIMB_BITS) + carry - 1 - x[i];
719 carry = val >> LIMB_BITS;
720 val % (1 << LIMB_BITS)
721 })
722}