openvm_circuit_primitives/bigint/
mod.rs

1use std::{
2    cmp::{max, min},
3    ops::{Add, AddAssign, Mul, MulAssign, Sub},
4};
5
6use num_bigint::BigUint;
7use openvm_stark_backend::p3_util::log2_ceil_usize;
8
9pub mod check_carry_mod_to_zero;
10pub mod check_carry_to_zero;
11pub mod utils;
12
13#[derive(Debug, Clone)]
14pub struct OverflowInt<T> {
15    // The limbs, e.g. [a_0, a_1, a_2, ...] , represents a_0 + a_1 x + a_2 x^2
16    // T can be AB::Expr, for example when the OverflowInt represents x * y
17    // a0 = x0 * y0
18    // a1 = x0 * y1 + x1 * y0 ...
19    limbs: Vec<T>,
20
21    // Track the max abs of limbs, so we can do arithmetic on them.
22    limb_max_abs: usize,
23
24    // All limbs should be within [-2^max_overflow_bits, 2^max_overflow_bits)
25    max_overflow_bits: usize,
26}
27
28impl<T> OverflowInt<T> {
29    // Note: sign or unsigned are not about the type T.
30    // It's how we will range check the limbs. If the limbs are non-negative, use this one.
31    pub fn from_canonical_unsigned_limbs(x: Vec<T>, limb_bits: usize) -> OverflowInt<T> {
32        OverflowInt {
33            limbs: x,
34            max_overflow_bits: limb_bits,
35            limb_max_abs: (1 << limb_bits) - 1,
36        }
37    }
38
39    // Limbs can be negative. So the max_overflow_bits and limb_max_abs are different from the range check result.
40    pub fn from_canonical_signed_limbs(x: Vec<T>, limb_bits: usize) -> OverflowInt<T> {
41        OverflowInt {
42            limbs: x,
43            max_overflow_bits: limb_bits + 1,
44            limb_max_abs: (1 << limb_bits),
45        }
46    }
47
48    // Used only when limbs are hand calculated.
49    pub fn from_computed_limbs(
50        x: Vec<T>,
51        limb_max_abs: usize,
52        max_overflow_bits: usize,
53    ) -> OverflowInt<T> {
54        OverflowInt {
55            limbs: x,
56            max_overflow_bits,
57            limb_max_abs,
58        }
59    }
60
61    pub fn max_overflow_bits(&self) -> usize {
62        self.max_overflow_bits
63    }
64
65    pub fn limb_max_abs(&self) -> usize {
66        self.limb_max_abs
67    }
68
69    pub fn num_limbs(&self) -> usize {
70        self.limbs.len()
71    }
72
73    pub fn limb(&self, i: usize) -> &T {
74        self.limbs.get(i).unwrap()
75    }
76
77    pub fn limbs(&self) -> &[T] {
78        &self.limbs
79    }
80}
81
82impl<T> OverflowInt<T>
83where
84    T: Clone + AddAssign + MulAssign,
85{
86    pub fn int_add(&self, s: isize, convert: fn(isize) -> T) -> OverflowInt<T> {
87        let mut limbs = self.limbs.clone();
88        limbs[0] += convert(s);
89        let limb_max_abs = self.limb_max_abs + s.unsigned_abs();
90        OverflowInt {
91            limbs,
92            limb_max_abs,
93            max_overflow_bits: log2_ceil_usize(limb_max_abs),
94        }
95    }
96
97    pub fn int_mul(&self, s: isize, convert: fn(isize) -> T) -> OverflowInt<T> {
98        let mut limbs = self.limbs.clone();
99        for limb in limbs.iter_mut() {
100            *limb *= convert(s);
101        }
102        let limb_max_abs = self.limb_max_abs * s.unsigned_abs();
103        OverflowInt {
104            limbs,
105            limb_max_abs,
106            max_overflow_bits: log2_ceil_usize(limb_max_abs),
107        }
108    }
109}
110
111impl OverflowInt<isize> {
112    pub fn from_biguint(
113        x: &BigUint,
114        limb_bits: usize,
115        min_limbs: Option<usize>,
116    ) -> OverflowInt<isize> {
117        let limbs = match min_limbs {
118            Some(min_limbs) => utils::big_uint_to_num_limbs(x, limb_bits, min_limbs),
119            None => utils::big_uint_to_limbs(x, limb_bits),
120        };
121        let limbs: Vec<isize> = limbs.iter().map(|x| *x as isize).collect();
122        OverflowInt::from_canonical_unsigned_limbs(limbs, limb_bits)
123    }
124
125    pub fn calculate_carries(&self, limb_bits: usize) -> Vec<isize> {
126        let mut carries = Vec::with_capacity(self.limbs.len());
127
128        let mut carry = 0;
129        for i in 0..self.limbs.len() {
130            carry = (carry + self.limbs[i]) >> limb_bits;
131            carries.push(carry);
132        }
133        carries
134    }
135}
136
137impl<T> Add for OverflowInt<T>
138where
139    T: Add<Output = T> + Clone + Default,
140{
141    type Output = OverflowInt<T>;
142
143    fn add(self, other: OverflowInt<T>) -> OverflowInt<T> {
144        let len = max(self.limbs.len(), other.limbs.len());
145        let mut limbs = Vec::with_capacity(len);
146        let zero = T::default();
147        for i in 0..len {
148            let a = self.limbs.get(i).unwrap_or(&zero);
149            let b = other.limbs.get(i).unwrap_or(&zero);
150            limbs.push(a.clone() + b.clone());
151        }
152        let new_max = self.limb_max_abs + other.limb_max_abs;
153        let max_bits = log2_ceil_usize(new_max);
154        OverflowInt {
155            limbs,
156            max_overflow_bits: max_bits,
157            limb_max_abs: new_max,
158        }
159    }
160}
161
162impl<T> Sub for OverflowInt<T>
163where
164    T: Sub<Output = T> + Clone + Default,
165{
166    type Output = OverflowInt<T>;
167
168    fn sub(self, other: OverflowInt<T>) -> OverflowInt<T> {
169        let len = max(self.limbs.len(), other.limbs.len());
170        let mut limbs = Vec::with_capacity(len);
171        for i in 0..len {
172            let zero = T::default();
173            let a = self.limbs.get(i).unwrap_or(&zero);
174            let b = other.limbs.get(i).unwrap_or(&zero);
175            limbs.push(a.clone() - b.clone());
176        }
177        let new_max = self.limb_max_abs + other.limb_max_abs;
178        let max_bits = log2_ceil_usize(new_max);
179        OverflowInt {
180            limbs,
181            max_overflow_bits: max_bits,
182            limb_max_abs: new_max,
183        }
184    }
185}
186
187impl<T> Mul for OverflowInt<T>
188where
189    T: Add<Output = T> + Mul<Output = T> + Clone + Default,
190{
191    type Output = OverflowInt<T>;
192
193    fn mul(self, other: OverflowInt<T>) -> OverflowInt<T> {
194        let len = self.limbs.len() + other.limbs.len() - 1;
195        let mut limbs = vec![T::default(); len];
196        for i in 0..self.limbs.len() {
197            for j in 0..other.limbs.len() {
198                // += doesn't work for T.
199                limbs[i + j] =
200                    limbs[i + j].clone() + self.limbs[i].clone() * other.limbs[j].clone();
201            }
202        }
203        let new_max =
204            self.limb_max_abs * other.limb_max_abs * min(self.limbs.len(), other.limbs.len());
205        let max_bits = log2_ceil_usize(new_max);
206        OverflowInt {
207            limbs,
208            max_overflow_bits: max_bits,
209            limb_max_abs: new_max,
210        }
211    }
212}