1use crate::{
4 bigint::{prelude::*, Limb, NonZero},
5 scalar::FromUintUnchecked,
6 scalar::IsHigh,
7 Curve, Error, FieldBytes, FieldBytesEncoding, Result,
8};
9use base16ct::HexDisplay;
10use core::{
11 cmp::Ordering,
12 fmt,
13 ops::{Add, AddAssign, Neg, ShrAssign, Sub, SubAssign},
14 str,
15};
16use generic_array::{typenum::Unsigned, GenericArray};
17use rand_core::CryptoRngCore;
18use subtle::{
19 Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
20 CtOption,
21};
22use zeroize::DefaultIsZeroes;
23
24#[cfg(feature = "arithmetic")]
25use super::{CurveArithmetic, Scalar};
26
27#[cfg(feature = "serde")]
28use serdect::serde::{de, ser, Deserialize, Serialize};
29
30#[derive(Copy, Clone, Debug, Default)]
45pub struct ScalarPrimitive<C: Curve> {
46 inner: C::Uint,
48}
49
50impl<C> ScalarPrimitive<C>
51where
52 C: Curve,
53{
54 pub const ZERO: Self = Self {
56 inner: C::Uint::ZERO,
57 };
58
59 pub const ONE: Self = Self {
61 inner: C::Uint::ONE,
62 };
63
64 pub const MODULUS: C::Uint = C::ORDER;
66
67 pub fn random(rng: &mut impl CryptoRngCore) -> Self {
69 Self {
70 inner: C::Uint::random_mod(rng, &NonZero::new(Self::MODULUS).unwrap()),
71 }
72 }
73
74 pub fn new(uint: C::Uint) -> CtOption<Self> {
76 CtOption::new(Self { inner: uint }, uint.ct_lt(&Self::MODULUS))
77 }
78
79 pub fn from_bytes(bytes: &FieldBytes<C>) -> CtOption<Self> {
81 Self::new(C::Uint::decode_field_bytes(bytes))
82 }
83
84 pub fn from_slice(slice: &[u8]) -> Result<Self> {
86 if slice.len() == C::FieldBytesSize::USIZE {
87 Option::from(Self::from_bytes(GenericArray::from_slice(slice))).ok_or(Error)
88 } else {
89 Err(Error)
90 }
91 }
92
93 pub fn as_uint(&self) -> &C::Uint {
95 &self.inner
96 }
97
98 pub fn as_limbs(&self) -> &[Limb] {
100 self.inner.as_ref()
101 }
102
103 pub fn is_zero(&self) -> Choice {
105 self.inner.is_zero()
106 }
107
108 pub fn is_even(&self) -> Choice {
110 self.inner.is_even()
111 }
112
113 pub fn is_odd(&self) -> Choice {
115 self.inner.is_odd()
116 }
117
118 pub fn to_bytes(&self) -> FieldBytes<C> {
120 self.inner.encode_field_bytes()
121 }
122
123 pub fn to_uint(&self) -> C::Uint {
125 self.inner
126 }
127}
128
129impl<C> FromUintUnchecked for ScalarPrimitive<C>
130where
131 C: Curve,
132{
133 type Uint = C::Uint;
134
135 fn from_uint_unchecked(uint: C::Uint) -> Self {
136 Self { inner: uint }
137 }
138}
139
140#[cfg(feature = "arithmetic")]
141impl<C> ScalarPrimitive<C>
142where
143 C: CurveArithmetic,
144{
145 pub(super) fn to_scalar(self) -> Scalar<C> {
147 Scalar::<C>::from_uint_unchecked(self.inner)
148 }
149}
150
151impl<C> AsRef<[Limb]> for ScalarPrimitive<C>
153where
154 C: Curve,
155{
156 fn as_ref(&self) -> &[Limb] {
157 self.as_limbs()
158 }
159}
160
161impl<C> ConditionallySelectable for ScalarPrimitive<C>
162where
163 C: Curve,
164{
165 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
166 Self {
167 inner: C::Uint::conditional_select(&a.inner, &b.inner, choice),
168 }
169 }
170}
171
172impl<C> ConstantTimeEq for ScalarPrimitive<C>
173where
174 C: Curve,
175{
176 fn ct_eq(&self, other: &Self) -> Choice {
177 self.inner.ct_eq(&other.inner)
178 }
179}
180
181impl<C> ConstantTimeLess for ScalarPrimitive<C>
182where
183 C: Curve,
184{
185 fn ct_lt(&self, other: &Self) -> Choice {
186 self.inner.ct_lt(&other.inner)
187 }
188}
189
190impl<C> ConstantTimeGreater for ScalarPrimitive<C>
191where
192 C: Curve,
193{
194 fn ct_gt(&self, other: &Self) -> Choice {
195 self.inner.ct_gt(&other.inner)
196 }
197}
198
199impl<C: Curve> DefaultIsZeroes for ScalarPrimitive<C> {}
200
201impl<C: Curve> Eq for ScalarPrimitive<C> {}
202
203impl<C> PartialEq for ScalarPrimitive<C>
204where
205 C: Curve,
206{
207 fn eq(&self, other: &Self) -> bool {
208 self.ct_eq(other).into()
209 }
210}
211
212impl<C> PartialOrd for ScalarPrimitive<C>
213where
214 C: Curve,
215{
216 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
217 Some(self.cmp(other))
218 }
219}
220
221impl<C> Ord for ScalarPrimitive<C>
222where
223 C: Curve,
224{
225 fn cmp(&self, other: &Self) -> Ordering {
226 self.inner.cmp(&other.inner)
227 }
228}
229
230impl<C> From<u64> for ScalarPrimitive<C>
231where
232 C: Curve,
233{
234 fn from(n: u64) -> Self {
235 Self {
236 inner: C::Uint::from(n),
237 }
238 }
239}
240
241impl<C> Add<ScalarPrimitive<C>> for ScalarPrimitive<C>
242where
243 C: Curve,
244{
245 type Output = Self;
246
247 fn add(self, other: Self) -> Self {
248 self.add(&other)
249 }
250}
251
252impl<C> Add<&ScalarPrimitive<C>> for ScalarPrimitive<C>
253where
254 C: Curve,
255{
256 type Output = Self;
257
258 fn add(self, other: &Self) -> Self {
259 Self {
260 inner: self.inner.add_mod(&other.inner, &Self::MODULUS),
261 }
262 }
263}
264
265impl<C> AddAssign<ScalarPrimitive<C>> for ScalarPrimitive<C>
266where
267 C: Curve,
268{
269 fn add_assign(&mut self, other: Self) {
270 *self = *self + other;
271 }
272}
273
274impl<C> AddAssign<&ScalarPrimitive<C>> for ScalarPrimitive<C>
275where
276 C: Curve,
277{
278 fn add_assign(&mut self, other: &Self) {
279 *self = *self + other;
280 }
281}
282
283impl<C> Sub<ScalarPrimitive<C>> for ScalarPrimitive<C>
284where
285 C: Curve,
286{
287 type Output = Self;
288
289 fn sub(self, other: Self) -> Self {
290 self.sub(&other)
291 }
292}
293
294impl<C> Sub<&ScalarPrimitive<C>> for ScalarPrimitive<C>
295where
296 C: Curve,
297{
298 type Output = Self;
299
300 fn sub(self, other: &Self) -> Self {
301 Self {
302 inner: self.inner.sub_mod(&other.inner, &Self::MODULUS),
303 }
304 }
305}
306
307impl<C> SubAssign<ScalarPrimitive<C>> for ScalarPrimitive<C>
308where
309 C: Curve,
310{
311 fn sub_assign(&mut self, other: Self) {
312 *self = *self - other;
313 }
314}
315
316impl<C> SubAssign<&ScalarPrimitive<C>> for ScalarPrimitive<C>
317where
318 C: Curve,
319{
320 fn sub_assign(&mut self, other: &Self) {
321 *self = *self - other;
322 }
323}
324
325impl<C> Neg for ScalarPrimitive<C>
326where
327 C: Curve,
328{
329 type Output = Self;
330
331 fn neg(self) -> Self {
332 Self {
333 inner: self.inner.neg_mod(&Self::MODULUS),
334 }
335 }
336}
337
338impl<C> Neg for &ScalarPrimitive<C>
339where
340 C: Curve,
341{
342 type Output = ScalarPrimitive<C>;
343
344 fn neg(self) -> ScalarPrimitive<C> {
345 -*self
346 }
347}
348
349impl<C> ShrAssign<usize> for ScalarPrimitive<C>
350where
351 C: Curve,
352{
353 fn shr_assign(&mut self, rhs: usize) {
354 self.inner >>= rhs;
355 }
356}
357
358impl<C> IsHigh for ScalarPrimitive<C>
359where
360 C: Curve,
361{
362 fn is_high(&self) -> Choice {
363 let n_2 = C::ORDER >> 1;
364 self.inner.ct_gt(&n_2)
365 }
366}
367
368impl<C> fmt::Display for ScalarPrimitive<C>
369where
370 C: Curve,
371{
372 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373 write!(f, "{self:X}")
374 }
375}
376
377impl<C> fmt::LowerHex for ScalarPrimitive<C>
378where
379 C: Curve,
380{
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 write!(f, "{:x}", HexDisplay(&self.to_bytes()))
383 }
384}
385
386impl<C> fmt::UpperHex for ScalarPrimitive<C>
387where
388 C: Curve,
389{
390 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391 write!(f, "{:X}", HexDisplay(&self.to_bytes()))
392 }
393}
394
395impl<C> str::FromStr for ScalarPrimitive<C>
396where
397 C: Curve,
398{
399 type Err = Error;
400
401 fn from_str(hex: &str) -> Result<Self> {
402 let mut bytes = FieldBytes::<C>::default();
403 base16ct::lower::decode(hex, &mut bytes)?;
404 Self::from_slice(&bytes)
405 }
406}
407
408#[cfg(feature = "serde")]
409impl<C> Serialize for ScalarPrimitive<C>
410where
411 C: Curve,
412{
413 fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
414 where
415 S: ser::Serializer,
416 {
417 serdect::array::serialize_hex_upper_or_bin(&self.to_bytes(), serializer)
418 }
419}
420
421#[cfg(feature = "serde")]
422impl<'de, C> Deserialize<'de> for ScalarPrimitive<C>
423where
424 C: Curve,
425{
426 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
427 where
428 D: de::Deserializer<'de>,
429 {
430 let mut bytes = FieldBytes::<C>::default();
431 serdect::array::deserialize_hex_or_bin(&mut bytes, deserializer)?;
432 Self::from_slice(&bytes).map_err(|_| de::Error::custom("scalar out of range"))
433 }
434}