openvm_bigint_guest/
u256.rs

1use core::{
2    cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd},
3    ops::{
4        Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Mul,
5        MulAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
6    },
7};
8
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11#[cfg(target_os = "zkvm")]
12use {
13    super::{Int256Funct7, BEQ256_FUNCT3, INT256_FUNCT3, OPCODE},
14    core::{arch::asm, mem::MaybeUninit},
15    openvm_platform::custom_insn_r,
16};
17#[cfg(not(target_os = "zkvm"))]
18use {num_bigint::BigUint, num_traits::One, openvm::utils::biguint_to_limbs};
19
20use crate::impl_bin_op;
21
22/// A 256-bit unsigned integer type.
23#[derive(Debug, Serialize, Deserialize)]
24#[repr(align(32), C)]
25pub struct U256 {
26    #[serde(with = "BigArray")]
27    limbs: [u8; 32],
28}
29
30impl U256 {
31    /// The maximum value of a U256.
32    pub const MAX: Self = Self {
33        limbs: [u8::MAX; 32],
34    };
35
36    /// The minimum value of a U256.
37    pub const MIN: Self = Self { limbs: [0u8; 32] };
38
39    /// The zero constant.
40    pub const ZERO: Self = Self { limbs: [0u8; 32] };
41
42    /// Construct [U256] from little-endian bytes.
43    pub const fn from_le_bytes(bytes: [u8; 32]) -> Self {
44        Self { limbs: bytes }
45    }
46
47    /// Value of this U256 as a BigUint.
48    #[cfg(not(target_os = "zkvm"))]
49    pub fn as_biguint(&self) -> BigUint {
50        BigUint::from_bytes_le(&self.limbs)
51    }
52
53    /// Creates a new U256 from a BigUint.
54    #[cfg(not(target_os = "zkvm"))]
55    pub fn from_biguint(value: &BigUint) -> Self {
56        Self {
57            limbs: biguint_to_limbs(value),
58        }
59    }
60
61    /// Creates a new U256 that equals to the given u8 value.
62    pub fn from_u8(value: u8) -> Self {
63        let mut limbs = [0u8; 32];
64        limbs[0] = value;
65        Self { limbs }
66    }
67
68    /// Creates a new U256 that equals to the given u32 value.
69    pub fn from_u32(value: u32) -> Self {
70        let mut limbs = [0u8; 32];
71        limbs[..4].copy_from_slice(&value.to_le_bytes());
72        Self { limbs }
73    }
74
75    /// Creates a new U256 that equals to the given u64 value.
76    pub fn from_u64(value: u64) -> Self {
77        let mut limbs = [0u8; 32];
78        limbs[..8].copy_from_slice(&value.to_le_bytes());
79        Self { limbs }
80    }
81
82    /// The little-endian byte representation of this U256.
83    pub fn as_le_bytes(&self) -> &[u8; 32] {
84        &self.limbs
85    }
86}
87
88impl_bin_op!(
89    U256,
90    Add,
91    AddAssign,
92    add,
93    add_assign,
94    OPCODE,
95    INT256_FUNCT3,
96    Int256Funct7::Add as u8,
97    +=,
98    |lhs: &U256, rhs: &U256| -> U256 {U256::from_biguint(&(lhs.as_biguint() + rhs.as_biguint()))}
99);
100
101impl_bin_op!(
102    U256,
103    Sub,
104    SubAssign,
105    sub,
106    sub_assign,
107    OPCODE,
108    INT256_FUNCT3,
109    Int256Funct7::Sub as u8,
110    -=,
111    |lhs: &U256, rhs: &U256| -> U256 {U256::from_biguint(&(U256::MAX.as_biguint() + BigUint::one() + lhs.as_biguint() - rhs.as_biguint()))}
112);
113
114impl_bin_op!(
115    U256,
116    Mul,
117    MulAssign,
118    mul,
119    mul_assign,
120    OPCODE,
121    INT256_FUNCT3,
122    Int256Funct7::Mul as u8,
123    *=,
124    |lhs: &U256, rhs: &U256| -> U256 {U256::from_biguint(&(lhs.as_biguint() * rhs.as_biguint()))}
125);
126
127impl_bin_op!(
128    U256,
129    BitXor,
130    BitXorAssign,
131    bitxor,
132    bitxor_assign,
133    OPCODE,
134    INT256_FUNCT3,
135    Int256Funct7::Xor as u8,
136    ^=,
137    |lhs: &U256, rhs: &U256| -> U256 {U256::from_biguint(&(lhs.as_biguint() ^ rhs.as_biguint()))}
138);
139
140impl_bin_op!(
141    U256,
142    BitAnd,
143    BitAndAssign,
144    bitand,
145    bitand_assign,
146    OPCODE,
147    INT256_FUNCT3,
148    Int256Funct7::And as u8,
149    &=,
150    |lhs: &U256, rhs: &U256| -> U256 {U256::from_biguint(&(lhs.as_biguint() & rhs.as_biguint()))}
151);
152
153impl_bin_op!(
154    U256,
155    BitOr,
156    BitOrAssign,
157    bitor,
158    bitor_assign,
159    OPCODE,
160    INT256_FUNCT3,
161    Int256Funct7::Or as u8,
162    |=,
163    |lhs: &U256, rhs: &U256| -> U256 {U256::from_biguint(&(lhs.as_biguint() | rhs.as_biguint()))}
164);
165
166impl_bin_op!(
167    U256,
168    Shl,
169    ShlAssign,
170    shl,
171    shl_assign,
172    OPCODE,
173    INT256_FUNCT3,
174    Int256Funct7::Sll as u8,
175    <<=,
176    |lhs: &U256, rhs: &U256| -> U256 {U256::from_biguint(&(lhs.as_biguint() << rhs.limbs[0] as usize))}
177);
178
179impl_bin_op!(
180    U256,
181    Shr,
182    ShrAssign,
183    shr,
184    shr_assign,
185    OPCODE,
186    INT256_FUNCT3,
187    Int256Funct7::Srl as u8,
188    >>=,
189    |lhs: &U256, rhs: &U256| -> U256 {U256::from_biguint(&(lhs.as_biguint() >> rhs.limbs[0] as usize))}
190);
191
192impl PartialEq for U256 {
193    fn eq(&self, other: &Self) -> bool {
194        #[cfg(target_os = "zkvm")]
195        {
196            let mut is_equal: u32;
197            unsafe {
198                asm!("li {res}, 1",
199                    ".insn b {opcode}, {func3}, {rs1}, {rs2}, 8",
200                    "li {res}, 0",
201                    opcode = const OPCODE,
202                    func3 = const BEQ256_FUNCT3,
203                    rs1 = in(reg) self as *const Self,
204                    rs2 = in(reg) other as *const Self,
205                    res = out(reg) is_equal
206                );
207            }
208            return is_equal == 1;
209        }
210        #[cfg(not(target_os = "zkvm"))]
211        return self.as_biguint() == other.as_biguint();
212    }
213}
214
215impl Eq for U256 {}
216
217impl PartialOrd for U256 {
218    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
219        Some(self.cmp(other))
220    }
221}
222
223impl Ord for U256 {
224    fn cmp(&self, other: &Self) -> Ordering {
225        #[cfg(target_os = "zkvm")]
226        {
227            let mut cmp_result = MaybeUninit::<U256>::uninit();
228            custom_insn_r!(
229                opcode = OPCODE,
230                funct3 = INT256_FUNCT3,
231                funct7 = Int256Funct7::Sltu as u8,
232                rd = In cmp_result.as_mut_ptr(),
233                rs1 = In self as *const Self,
234                rs2 = In other as *const Self
235            );
236            let mut cmp_result = unsafe { cmp_result.assume_init() };
237            if cmp_result.limbs[0] != 0 {
238                return Ordering::Less;
239            }
240            custom_insn_r!(
241                opcode = OPCODE,
242                funct3 = INT256_FUNCT3,
243                funct7 = Int256Funct7::Sltu as u8,
244                rd = In &mut cmp_result as *mut U256,
245                rs1 = In other as *const Self,
246                rs2 = In self as *const Self
247            );
248            if cmp_result.limbs[0] != 0 {
249                return Ordering::Greater;
250            }
251            return Ordering::Equal;
252        }
253        #[cfg(not(target_os = "zkvm"))]
254        return self.as_biguint().cmp(&other.as_biguint());
255    }
256}
257
258impl Clone for U256 {
259    fn clone(&self) -> Self {
260        #[cfg(target_os = "zkvm")]
261        {
262            let mut uninit: MaybeUninit<Self> = MaybeUninit::uninit();
263            custom_insn_r!(
264                opcode = OPCODE,
265                funct3 = INT256_FUNCT3,
266                funct7 = Int256Funct7::Add as u8,
267                rd = In uninit.as_mut_ptr(),
268                rs1 = In self as *const Self,
269                rs2 = In &Self::ZERO as *const Self
270            );
271            unsafe { uninit.assume_init() }
272        }
273        #[cfg(not(target_os = "zkvm"))]
274        return Self { limbs: self.limbs };
275    }
276}