ring/arithmetic/bigint/
modulus.rs

1// Copyright 2015-2024 Brian Smith.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15use super::{
16    super::montgomery::Unencoded, unwrap_impossible_len_mismatch_error, BoxedLimbs, Elem,
17    OwnedModulusValue, PublicModulus, Storage, N0,
18};
19use crate::{
20    bits::BitLength,
21    cpu, error,
22    limb::{self, Limb, LIMB_BITS},
23    polyfill::LeadingZerosStripped,
24};
25use core::marker::PhantomData;
26
27/// The modulus *m* for a ring ℤ/mℤ, along with the precomputed values needed
28/// for efficient Montgomery multiplication modulo *m*. The value must be odd
29/// and larger than 2. The larger-than-1 requirement is imposed, at least, by
30/// the modular inversion code.
31pub struct OwnedModulus<M> {
32    inner: OwnedModulusValue<M>,
33
34    // n0 * N == -1 (mod r).
35    //
36    // r == 2**(N0::LIMBS_USED * LIMB_BITS) and LG_LITTLE_R == lg(r). This
37    // ensures that we can do integer division by |r| by simply ignoring
38    // `N0::LIMBS_USED` limbs. Similarly, we can calculate values modulo `r` by
39    // just looking at the lowest `N0::LIMBS_USED` limbs. This is what makes
40    // Montgomery multiplication efficient.
41    //
42    // As shown in Algorithm 1 of "Fast Prime Field Elliptic Curve Cryptography
43    // with 256 Bit Primes" by Shay Gueron and Vlad Krasnov, in the loop of a
44    // multi-limb Montgomery multiplication of a * b (mod n), given the
45    // unreduced product t == a * b, we repeatedly calculate:
46    //
47    //    t1 := t % r         |t1| is |t|'s lowest limb (see previous paragraph).
48    //    t2 := t1*n0*n
49    //    t3 := t + t2
50    //    t := t3 / r         copy all limbs of |t3| except the lowest to |t|.
51    //
52    // In the last step, it would only make sense to ignore the lowest limb of
53    // |t3| if it were zero. The middle steps ensure that this is the case:
54    //
55    //                            t3 ==  0 (mod r)
56    //                        t + t2 ==  0 (mod r)
57    //                   t + t1*n0*n ==  0 (mod r)
58    //                       t1*n0*n == -t (mod r)
59    //                        t*n0*n == -t (mod r)
60    //                          n0*n == -1 (mod r)
61    //                            n0 == -1/n (mod r)
62    //
63    // Thus, in each iteration of the loop, we multiply by the constant factor
64    // n0, the negative inverse of n (mod r).
65    //
66    // TODO(perf): Not all 32-bit platforms actually make use of n0[1]. For the
67    // ones that don't, we could use a shorter `R` value and use faster `Limb`
68    // calculations instead of double-precision `u64` calculations.
69    n0: N0,
70}
71
72impl<M: PublicModulus> Clone for OwnedModulus<M> {
73    fn clone(&self) -> Self {
74        Self {
75            inner: self.inner.clone(),
76            n0: self.n0,
77        }
78    }
79}
80
81impl<M> OwnedModulus<M> {
82    pub(crate) fn from(n: OwnedModulusValue<M>) -> Self {
83        // n_mod_r = n % r. As explained in the documentation for `n0`, this is
84        // done by taking the lowest `N0::LIMBS_USED` limbs of `n`.
85        #[allow(clippy::useless_conversion)]
86        let n0 = {
87            prefixed_extern! {
88                fn bn_neg_inv_mod_r_u64(n: u64) -> u64;
89            }
90
91            // XXX: u64::from isn't guaranteed to be constant time.
92            let mut n_mod_r: u64 = u64::from(n.limbs()[0]);
93
94            if N0::LIMBS_USED == 2 {
95                // XXX: If we use `<< LIMB_BITS` here then 64-bit builds
96                // fail to compile because of `deny(exceeding_bitshifts)`.
97                debug_assert_eq!(LIMB_BITS, 32);
98                n_mod_r |= u64::from(n.limbs()[1]) << 32;
99            }
100            N0::precalculated(unsafe { bn_neg_inv_mod_r_u64(n_mod_r) })
101        };
102
103        Self { inner: n, n0 }
104    }
105
106    pub fn to_elem<L>(&self, l: &Modulus<L>) -> Result<Elem<L, Unencoded>, error::Unspecified> {
107        self.inner.verify_less_than(l)?;
108        let mut limbs = BoxedLimbs::zero(l.limbs().len());
109        limbs[..self.inner.limbs().len()].copy_from_slice(self.inner.limbs());
110        Ok(Elem {
111            limbs,
112            encoding: PhantomData,
113        })
114    }
115
116    pub(crate) fn modulus(&self, cpu_features: cpu::Features) -> Modulus<M> {
117        Modulus {
118            limbs: self.inner.limbs(),
119            n0: self.n0,
120            len_bits: self.len_bits(),
121            m: PhantomData,
122            cpu_features,
123        }
124    }
125
126    pub fn len_bits(&self) -> BitLength {
127        self.inner.len_bits()
128    }
129}
130
131impl<M: PublicModulus> OwnedModulus<M> {
132    pub fn be_bytes(&self) -> LeadingZerosStripped<impl ExactSizeIterator<Item = u8> + Clone + '_> {
133        LeadingZerosStripped::new(limb::unstripped_be_bytes(self.inner.limbs()))
134    }
135}
136
137pub struct Modulus<'a, M> {
138    limbs: &'a [Limb],
139    n0: N0,
140    len_bits: BitLength,
141    m: PhantomData<M>,
142    cpu_features: cpu::Features,
143}
144
145impl<M> Modulus<'_, M> {
146    pub(super) fn oneR(&self, out: &mut [Limb]) {
147        assert_eq!(self.limbs.len(), out.len());
148
149        let r = self.limbs.len() * LIMB_BITS;
150
151        // out = 2**r - m where m = self.
152        limb::limbs_negative_odd(out, self.limbs);
153
154        let lg_m = self.len_bits().as_bits();
155        let leading_zero_bits_in_m = r - lg_m;
156
157        // When m's length is a multiple of LIMB_BITS, which is the case we
158        // most want to optimize for, then we already have
159        // out == 2**r - m == 2**r (mod m).
160        if leading_zero_bits_in_m != 0 {
161            debug_assert!(leading_zero_bits_in_m < LIMB_BITS);
162            // Correct out to 2**(lg m) (mod m). `limbs_negative_odd` flipped
163            // all the leading zero bits to ones. Flip them back.
164            *out.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m;
165
166            // Now we have out == 2**(lg m) (mod m). Keep doubling until we get
167            // to 2**r (mod m).
168            for _ in 0..leading_zero_bits_in_m {
169                limb::limbs_double_mod(out, self.limbs)
170                    .unwrap_or_else(unwrap_impossible_len_mismatch_error);
171            }
172        }
173
174        // Now out == 2**r (mod m) == 1*R.
175    }
176
177    // TODO: XXX Avoid duplication with `Modulus`.
178    pub fn alloc_zero(&self) -> Storage<M> {
179        Storage {
180            limbs: BoxedLimbs::zero(self.limbs.len()),
181        }
182    }
183
184    #[inline]
185    pub(super) fn limbs(&self) -> &[Limb] {
186        self.limbs
187    }
188
189    #[inline]
190    pub(super) fn n0(&self) -> &N0 {
191        &self.n0
192    }
193
194    pub fn len_bits(&self) -> BitLength {
195        self.len_bits
196    }
197
198    #[inline]
199    pub(crate) fn cpu_features(&self) -> cpu::Features {
200        self.cpu_features
201    }
202}