1use core::{
2 cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd},
3 ops::{
4 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Mul,
5 MulAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
6 },
7};
8
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11#[cfg(not(target_os = "zkvm"))]
12use {super::bigint_to_limbs, num_bigint::BigInt};
13#[cfg(target_os = "zkvm")]
14use {
15 super::{Int256Funct7, BEQ256_FUNCT3, INT256_FUNCT3, OPCODE},
16 core::{arch::asm, mem::MaybeUninit},
17 openvm_platform::custom_insn_r,
18};
19
20use crate::impl_bin_op;
21
22#[derive(Debug, Serialize, Deserialize)]
24#[repr(align(32), C)]
25pub struct I256 {
26 #[serde(with = "BigArray")]
27 limbs: [u8; 32],
28}
29
30impl I256 {
31 pub const MIN: Self = Self::generate_min();
33
34 pub const MAX: Self = Self::generate_max();
36
37 pub const ZERO: Self = Self { limbs: [0u8; 32] };
39
40 pub const fn from_le_bytes(bytes: [u8; 32]) -> Self {
42 Self { limbs: bytes }
43 }
44
45 #[cfg(not(target_os = "zkvm"))]
47 pub fn as_bigint(&self) -> BigInt {
48 BigInt::from_signed_bytes_le(&self.limbs)
49 }
50
51 #[cfg(not(target_os = "zkvm"))]
53 pub fn from_bigint(value: &BigInt) -> Self {
54 Self {
55 limbs: bigint_to_limbs(value),
56 }
57 }
58
59 pub fn from_i8(value: i8) -> Self {
61 let mut limbs = if value < 0 { [u8::MAX; 32] } else { [0u8; 32] };
62 limbs[0] = value as u8;
63 Self { limbs }
64 }
65
66 pub fn from_i32(value: i32) -> Self {
68 let mut limbs = if value < 0 { [u8::MAX; 32] } else { [0u8; 32] };
69 let value = value as u32;
70 limbs[..4].copy_from_slice(&value.to_le_bytes());
71 Self { limbs }
72 }
73
74 pub fn from_i64(value: i64) -> Self {
76 let mut limbs = if value < 0 { [u8::MAX; 32] } else { [0u8; 32] };
77 let value = value as u64;
78 limbs[..8].copy_from_slice(&value.to_le_bytes());
79 Self { limbs }
80 }
81
82 const fn generate_min() -> Self {
84 let mut limbs = [0u8; 32];
85 limbs[31] = i8::MIN as u8;
86 Self { limbs }
87 }
88
89 const fn generate_max() -> Self {
91 let mut limbs = [u8::MAX; 32];
92 limbs[31] = i8::MAX as u8;
93 Self { limbs }
94 }
95}
96
97impl_bin_op!(
98 I256,
99 Add,
100 AddAssign,
101 add,
102 add_assign,
103 OPCODE,
104 INT256_FUNCT3,
105 Int256Funct7::Add as u8,
106 +=,
107 |lhs: &I256, rhs: &I256| -> I256 {I256::from_bigint(&(lhs.as_bigint() + rhs.as_bigint()))}
108);
109
110impl_bin_op!(
111 I256,
112 Sub,
113 SubAssign,
114 sub,
115 sub_assign,
116 OPCODE,
117 INT256_FUNCT3,
118 Int256Funct7::Sub as u8,
119 -=,
120 |lhs: &I256, rhs: &I256| -> I256 {I256::from_bigint(&(lhs.as_bigint() - rhs.as_bigint()))}
121);
122
123impl_bin_op!(
124 I256,
125 Mul,
126 MulAssign,
127 mul,
128 mul_assign,
129 OPCODE,
130 INT256_FUNCT3,
131 Int256Funct7::Mul as u8,
132 *=,
133 |lhs: &I256, rhs: &I256| -> I256 {I256::from_bigint(&(lhs.as_bigint() * rhs.as_bigint()))}
134);
135
136impl_bin_op!(
137 I256,
138 BitXor,
139 BitXorAssign,
140 bitxor,
141 bitxor_assign,
142 OPCODE,
143 INT256_FUNCT3,
144 Int256Funct7::Xor as u8,
145 ^=,
146 |lhs: &I256, rhs: &I256| -> I256 {I256::from_bigint(&(lhs.as_bigint() ^ rhs.as_bigint()))}
147);
148
149impl_bin_op!(
150 I256,
151 BitAnd,
152 BitAndAssign,
153 bitand,
154 bitand_assign,
155 OPCODE,
156 INT256_FUNCT3,
157 Int256Funct7::And as u8,
158 &=,
159 |lhs: &I256, rhs: &I256| -> I256 {I256::from_bigint(&(lhs.as_bigint() & rhs.as_bigint()))}
160);
161
162impl_bin_op!(
163 I256,
164 BitOr,
165 BitOrAssign,
166 bitor,
167 bitor_assign,
168 OPCODE,
169 INT256_FUNCT3,
170 Int256Funct7::Or as u8,
171 |=,
172 |lhs: &I256, rhs: &I256| -> I256 {I256::from_bigint(&(lhs.as_bigint() | rhs.as_bigint()))}
173);
174
175impl_bin_op!(
176 I256,
177 Shl,
178 ShlAssign,
179 shl,
180 shl_assign,
181 OPCODE,
182 INT256_FUNCT3,
183 Int256Funct7::Sll as u8,
184 <<=,
185 |lhs: &I256, rhs: &I256| -> I256 {I256::from_bigint(&(lhs.as_bigint() << rhs.limbs[0] as usize))}
186);
187
188impl_bin_op!(
189 I256,
190 Shr,
191 ShrAssign,
192 shr,
193 shr_assign,
194 OPCODE,
195 INT256_FUNCT3,
196 Int256Funct7::Sra as u8,
197 >>=,
198 |lhs: &I256, rhs: &I256| -> I256 {I256::from_bigint(&(lhs.as_bigint() >> rhs.limbs[0] as usize))}
199);
200
201impl PartialEq for I256 {
202 fn eq(&self, other: &Self) -> bool {
203 #[cfg(target_os = "zkvm")]
204 {
205 let mut is_equal: u32;
206 unsafe {
207 asm!("li {res}, 1",
208 ".insn b {opcode}, {func3}, {rs1}, {rs2}, 8",
209 "li {res}, 0",
210 opcode = const OPCODE,
211 func3 = const BEQ256_FUNCT3,
212 rs1 = in(reg) self as *const Self,
213 rs2 = in(reg) other as *const Self,
214 res = out(reg) is_equal
215 );
216 }
217 return is_equal == 1;
218 }
219 #[cfg(not(target_os = "zkvm"))]
220 return self.as_bigint() == other.as_bigint();
221 }
222}
223
224impl Eq for I256 {}
225
226impl PartialOrd for I256 {
227 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
228 Some(self.cmp(other))
229 }
230}
231
232impl Ord for I256 {
233 fn cmp(&self, other: &Self) -> Ordering {
234 #[cfg(target_os = "zkvm")]
235 {
236 let mut cmp_result = MaybeUninit::<I256>::uninit();
237 custom_insn_r!(
238 opcode = OPCODE,
239 funct3 = INT256_FUNCT3,
240 funct7 = Int256Funct7::Slt as u8,
241 rd = In cmp_result.as_mut_ptr(),
242 rs1 = In self as *const Self,
243 rs2 = In other as *const Self
244 );
245 let mut cmp_result = unsafe { cmp_result.assume_init() };
246 if cmp_result.limbs[0] != 0 {
247 return Ordering::Less;
248 }
249 custom_insn_r!(
250 opcode = OPCODE,
251 funct3 = INT256_FUNCT3,
252 funct7 = Int256Funct7::Slt as u8,
253 rd = In &mut cmp_result as *mut I256,
254 rs1 = In other as *const Self,
255 rs2 = In self as *const Self
256 );
257 if cmp_result.limbs[0] != 0 {
258 return Ordering::Greater;
259 }
260 return Ordering::Equal;
261 }
262 #[cfg(not(target_os = "zkvm"))]
263 return self.as_bigint().cmp(&other.as_bigint());
264 }
265}
266
267impl Clone for I256 {
268 fn clone(&self) -> Self {
269 #[cfg(target_os = "zkvm")]
270 {
271 let mut uninit: MaybeUninit<Self> = MaybeUninit::uninit();
272 custom_insn_r!(
273 opcode = OPCODE,
274 funct3 = INT256_FUNCT3,
275 funct7 = Int256Funct7::Add as u8,
276 rd = In uninit.as_mut_ptr(),
277 rs1 = In self as *const Self,
278 rs2 = In &Self::ZERO as *const Self
279 );
280 unsafe { uninit.assume_init() }
281 }
282 #[cfg(not(target_os = "zkvm"))]
283 return Self { limbs: self.limbs };
284 }
285}