elliptic_curve/scalar/
primitive.rs

1//! Generic scalar type with primitive functionality.
2
3use crate::{
4    bigint::{prelude::*, Limb, NonZero},
5    scalar::FromUintUnchecked,
6    scalar::IsHigh,
7    Curve, Error, FieldBytes, FieldBytesEncoding, Result,
8};
9use base16ct::HexDisplay;
10use core::{
11    cmp::Ordering,
12    fmt,
13    ops::{Add, AddAssign, Neg, ShrAssign, Sub, SubAssign},
14    str,
15};
16use generic_array::{typenum::Unsigned, GenericArray};
17use rand_core::CryptoRngCore;
18use subtle::{
19    Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
20    CtOption,
21};
22use zeroize::DefaultIsZeroes;
23
24#[cfg(feature = "arithmetic")]
25use super::{CurveArithmetic, Scalar};
26
27#[cfg(feature = "serde")]
28use serdect::serde::{de, ser, Deserialize, Serialize};
29
30/// Generic scalar type with primitive functionality.
31///
32/// This type provides a baseline level of scalar arithmetic functionality
33/// which is always available for all curves, regardless of if they implement
34/// any arithmetic traits.
35///
36/// # `serde` support
37///
38/// When the optional `serde` feature of this create is enabled, [`Serialize`]
39/// and [`Deserialize`] impls are provided for this type.
40///
41/// The serialization is a fixed-width big endian encoding. When used with
42/// textual formats, the binary data is encoded as hexadecimal.
43// TODO(tarcieri): use `crypto-bigint`'s `Residue` type, expose more functionality?
44#[derive(Copy, Clone, Debug, Default)]
45pub struct ScalarPrimitive<C: Curve> {
46    /// Inner unsigned integer type.
47    inner: C::Uint,
48}
49
50impl<C> ScalarPrimitive<C>
51where
52    C: Curve,
53{
54    /// Zero scalar.
55    pub const ZERO: Self = Self {
56        inner: C::Uint::ZERO,
57    };
58
59    /// Multiplicative identity.
60    pub const ONE: Self = Self {
61        inner: C::Uint::ONE,
62    };
63
64    /// Scalar modulus.
65    pub const MODULUS: C::Uint = C::ORDER;
66
67    /// Generate a random [`ScalarPrimitive`].
68    pub fn random(rng: &mut impl CryptoRngCore) -> Self {
69        Self {
70            inner: C::Uint::random_mod(rng, &NonZero::new(Self::MODULUS).unwrap()),
71        }
72    }
73
74    /// Create a new scalar from [`Curve::Uint`].
75    pub fn new(uint: C::Uint) -> CtOption<Self> {
76        CtOption::new(Self { inner: uint }, uint.ct_lt(&Self::MODULUS))
77    }
78
79    /// Decode [`ScalarPrimitive`] from a serialized field element
80    pub fn from_bytes(bytes: &FieldBytes<C>) -> CtOption<Self> {
81        Self::new(C::Uint::decode_field_bytes(bytes))
82    }
83
84    /// Decode [`ScalarPrimitive`] from a big endian byte slice.
85    pub fn from_slice(slice: &[u8]) -> Result<Self> {
86        if slice.len() == C::FieldBytesSize::USIZE {
87            Option::from(Self::from_bytes(GenericArray::from_slice(slice))).ok_or(Error)
88        } else {
89            Err(Error)
90        }
91    }
92
93    /// Borrow the inner `C::Uint`.
94    pub fn as_uint(&self) -> &C::Uint {
95        &self.inner
96    }
97
98    /// Borrow the inner limbs as a slice.
99    pub fn as_limbs(&self) -> &[Limb] {
100        self.inner.as_ref()
101    }
102
103    /// Is this [`ScalarPrimitive`] value equal to zero?
104    pub fn is_zero(&self) -> Choice {
105        self.inner.is_zero()
106    }
107
108    /// Is this [`ScalarPrimitive`] value even?
109    pub fn is_even(&self) -> Choice {
110        self.inner.is_even()
111    }
112
113    /// Is this [`ScalarPrimitive`] value odd?
114    pub fn is_odd(&self) -> Choice {
115        self.inner.is_odd()
116    }
117
118    /// Encode [`ScalarPrimitive`] as a serialized field element.
119    pub fn to_bytes(&self) -> FieldBytes<C> {
120        self.inner.encode_field_bytes()
121    }
122
123    /// Convert to a `C::Uint`.
124    pub fn to_uint(&self) -> C::Uint {
125        self.inner
126    }
127}
128
129impl<C> FromUintUnchecked for ScalarPrimitive<C>
130where
131    C: Curve,
132{
133    type Uint = C::Uint;
134
135    fn from_uint_unchecked(uint: C::Uint) -> Self {
136        Self { inner: uint }
137    }
138}
139
140#[cfg(feature = "arithmetic")]
141impl<C> ScalarPrimitive<C>
142where
143    C: CurveArithmetic,
144{
145    /// Convert [`ScalarPrimitive`] into a given curve's scalar type.
146    pub(super) fn to_scalar(self) -> Scalar<C> {
147        Scalar::<C>::from_uint_unchecked(self.inner)
148    }
149}
150
151// TODO(tarcieri): better encapsulate this?
152impl<C> AsRef<[Limb]> for ScalarPrimitive<C>
153where
154    C: Curve,
155{
156    fn as_ref(&self) -> &[Limb] {
157        self.as_limbs()
158    }
159}
160
161impl<C> ConditionallySelectable for ScalarPrimitive<C>
162where
163    C: Curve,
164{
165    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
166        Self {
167            inner: C::Uint::conditional_select(&a.inner, &b.inner, choice),
168        }
169    }
170}
171
172impl<C> ConstantTimeEq for ScalarPrimitive<C>
173where
174    C: Curve,
175{
176    fn ct_eq(&self, other: &Self) -> Choice {
177        self.inner.ct_eq(&other.inner)
178    }
179}
180
181impl<C> ConstantTimeLess for ScalarPrimitive<C>
182where
183    C: Curve,
184{
185    fn ct_lt(&self, other: &Self) -> Choice {
186        self.inner.ct_lt(&other.inner)
187    }
188}
189
190impl<C> ConstantTimeGreater for ScalarPrimitive<C>
191where
192    C: Curve,
193{
194    fn ct_gt(&self, other: &Self) -> Choice {
195        self.inner.ct_gt(&other.inner)
196    }
197}
198
199impl<C: Curve> DefaultIsZeroes for ScalarPrimitive<C> {}
200
201impl<C: Curve> Eq for ScalarPrimitive<C> {}
202
203impl<C> PartialEq for ScalarPrimitive<C>
204where
205    C: Curve,
206{
207    fn eq(&self, other: &Self) -> bool {
208        self.ct_eq(other).into()
209    }
210}
211
212impl<C> PartialOrd for ScalarPrimitive<C>
213where
214    C: Curve,
215{
216    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
217        Some(self.cmp(other))
218    }
219}
220
221impl<C> Ord for ScalarPrimitive<C>
222where
223    C: Curve,
224{
225    fn cmp(&self, other: &Self) -> Ordering {
226        self.inner.cmp(&other.inner)
227    }
228}
229
230impl<C> From<u64> for ScalarPrimitive<C>
231where
232    C: Curve,
233{
234    fn from(n: u64) -> Self {
235        Self {
236            inner: C::Uint::from(n),
237        }
238    }
239}
240
241impl<C> Add<ScalarPrimitive<C>> for ScalarPrimitive<C>
242where
243    C: Curve,
244{
245    type Output = Self;
246
247    fn add(self, other: Self) -> Self {
248        self.add(&other)
249    }
250}
251
252impl<C> Add<&ScalarPrimitive<C>> for ScalarPrimitive<C>
253where
254    C: Curve,
255{
256    type Output = Self;
257
258    fn add(self, other: &Self) -> Self {
259        Self {
260            inner: self.inner.add_mod(&other.inner, &Self::MODULUS),
261        }
262    }
263}
264
265impl<C> AddAssign<ScalarPrimitive<C>> for ScalarPrimitive<C>
266where
267    C: Curve,
268{
269    fn add_assign(&mut self, other: Self) {
270        *self = *self + other;
271    }
272}
273
274impl<C> AddAssign<&ScalarPrimitive<C>> for ScalarPrimitive<C>
275where
276    C: Curve,
277{
278    fn add_assign(&mut self, other: &Self) {
279        *self = *self + other;
280    }
281}
282
283impl<C> Sub<ScalarPrimitive<C>> for ScalarPrimitive<C>
284where
285    C: Curve,
286{
287    type Output = Self;
288
289    fn sub(self, other: Self) -> Self {
290        self.sub(&other)
291    }
292}
293
294impl<C> Sub<&ScalarPrimitive<C>> for ScalarPrimitive<C>
295where
296    C: Curve,
297{
298    type Output = Self;
299
300    fn sub(self, other: &Self) -> Self {
301        Self {
302            inner: self.inner.sub_mod(&other.inner, &Self::MODULUS),
303        }
304    }
305}
306
307impl<C> SubAssign<ScalarPrimitive<C>> for ScalarPrimitive<C>
308where
309    C: Curve,
310{
311    fn sub_assign(&mut self, other: Self) {
312        *self = *self - other;
313    }
314}
315
316impl<C> SubAssign<&ScalarPrimitive<C>> for ScalarPrimitive<C>
317where
318    C: Curve,
319{
320    fn sub_assign(&mut self, other: &Self) {
321        *self = *self - other;
322    }
323}
324
325impl<C> Neg for ScalarPrimitive<C>
326where
327    C: Curve,
328{
329    type Output = Self;
330
331    fn neg(self) -> Self {
332        Self {
333            inner: self.inner.neg_mod(&Self::MODULUS),
334        }
335    }
336}
337
338impl<C> Neg for &ScalarPrimitive<C>
339where
340    C: Curve,
341{
342    type Output = ScalarPrimitive<C>;
343
344    fn neg(self) -> ScalarPrimitive<C> {
345        -*self
346    }
347}
348
349impl<C> ShrAssign<usize> for ScalarPrimitive<C>
350where
351    C: Curve,
352{
353    fn shr_assign(&mut self, rhs: usize) {
354        self.inner >>= rhs;
355    }
356}
357
358impl<C> IsHigh for ScalarPrimitive<C>
359where
360    C: Curve,
361{
362    fn is_high(&self) -> Choice {
363        let n_2 = C::ORDER >> 1;
364        self.inner.ct_gt(&n_2)
365    }
366}
367
368impl<C> fmt::Display for ScalarPrimitive<C>
369where
370    C: Curve,
371{
372    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373        write!(f, "{self:X}")
374    }
375}
376
377impl<C> fmt::LowerHex for ScalarPrimitive<C>
378where
379    C: Curve,
380{
381    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382        write!(f, "{:x}", HexDisplay(&self.to_bytes()))
383    }
384}
385
386impl<C> fmt::UpperHex for ScalarPrimitive<C>
387where
388    C: Curve,
389{
390    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391        write!(f, "{:X}", HexDisplay(&self.to_bytes()))
392    }
393}
394
395impl<C> str::FromStr for ScalarPrimitive<C>
396where
397    C: Curve,
398{
399    type Err = Error;
400
401    fn from_str(hex: &str) -> Result<Self> {
402        let mut bytes = FieldBytes::<C>::default();
403        base16ct::lower::decode(hex, &mut bytes)?;
404        Self::from_slice(&bytes)
405    }
406}
407
408#[cfg(feature = "serde")]
409impl<C> Serialize for ScalarPrimitive<C>
410where
411    C: Curve,
412{
413    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
414    where
415        S: ser::Serializer,
416    {
417        serdect::array::serialize_hex_upper_or_bin(&self.to_bytes(), serializer)
418    }
419}
420
421#[cfg(feature = "serde")]
422impl<'de, C> Deserialize<'de> for ScalarPrimitive<C>
423where
424    C: Curve,
425{
426    fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
427    where
428        D: de::Deserializer<'de>,
429    {
430        let mut bytes = FieldBytes::<C>::default();
431        serdect::array::deserialize_hex_or_bin(&mut bytes, deserializer)?;
432        Self::from_slice(&bytes).map_err(|_| de::Error::custom("scalar out of range"))
433    }
434}