k256/arithmetic/
scalar.rs

1//! Scalar field arithmetic.
2
3#[cfg_attr(not(target_pointer_width = "64"), path = "scalar/wide32.rs")]
4#[cfg_attr(target_pointer_width = "64", path = "scalar/wide64.rs")]
5mod wide;
6
7pub(crate) use self::wide::WideScalar;
8
9use crate::{FieldBytes, Secp256k1, WideBytes, ORDER, ORDER_HEX};
10use core::{
11    iter::{Product, Sum},
12    ops::{Add, AddAssign, Mul, MulAssign, Neg, Shr, ShrAssign, Sub, SubAssign},
13};
14use elliptic_curve::{
15    bigint::{prelude::*, Limb, Word, U256, U512},
16    ff::{self, Field, PrimeField},
17    ops::{Invert, Reduce, ReduceNonZero},
18    rand_core::{CryptoRngCore, RngCore},
19    scalar::{FromUintUnchecked, IsHigh},
20    subtle::{
21        Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
22        CtOption,
23    },
24    zeroize::DefaultIsZeroes,
25    Curve, ScalarPrimitive,
26};
27
28#[cfg(feature = "bits")]
29use {crate::ScalarBits, elliptic_curve::group::ff::PrimeFieldBits};
30
31#[cfg(feature = "serde")]
32use serdect::serde::{de, ser, Deserialize, Serialize};
33
34#[cfg(test)]
35use num_bigint::{BigUint, ToBigUint};
36
37/// Constant representing the modulus
38/// n = FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141
39const MODULUS: [Word; U256::LIMBS] = ORDER.to_words();
40
41/// Constant representing the modulus / 2
42const FRAC_MODULUS_2: U256 = ORDER.shr_vartime(1);
43
44/// Scalars are elements in the finite field modulo n.
45///
46/// # Trait impls
47///
48/// Much of the important functionality of scalars is provided by traits from
49/// the [`ff`](https://docs.rs/ff/) crate, which is re-exported as
50/// `k256::elliptic_curve::ff`:
51///
52/// - [`Field`](https://docs.rs/ff/latest/ff/trait.Field.html) -
53///   represents elements of finite fields and provides:
54///   - [`Field::random`](https://docs.rs/ff/latest/ff/trait.Field.html#tymethod.random) -
55///     generate a random scalar
56///   - `double`, `square`, and `invert` operations
57///   - Bounds for [`Add`], [`Sub`], [`Mul`], and [`Neg`] (as well as `*Assign` equivalents)
58///   - Bounds for [`ConditionallySelectable`] from the `subtle` crate
59/// - [`PrimeField`](https://docs.rs/ff/latest/ff/trait.PrimeField.html) -
60///   represents elements of prime fields and provides:
61///   - `from_repr`/`to_repr` for converting field elements from/to big integers.
62///   - `multiplicative_generator` and `root_of_unity` constants.
63/// - [`PrimeFieldBits`](https://docs.rs/ff/latest/ff/trait.PrimeFieldBits.html) -
64///   operations over field elements represented as bits (requires `bits` feature)
65///
66/// Please see the documentation for the relevant traits for more information.
67///
68/// # `serde` support
69///
70/// When the `serde` feature of this crate is enabled, the `Serialize` and
71/// `Deserialize` traits are impl'd for this type.
72///
73/// The serialization is a fixed-width big endian encoding. When used with
74/// textual formats, the binary data is encoded as hexadecimal.
75#[derive(Clone, Copy, Debug, Default, PartialOrd, Ord)]
76pub struct Scalar(pub(crate) U256);
77
78impl Scalar {
79    /// Zero scalar.
80    pub const ZERO: Self = Self(U256::ZERO);
81
82    /// Multiplicative identity.
83    pub const ONE: Self = Self(U256::ONE);
84
85    /// Checks if the scalar is zero.
86    pub fn is_zero(&self) -> Choice {
87        self.0.is_zero()
88    }
89
90    /// Returns the SEC1 encoding of this scalar.
91    pub fn to_bytes(&self) -> FieldBytes {
92        self.0.to_be_byte_array()
93    }
94
95    /// Negates the scalar.
96    pub const fn negate(&self) -> Self {
97        Self(self.0.neg_mod(&ORDER))
98    }
99
100    /// Returns self + rhs mod n.
101    pub const fn add(&self, rhs: &Self) -> Self {
102        Self(self.0.add_mod(&rhs.0, &ORDER))
103    }
104
105    /// Returns self - rhs mod n.
106    pub const fn sub(&self, rhs: &Self) -> Self {
107        Self(self.0.sub_mod(&rhs.0, &ORDER))
108    }
109
110    /// Modulo multiplies two scalars.
111    pub fn mul(&self, rhs: &Scalar) -> Scalar {
112        WideScalar::mul_wide(self, rhs).reduce()
113    }
114
115    /// Modulo squares the scalar.
116    pub fn square(&self) -> Self {
117        self.mul(self)
118    }
119
120    /// Right shifts the scalar.
121    ///
122    /// Note: not constant-time with respect to the `shift` parameter.
123    pub fn shr_vartime(&self, shift: usize) -> Scalar {
124        Self(self.0.shr_vartime(shift))
125    }
126
127    /// Inverts the scalar.
128    pub fn invert(&self) -> CtOption<Self> {
129        // Using an addition chain from
130        // https://briansmith.org/ecc-inversion-addition-chains-01#secp256k1_scalar_inversion
131        let x_1 = *self;
132        let x_10 = self.pow2k(1);
133        let x_11 = x_10.mul(&x_1);
134        let x_101 = x_10.mul(&x_11);
135        let x_111 = x_10.mul(&x_101);
136        let x_1001 = x_10.mul(&x_111);
137        let x_1011 = x_10.mul(&x_1001);
138        let x_1101 = x_10.mul(&x_1011);
139
140        let x6 = x_1101.pow2k(2).mul(&x_1011);
141        let x8 = x6.pow2k(2).mul(&x_11);
142        let x14 = x8.pow2k(6).mul(&x6);
143        let x28 = x14.pow2k(14).mul(&x14);
144        let x56 = x28.pow2k(28).mul(&x28);
145
146        #[rustfmt::skip]
147            let res = x56
148            .pow2k(56).mul(&x56)
149            .pow2k(14).mul(&x14)
150            .pow2k(3).mul(&x_101)
151            .pow2k(4).mul(&x_111)
152            .pow2k(4).mul(&x_101)
153            .pow2k(5).mul(&x_1011)
154            .pow2k(4).mul(&x_1011)
155            .pow2k(4).mul(&x_111)
156            .pow2k(5).mul(&x_111)
157            .pow2k(6).mul(&x_1101)
158            .pow2k(4).mul(&x_101)
159            .pow2k(3).mul(&x_111)
160            .pow2k(5).mul(&x_1001)
161            .pow2k(6).mul(&x_101)
162            .pow2k(10).mul(&x_111)
163            .pow2k(4).mul(&x_111)
164            .pow2k(9).mul(&x8)
165            .pow2k(5).mul(&x_1001)
166            .pow2k(6).mul(&x_1011)
167            .pow2k(4).mul(&x_1101)
168            .pow2k(5).mul(&x_11)
169            .pow2k(6).mul(&x_1101)
170            .pow2k(10).mul(&x_1101)
171            .pow2k(4).mul(&x_1001)
172            .pow2k(6).mul(&x_1)
173            .pow2k(8).mul(&x6);
174
175        CtOption::new(res, !self.is_zero())
176    }
177
178    /// Returns the scalar modulus as a `BigUint` object.
179    #[cfg(test)]
180    pub fn modulus_as_biguint() -> BigUint {
181        Self::ONE.negate().to_biguint().unwrap() + 1.to_biguint().unwrap()
182    }
183
184    /// Returns a (nearly) uniformly-random scalar, generated in constant time.
185    pub fn generate_biased(rng: &mut impl CryptoRngCore) -> Self {
186        // We reduce a random 512-bit value into a 256-bit field, which results in a
187        // negligible bias from the uniform distribution, but the process is constant-time.
188        let mut buf = [0u8; 64];
189        rng.fill_bytes(&mut buf);
190        WideScalar::from_bytes(&buf).reduce()
191    }
192
193    /// Returns a uniformly-random scalar, generated using rejection sampling.
194    // TODO(tarcieri): make this a `CryptoRng` when `ff` allows it
195    pub fn generate_vartime(rng: &mut impl RngCore) -> Self {
196        let mut bytes = FieldBytes::default();
197
198        // TODO: pre-generate several scalars to bring the probability of non-constant-timeness down?
199        loop {
200            rng.fill_bytes(&mut bytes);
201            if let Some(scalar) = Scalar::from_repr(bytes).into() {
202                return scalar;
203            }
204        }
205    }
206
207    /// Attempts to parse the given byte array as a scalar.
208    /// Does not check the result for being in the correct range.
209    pub(crate) const fn from_bytes_unchecked(bytes: &[u8; 32]) -> Self {
210        Self(U256::from_be_slice(bytes))
211    }
212
213    /// Raises the scalar to the power `2^k`.
214    fn pow2k(&self, k: usize) -> Self {
215        let mut x = *self;
216        for _j in 0..k {
217            x = x.square();
218        }
219        x
220    }
221}
222
223impl Field for Scalar {
224    const ZERO: Self = Self::ZERO;
225    const ONE: Self = Self::ONE;
226
227    fn random(mut rng: impl RngCore) -> Self {
228        // Uses rejection sampling as the default random generation method,
229        // which produces a uniformly random distribution of scalars.
230        //
231        // This method is not constant time, but should be secure so long as
232        // rejected RNG outputs are unrelated to future ones (which is a
233        // necessary property of a `CryptoRng`).
234        //
235        // With an unbiased RNG, the probability of failing to complete after 4
236        // iterations is vanishingly small.
237        Self::generate_vartime(&mut rng)
238    }
239
240    #[must_use]
241    fn square(&self) -> Self {
242        Scalar::square(self)
243    }
244
245    #[must_use]
246    fn double(&self) -> Self {
247        self.add(self)
248    }
249
250    fn invert(&self) -> CtOption<Self> {
251        Scalar::invert(self)
252    }
253
254    /// Tonelli-Shank's algorithm for q mod 16 = 1
255    /// <https://eprint.iacr.org/2012/685.pdf> (page 12, algorithm 5)
256    #[allow(clippy::many_single_char_names)]
257    fn sqrt(&self) -> CtOption<Self> {
258        // Note: `pow_vartime` is constant-time with respect to `self`
259        let w = self.pow_vartime([
260            0x777fa4bd19a06c82,
261            0xfd755db9cd5e9140,
262            0xffffffffffffffff,
263            0x1ffffffffffffff,
264        ]);
265
266        let mut v = Self::S;
267        let mut x = *self * w;
268        let mut b = x * w;
269        let mut z = Self::ROOT_OF_UNITY;
270
271        for max_v in (1..=Self::S).rev() {
272            let mut k = 1;
273            let mut tmp = b.square();
274            let mut j_less_than_v = Choice::from(1);
275
276            for j in 2..max_v {
277                let tmp_is_one = tmp.ct_eq(&Self::ONE);
278                let squared = Self::conditional_select(&tmp, &z, tmp_is_one).square();
279                tmp = Self::conditional_select(&squared, &tmp, tmp_is_one);
280                let new_z = Self::conditional_select(&z, &squared, tmp_is_one);
281                j_less_than_v &= !j.ct_eq(&v);
282                k = u32::conditional_select(&j, &k, tmp_is_one);
283                z = Self::conditional_select(&z, &new_z, j_less_than_v);
284            }
285
286            let result = x * z;
287            x = Self::conditional_select(&result, &x, b.ct_eq(&Self::ONE));
288            z = z.square();
289            b *= z;
290            v = k;
291        }
292
293        CtOption::new(x, x.square().ct_eq(self))
294    }
295
296    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
297        ff::helpers::sqrt_ratio_generic(num, div)
298    }
299}
300
301impl AsRef<Scalar> for Scalar {
302    fn as_ref(&self) -> &Scalar {
303        self
304    }
305}
306
307impl PrimeField for Scalar {
308    type Repr = FieldBytes;
309
310    const MODULUS: &'static str = ORDER_HEX;
311    const NUM_BITS: u32 = 256;
312    const CAPACITY: u32 = 255;
313    const TWO_INV: Self = Self(U256::from_be_hex(
314        "7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a1",
315    ));
316    const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_u8(7));
317    const S: u32 = 6;
318    const ROOT_OF_UNITY: Self = Self(U256::from_be_hex(
319        "0c1dc060e7a91986df9879a3fbc483a898bdeab680756045992f4b5402b052f2",
320    ));
321    const ROOT_OF_UNITY_INV: Self = Self(U256::from_be_hex(
322        "fd3ae181f12d7096efc7b0c75b8cbb7277a275910aa413c3b6fb30a0884f0d1c",
323    ));
324    const DELTA: Self = Self(U256::from_be_hex(
325        "0000000000000000000cbc21fe4561c8d63b78e780e1341e199417c8c0bb7601",
326    ));
327
328    /// Attempts to parse the given byte array as an SEC1-encoded scalar.
329    ///
330    /// Returns None if the byte array does not contain a big-endian integer in the range
331    /// [0, p).
332    fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
333        let inner = U256::from_be_byte_array(bytes);
334        CtOption::new(Self(inner), inner.ct_lt(&Secp256k1::ORDER))
335    }
336
337    fn to_repr(&self) -> FieldBytes {
338        self.to_bytes()
339    }
340
341    fn is_odd(&self) -> Choice {
342        self.0.is_odd()
343    }
344}
345
346#[cfg(feature = "bits")]
347impl PrimeFieldBits for Scalar {
348    #[cfg(target_pointer_width = "32")]
349    type ReprBits = [u32; 8];
350
351    #[cfg(target_pointer_width = "64")]
352    type ReprBits = [u64; 4];
353
354    fn to_le_bits(&self) -> ScalarBits {
355        self.into()
356    }
357
358    fn char_le_bits() -> ScalarBits {
359        ORDER.to_words().into()
360    }
361}
362
363impl DefaultIsZeroes for Scalar {}
364
365impl From<u32> for Scalar {
366    fn from(k: u32) -> Self {
367        Self(k.into())
368    }
369}
370
371impl From<u64> for Scalar {
372    fn from(k: u64) -> Self {
373        Self(k.into())
374    }
375}
376
377impl From<u128> for Scalar {
378    fn from(k: u128) -> Self {
379        Self(k.into())
380    }
381}
382
383impl From<ScalarPrimitive<Secp256k1>> for Scalar {
384    fn from(scalar: ScalarPrimitive<Secp256k1>) -> Scalar {
385        Scalar(*scalar.as_uint())
386    }
387}
388
389impl From<&ScalarPrimitive<Secp256k1>> for Scalar {
390    fn from(scalar: &ScalarPrimitive<Secp256k1>) -> Scalar {
391        Scalar(*scalar.as_uint())
392    }
393}
394
395impl From<Scalar> for ScalarPrimitive<Secp256k1> {
396    fn from(scalar: Scalar) -> ScalarPrimitive<Secp256k1> {
397        ScalarPrimitive::from(&scalar)
398    }
399}
400
401impl From<&Scalar> for ScalarPrimitive<Secp256k1> {
402    fn from(scalar: &Scalar) -> ScalarPrimitive<Secp256k1> {
403        ScalarPrimitive::new(scalar.0).unwrap()
404    }
405}
406
407impl FromUintUnchecked for Scalar {
408    type Uint = U256;
409
410    fn from_uint_unchecked(uint: Self::Uint) -> Self {
411        Self(uint)
412    }
413}
414
415impl Invert for Scalar {
416    type Output = CtOption<Self>;
417
418    fn invert(&self) -> CtOption<Self> {
419        self.invert()
420    }
421
422    /// Fast variable-time inversion using Stein's algorithm.
423    ///
424    /// Returns none if the scalar is zero.
425    ///
426    /// <https://link.springer.com/article/10.1007/s13389-016-0135-4>
427    ///
428    /// ⚠️ WARNING!
429    ///
430    /// This method should not be used with (unblinded) secret scalars, as its
431    /// variable-time operation can potentially leak secrets through
432    /// sidechannels.
433    #[allow(non_snake_case)]
434    fn invert_vartime(&self) -> CtOption<Self> {
435        let mut u = *self;
436        let mut v = Self::from_uint_unchecked(Secp256k1::ORDER);
437        let mut A = Self::ONE;
438        let mut C = Self::ZERO;
439
440        while !bool::from(u.is_zero()) {
441            // u-loop
442            while bool::from(u.is_even()) {
443                u >>= 1;
444
445                let was_odd: bool = A.is_odd().into();
446                A >>= 1;
447
448                if was_odd {
449                    A += Self::from_uint_unchecked(FRAC_MODULUS_2);
450                    A += Self::ONE;
451                }
452            }
453
454            // v-loop
455            while bool::from(v.is_even()) {
456                v >>= 1;
457
458                let was_odd: bool = C.is_odd().into();
459                C >>= 1;
460
461                if was_odd {
462                    C += Self::from_uint_unchecked(FRAC_MODULUS_2);
463                    C += Self::ONE;
464                }
465            }
466
467            // sub-step
468            if u >= v {
469                u -= &v;
470                A -= &C;
471            } else {
472                v -= &u;
473                C -= &A;
474            }
475        }
476
477        CtOption::new(C, !self.is_zero())
478    }
479}
480
481impl IsHigh for Scalar {
482    fn is_high(&self) -> Choice {
483        self.0.ct_gt(&FRAC_MODULUS_2)
484    }
485}
486
487impl Shr<usize> for Scalar {
488    type Output = Self;
489
490    fn shr(self, rhs: usize) -> Self::Output {
491        self.shr_vartime(rhs)
492    }
493}
494
495impl Shr<usize> for &Scalar {
496    type Output = Scalar;
497
498    fn shr(self, rhs: usize) -> Self::Output {
499        self.shr_vartime(rhs)
500    }
501}
502
503impl ShrAssign<usize> for Scalar {
504    fn shr_assign(&mut self, rhs: usize) {
505        *self = *self >> rhs;
506    }
507}
508
509impl ConditionallySelectable for Scalar {
510    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
511        Self(U256::conditional_select(&a.0, &b.0, choice))
512    }
513}
514
515impl ConstantTimeEq for Scalar {
516    fn ct_eq(&self, other: &Self) -> Choice {
517        self.0.ct_eq(&(other.0))
518    }
519}
520
521impl PartialEq for Scalar {
522    fn eq(&self, other: &Self) -> bool {
523        self.ct_eq(other).into()
524    }
525}
526
527impl Eq for Scalar {}
528
529impl Neg for Scalar {
530    type Output = Scalar;
531
532    fn neg(self) -> Scalar {
533        self.negate()
534    }
535}
536
537impl Neg for &Scalar {
538    type Output = Scalar;
539
540    fn neg(self) -> Scalar {
541        self.negate()
542    }
543}
544
545impl Add<Scalar> for Scalar {
546    type Output = Scalar;
547
548    fn add(self, other: Scalar) -> Scalar {
549        Scalar::add(&self, &other)
550    }
551}
552
553impl Add<&Scalar> for &Scalar {
554    type Output = Scalar;
555
556    fn add(self, other: &Scalar) -> Scalar {
557        Scalar::add(self, other)
558    }
559}
560
561impl Add<Scalar> for &Scalar {
562    type Output = Scalar;
563
564    fn add(self, other: Scalar) -> Scalar {
565        Scalar::add(self, &other)
566    }
567}
568
569impl Add<&Scalar> for Scalar {
570    type Output = Scalar;
571
572    fn add(self, other: &Scalar) -> Scalar {
573        Scalar::add(&self, other)
574    }
575}
576
577impl AddAssign<Scalar> for Scalar {
578    #[inline]
579    fn add_assign(&mut self, rhs: Scalar) {
580        *self = Scalar::add(self, &rhs);
581    }
582}
583
584impl AddAssign<&Scalar> for Scalar {
585    fn add_assign(&mut self, rhs: &Scalar) {
586        *self = Scalar::add(self, rhs);
587    }
588}
589
590impl Sub<Scalar> for Scalar {
591    type Output = Scalar;
592
593    fn sub(self, other: Scalar) -> Scalar {
594        Scalar::sub(&self, &other)
595    }
596}
597
598impl Sub<&Scalar> for &Scalar {
599    type Output = Scalar;
600
601    fn sub(self, other: &Scalar) -> Scalar {
602        Scalar::sub(self, other)
603    }
604}
605
606impl Sub<&Scalar> for Scalar {
607    type Output = Scalar;
608
609    fn sub(self, other: &Scalar) -> Scalar {
610        Scalar::sub(&self, other)
611    }
612}
613
614impl SubAssign<Scalar> for Scalar {
615    fn sub_assign(&mut self, rhs: Scalar) {
616        *self = Scalar::sub(self, &rhs);
617    }
618}
619
620impl SubAssign<&Scalar> for Scalar {
621    fn sub_assign(&mut self, rhs: &Scalar) {
622        *self = Scalar::sub(self, rhs);
623    }
624}
625
626impl Mul<Scalar> for Scalar {
627    type Output = Scalar;
628
629    fn mul(self, other: Scalar) -> Scalar {
630        Scalar::mul(&self, &other)
631    }
632}
633
634impl Mul<&Scalar> for &Scalar {
635    type Output = Scalar;
636
637    fn mul(self, other: &Scalar) -> Scalar {
638        Scalar::mul(self, other)
639    }
640}
641
642impl Mul<&Scalar> for Scalar {
643    type Output = Scalar;
644
645    fn mul(self, other: &Scalar) -> Scalar {
646        Scalar::mul(&self, other)
647    }
648}
649
650impl MulAssign<Scalar> for Scalar {
651    fn mul_assign(&mut self, rhs: Scalar) {
652        *self = Scalar::mul(self, &rhs);
653    }
654}
655
656impl MulAssign<&Scalar> for Scalar {
657    fn mul_assign(&mut self, rhs: &Scalar) {
658        *self = Scalar::mul(self, rhs);
659    }
660}
661
662impl Reduce<U256> for Scalar {
663    type Bytes = FieldBytes;
664
665    fn reduce(w: U256) -> Self {
666        let (r, underflow) = w.sbb(&ORDER, Limb::ZERO);
667        let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
668        Self(U256::conditional_select(&w, &r, !underflow))
669    }
670
671    #[inline]
672    fn reduce_bytes(bytes: &FieldBytes) -> Self {
673        Self::reduce(U256::from_be_byte_array(*bytes))
674    }
675}
676
677impl Reduce<U512> for Scalar {
678    type Bytes = WideBytes;
679
680    fn reduce(w: U512) -> Self {
681        WideScalar(w).reduce()
682    }
683
684    fn reduce_bytes(bytes: &WideBytes) -> Self {
685        Self::reduce(U512::from_be_byte_array(*bytes))
686    }
687}
688
689impl ReduceNonZero<U256> for Scalar {
690    fn reduce_nonzero(w: U256) -> Self {
691        const ORDER_MINUS_ONE: U256 = ORDER.wrapping_sub(&U256::ONE);
692        let (r, underflow) = w.sbb(&ORDER_MINUS_ONE, Limb::ZERO);
693        let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
694        Self(U256::conditional_select(&w, &r, !underflow).wrapping_add(&U256::ONE))
695    }
696
697    #[inline]
698    fn reduce_nonzero_bytes(bytes: &FieldBytes) -> Self {
699        Self::reduce_nonzero(U256::from_be_byte_array(*bytes))
700    }
701}
702
703impl ReduceNonZero<U512> for Scalar {
704    fn reduce_nonzero(w: U512) -> Self {
705        WideScalar(w).reduce_nonzero()
706    }
707
708    #[inline]
709    fn reduce_nonzero_bytes(bytes: &WideBytes) -> Self {
710        Self::reduce_nonzero(U512::from_be_byte_array(*bytes))
711    }
712}
713
714impl Sum for Scalar {
715    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
716        iter.reduce(core::ops::Add::add).unwrap_or(Self::ZERO)
717    }
718}
719
720impl<'a> Sum<&'a Scalar> for Scalar {
721    fn sum<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
722        iter.copied().sum()
723    }
724}
725
726impl Product for Scalar {
727    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
728        iter.reduce(core::ops::Mul::mul).unwrap_or(Self::ONE)
729    }
730}
731
732impl<'a> Product<&'a Scalar> for Scalar {
733    fn product<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
734        iter.copied().product()
735    }
736}
737
738#[cfg(feature = "bits")]
739impl From<&Scalar> for ScalarBits {
740    fn from(scalar: &Scalar) -> ScalarBits {
741        scalar.0.to_words().into()
742    }
743}
744
745impl From<Scalar> for FieldBytes {
746    fn from(scalar: Scalar) -> Self {
747        scalar.to_bytes()
748    }
749}
750
751impl From<&Scalar> for FieldBytes {
752    fn from(scalar: &Scalar) -> Self {
753        scalar.to_bytes()
754    }
755}
756
757impl From<Scalar> for U256 {
758    fn from(scalar: Scalar) -> Self {
759        scalar.0
760    }
761}
762
763impl From<&Scalar> for U256 {
764    fn from(scalar: &Scalar) -> Self {
765        scalar.0
766    }
767}
768
769#[cfg(feature = "serde")]
770impl Serialize for Scalar {
771    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
772    where
773        S: ser::Serializer,
774    {
775        ScalarPrimitive::from(self).serialize(serializer)
776    }
777}
778
779#[cfg(feature = "serde")]
780impl<'de> Deserialize<'de> for Scalar {
781    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
782    where
783        D: de::Deserializer<'de>,
784    {
785        Ok(ScalarPrimitive::deserialize(deserializer)?.into())
786    }
787}
788
789#[cfg(test)]
790mod tests {
791    use super::Scalar;
792    use crate::{
793        arithmetic::dev::{biguint_to_bytes, bytes_to_biguint},
794        FieldBytes, NonZeroScalar, WideBytes, ORDER,
795    };
796    use elliptic_curve::{
797        bigint::{ArrayEncoding, U256, U512},
798        ff::{Field, PrimeField},
799        generic_array::GenericArray,
800        ops::{Invert, Reduce},
801        scalar::IsHigh,
802    };
803    use num_bigint::{BigUint, ToBigUint};
804    use num_traits::Zero;
805    use proptest::prelude::*;
806    use rand_core::OsRng;
807
808    #[cfg(feature = "alloc")]
809    use alloc::vec::Vec;
810    use elliptic_curve::ops::BatchInvert;
811
812    impl From<&BigUint> for Scalar {
813        fn from(x: &BigUint) -> Self {
814            debug_assert!(x < &Scalar::modulus_as_biguint());
815            let bytes = biguint_to_bytes(x);
816            Self::from_repr(bytes.into()).unwrap()
817        }
818    }
819
820    impl From<BigUint> for Scalar {
821        fn from(x: BigUint) -> Self {
822            Self::from(&x)
823        }
824    }
825
826    impl ToBigUint for Scalar {
827        fn to_biguint(&self) -> Option<BigUint> {
828            Some(bytes_to_biguint(self.to_bytes().as_ref()))
829        }
830    }
831
832    /// t = (modulus - 1) >> S
833    const T: [u64; 4] = [
834        0xeeff497a3340d905,
835        0xfaeabb739abd2280,
836        0xffffffffffffffff,
837        0x03ffffffffffffff,
838    ];
839
840    #[test]
841    fn two_inv_constant() {
842        assert_eq!(Scalar::from(2u32) * Scalar::TWO_INV, Scalar::ONE);
843    }
844
845    #[test]
846    fn root_of_unity_constant() {
847        // ROOT_OF_UNITY^{2^s} mod m == 1
848        assert_eq!(
849            Scalar::ROOT_OF_UNITY.pow_vartime(&[1u64 << Scalar::S, 0, 0, 0]),
850            Scalar::ONE
851        );
852
853        // MULTIPLICATIVE_GENERATOR^{t} mod m == ROOT_OF_UNITY
854        assert_eq!(
855            Scalar::MULTIPLICATIVE_GENERATOR.pow_vartime(&T),
856            Scalar::ROOT_OF_UNITY
857        )
858    }
859
860    #[test]
861    fn root_of_unity_inv_constant() {
862        assert_eq!(
863            Scalar::ROOT_OF_UNITY * Scalar::ROOT_OF_UNITY_INV,
864            Scalar::ONE
865        );
866    }
867
868    #[test]
869    fn delta_constant() {
870        // DELTA^{t} mod m == 1
871        assert_eq!(Scalar::DELTA.pow_vartime(&T), Scalar::ONE);
872    }
873
874    #[test]
875    fn is_high() {
876        // 0 is not high
877        let high: bool = Scalar::ZERO.is_high().into();
878        assert!(!high);
879
880        // 1 is not high
881        let one = 1.to_biguint().unwrap();
882        let high: bool = Scalar::from(&one).is_high().into();
883        assert!(!high);
884
885        let m = Scalar::modulus_as_biguint();
886        let m_by_2 = &m >> 1;
887
888        // M / 2 is not high
889        let high: bool = Scalar::from(&m_by_2).is_high().into();
890        assert!(!high);
891
892        // M / 2 + 1 is high
893        let high: bool = Scalar::from(&m_by_2 + &one).is_high().into();
894        assert!(high);
895
896        // MODULUS - 1 is high
897        let high: bool = Scalar::from(&m - &one).is_high().into();
898        assert!(high);
899    }
900
901    /// Basic tests that sqrt works.
902    #[test]
903    fn sqrt() {
904        for &n in &[1u64, 4, 9, 16, 25, 36, 49, 64] {
905            let scalar = Scalar::from(n);
906            let sqrt = scalar.sqrt().unwrap();
907            assert_eq!(sqrt.square(), scalar);
908        }
909    }
910
911    /// Basic tests that `invert` works.
912    #[test]
913    fn invert() {
914        assert_eq!(Scalar::ONE, Scalar::ONE.invert().unwrap());
915
916        let three = Scalar::from(3u64);
917        let inv_three = three.invert().unwrap();
918        assert_eq!(three * inv_three, Scalar::ONE);
919
920        let minus_three = -three;
921        let inv_minus_three = minus_three.invert().unwrap();
922        assert_eq!(inv_minus_three, -inv_three);
923        assert_eq!(three * inv_minus_three, -Scalar::ONE);
924
925        assert!(bool::from(Scalar::ZERO.invert().is_none()));
926        assert_eq!(Scalar::from(2u64).invert().unwrap(), Scalar::TWO_INV);
927        assert_eq!(
928            Scalar::ROOT_OF_UNITY.invert_vartime().unwrap(),
929            Scalar::ROOT_OF_UNITY_INV
930        );
931    }
932
933    /// Basic tests that `invert_vartime` works.
934    #[test]
935    fn invert_vartime() {
936        assert_eq!(Scalar::ONE, Scalar::ONE.invert_vartime().unwrap());
937
938        let three = Scalar::from(3u64);
939        let inv_three = three.invert_vartime().unwrap();
940        assert_eq!(three * inv_three, Scalar::ONE);
941
942        let minus_three = -three;
943        let inv_minus_three = minus_three.invert_vartime().unwrap();
944        assert_eq!(inv_minus_three, -inv_three);
945        assert_eq!(three * inv_minus_three, -Scalar::ONE);
946
947        assert!(bool::from(Scalar::ZERO.invert_vartime().is_none()));
948        assert_eq!(
949            Scalar::from(2u64).invert_vartime().unwrap(),
950            Scalar::TWO_INV
951        );
952        assert_eq!(
953            Scalar::ROOT_OF_UNITY.invert_vartime().unwrap(),
954            Scalar::ROOT_OF_UNITY_INV
955        );
956    }
957
958    #[test]
959    fn batch_invert_array() {
960        let k: Scalar = Scalar::random(&mut OsRng);
961        let l: Scalar = Scalar::random(&mut OsRng);
962
963        let expected = [k.invert().unwrap(), l.invert().unwrap()];
964        assert_eq!(
965            <Scalar as BatchInvert<_>>::batch_invert(&[k, l]).unwrap(),
966            expected
967        );
968    }
969
970    #[test]
971    #[cfg(feature = "alloc")]
972    fn batch_invert() {
973        let k: Scalar = Scalar::random(&mut OsRng);
974        let l: Scalar = Scalar::random(&mut OsRng);
975
976        let expected = vec![k.invert().unwrap(), l.invert().unwrap()];
977        let scalars = vec![k, l];
978        let res: Vec<_> = <Scalar as BatchInvert<_>>::batch_invert(scalars.as_slice()).unwrap();
979        assert_eq!(res, expected);
980    }
981
982    #[test]
983    fn negate() {
984        let zero_neg = -Scalar::ZERO;
985        assert_eq!(zero_neg, Scalar::ZERO);
986
987        let m = Scalar::modulus_as_biguint();
988        let one = 1.to_biguint().unwrap();
989        let m_minus_one = &m - &one;
990        let m_by_2 = &m >> 1;
991
992        let one_neg = -Scalar::ONE;
993        assert_eq!(one_neg, Scalar::from(&m_minus_one));
994
995        let frac_modulus_2_neg = -Scalar::from(&m_by_2);
996        let frac_modulus_2_plus_one = Scalar::from(&m_by_2 + &one);
997        assert_eq!(frac_modulus_2_neg, frac_modulus_2_plus_one);
998
999        let modulus_minus_one_neg = -Scalar::from(&m - &one);
1000        assert_eq!(modulus_minus_one_neg, Scalar::ONE);
1001    }
1002
1003    #[test]
1004    fn add_result_within_256_bits() {
1005        // A regression for a bug where reduction was not applied
1006        // when the unreduced result of addition was in the range `[modulus, 2^256)`.
1007        let t = 1.to_biguint().unwrap() << 255;
1008        let one = 1.to_biguint().unwrap();
1009
1010        let a = Scalar::from(&t - &one);
1011        let b = Scalar::from(&t);
1012        let res = &a + &b;
1013
1014        let m = Scalar::modulus_as_biguint();
1015        let res_ref = Scalar::from((&t + &t - &one) % &m);
1016
1017        assert_eq!(res, res_ref);
1018    }
1019
1020    #[test]
1021    fn generate_biased() {
1022        use elliptic_curve::rand_core::OsRng;
1023        let a = Scalar::generate_biased(&mut OsRng);
1024        // just to make sure `a` is not optimized out by the compiler
1025        assert_eq!((a - &a).is_zero().unwrap_u8(), 1);
1026    }
1027
1028    #[test]
1029    fn generate_vartime() {
1030        use elliptic_curve::rand_core::OsRng;
1031        let a = Scalar::generate_vartime(&mut OsRng);
1032        // just to make sure `a` is not optimized out by the compiler
1033        assert_eq!((a - &a).is_zero().unwrap_u8(), 1);
1034    }
1035
1036    #[test]
1037    fn from_bytes_reduced() {
1038        let m = Scalar::modulus_as_biguint();
1039
1040        fn reduce<T: Reduce<U256, Bytes = FieldBytes>>(arr: &[u8]) -> T {
1041            T::reduce_bytes(GenericArray::from_slice(arr))
1042        }
1043
1044        // Regular reduction
1045
1046        let s = reduce::<Scalar>(&[0xffu8; 32]).to_biguint().unwrap();
1047        assert!(s < m);
1048
1049        let s = reduce::<Scalar>(&[0u8; 32]).to_biguint().unwrap();
1050        assert!(s.is_zero());
1051
1052        let s = reduce::<Scalar>(&ORDER.to_be_byte_array())
1053            .to_biguint()
1054            .unwrap();
1055        assert!(s.is_zero());
1056
1057        // Reduction to a non-zero scalar
1058
1059        let s = reduce::<NonZeroScalar>(&[0xffu8; 32]).to_biguint().unwrap();
1060        assert!(s < m);
1061
1062        let s = reduce::<NonZeroScalar>(&[0u8; 32]).to_biguint().unwrap();
1063        assert!(s < m);
1064        assert!(!s.is_zero());
1065
1066        let s = reduce::<NonZeroScalar>(&ORDER.to_be_byte_array())
1067            .to_biguint()
1068            .unwrap();
1069        assert!(s < m);
1070        assert!(!s.is_zero());
1071
1072        let s = reduce::<NonZeroScalar>(&(ORDER.wrapping_sub(&U256::ONE)).to_be_byte_array())
1073            .to_biguint()
1074            .unwrap();
1075        assert!(s < m);
1076        assert!(!s.is_zero());
1077    }
1078
1079    #[test]
1080    fn from_wide_bytes_reduced() {
1081        let m = Scalar::modulus_as_biguint();
1082
1083        fn reduce<T: Reduce<U512, Bytes = WideBytes>>(slice: &[u8]) -> T {
1084            let mut bytes = WideBytes::default();
1085            bytes[(64 - slice.len())..].copy_from_slice(slice);
1086            T::reduce_bytes(&bytes)
1087        }
1088
1089        // Regular reduction
1090
1091        let s = reduce::<Scalar>(&[0xffu8; 64]).to_biguint().unwrap();
1092        assert!(s < m);
1093
1094        let s = reduce::<Scalar>(&[0u8; 64]).to_biguint().unwrap();
1095        assert!(s.is_zero());
1096
1097        let s = reduce::<Scalar>(&ORDER.to_be_byte_array())
1098            .to_biguint()
1099            .unwrap();
1100        assert!(s.is_zero());
1101
1102        // Reduction to a non-zero scalar
1103
1104        let s = reduce::<NonZeroScalar>(&[0xffu8; 64]).to_biguint().unwrap();
1105        assert!(s < m);
1106
1107        let s = reduce::<NonZeroScalar>(&[0u8; 64]).to_biguint().unwrap();
1108        assert!(s < m);
1109        assert!(!s.is_zero());
1110
1111        let s = reduce::<NonZeroScalar>(&ORDER.to_be_byte_array())
1112            .to_biguint()
1113            .unwrap();
1114        assert!(s < m);
1115        assert!(!s.is_zero());
1116
1117        let s = reduce::<NonZeroScalar>(&(ORDER.wrapping_sub(&U256::ONE)).to_be_byte_array())
1118            .to_biguint()
1119            .unwrap();
1120        assert!(s < m);
1121        assert!(!s.is_zero());
1122    }
1123
1124    prop_compose! {
1125        fn scalar()(bytes in any::<[u8; 32]>()) -> Scalar {
1126            <Scalar as Reduce<U256>>::reduce_bytes(&bytes.into())
1127        }
1128    }
1129
1130    proptest! {
1131        #[test]
1132        fn fuzzy_roundtrip_to_bytes(a in scalar()) {
1133            let a_back = Scalar::from_repr(a.to_bytes()).unwrap();
1134            assert_eq!(a, a_back);
1135        }
1136
1137        #[test]
1138        fn fuzzy_roundtrip_to_bytes_unchecked(a in scalar()) {
1139            let bytes = a.to_bytes();
1140            let a_back = Scalar::from_bytes_unchecked(bytes.as_ref());
1141            assert_eq!(a, a_back);
1142        }
1143
1144        #[test]
1145        fn fuzzy_add(a in scalar(), b in scalar()) {
1146            let a_bi = a.to_biguint().unwrap();
1147            let b_bi = b.to_biguint().unwrap();
1148
1149            let res_bi = (&a_bi + &b_bi) % &Scalar::modulus_as_biguint();
1150            let res_ref = Scalar::from(&res_bi);
1151            let res_test = a.add(&b);
1152
1153            assert_eq!(res_ref, res_test);
1154        }
1155
1156        #[test]
1157        fn fuzzy_sub(a in scalar(), b in scalar()) {
1158            let a_bi = a.to_biguint().unwrap();
1159            let b_bi = b.to_biguint().unwrap();
1160
1161            let m = Scalar::modulus_as_biguint();
1162            let res_bi = (&m + &a_bi - &b_bi) % &m;
1163            let res_ref = Scalar::from(&res_bi);
1164            let res_test = a.sub(&b);
1165
1166            assert_eq!(res_ref, res_test);
1167        }
1168
1169        #[test]
1170        fn fuzzy_neg(a in scalar()) {
1171            let a_bi = a.to_biguint().unwrap();
1172
1173            let m = Scalar::modulus_as_biguint();
1174            let res_bi = (&m - &a_bi) % &m;
1175            let res_ref = Scalar::from(&res_bi);
1176            let res_test = -a;
1177
1178            assert_eq!(res_ref, res_test);
1179        }
1180
1181        #[test]
1182        fn fuzzy_mul(a in scalar(), b in scalar()) {
1183            let a_bi = a.to_biguint().unwrap();
1184            let b_bi = b.to_biguint().unwrap();
1185
1186            let res_bi = (&a_bi * &b_bi) % &Scalar::modulus_as_biguint();
1187            let res_ref = Scalar::from(&res_bi);
1188            let res_test = a.mul(&b);
1189
1190            assert_eq!(res_ref, res_test);
1191        }
1192
1193        #[test]
1194        fn fuzzy_rshift(a in scalar(), b in 0usize..512) {
1195            let a_bi = a.to_biguint().unwrap();
1196
1197            let res_bi = &a_bi >> b;
1198            let res_ref = Scalar::from(&res_bi);
1199            let res_test = a >> b;
1200
1201            assert_eq!(res_ref, res_test);
1202        }
1203
1204        #[test]
1205        fn fuzzy_invert(
1206            a in scalar()
1207        ) {
1208            let a = if bool::from(a.is_zero()) { Scalar::ONE } else { a };
1209            let a_bi = a.to_biguint().unwrap();
1210            let inv = a.invert().unwrap();
1211            let inv_bi = inv.to_biguint().unwrap();
1212            let m = Scalar::modulus_as_biguint();
1213            assert_eq!((&inv_bi * &a_bi) % &m, 1.to_biguint().unwrap());
1214        }
1215
1216        #[test]
1217        fn fuzzy_invert_vartime(w in scalar()) {
1218            let inv: Option<Scalar> = w.invert().into();
1219            let inv_vartime: Option<Scalar> = w.invert_vartime().into();
1220            assert_eq!(inv, inv_vartime);
1221        }
1222
1223        #[test]
1224        fn fuzzy_from_wide_bytes_reduced(bytes_hi in any::<[u8; 32]>(), bytes_lo in any::<[u8; 32]>()) {
1225            let m = Scalar::modulus_as_biguint();
1226            let mut bytes = [0u8; 64];
1227            bytes[0..32].clone_from_slice(&bytes_hi);
1228            bytes[32..64].clone_from_slice(&bytes_lo);
1229            let s = <Scalar as Reduce<U512>>::reduce(U512::from_be_slice(&bytes));
1230            let s_bu = s.to_biguint().unwrap();
1231            assert!(s_bu < m);
1232        }
1233    }
1234}