bls12_381/
scalar.rs

1//! This module provides an implementation of the BLS12-381 scalar field $\mathbb{F}_q$
2//! where `q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001`
3
4use core::fmt;
5use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
6use rand_core::RngCore;
7
8use ff::{Field, PrimeField};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "bits")]
12use ff::{FieldBits, PrimeFieldBits};
13
14use crate::util::{adc, mac, sbb};
15
16/// Represents an element of the scalar field $\mathbb{F}_q$ of the BLS12-381 elliptic
17/// curve construction.
18// The internal representation of this type is four 64-bit unsigned
19// integers in little-endian order. `Scalar` values are always in
20// Montgomery form; i.e., Scalar(a) = aR mod q, with R = 2^256.
21#[derive(Clone, Copy, Eq)]
22pub struct Scalar(pub(crate) [u64; 4]);
23
24impl fmt::Debug for Scalar {
25    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26        let tmp = self.to_bytes();
27        write!(f, "0x")?;
28        for &b in tmp.iter().rev() {
29            write!(f, "{:02x}", b)?;
30        }
31        Ok(())
32    }
33}
34
35impl fmt::Display for Scalar {
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        write!(f, "{:?}", self)
38    }
39}
40
41impl From<u64> for Scalar {
42    fn from(val: u64) -> Scalar {
43        Scalar([val, 0, 0, 0]) * R2
44    }
45}
46
47impl ConstantTimeEq for Scalar {
48    fn ct_eq(&self, other: &Self) -> Choice {
49        self.0[0].ct_eq(&other.0[0])
50            & self.0[1].ct_eq(&other.0[1])
51            & self.0[2].ct_eq(&other.0[2])
52            & self.0[3].ct_eq(&other.0[3])
53    }
54}
55
56impl PartialEq for Scalar {
57    #[inline]
58    fn eq(&self, other: &Self) -> bool {
59        bool::from(self.ct_eq(other))
60    }
61}
62
63impl ConditionallySelectable for Scalar {
64    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
65        Scalar([
66            u64::conditional_select(&a.0[0], &b.0[0], choice),
67            u64::conditional_select(&a.0[1], &b.0[1], choice),
68            u64::conditional_select(&a.0[2], &b.0[2], choice),
69            u64::conditional_select(&a.0[3], &b.0[3], choice),
70        ])
71    }
72}
73
74/// Constant representing the modulus
75/// q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
76const MODULUS: Scalar = Scalar([
77    0xffff_ffff_0000_0001,
78    0x53bd_a402_fffe_5bfe,
79    0x3339_d808_09a1_d805,
80    0x73ed_a753_299d_7d48,
81]);
82
83/// The modulus as u32 limbs.
84#[cfg(all(feature = "bits", not(target_pointer_width = "64")))]
85const MODULUS_LIMBS_32: [u32; 8] = [
86    0x0000_0001,
87    0xffff_ffff,
88    0xfffe_5bfe,
89    0x53bd_a402,
90    0x09a1_d805,
91    0x3339_d808,
92    0x299d_7d48,
93    0x73ed_a753,
94];
95
96// The number of bits needed to represent the modulus.
97const MODULUS_BITS: u32 = 255;
98
99// GENERATOR = 7 (multiplicative generator of r-1 order, that is also quadratic nonresidue)
100const GENERATOR: Scalar = Scalar([
101    0x0000_000e_ffff_fff1,
102    0x17e3_63d3_0018_9c0f,
103    0xff9c_5787_6f84_57b0,
104    0x3513_3220_8fc5_a8c4,
105]);
106
107impl<'a> Neg for &'a Scalar {
108    type Output = Scalar;
109
110    #[inline]
111    fn neg(self) -> Scalar {
112        self.neg()
113    }
114}
115
116impl Neg for Scalar {
117    type Output = Scalar;
118
119    #[inline]
120    fn neg(self) -> Scalar {
121        -&self
122    }
123}
124
125impl<'a, 'b> Sub<&'b Scalar> for &'a Scalar {
126    type Output = Scalar;
127
128    #[inline]
129    fn sub(self, rhs: &'b Scalar) -> Scalar {
130        self.sub(rhs)
131    }
132}
133
134impl<'a, 'b> Add<&'b Scalar> for &'a Scalar {
135    type Output = Scalar;
136
137    #[inline]
138    fn add(self, rhs: &'b Scalar) -> Scalar {
139        self.add(rhs)
140    }
141}
142
143impl<'a, 'b> Mul<&'b Scalar> for &'a Scalar {
144    type Output = Scalar;
145
146    #[inline]
147    fn mul(self, rhs: &'b Scalar) -> Scalar {
148        self.mul(rhs)
149    }
150}
151
152impl_binops_additive!(Scalar, Scalar);
153impl_binops_multiplicative!(Scalar, Scalar);
154
155/// INV = -(q^{-1} mod 2^64) mod 2^64
156const INV: u64 = 0xffff_fffe_ffff_ffff;
157
158/// R = 2^256 mod q
159const R: Scalar = Scalar([
160    0x0000_0001_ffff_fffe,
161    0x5884_b7fa_0003_4802,
162    0x998c_4fef_ecbc_4ff5,
163    0x1824_b159_acc5_056f,
164]);
165
166/// R^2 = 2^512 mod q
167const R2: Scalar = Scalar([
168    0xc999_e990_f3f2_9c6d,
169    0x2b6c_edcb_8792_5c23,
170    0x05d3_1496_7254_398f,
171    0x0748_d9d9_9f59_ff11,
172]);
173
174/// R^3 = 2^768 mod q
175const R3: Scalar = Scalar([
176    0xc62c_1807_439b_73af,
177    0x1b3e_0d18_8cf0_6990,
178    0x73d1_3c71_c7b5_f418,
179    0x6e2a_5bb9_c8db_33e9,
180]);
181
182// 2^S * t = MODULUS - 1 with t odd
183const S: u32 = 32;
184
185/// GENERATOR^t where t * 2^s + 1 = q
186/// with t odd. In other words, this
187/// is a 2^s root of unity.
188///
189/// `GENERATOR = 7 mod q` is a generator
190/// of the q - 1 order multiplicative
191/// subgroup.
192const ROOT_OF_UNITY: Scalar = Scalar([
193    0xb9b5_8d8c_5f0e_466a,
194    0x5b1b_4c80_1819_d7ec,
195    0x0af5_3ae3_52a3_1e64,
196    0x5bf3_adda_19e9_b27b,
197]);
198
199impl Default for Scalar {
200    #[inline]
201    fn default() -> Self {
202        Self::zero()
203    }
204}
205
206#[cfg(feature = "zeroize")]
207impl zeroize::DefaultIsZeroes for Scalar {}
208
209impl Scalar {
210    /// Returns zero, the additive identity.
211    #[inline]
212    pub const fn zero() -> Scalar {
213        Scalar([0, 0, 0, 0])
214    }
215
216    /// Returns one, the multiplicative identity.
217    #[inline]
218    pub const fn one() -> Scalar {
219        R
220    }
221
222    /// Doubles this field element.
223    #[inline]
224    pub const fn double(&self) -> Scalar {
225        // TODO: This can be achieved more efficiently with a bitshift.
226        self.add(self)
227    }
228
229    /// Attempts to convert a little-endian byte representation of
230    /// a scalar into a `Scalar`, failing if the input is not canonical.
231    pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> {
232        let mut tmp = Scalar([0, 0, 0, 0]);
233
234        tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap());
235        tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap());
236        tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap());
237        tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap());
238
239        // Try to subtract the modulus
240        let (_, borrow) = sbb(tmp.0[0], MODULUS.0[0], 0);
241        let (_, borrow) = sbb(tmp.0[1], MODULUS.0[1], borrow);
242        let (_, borrow) = sbb(tmp.0[2], MODULUS.0[2], borrow);
243        let (_, borrow) = sbb(tmp.0[3], MODULUS.0[3], borrow);
244
245        // If the element is smaller than MODULUS then the
246        // subtraction will underflow, producing a borrow value
247        // of 0xffff...ffff. Otherwise, it'll be zero.
248        let is_some = (borrow as u8) & 1;
249
250        // Convert to Montgomery form by computing
251        // (a.R^0 * R^2) / R = a.R
252        tmp *= &R2;
253
254        CtOption::new(tmp, Choice::from(is_some))
255    }
256
257    /// Converts an element of `Scalar` into a byte representation in
258    /// little-endian byte order.
259    pub fn to_bytes(&self) -> [u8; 32] {
260        // Turn into canonical form by computing
261        // (a.R) / R = a
262        let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0);
263
264        let mut res = [0; 32];
265        res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
266        res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes());
267        res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes());
268        res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes());
269
270        res
271    }
272
273    /// Converts a 512-bit little endian integer into
274    /// a `Scalar` by reducing by the modulus.
275    pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar {
276        Scalar::from_u512([
277            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()),
278            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()),
279            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()),
280            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()),
281            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()),
282            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()),
283            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()),
284            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[56..64]).unwrap()),
285        ])
286    }
287
288    fn from_u512(limbs: [u64; 8]) -> Scalar {
289        // We reduce an arbitrary 512-bit number by decomposing it into two 256-bit digits
290        // with the higher bits multiplied by 2^256. Thus, we perform two reductions
291        //
292        // 1. the lower bits are multiplied by R^2, as normal
293        // 2. the upper bits are multiplied by R^2 * 2^256 = R^3
294        //
295        // and computing their sum in the field. It remains to see that arbitrary 256-bit
296        // numbers can be placed into Montgomery form safely using the reduction. The
297        // reduction works so long as the product is less than R=2^256 multiplied by
298        // the modulus. This holds because for any `c` smaller than the modulus, we have
299        // that (2^256 - 1)*c is an acceptable product for the reduction. Therefore, the
300        // reduction always works so long as `c` is in the field; in this case it is either the
301        // constant `R2` or `R3`.
302        let d0 = Scalar([limbs[0], limbs[1], limbs[2], limbs[3]]);
303        let d1 = Scalar([limbs[4], limbs[5], limbs[6], limbs[7]]);
304        // Convert to Montgomery form
305        d0 * R2 + d1 * R3
306    }
307
308    /// Converts from an integer represented in little endian
309    /// into its (congruent) `Scalar` representation.
310    pub const fn from_raw(val: [u64; 4]) -> Self {
311        (&Scalar(val)).mul(&R2)
312    }
313
314    /// Squares this element.
315    #[inline]
316    pub const fn square(&self) -> Scalar {
317        let (r1, carry) = mac(0, self.0[0], self.0[1], 0);
318        let (r2, carry) = mac(0, self.0[0], self.0[2], carry);
319        let (r3, r4) = mac(0, self.0[0], self.0[3], carry);
320
321        let (r3, carry) = mac(r3, self.0[1], self.0[2], 0);
322        let (r4, r5) = mac(r4, self.0[1], self.0[3], carry);
323
324        let (r5, r6) = mac(r5, self.0[2], self.0[3], 0);
325
326        let r7 = r6 >> 63;
327        let r6 = (r6 << 1) | (r5 >> 63);
328        let r5 = (r5 << 1) | (r4 >> 63);
329        let r4 = (r4 << 1) | (r3 >> 63);
330        let r3 = (r3 << 1) | (r2 >> 63);
331        let r2 = (r2 << 1) | (r1 >> 63);
332        let r1 = r1 << 1;
333
334        let (r0, carry) = mac(0, self.0[0], self.0[0], 0);
335        let (r1, carry) = adc(0, r1, carry);
336        let (r2, carry) = mac(r2, self.0[1], self.0[1], carry);
337        let (r3, carry) = adc(0, r3, carry);
338        let (r4, carry) = mac(r4, self.0[2], self.0[2], carry);
339        let (r5, carry) = adc(0, r5, carry);
340        let (r6, carry) = mac(r6, self.0[3], self.0[3], carry);
341        let (r7, _) = adc(0, r7, carry);
342
343        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
344    }
345
346    /// Computes the square root of this element, if it exists.
347    pub fn sqrt(&self) -> CtOption<Self> {
348        // Tonelli-Shank's algorithm for q mod 16 = 1
349        // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
350
351        // w = self^((t - 1) // 2)
352        //   = self^6104339283789297388802252303364915521546564123189034618274734669823
353        let w = self.pow_vartime(&[
354            0x7fff_2dff_7fff_ffff,
355            0x04d0_ec02_a9de_d201,
356            0x94ce_bea4_199c_ec04,
357            0x0000_0000_39f6_d3a9,
358        ]);
359
360        let mut v = S;
361        let mut x = self * w;
362        let mut b = x * w;
363
364        // Initialize z as the 2^S root of unity.
365        let mut z = ROOT_OF_UNITY;
366
367        for max_v in (1..=S).rev() {
368            let mut k = 1;
369            let mut tmp = b.square();
370            let mut j_less_than_v: Choice = 1.into();
371
372            for j in 2..max_v {
373                let tmp_is_one = tmp.ct_eq(&Scalar::one());
374                let squared = Scalar::conditional_select(&tmp, &z, tmp_is_one).square();
375                tmp = Scalar::conditional_select(&squared, &tmp, tmp_is_one);
376                let new_z = Scalar::conditional_select(&z, &squared, tmp_is_one);
377                j_less_than_v &= !j.ct_eq(&v);
378                k = u32::conditional_select(&j, &k, tmp_is_one);
379                z = Scalar::conditional_select(&z, &new_z, j_less_than_v);
380            }
381
382            let result = x * z;
383            x = Scalar::conditional_select(&result, &x, b.ct_eq(&Scalar::one()));
384            z = z.square();
385            b *= z;
386            v = k;
387        }
388
389        CtOption::new(
390            x,
391            (x * x).ct_eq(self), // Only return Some if it's the square root.
392        )
393    }
394
395    /// Exponentiates `self` by `by`, where `by` is a
396    /// little-endian order integer exponent.
397    pub fn pow(&self, by: &[u64; 4]) -> Self {
398        let mut res = Self::one();
399        for e in by.iter().rev() {
400            for i in (0..64).rev() {
401                res = res.square();
402                let mut tmp = res;
403                tmp *= self;
404                res.conditional_assign(&tmp, (((*e >> i) & 0x1) as u8).into());
405            }
406        }
407        res
408    }
409
410    /// Exponentiates `self` by `by`, where `by` is a
411    /// little-endian order integer exponent.
412    ///
413    /// **This operation is variable time with respect
414    /// to the exponent.** If the exponent is fixed,
415    /// this operation is effectively constant time.
416    pub fn pow_vartime(&self, by: &[u64; 4]) -> Self {
417        let mut res = Self::one();
418        for e in by.iter().rev() {
419            for i in (0..64).rev() {
420                res = res.square();
421
422                if ((*e >> i) & 1) == 1 {
423                    res.mul_assign(self);
424                }
425            }
426        }
427        res
428    }
429
430    /// Computes the multiplicative inverse of this element,
431    /// failing if the element is zero.
432    pub fn invert(&self) -> CtOption<Self> {
433        #[inline(always)]
434        fn square_assign_multi(n: &mut Scalar, num_times: usize) {
435            for _ in 0..num_times {
436                *n = n.square();
437            }
438        }
439        // found using https://github.com/kwantam/addchain
440        let mut t0 = self.square();
441        let mut t1 = t0 * self;
442        let mut t16 = t0.square();
443        let mut t6 = t16.square();
444        let mut t5 = t6 * t0;
445        t0 = t6 * t16;
446        let mut t12 = t5 * t16;
447        let mut t2 = t6.square();
448        let mut t7 = t5 * t6;
449        let mut t15 = t0 * t5;
450        let mut t17 = t12.square();
451        t1 *= t17;
452        let mut t3 = t7 * t2;
453        let t8 = t1 * t17;
454        let t4 = t8 * t2;
455        let t9 = t8 * t7;
456        t7 = t4 * t5;
457        let t11 = t4 * t17;
458        t5 = t9 * t17;
459        let t14 = t7 * t15;
460        let t13 = t11 * t12;
461        t12 = t11 * t17;
462        t15 *= &t12;
463        t16 *= &t15;
464        t3 *= &t16;
465        t17 *= &t3;
466        t0 *= &t17;
467        t6 *= &t0;
468        t2 *= &t6;
469        square_assign_multi(&mut t0, 8);
470        t0 *= &t17;
471        square_assign_multi(&mut t0, 9);
472        t0 *= &t16;
473        square_assign_multi(&mut t0, 9);
474        t0 *= &t15;
475        square_assign_multi(&mut t0, 9);
476        t0 *= &t15;
477        square_assign_multi(&mut t0, 7);
478        t0 *= &t14;
479        square_assign_multi(&mut t0, 7);
480        t0 *= &t13;
481        square_assign_multi(&mut t0, 10);
482        t0 *= &t12;
483        square_assign_multi(&mut t0, 9);
484        t0 *= &t11;
485        square_assign_multi(&mut t0, 8);
486        t0 *= &t8;
487        square_assign_multi(&mut t0, 8);
488        t0 *= self;
489        square_assign_multi(&mut t0, 14);
490        t0 *= &t9;
491        square_assign_multi(&mut t0, 10);
492        t0 *= &t8;
493        square_assign_multi(&mut t0, 15);
494        t0 *= &t7;
495        square_assign_multi(&mut t0, 10);
496        t0 *= &t6;
497        square_assign_multi(&mut t0, 8);
498        t0 *= &t5;
499        square_assign_multi(&mut t0, 16);
500        t0 *= &t3;
501        square_assign_multi(&mut t0, 8);
502        t0 *= &t2;
503        square_assign_multi(&mut t0, 7);
504        t0 *= &t4;
505        square_assign_multi(&mut t0, 9);
506        t0 *= &t2;
507        square_assign_multi(&mut t0, 8);
508        t0 *= &t3;
509        square_assign_multi(&mut t0, 8);
510        t0 *= &t2;
511        square_assign_multi(&mut t0, 8);
512        t0 *= &t2;
513        square_assign_multi(&mut t0, 8);
514        t0 *= &t2;
515        square_assign_multi(&mut t0, 8);
516        t0 *= &t3;
517        square_assign_multi(&mut t0, 8);
518        t0 *= &t2;
519        square_assign_multi(&mut t0, 8);
520        t0 *= &t2;
521        square_assign_multi(&mut t0, 5);
522        t0 *= &t1;
523        square_assign_multi(&mut t0, 5);
524        t0 *= &t1;
525
526        CtOption::new(t0, !self.ct_eq(&Self::zero()))
527    }
528
529    #[inline(always)]
530    const fn montgomery_reduce(
531        r0: u64,
532        r1: u64,
533        r2: u64,
534        r3: u64,
535        r4: u64,
536        r5: u64,
537        r6: u64,
538        r7: u64,
539    ) -> Self {
540        // The Montgomery reduction here is based on Algorithm 14.32 in
541        // Handbook of Applied Cryptography
542        // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
543
544        let k = r0.wrapping_mul(INV);
545        let (_, carry) = mac(r0, k, MODULUS.0[0], 0);
546        let (r1, carry) = mac(r1, k, MODULUS.0[1], carry);
547        let (r2, carry) = mac(r2, k, MODULUS.0[2], carry);
548        let (r3, carry) = mac(r3, k, MODULUS.0[3], carry);
549        let (r4, carry2) = adc(r4, 0, carry);
550
551        let k = r1.wrapping_mul(INV);
552        let (_, carry) = mac(r1, k, MODULUS.0[0], 0);
553        let (r2, carry) = mac(r2, k, MODULUS.0[1], carry);
554        let (r3, carry) = mac(r3, k, MODULUS.0[2], carry);
555        let (r4, carry) = mac(r4, k, MODULUS.0[3], carry);
556        let (r5, carry2) = adc(r5, carry2, carry);
557
558        let k = r2.wrapping_mul(INV);
559        let (_, carry) = mac(r2, k, MODULUS.0[0], 0);
560        let (r3, carry) = mac(r3, k, MODULUS.0[1], carry);
561        let (r4, carry) = mac(r4, k, MODULUS.0[2], carry);
562        let (r5, carry) = mac(r5, k, MODULUS.0[3], carry);
563        let (r6, carry2) = adc(r6, carry2, carry);
564
565        let k = r3.wrapping_mul(INV);
566        let (_, carry) = mac(r3, k, MODULUS.0[0], 0);
567        let (r4, carry) = mac(r4, k, MODULUS.0[1], carry);
568        let (r5, carry) = mac(r5, k, MODULUS.0[2], carry);
569        let (r6, carry) = mac(r6, k, MODULUS.0[3], carry);
570        let (r7, _) = adc(r7, carry2, carry);
571
572        // Result may be within MODULUS of the correct value
573        (&Scalar([r4, r5, r6, r7])).sub(&MODULUS)
574    }
575
576    /// Multiplies `rhs` by `self`, returning the result.
577    #[inline]
578    pub const fn mul(&self, rhs: &Self) -> Self {
579        // Schoolbook multiplication
580
581        let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
582        let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
583        let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
584        let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);
585
586        let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
587        let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
588        let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
589        let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);
590
591        let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
592        let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
593        let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
594        let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);
595
596        let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
597        let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
598        let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
599        let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);
600
601        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
602    }
603
604    /// Subtracts `rhs` from `self`, returning the result.
605    #[inline]
606    pub const fn sub(&self, rhs: &Self) -> Self {
607        let (d0, borrow) = sbb(self.0[0], rhs.0[0], 0);
608        let (d1, borrow) = sbb(self.0[1], rhs.0[1], borrow);
609        let (d2, borrow) = sbb(self.0[2], rhs.0[2], borrow);
610        let (d3, borrow) = sbb(self.0[3], rhs.0[3], borrow);
611
612        // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise
613        // borrow = 0x000...000. Thus, we use it as a mask to conditionally add the modulus.
614        let (d0, carry) = adc(d0, MODULUS.0[0] & borrow, 0);
615        let (d1, carry) = adc(d1, MODULUS.0[1] & borrow, carry);
616        let (d2, carry) = adc(d2, MODULUS.0[2] & borrow, carry);
617        let (d3, _) = adc(d3, MODULUS.0[3] & borrow, carry);
618
619        Scalar([d0, d1, d2, d3])
620    }
621
622    /// Adds `rhs` to `self`, returning the result.
623    #[inline]
624    pub const fn add(&self, rhs: &Self) -> Self {
625        let (d0, carry) = adc(self.0[0], rhs.0[0], 0);
626        let (d1, carry) = adc(self.0[1], rhs.0[1], carry);
627        let (d2, carry) = adc(self.0[2], rhs.0[2], carry);
628        let (d3, _) = adc(self.0[3], rhs.0[3], carry);
629
630        // Attempt to subtract the modulus, to ensure the value
631        // is smaller than the modulus.
632        (&Scalar([d0, d1, d2, d3])).sub(&MODULUS)
633    }
634
635    /// Negates `self`.
636    #[inline]
637    pub const fn neg(&self) -> Self {
638        // Subtract `self` from `MODULUS` to negate. Ignore the final
639        // borrow because it cannot underflow; self is guaranteed to
640        // be in the field.
641        let (d0, borrow) = sbb(MODULUS.0[0], self.0[0], 0);
642        let (d1, borrow) = sbb(MODULUS.0[1], self.0[1], borrow);
643        let (d2, borrow) = sbb(MODULUS.0[2], self.0[2], borrow);
644        let (d3, _) = sbb(MODULUS.0[3], self.0[3], borrow);
645
646        // `tmp` could be `MODULUS` if `self` was zero. Create a mask that is
647        // zero if `self` was zero, and `u64::max_value()` if self was nonzero.
648        let mask = (((self.0[0] | self.0[1] | self.0[2] | self.0[3]) == 0) as u64).wrapping_sub(1);
649
650        Scalar([d0 & mask, d1 & mask, d2 & mask, d3 & mask])
651    }
652}
653
654impl From<Scalar> for [u8; 32] {
655    fn from(value: Scalar) -> [u8; 32] {
656        value.to_bytes()
657    }
658}
659
660impl<'a> From<&'a Scalar> for [u8; 32] {
661    fn from(value: &'a Scalar) -> [u8; 32] {
662        value.to_bytes()
663    }
664}
665
666impl Field for Scalar {
667    fn random(mut rng: impl RngCore) -> Self {
668        let mut buf = [0; 64];
669        rng.fill_bytes(&mut buf);
670        Self::from_bytes_wide(&buf)
671    }
672
673    fn zero() -> Self {
674        Self::zero()
675    }
676
677    fn one() -> Self {
678        Self::one()
679    }
680
681    #[must_use]
682    fn square(&self) -> Self {
683        self.square()
684    }
685
686    #[must_use]
687    fn double(&self) -> Self {
688        self.double()
689    }
690
691    fn invert(&self) -> CtOption<Self> {
692        self.invert()
693    }
694
695    fn sqrt(&self) -> CtOption<Self> {
696        self.sqrt()
697    }
698}
699
700impl PrimeField for Scalar {
701    type Repr = [u8; 32];
702
703    fn from_repr(r: Self::Repr) -> CtOption<Self> {
704        Self::from_bytes(&r)
705    }
706
707    fn to_repr(&self) -> Self::Repr {
708        self.to_bytes()
709    }
710
711    fn is_odd(&self) -> Choice {
712        Choice::from(self.to_bytes()[0] & 1)
713    }
714
715    const NUM_BITS: u32 = MODULUS_BITS;
716    const CAPACITY: u32 = Self::NUM_BITS - 1;
717
718    fn multiplicative_generator() -> Self {
719        GENERATOR
720    }
721
722    const S: u32 = S;
723
724    fn root_of_unity() -> Self {
725        ROOT_OF_UNITY
726    }
727}
728
729#[cfg(all(feature = "bits", not(target_pointer_width = "64")))]
730type ReprBits = [u32; 8];
731
732#[cfg(all(feature = "bits", target_pointer_width = "64"))]
733type ReprBits = [u64; 4];
734
735#[cfg(feature = "bits")]
736impl PrimeFieldBits for Scalar {
737    type ReprBits = ReprBits;
738
739    fn to_le_bits(&self) -> FieldBits<Self::ReprBits> {
740        let bytes = self.to_bytes();
741
742        #[cfg(not(target_pointer_width = "64"))]
743        let limbs = [
744            u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
745            u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
746            u32::from_le_bytes(bytes[8..12].try_into().unwrap()),
747            u32::from_le_bytes(bytes[12..16].try_into().unwrap()),
748            u32::from_le_bytes(bytes[16..20].try_into().unwrap()),
749            u32::from_le_bytes(bytes[20..24].try_into().unwrap()),
750            u32::from_le_bytes(bytes[24..28].try_into().unwrap()),
751            u32::from_le_bytes(bytes[28..32].try_into().unwrap()),
752        ];
753
754        #[cfg(target_pointer_width = "64")]
755        let limbs = [
756            u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
757            u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
758            u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
759            u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
760        ];
761
762        FieldBits::new(limbs)
763    }
764
765    fn char_le_bits() -> FieldBits<Self::ReprBits> {
766        #[cfg(not(target_pointer_width = "64"))]
767        {
768            FieldBits::new(MODULUS_LIMBS_32)
769        }
770
771        #[cfg(target_pointer_width = "64")]
772        FieldBits::new(MODULUS.0)
773    }
774}
775
776impl<T> core::iter::Sum<T> for Scalar
777where
778    T: core::borrow::Borrow<Scalar>,
779{
780    fn sum<I>(iter: I) -> Self
781    where
782        I: Iterator<Item = T>,
783    {
784        iter.fold(Self::zero(), |acc, item| acc + item.borrow())
785    }
786}
787
788#[test]
789fn test_inv() {
790    // Compute -(q^{-1} mod 2^64) mod 2^64 by exponentiating
791    // by totient(2**64) - 1
792
793    let mut inv = 1u64;
794    for _ in 0..63 {
795        inv = inv.wrapping_mul(inv);
796        inv = inv.wrapping_mul(MODULUS.0[0]);
797    }
798    inv = inv.wrapping_neg();
799
800    assert_eq!(inv, INV);
801}
802
803#[cfg(feature = "std")]
804#[test]
805fn test_debug() {
806    assert_eq!(
807        format!("{:?}", Scalar::zero()),
808        "0x0000000000000000000000000000000000000000000000000000000000000000"
809    );
810    assert_eq!(
811        format!("{:?}", Scalar::one()),
812        "0x0000000000000000000000000000000000000000000000000000000000000001"
813    );
814    assert_eq!(
815        format!("{:?}", R2),
816        "0x1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe"
817    );
818}
819
820#[test]
821fn test_equality() {
822    assert_eq!(Scalar::zero(), Scalar::zero());
823    assert_eq!(Scalar::one(), Scalar::one());
824    assert_eq!(R2, R2);
825
826    assert!(Scalar::zero() != Scalar::one());
827    assert!(Scalar::one() != R2);
828}
829
830#[test]
831fn test_to_bytes() {
832    assert_eq!(
833        Scalar::zero().to_bytes(),
834        [
835            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
836            0, 0, 0
837        ]
838    );
839
840    assert_eq!(
841        Scalar::one().to_bytes(),
842        [
843            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
844            0, 0, 0
845        ]
846    );
847
848    assert_eq!(
849        R2.to_bytes(),
850        [
851            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
852            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24
853        ]
854    );
855
856    assert_eq!(
857        (-&Scalar::one()).to_bytes(),
858        [
859            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
860            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
861        ]
862    );
863}
864
865#[test]
866fn test_from_bytes() {
867    assert_eq!(
868        Scalar::from_bytes(&[
869            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
870            0, 0, 0
871        ])
872        .unwrap(),
873        Scalar::zero()
874    );
875
876    assert_eq!(
877        Scalar::from_bytes(&[
878            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
879            0, 0, 0
880        ])
881        .unwrap(),
882        Scalar::one()
883    );
884
885    assert_eq!(
886        Scalar::from_bytes(&[
887            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
888            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24
889        ])
890        .unwrap(),
891        R2
892    );
893
894    // -1 should work
895    assert!(bool::from(
896        Scalar::from_bytes(&[
897            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
898            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
899        ])
900        .is_some()
901    ));
902
903    // modulus is invalid
904    assert!(bool::from(
905        Scalar::from_bytes(&[
906            1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
907            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
908        ])
909        .is_none()
910    ));
911
912    // Anything larger than the modulus is invalid
913    assert!(bool::from(
914        Scalar::from_bytes(&[
915            2, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
916            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
917        ])
918        .is_none()
919    ));
920    assert!(bool::from(
921        Scalar::from_bytes(&[
922            1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
923            216, 58, 51, 72, 125, 157, 41, 83, 167, 237, 115
924        ])
925        .is_none()
926    ));
927    assert!(bool::from(
928        Scalar::from_bytes(&[
929            1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
930            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 116
931        ])
932        .is_none()
933    ));
934}
935
936#[test]
937fn test_from_u512_zero() {
938    assert_eq!(
939        Scalar::zero(),
940        Scalar::from_u512([
941            MODULUS.0[0],
942            MODULUS.0[1],
943            MODULUS.0[2],
944            MODULUS.0[3],
945            0,
946            0,
947            0,
948            0
949        ])
950    );
951}
952
953#[test]
954fn test_from_u512_r() {
955    assert_eq!(R, Scalar::from_u512([1, 0, 0, 0, 0, 0, 0, 0]));
956}
957
958#[test]
959fn test_from_u512_r2() {
960    assert_eq!(R2, Scalar::from_u512([0, 0, 0, 0, 1, 0, 0, 0]));
961}
962
963#[test]
964fn test_from_u512_max() {
965    let max_u64 = 0xffff_ffff_ffff_ffff;
966    assert_eq!(
967        R3 - R,
968        Scalar::from_u512([max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64])
969    );
970}
971
972#[test]
973fn test_from_bytes_wide_r2() {
974    assert_eq!(
975        R2,
976        Scalar::from_bytes_wide(&[
977            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
978            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
979            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
980        ])
981    );
982}
983
984#[test]
985fn test_from_bytes_wide_negative_one() {
986    assert_eq!(
987        -&Scalar::one(),
988        Scalar::from_bytes_wide(&[
989            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
990            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
991            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
992        ])
993    );
994}
995
996#[test]
997fn test_from_bytes_wide_maximum() {
998    assert_eq!(
999        Scalar([
1000            0xc62c_1805_439b_73b1,
1001            0xc2b9_551e_8ced_218e,
1002            0xda44_ec81_daf9_a422,
1003            0x5605_aa60_1c16_2e79,
1004        ]),
1005        Scalar::from_bytes_wide(&[0xff; 64])
1006    );
1007}
1008
1009#[test]
1010fn test_zero() {
1011    assert_eq!(Scalar::zero(), -&Scalar::zero());
1012    assert_eq!(Scalar::zero(), Scalar::zero() + Scalar::zero());
1013    assert_eq!(Scalar::zero(), Scalar::zero() - Scalar::zero());
1014    assert_eq!(Scalar::zero(), Scalar::zero() * Scalar::zero());
1015}
1016
1017#[cfg(test)]
1018const LARGEST: Scalar = Scalar([
1019    0xffff_ffff_0000_0000,
1020    0x53bd_a402_fffe_5bfe,
1021    0x3339_d808_09a1_d805,
1022    0x73ed_a753_299d_7d48,
1023]);
1024
1025#[test]
1026fn test_addition() {
1027    let mut tmp = LARGEST;
1028    tmp += &LARGEST;
1029
1030    assert_eq!(
1031        tmp,
1032        Scalar([
1033            0xffff_fffe_ffff_ffff,
1034            0x53bd_a402_fffe_5bfe,
1035            0x3339_d808_09a1_d805,
1036            0x73ed_a753_299d_7d48,
1037        ])
1038    );
1039
1040    let mut tmp = LARGEST;
1041    tmp += &Scalar([1, 0, 0, 0]);
1042
1043    assert_eq!(tmp, Scalar::zero());
1044}
1045
1046#[test]
1047fn test_negation() {
1048    let tmp = -&LARGEST;
1049
1050    assert_eq!(tmp, Scalar([1, 0, 0, 0]));
1051
1052    let tmp = -&Scalar::zero();
1053    assert_eq!(tmp, Scalar::zero());
1054    let tmp = -&Scalar([1, 0, 0, 0]);
1055    assert_eq!(tmp, LARGEST);
1056}
1057
1058#[test]
1059fn test_subtraction() {
1060    let mut tmp = LARGEST;
1061    tmp -= &LARGEST;
1062
1063    assert_eq!(tmp, Scalar::zero());
1064
1065    let mut tmp = Scalar::zero();
1066    tmp -= &LARGEST;
1067
1068    let mut tmp2 = MODULUS;
1069    tmp2 -= &LARGEST;
1070
1071    assert_eq!(tmp, tmp2);
1072}
1073
1074#[test]
1075fn test_multiplication() {
1076    let mut cur = LARGEST;
1077
1078    for _ in 0..100 {
1079        let mut tmp = cur;
1080        tmp *= &cur;
1081
1082        let mut tmp2 = Scalar::zero();
1083        for b in cur
1084            .to_bytes()
1085            .iter()
1086            .rev()
1087            .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8))
1088        {
1089            let tmp3 = tmp2;
1090            tmp2.add_assign(&tmp3);
1091
1092            if b {
1093                tmp2.add_assign(&cur);
1094            }
1095        }
1096
1097        assert_eq!(tmp, tmp2);
1098
1099        cur.add_assign(&LARGEST);
1100    }
1101}
1102
1103#[test]
1104fn test_squaring() {
1105    let mut cur = LARGEST;
1106
1107    for _ in 0..100 {
1108        let mut tmp = cur;
1109        tmp = tmp.square();
1110
1111        let mut tmp2 = Scalar::zero();
1112        for b in cur
1113            .to_bytes()
1114            .iter()
1115            .rev()
1116            .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8))
1117        {
1118            let tmp3 = tmp2;
1119            tmp2.add_assign(&tmp3);
1120
1121            if b {
1122                tmp2.add_assign(&cur);
1123            }
1124        }
1125
1126        assert_eq!(tmp, tmp2);
1127
1128        cur.add_assign(&LARGEST);
1129    }
1130}
1131
1132#[test]
1133fn test_inversion() {
1134    assert!(bool::from(Scalar::zero().invert().is_none()));
1135    assert_eq!(Scalar::one().invert().unwrap(), Scalar::one());
1136    assert_eq!((-&Scalar::one()).invert().unwrap(), -&Scalar::one());
1137
1138    let mut tmp = R2;
1139
1140    for _ in 0..100 {
1141        let mut tmp2 = tmp.invert().unwrap();
1142        tmp2.mul_assign(&tmp);
1143
1144        assert_eq!(tmp2, Scalar::one());
1145
1146        tmp.add_assign(&R2);
1147    }
1148}
1149
1150#[test]
1151fn test_invert_is_pow() {
1152    let q_minus_2 = [
1153        0xffff_fffe_ffff_ffff,
1154        0x53bd_a402_fffe_5bfe,
1155        0x3339_d808_09a1_d805,
1156        0x73ed_a753_299d_7d48,
1157    ];
1158
1159    let mut r1 = R;
1160    let mut r2 = R;
1161    let mut r3 = R;
1162
1163    for _ in 0..100 {
1164        r1 = r1.invert().unwrap();
1165        r2 = r2.pow_vartime(&q_minus_2);
1166        r3 = r3.pow(&q_minus_2);
1167
1168        assert_eq!(r1, r2);
1169        assert_eq!(r2, r3);
1170        // Add R so we check something different next time around
1171        r1.add_assign(&R);
1172        r2 = r1;
1173        r3 = r1;
1174    }
1175}
1176
1177#[test]
1178fn test_sqrt() {
1179    {
1180        assert_eq!(Scalar::zero().sqrt().unwrap(), Scalar::zero());
1181    }
1182
1183    let mut square = Scalar([
1184        0x46cd_85a5_f273_077e,
1185        0x1d30_c47d_d68f_c735,
1186        0x77f6_56f6_0bec_a0eb,
1187        0x494a_a01b_df32_468d,
1188    ]);
1189
1190    let mut none_count = 0;
1191
1192    for _ in 0..100 {
1193        let square_root = square.sqrt();
1194        if bool::from(square_root.is_none()) {
1195            none_count += 1;
1196        } else {
1197            assert_eq!(square_root.unwrap() * square_root.unwrap(), square);
1198        }
1199        square -= Scalar::one();
1200    }
1201
1202    assert_eq!(49, none_count);
1203}
1204
1205#[test]
1206fn test_from_raw() {
1207    assert_eq!(
1208        Scalar::from_raw([
1209            0x0001_ffff_fffd,
1210            0x5884_b7fa_0003_4802,
1211            0x998c_4fef_ecbc_4ff5,
1212            0x1824_b159_acc5_056f,
1213        ]),
1214        Scalar::from_raw([0xffff_ffff_ffff_ffff; 4])
1215    );
1216
1217    assert_eq!(Scalar::from_raw(MODULUS.0), Scalar::zero());
1218
1219    assert_eq!(Scalar::from_raw([1, 0, 0, 0]), R);
1220}
1221
1222#[test]
1223fn test_double() {
1224    let a = Scalar::from_raw([
1225        0x1fff_3231_233f_fffd,
1226        0x4884_b7fa_0003_4802,
1227        0x998c_4fef_ecbc_4ff3,
1228        0x1824_b159_acc5_0562,
1229    ]);
1230
1231    assert_eq!(a.double(), a + a);
1232}
1233
1234#[cfg(feature = "zeroize")]
1235#[test]
1236fn test_zeroize() {
1237    use zeroize::Zeroize;
1238
1239    let mut a = Scalar::from_raw([
1240        0x1fff_3231_233f_fffd,
1241        0x4884_b7fa_0003_4802,
1242        0x998c_4fef_ecbc_4ff3,
1243        0x1824_b159_acc5_0562,
1244    ]);
1245    a.zeroize();
1246    assert!(bool::from(a.is_zero()));
1247}