openvm_circuit_primitives/bigint/
mod.rs

1use std::{
2    cmp::{max, min},
3    ops::{Add, AddAssign, Mul, MulAssign, Sub},
4};
5
6use num_bigint::BigUint;
7use openvm_stark_backend::p3_util::log2_ceil_usize;
8
9pub mod check_carry_mod_to_zero;
10pub mod check_carry_to_zero;
11pub mod utils;
12
13#[derive(Debug, Clone)]
14pub struct OverflowInt<T> {
15    // The limbs, e.g. [a_0, a_1, a_2, ...] , represents a_0 + a_1 x + a_2 x^2
16    // T can be AB::Expr, for example when the OverflowInt represents x * y
17    // a0 = x0 * y0
18    // a1 = x0 * y1 + x1 * y0 ...
19    limbs: Vec<T>,
20
21    // Track the max abs of limbs, so we can do arithmetic on them.
22    limb_max_abs: usize,
23
24    // All limbs should be within [-2^max_overflow_bits, 2^max_overflow_bits)
25    max_overflow_bits: usize,
26}
27
28impl<T> OverflowInt<T> {
29    // Note: sign or unsigned are not about the type T.
30    // It's how we will range check the limbs. If the limbs are non-negative, use this one.
31    pub fn from_canonical_unsigned_limbs(x: Vec<T>, limb_bits: usize) -> OverflowInt<T> {
32        OverflowInt {
33            limbs: x,
34            max_overflow_bits: limb_bits,
35            limb_max_abs: (1 << limb_bits) - 1,
36        }
37    }
38
39    // Limbs can be negative. So the max_overflow_bits and limb_max_abs are different from the range
40    // check result.
41    pub fn from_canonical_signed_limbs(x: Vec<T>, limb_bits: usize) -> OverflowInt<T> {
42        OverflowInt {
43            limbs: x,
44            max_overflow_bits: limb_bits + 1,
45            limb_max_abs: (1 << limb_bits),
46        }
47    }
48
49    // Used only when limbs are hand calculated.
50    pub fn from_computed_limbs(
51        x: Vec<T>,
52        limb_max_abs: usize,
53        max_overflow_bits: usize,
54    ) -> OverflowInt<T> {
55        OverflowInt {
56            limbs: x,
57            max_overflow_bits,
58            limb_max_abs,
59        }
60    }
61
62    pub fn max_overflow_bits(&self) -> usize {
63        self.max_overflow_bits
64    }
65
66    pub fn limb_max_abs(&self) -> usize {
67        self.limb_max_abs
68    }
69
70    pub fn num_limbs(&self) -> usize {
71        self.limbs.len()
72    }
73
74    pub fn limb(&self, i: usize) -> &T {
75        self.limbs.get(i).unwrap()
76    }
77
78    pub fn limbs(&self) -> &[T] {
79        &self.limbs
80    }
81}
82
83impl<T> OverflowInt<T>
84where
85    T: Clone + AddAssign + MulAssign,
86{
87    pub fn int_add(&self, s: isize, convert: fn(isize) -> T) -> OverflowInt<T> {
88        let mut limbs = self.limbs.clone();
89        limbs[0] += convert(s);
90        let limb_max_abs = self.limb_max_abs + s.unsigned_abs();
91        OverflowInt {
92            limbs,
93            limb_max_abs,
94            max_overflow_bits: log2_ceil_usize(limb_max_abs),
95        }
96    }
97
98    pub fn int_mul(&self, s: isize, convert: fn(isize) -> T) -> OverflowInt<T> {
99        let mut limbs = self.limbs.clone();
100        for limb in limbs.iter_mut() {
101            *limb *= convert(s);
102        }
103        let limb_max_abs = self.limb_max_abs * s.unsigned_abs();
104        OverflowInt {
105            limbs,
106            limb_max_abs,
107            max_overflow_bits: log2_ceil_usize(limb_max_abs),
108        }
109    }
110}
111
112impl OverflowInt<isize> {
113    pub fn from_biguint(
114        x: &BigUint,
115        limb_bits: usize,
116        min_limbs: Option<usize>,
117    ) -> OverflowInt<isize> {
118        let limbs = match min_limbs {
119            Some(min_limbs) => utils::big_uint_to_num_limbs(x, limb_bits, min_limbs),
120            None => utils::big_uint_to_limbs(x, limb_bits),
121        };
122        let limbs: Vec<isize> = limbs.iter().map(|x| *x as isize).collect();
123        OverflowInt::from_canonical_unsigned_limbs(limbs, limb_bits)
124    }
125
126    pub fn calculate_carries(&self, limb_bits: usize) -> Vec<isize> {
127        let mut carries = Vec::with_capacity(self.limbs.len());
128
129        let mut carry = 0;
130        for i in 0..self.limbs.len() {
131            carry = (carry + self.limbs[i]) >> limb_bits;
132            carries.push(carry);
133        }
134        carries
135    }
136}
137
138impl<T> Add for OverflowInt<T>
139where
140    T: Add<Output = T> + Clone + Default,
141{
142    type Output = OverflowInt<T>;
143
144    fn add(self, other: OverflowInt<T>) -> OverflowInt<T> {
145        let len = max(self.limbs.len(), other.limbs.len());
146        let mut limbs = Vec::with_capacity(len);
147        let zero = T::default();
148        for i in 0..len {
149            let a = self.limbs.get(i).unwrap_or(&zero);
150            let b = other.limbs.get(i).unwrap_or(&zero);
151            limbs.push(a.clone() + b.clone());
152        }
153        let new_max = self.limb_max_abs + other.limb_max_abs;
154        let max_bits = log2_ceil_usize(new_max);
155        OverflowInt {
156            limbs,
157            max_overflow_bits: max_bits,
158            limb_max_abs: new_max,
159        }
160    }
161}
162
163impl<T> Sub for OverflowInt<T>
164where
165    T: Sub<Output = T> + Clone + Default,
166{
167    type Output = OverflowInt<T>;
168
169    fn sub(self, other: OverflowInt<T>) -> OverflowInt<T> {
170        let len = max(self.limbs.len(), other.limbs.len());
171        let mut limbs = Vec::with_capacity(len);
172        for i in 0..len {
173            let zero = T::default();
174            let a = self.limbs.get(i).unwrap_or(&zero);
175            let b = other.limbs.get(i).unwrap_or(&zero);
176            limbs.push(a.clone() - b.clone());
177        }
178        let new_max = self.limb_max_abs + other.limb_max_abs;
179        let max_bits = log2_ceil_usize(new_max);
180        OverflowInt {
181            limbs,
182            max_overflow_bits: max_bits,
183            limb_max_abs: new_max,
184        }
185    }
186}
187
188impl<T> Mul for OverflowInt<T>
189where
190    T: Add<Output = T> + Mul<Output = T> + Clone + Default,
191{
192    type Output = OverflowInt<T>;
193
194    fn mul(self, other: OverflowInt<T>) -> OverflowInt<T> {
195        let len = self.limbs.len() + other.limbs.len() - 1;
196        let mut limbs = vec![T::default(); len];
197        for i in 0..self.limbs.len() {
198            for j in 0..other.limbs.len() {
199                // += doesn't work for T.
200                limbs[i + j] =
201                    limbs[i + j].clone() + self.limbs[i].clone() * other.limbs[j].clone();
202            }
203        }
204        let new_max =
205            self.limb_max_abs * other.limb_max_abs * min(self.limbs.len(), other.limbs.len());
206        let max_bits = log2_ceil_usize(new_max);
207        OverflowInt {
208            limbs,
209            max_overflow_bits: max_bits,
210            limb_max_abs: new_max,
211        }
212    }
213}