1use crate::{
2 algorithms::{borrowing_sub, carrying_add},
3 Uint,
4};
5use core::{
6 iter::Sum,
7 ops::{Add, AddAssign, Neg, Sub, SubAssign},
8};
9
10impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
11 #[inline(always)]
15 #[must_use]
16 pub fn abs_diff(self, other: Self) -> Self {
17 if self < other {
18 other.wrapping_sub(self)
19 } else {
20 self.wrapping_sub(other)
21 }
22 }
23
24 #[inline(always)]
26 #[must_use]
27 pub const fn checked_add(self, rhs: Self) -> Option<Self> {
28 match self.overflowing_add(rhs) {
29 (value, false) => Some(value),
30 _ => None,
31 }
32 }
33
34 #[inline(always)]
36 #[must_use]
37 pub const fn checked_neg(self) -> Option<Self> {
38 match self.overflowing_neg() {
39 (value, false) => Some(value),
40 _ => None,
41 }
42 }
43
44 #[inline(always)]
46 #[must_use]
47 pub const fn checked_sub(self, rhs: Self) -> Option<Self> {
48 match self.overflowing_sub(rhs) {
49 (value, false) => Some(value),
50 _ => None,
51 }
52 }
53
54 #[inline]
60 #[must_use]
61 pub const fn overflowing_add(mut self, rhs: Self) -> (Self, bool) {
62 if BITS == 0 {
63 return (Self::ZERO, false);
64 }
65 let mut carry = false;
66 let mut i = 0;
67 while i < LIMBS {
68 (self.limbs[i], carry) = carrying_add(self.limbs[i], rhs.limbs[i], carry);
69 i += 1;
70 }
71 let overflow = carry | (self.limbs[LIMBS - 1] > Self::MASK);
72 (self.masked(), overflow)
73 }
74
75 #[inline(always)]
82 #[must_use]
83 pub const fn overflowing_neg(self) -> (Self, bool) {
84 Self::ZERO.overflowing_sub(self)
85 }
86
87 #[inline]
93 #[must_use]
94 pub const fn overflowing_sub(mut self, rhs: Self) -> (Self, bool) {
95 if BITS == 0 {
96 return (Self::ZERO, false);
97 }
98 let mut borrow = false;
99 let mut i = 0;
100 while i < LIMBS {
101 (self.limbs[i], borrow) = borrowing_sub(self.limbs[i], rhs.limbs[i], borrow);
102 i += 1;
103 }
104 let overflow = borrow | (self.limbs[LIMBS - 1] > Self::MASK);
105 (self.masked(), overflow)
106 }
107
108 #[inline(always)]
111 #[must_use]
112 pub const fn saturating_add(self, rhs: Self) -> Self {
113 match self.overflowing_add(rhs) {
114 (value, false) => value,
115 _ => Self::MAX,
116 }
117 }
118
119 #[inline(always)]
122 #[must_use]
123 pub const fn saturating_sub(self, rhs: Self) -> Self {
124 match self.overflowing_sub(rhs) {
125 (value, false) => value,
126 _ => Self::ZERO,
127 }
128 }
129
130 #[cfg(not(target_os = "zkvm"))]
132 #[inline(always)]
133 #[must_use]
134 pub const fn wrapping_add(self, rhs: Self) -> Self {
135 self.overflowing_add(rhs).0
136 }
137
138 #[cfg(target_os = "zkvm")]
140 #[inline(always)]
141 #[must_use]
142 pub fn wrapping_add(mut self, rhs: Self) -> Self {
143 use crate::support::zkvm::zkvm_u256_wrapping_add_impl;
144 if BITS == 256 {
145 unsafe {
146 zkvm_u256_wrapping_add_impl(
147 self.limbs.as_mut_ptr() as *mut u8,
148 self.limbs.as_ptr() as *const u8,
149 rhs.limbs.as_ptr() as *const u8,
150 );
151 }
152 return self;
153 }
154 self.overflowing_add(rhs).0
155 }
156
157 #[cfg(not(target_os = "zkvm"))]
159 #[inline(always)]
160 #[must_use]
161 pub const fn wrapping_neg(self) -> Self {
162 self.overflowing_neg().0
163 }
164
165 #[cfg(target_os = "zkvm")]
167 #[inline(always)]
168 #[must_use]
169 pub fn wrapping_neg(self) -> Self {
170 Self::ZERO.wrapping_sub(self)
171 }
172
173 #[cfg(not(target_os = "zkvm"))]
175 #[inline(always)]
176 #[must_use]
177 pub const fn wrapping_sub(self, rhs: Self) -> Self {
178 self.overflowing_sub(rhs).0
179 }
180
181 #[cfg(target_os = "zkvm")]
183 #[inline(always)]
184 #[must_use]
185 pub fn wrapping_sub(mut self, rhs: Self) -> Self {
186 use crate::support::zkvm::zkvm_u256_wrapping_sub_impl;
187 if BITS == 256 {
188 unsafe {
189 zkvm_u256_wrapping_sub_impl(
190 self.limbs.as_mut_ptr() as *mut u8,
191 self.limbs.as_ptr() as *const u8,
192 rhs.limbs.as_ptr() as *const u8,
193 );
194 }
195 return self;
196 }
197 self.overflowing_sub(rhs).0
198 }
199}
200
201impl<const BITS: usize, const LIMBS: usize> Neg for Uint<BITS, LIMBS> {
202 type Output = Self;
203
204 #[inline(always)]
205 fn neg(self) -> Self::Output {
206 self.wrapping_neg()
207 }
208}
209
210impl<const BITS: usize, const LIMBS: usize> Neg for &Uint<BITS, LIMBS> {
211 type Output = Uint<BITS, LIMBS>;
212
213 #[inline(always)]
214 fn neg(self) -> Self::Output {
215 self.wrapping_neg()
216 }
217}
218
219impl<const BITS: usize, const LIMBS: usize> Sum<Self> for Uint<BITS, LIMBS> {
220 #[inline]
221 fn sum<I>(iter: I) -> Self
222 where
223 I: Iterator<Item = Self>,
224 {
225 iter.fold(Self::ZERO, Self::wrapping_add)
226 }
227}
228
229impl<'a, const BITS: usize, const LIMBS: usize> Sum<&'a Self> for Uint<BITS, LIMBS> {
230 #[inline]
231 fn sum<I>(iter: I) -> Self
232 where
233 I: Iterator<Item = &'a Self>,
234 {
235 iter.copied().fold(Self::ZERO, Self::wrapping_add)
236 }
237}
238
239impl_bin_op!(Add, add, AddAssign, add_assign, wrapping_add);
240impl_bin_op!(Sub, sub, SubAssign, sub_assign, wrapping_sub);
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::{const_for, nlimbs};
246 use proptest::proptest;
247
248 #[test]
249 fn test_neg_one() {
250 const_for!(BITS in NON_ZERO {
251 const LIMBS: usize = nlimbs(BITS);
252 type U = Uint<BITS, LIMBS>;
253 assert_eq!(-U::ONE, !U::ZERO);
254 });
255 }
256
257 #[test]
258 fn test_commutative() {
259 const_for!(BITS in SIZES {
260 const LIMBS: usize = nlimbs(BITS);
261 type U = Uint<BITS, LIMBS>;
262 proptest!(|(a: U, b: U)| {
263 assert_eq!(a + b, b + a);
264 assert_eq!(a - b, -(b - a));
265 });
266 });
267 }
268
269 #[test]
270 fn test_associative() {
271 const_for!(BITS in SIZES {
272 const LIMBS: usize = nlimbs(BITS);
273 type U = Uint<BITS, LIMBS>;
274 proptest!(|(a: U, b: U, c: U)| {
275 assert_eq!(a + (b + c), (a + b) + c);
276 });
277 });
278 }
279
280 #[test]
281 fn test_identity() {
282 const_for!(BITS in SIZES {
283 const LIMBS: usize = nlimbs(BITS);
284 type U = Uint<BITS, LIMBS>;
285 proptest!(|(value: U)| {
286 assert_eq!(value + U::ZERO, value);
287 assert_eq!(value - U::ZERO, value);
288 });
289 });
290 }
291
292 #[test]
293 fn test_inverse() {
294 const_for!(BITS in SIZES {
295 const LIMBS: usize = nlimbs(BITS);
296 type U = Uint<BITS, LIMBS>;
297 proptest!(|(a: U)| {
298 assert_eq!(a + (-a), U::ZERO);
299 assert_eq!(a - a, U::ZERO);
300 assert_eq!(-(-a), a);
301 });
302 });
303 }
304}