elliptic_curve/scalar/
nonzero.rs

1//! Non-zero scalar type.
2
3use crate::{
4    ops::{Invert, Reduce, ReduceNonZero},
5    scalar::IsHigh,
6    CurveArithmetic, Error, FieldBytes, PrimeCurve, Scalar, ScalarPrimitive, SecretKey,
7};
8use base16ct::HexDisplay;
9use core::{
10    fmt,
11    ops::{Deref, Mul, Neg},
12    str,
13};
14use crypto_bigint::{ArrayEncoding, Integer};
15use ff::{Field, PrimeField};
16use generic_array::{typenum::Unsigned, GenericArray};
17use rand_core::CryptoRngCore;
18use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
19use zeroize::Zeroize;
20
21#[cfg(feature = "serde")]
22use serdect::serde::{de, ser, Deserialize, Serialize};
23
24/// Non-zero scalar type.
25///
26/// This type ensures that its value is not zero, ala `core::num::NonZero*`.
27/// To do this, the generic `S` type must impl both `Default` and
28/// `ConstantTimeEq`, with the requirement that `S::default()` returns 0.
29///
30/// In the context of ECC, it's useful for ensuring that scalar multiplication
31/// cannot result in the point at infinity.
32#[derive(Clone)]
33pub struct NonZeroScalar<C>
34where
35    C: CurveArithmetic,
36{
37    scalar: Scalar<C>,
38}
39
40impl<C> NonZeroScalar<C>
41where
42    C: CurveArithmetic,
43{
44    /// Generate a random `NonZeroScalar`.
45    pub fn random(mut rng: &mut impl CryptoRngCore) -> Self {
46        // Use rejection sampling to eliminate zero values.
47        // While this method isn't constant-time, the attacker shouldn't learn
48        // anything about unrelated outputs so long as `rng` is a secure `CryptoRng`.
49        loop {
50            if let Some(result) = Self::new(Field::random(&mut rng)).into() {
51                break result;
52            }
53        }
54    }
55
56    /// Create a [`NonZeroScalar`] from a scalar.
57    pub fn new(scalar: Scalar<C>) -> CtOption<Self> {
58        CtOption::new(Self { scalar }, !scalar.is_zero())
59    }
60
61    /// Decode a [`NonZeroScalar`] from a big endian-serialized field element.
62    pub fn from_repr(repr: FieldBytes<C>) -> CtOption<Self> {
63        Scalar::<C>::from_repr(repr).and_then(Self::new)
64    }
65
66    /// Create a [`NonZeroScalar`] from a `C::Uint`.
67    pub fn from_uint(uint: C::Uint) -> CtOption<Self> {
68        ScalarPrimitive::new(uint).and_then(|scalar| Self::new(scalar.into()))
69    }
70}
71
72impl<C> AsRef<Scalar<C>> for NonZeroScalar<C>
73where
74    C: CurveArithmetic,
75{
76    fn as_ref(&self) -> &Scalar<C> {
77        &self.scalar
78    }
79}
80
81impl<C> ConditionallySelectable for NonZeroScalar<C>
82where
83    C: CurveArithmetic,
84{
85    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
86        Self {
87            scalar: Scalar::<C>::conditional_select(&a.scalar, &b.scalar, choice),
88        }
89    }
90}
91
92impl<C> ConstantTimeEq for NonZeroScalar<C>
93where
94    C: CurveArithmetic,
95{
96    fn ct_eq(&self, other: &Self) -> Choice {
97        self.scalar.ct_eq(&other.scalar)
98    }
99}
100
101impl<C> Copy for NonZeroScalar<C> where C: CurveArithmetic {}
102
103impl<C> Deref for NonZeroScalar<C>
104where
105    C: CurveArithmetic,
106{
107    type Target = Scalar<C>;
108
109    fn deref(&self) -> &Scalar<C> {
110        &self.scalar
111    }
112}
113
114impl<C> From<NonZeroScalar<C>> for FieldBytes<C>
115where
116    C: CurveArithmetic,
117{
118    fn from(scalar: NonZeroScalar<C>) -> FieldBytes<C> {
119        Self::from(&scalar)
120    }
121}
122
123impl<C> From<&NonZeroScalar<C>> for FieldBytes<C>
124where
125    C: CurveArithmetic,
126{
127    fn from(scalar: &NonZeroScalar<C>) -> FieldBytes<C> {
128        scalar.to_repr()
129    }
130}
131
132impl<C> From<NonZeroScalar<C>> for ScalarPrimitive<C>
133where
134    C: CurveArithmetic,
135{
136    #[inline]
137    fn from(scalar: NonZeroScalar<C>) -> ScalarPrimitive<C> {
138        Self::from(&scalar)
139    }
140}
141
142impl<C> From<&NonZeroScalar<C>> for ScalarPrimitive<C>
143where
144    C: CurveArithmetic,
145{
146    fn from(scalar: &NonZeroScalar<C>) -> ScalarPrimitive<C> {
147        ScalarPrimitive::from_bytes(&scalar.to_repr()).unwrap()
148    }
149}
150
151impl<C> From<SecretKey<C>> for NonZeroScalar<C>
152where
153    C: CurveArithmetic,
154{
155    fn from(sk: SecretKey<C>) -> NonZeroScalar<C> {
156        Self::from(&sk)
157    }
158}
159
160impl<C> From<&SecretKey<C>> for NonZeroScalar<C>
161where
162    C: CurveArithmetic,
163{
164    fn from(sk: &SecretKey<C>) -> NonZeroScalar<C> {
165        let scalar = sk.as_scalar_primitive().to_scalar();
166        debug_assert!(!bool::from(scalar.is_zero()));
167        Self { scalar }
168    }
169}
170
171impl<C> Invert for NonZeroScalar<C>
172where
173    C: CurveArithmetic,
174    Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
175{
176    type Output = Self;
177
178    fn invert(&self) -> Self {
179        Self {
180            // This will always succeed since `scalar` will never be 0
181            scalar: Invert::invert(&self.scalar).unwrap(),
182        }
183    }
184
185    fn invert_vartime(&self) -> Self::Output {
186        Self {
187            // This will always succeed since `scalar` will never be 0
188            scalar: Invert::invert_vartime(&self.scalar).unwrap(),
189        }
190    }
191}
192
193impl<C> IsHigh for NonZeroScalar<C>
194where
195    C: CurveArithmetic,
196{
197    fn is_high(&self) -> Choice {
198        self.scalar.is_high()
199    }
200}
201
202impl<C> Neg for NonZeroScalar<C>
203where
204    C: CurveArithmetic,
205{
206    type Output = NonZeroScalar<C>;
207
208    fn neg(self) -> NonZeroScalar<C> {
209        let scalar = -self.scalar;
210        debug_assert!(!bool::from(scalar.is_zero()));
211        NonZeroScalar { scalar }
212    }
213}
214
215impl<C> Mul<NonZeroScalar<C>> for NonZeroScalar<C>
216where
217    C: PrimeCurve + CurveArithmetic,
218{
219    type Output = Self;
220
221    #[inline]
222    fn mul(self, other: Self) -> Self {
223        Self::mul(self, &other)
224    }
225}
226
227impl<C> Mul<&NonZeroScalar<C>> for NonZeroScalar<C>
228where
229    C: PrimeCurve + CurveArithmetic,
230{
231    type Output = Self;
232
233    fn mul(self, other: &Self) -> Self {
234        // Multiplication is modulo a prime, so the product of two non-zero
235        // scalars is also non-zero.
236        let scalar = self.scalar * other.scalar;
237        debug_assert!(!bool::from(scalar.is_zero()));
238        NonZeroScalar { scalar }
239    }
240}
241
242/// Note: this is a non-zero reduction, as it's impl'd for [`NonZeroScalar`].
243impl<C, I> Reduce<I> for NonZeroScalar<C>
244where
245    C: CurveArithmetic,
246    I: Integer + ArrayEncoding,
247    Scalar<C>: Reduce<I> + ReduceNonZero<I>,
248{
249    type Bytes = <Scalar<C> as Reduce<I>>::Bytes;
250
251    fn reduce(n: I) -> Self {
252        let scalar = Scalar::<C>::reduce_nonzero(n);
253        debug_assert!(!bool::from(scalar.is_zero()));
254        Self { scalar }
255    }
256
257    fn reduce_bytes(bytes: &Self::Bytes) -> Self {
258        let scalar = Scalar::<C>::reduce_nonzero_bytes(bytes);
259        debug_assert!(!bool::from(scalar.is_zero()));
260        Self { scalar }
261    }
262}
263
264/// Note: forwards to the [`Reduce`] impl.
265impl<C, I> ReduceNonZero<I> for NonZeroScalar<C>
266where
267    Self: Reduce<I>,
268    C: CurveArithmetic,
269    I: Integer + ArrayEncoding,
270    Scalar<C>: Reduce<I, Bytes = Self::Bytes> + ReduceNonZero<I>,
271{
272    fn reduce_nonzero(n: I) -> Self {
273        Self::reduce(n)
274    }
275
276    fn reduce_nonzero_bytes(bytes: &Self::Bytes) -> Self {
277        Self::reduce_bytes(bytes)
278    }
279}
280
281impl<C> TryFrom<&[u8]> for NonZeroScalar<C>
282where
283    C: CurveArithmetic,
284{
285    type Error = Error;
286
287    fn try_from(bytes: &[u8]) -> Result<Self, Error> {
288        if bytes.len() == C::FieldBytesSize::USIZE {
289            Option::from(NonZeroScalar::from_repr(GenericArray::clone_from_slice(
290                bytes,
291            )))
292            .ok_or(Error)
293        } else {
294            Err(Error)
295        }
296    }
297}
298
299impl<C> Zeroize for NonZeroScalar<C>
300where
301    C: CurveArithmetic,
302{
303    fn zeroize(&mut self) {
304        // Use zeroize's volatile writes to ensure value is cleared.
305        self.scalar.zeroize();
306
307        // Write a 1 instead of a 0 to ensure this type's non-zero invariant
308        // is upheld.
309        self.scalar = Scalar::<C>::ONE;
310    }
311}
312
313impl<C> fmt::Display for NonZeroScalar<C>
314where
315    C: CurveArithmetic,
316{
317    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318        write!(f, "{self:X}")
319    }
320}
321
322impl<C> fmt::LowerHex for NonZeroScalar<C>
323where
324    C: CurveArithmetic,
325{
326    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327        write!(f, "{:x}", HexDisplay(&self.to_repr()))
328    }
329}
330
331impl<C> fmt::UpperHex for NonZeroScalar<C>
332where
333    C: CurveArithmetic,
334{
335    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336        write!(f, "{:}", HexDisplay(&self.to_repr()))
337    }
338}
339
340impl<C> str::FromStr for NonZeroScalar<C>
341where
342    C: CurveArithmetic,
343{
344    type Err = Error;
345
346    fn from_str(hex: &str) -> Result<Self, Error> {
347        let mut bytes = FieldBytes::<C>::default();
348
349        if base16ct::mixed::decode(hex, &mut bytes)?.len() == bytes.len() {
350            Option::from(Self::from_repr(bytes)).ok_or(Error)
351        } else {
352            Err(Error)
353        }
354    }
355}
356
357#[cfg(feature = "serde")]
358impl<C> Serialize for NonZeroScalar<C>
359where
360    C: CurveArithmetic,
361{
362    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
363    where
364        S: ser::Serializer,
365    {
366        ScalarPrimitive::from(self).serialize(serializer)
367    }
368}
369
370#[cfg(feature = "serde")]
371impl<'de, C> Deserialize<'de> for NonZeroScalar<C>
372where
373    C: CurveArithmetic,
374{
375    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
376    where
377        D: de::Deserializer<'de>,
378    {
379        let scalar = ScalarPrimitive::deserialize(deserializer)?;
380        Option::from(Self::new(scalar.into()))
381            .ok_or_else(|| de::Error::custom("expected non-zero scalar"))
382    }
383}
384
385#[cfg(all(test, feature = "dev"))]
386mod tests {
387    use crate::dev::{NonZeroScalar, Scalar};
388    use ff::{Field, PrimeField};
389    use hex_literal::hex;
390    use zeroize::Zeroize;
391
392    #[test]
393    fn round_trip() {
394        let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
395        let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap();
396        assert_eq!(&bytes, scalar.to_repr().as_slice());
397    }
398
399    #[test]
400    fn zeroize() {
401        let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap();
402        scalar.zeroize();
403        assert_eq!(*scalar, Scalar::ONE);
404    }
405}