1use crate::{
4 ops::{Invert, Reduce, ReduceNonZero},
5 scalar::IsHigh,
6 CurveArithmetic, Error, FieldBytes, PrimeCurve, Scalar, ScalarPrimitive, SecretKey,
7};
8use base16ct::HexDisplay;
9use core::{
10 fmt,
11 ops::{Deref, Mul, Neg},
12 str,
13};
14use crypto_bigint::{ArrayEncoding, Integer};
15use ff::{Field, PrimeField};
16use generic_array::{typenum::Unsigned, GenericArray};
17use rand_core::CryptoRngCore;
18use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
19use zeroize::Zeroize;
20
21#[cfg(feature = "serde")]
22use serdect::serde::{de, ser, Deserialize, Serialize};
23
24#[derive(Clone)]
33pub struct NonZeroScalar<C>
34where
35 C: CurveArithmetic,
36{
37 scalar: Scalar<C>,
38}
39
40impl<C> NonZeroScalar<C>
41where
42 C: CurveArithmetic,
43{
44 pub fn random(mut rng: &mut impl CryptoRngCore) -> Self {
46 loop {
50 if let Some(result) = Self::new(Field::random(&mut rng)).into() {
51 break result;
52 }
53 }
54 }
55
56 pub fn new(scalar: Scalar<C>) -> CtOption<Self> {
58 CtOption::new(Self { scalar }, !scalar.is_zero())
59 }
60
61 pub fn from_repr(repr: FieldBytes<C>) -> CtOption<Self> {
63 Scalar::<C>::from_repr(repr).and_then(Self::new)
64 }
65
66 pub fn from_uint(uint: C::Uint) -> CtOption<Self> {
68 ScalarPrimitive::new(uint).and_then(|scalar| Self::new(scalar.into()))
69 }
70}
71
72impl<C> AsRef<Scalar<C>> for NonZeroScalar<C>
73where
74 C: CurveArithmetic,
75{
76 fn as_ref(&self) -> &Scalar<C> {
77 &self.scalar
78 }
79}
80
81impl<C> ConditionallySelectable for NonZeroScalar<C>
82where
83 C: CurveArithmetic,
84{
85 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
86 Self {
87 scalar: Scalar::<C>::conditional_select(&a.scalar, &b.scalar, choice),
88 }
89 }
90}
91
92impl<C> ConstantTimeEq for NonZeroScalar<C>
93where
94 C: CurveArithmetic,
95{
96 fn ct_eq(&self, other: &Self) -> Choice {
97 self.scalar.ct_eq(&other.scalar)
98 }
99}
100
101impl<C> Copy for NonZeroScalar<C> where C: CurveArithmetic {}
102
103impl<C> Deref for NonZeroScalar<C>
104where
105 C: CurveArithmetic,
106{
107 type Target = Scalar<C>;
108
109 fn deref(&self) -> &Scalar<C> {
110 &self.scalar
111 }
112}
113
114impl<C> From<NonZeroScalar<C>> for FieldBytes<C>
115where
116 C: CurveArithmetic,
117{
118 fn from(scalar: NonZeroScalar<C>) -> FieldBytes<C> {
119 Self::from(&scalar)
120 }
121}
122
123impl<C> From<&NonZeroScalar<C>> for FieldBytes<C>
124where
125 C: CurveArithmetic,
126{
127 fn from(scalar: &NonZeroScalar<C>) -> FieldBytes<C> {
128 scalar.to_repr()
129 }
130}
131
132impl<C> From<NonZeroScalar<C>> for ScalarPrimitive<C>
133where
134 C: CurveArithmetic,
135{
136 #[inline]
137 fn from(scalar: NonZeroScalar<C>) -> ScalarPrimitive<C> {
138 Self::from(&scalar)
139 }
140}
141
142impl<C> From<&NonZeroScalar<C>> for ScalarPrimitive<C>
143where
144 C: CurveArithmetic,
145{
146 fn from(scalar: &NonZeroScalar<C>) -> ScalarPrimitive<C> {
147 ScalarPrimitive::from_bytes(&scalar.to_repr()).unwrap()
148 }
149}
150
151impl<C> From<SecretKey<C>> for NonZeroScalar<C>
152where
153 C: CurveArithmetic,
154{
155 fn from(sk: SecretKey<C>) -> NonZeroScalar<C> {
156 Self::from(&sk)
157 }
158}
159
160impl<C> From<&SecretKey<C>> for NonZeroScalar<C>
161where
162 C: CurveArithmetic,
163{
164 fn from(sk: &SecretKey<C>) -> NonZeroScalar<C> {
165 let scalar = sk.as_scalar_primitive().to_scalar();
166 debug_assert!(!bool::from(scalar.is_zero()));
167 Self { scalar }
168 }
169}
170
171impl<C> Invert for NonZeroScalar<C>
172where
173 C: CurveArithmetic,
174 Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
175{
176 type Output = Self;
177
178 fn invert(&self) -> Self {
179 Self {
180 scalar: Invert::invert(&self.scalar).unwrap(),
182 }
183 }
184
185 fn invert_vartime(&self) -> Self::Output {
186 Self {
187 scalar: Invert::invert_vartime(&self.scalar).unwrap(),
189 }
190 }
191}
192
193impl<C> IsHigh for NonZeroScalar<C>
194where
195 C: CurveArithmetic,
196{
197 fn is_high(&self) -> Choice {
198 self.scalar.is_high()
199 }
200}
201
202impl<C> Neg for NonZeroScalar<C>
203where
204 C: CurveArithmetic,
205{
206 type Output = NonZeroScalar<C>;
207
208 fn neg(self) -> NonZeroScalar<C> {
209 let scalar = -self.scalar;
210 debug_assert!(!bool::from(scalar.is_zero()));
211 NonZeroScalar { scalar }
212 }
213}
214
215impl<C> Mul<NonZeroScalar<C>> for NonZeroScalar<C>
216where
217 C: PrimeCurve + CurveArithmetic,
218{
219 type Output = Self;
220
221 #[inline]
222 fn mul(self, other: Self) -> Self {
223 Self::mul(self, &other)
224 }
225}
226
227impl<C> Mul<&NonZeroScalar<C>> for NonZeroScalar<C>
228where
229 C: PrimeCurve + CurveArithmetic,
230{
231 type Output = Self;
232
233 fn mul(self, other: &Self) -> Self {
234 let scalar = self.scalar * other.scalar;
237 debug_assert!(!bool::from(scalar.is_zero()));
238 NonZeroScalar { scalar }
239 }
240}
241
242impl<C, I> Reduce<I> for NonZeroScalar<C>
244where
245 C: CurveArithmetic,
246 I: Integer + ArrayEncoding,
247 Scalar<C>: Reduce<I> + ReduceNonZero<I>,
248{
249 type Bytes = <Scalar<C> as Reduce<I>>::Bytes;
250
251 fn reduce(n: I) -> Self {
252 let scalar = Scalar::<C>::reduce_nonzero(n);
253 debug_assert!(!bool::from(scalar.is_zero()));
254 Self { scalar }
255 }
256
257 fn reduce_bytes(bytes: &Self::Bytes) -> Self {
258 let scalar = Scalar::<C>::reduce_nonzero_bytes(bytes);
259 debug_assert!(!bool::from(scalar.is_zero()));
260 Self { scalar }
261 }
262}
263
264impl<C, I> ReduceNonZero<I> for NonZeroScalar<C>
266where
267 Self: Reduce<I>,
268 C: CurveArithmetic,
269 I: Integer + ArrayEncoding,
270 Scalar<C>: Reduce<I, Bytes = Self::Bytes> + ReduceNonZero<I>,
271{
272 fn reduce_nonzero(n: I) -> Self {
273 Self::reduce(n)
274 }
275
276 fn reduce_nonzero_bytes(bytes: &Self::Bytes) -> Self {
277 Self::reduce_bytes(bytes)
278 }
279}
280
281impl<C> TryFrom<&[u8]> for NonZeroScalar<C>
282where
283 C: CurveArithmetic,
284{
285 type Error = Error;
286
287 fn try_from(bytes: &[u8]) -> Result<Self, Error> {
288 if bytes.len() == C::FieldBytesSize::USIZE {
289 Option::from(NonZeroScalar::from_repr(GenericArray::clone_from_slice(
290 bytes,
291 )))
292 .ok_or(Error)
293 } else {
294 Err(Error)
295 }
296 }
297}
298
299impl<C> Zeroize for NonZeroScalar<C>
300where
301 C: CurveArithmetic,
302{
303 fn zeroize(&mut self) {
304 self.scalar.zeroize();
306
307 self.scalar = Scalar::<C>::ONE;
310 }
311}
312
313impl<C> fmt::Display for NonZeroScalar<C>
314where
315 C: CurveArithmetic,
316{
317 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318 write!(f, "{self:X}")
319 }
320}
321
322impl<C> fmt::LowerHex for NonZeroScalar<C>
323where
324 C: CurveArithmetic,
325{
326 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327 write!(f, "{:x}", HexDisplay(&self.to_repr()))
328 }
329}
330
331impl<C> fmt::UpperHex for NonZeroScalar<C>
332where
333 C: CurveArithmetic,
334{
335 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336 write!(f, "{:}", HexDisplay(&self.to_repr()))
337 }
338}
339
340impl<C> str::FromStr for NonZeroScalar<C>
341where
342 C: CurveArithmetic,
343{
344 type Err = Error;
345
346 fn from_str(hex: &str) -> Result<Self, Error> {
347 let mut bytes = FieldBytes::<C>::default();
348
349 if base16ct::mixed::decode(hex, &mut bytes)?.len() == bytes.len() {
350 Option::from(Self::from_repr(bytes)).ok_or(Error)
351 } else {
352 Err(Error)
353 }
354 }
355}
356
357#[cfg(feature = "serde")]
358impl<C> Serialize for NonZeroScalar<C>
359where
360 C: CurveArithmetic,
361{
362 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
363 where
364 S: ser::Serializer,
365 {
366 ScalarPrimitive::from(self).serialize(serializer)
367 }
368}
369
370#[cfg(feature = "serde")]
371impl<'de, C> Deserialize<'de> for NonZeroScalar<C>
372where
373 C: CurveArithmetic,
374{
375 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
376 where
377 D: de::Deserializer<'de>,
378 {
379 let scalar = ScalarPrimitive::deserialize(deserializer)?;
380 Option::from(Self::new(scalar.into()))
381 .ok_or_else(|| de::Error::custom("expected non-zero scalar"))
382 }
383}
384
385#[cfg(all(test, feature = "dev"))]
386mod tests {
387 use crate::dev::{NonZeroScalar, Scalar};
388 use ff::{Field, PrimeField};
389 use hex_literal::hex;
390 use zeroize::Zeroize;
391
392 #[test]
393 fn round_trip() {
394 let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
395 let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap();
396 assert_eq!(&bytes, scalar.to_repr().as_slice());
397 }
398
399 #[test]
400 fn zeroize() {
401 let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap();
402 scalar.zeroize();
403 assert_eq!(*scalar, Scalar::ONE);
404 }
405}