ring/arithmetic/
montgomery.rs

1// Copyright 2017-2025 Brian Smith.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15pub use super::n0::N0;
16use super::{inout::AliasingSlices3, LimbSliceError, MIN_LIMBS};
17use crate::cpu;
18use cfg_if::cfg_if;
19
20// Indicates that the element is not encoded; there is no *R* factor
21// that needs to be canceled out.
22#[derive(Copy, Clone)]
23pub enum Unencoded {}
24
25// Indicates that the element is encoded; the value has one *R*
26// factor that needs to be canceled out.
27#[derive(Copy, Clone)]
28pub enum R {}
29
30// Indicates the element is encoded three times; the value has three
31// *R* factors that need to be canceled out.
32#[allow(clippy::upper_case_acronyms)]
33#[derive(Copy, Clone)]
34pub enum RRR {}
35
36// Indicates the element is encoded twice; the value has two *R*
37// factors that need to be canceled out.
38#[derive(Copy, Clone)]
39pub enum RR {}
40
41// Indicates the element is inversely encoded; the value has one
42// 1/*R* factor that needs to be canceled out.
43#[derive(Copy, Clone)]
44pub enum RInverse {}
45
46pub trait Encoding {}
47
48impl Encoding for RRR {}
49impl Encoding for RR {}
50impl Encoding for R {}
51impl Encoding for Unencoded {}
52impl Encoding for RInverse {}
53
54/// The encoding of the result of a reduction.
55pub trait ReductionEncoding {
56    type Output: Encoding;
57}
58
59impl ReductionEncoding for RRR {
60    type Output = RR;
61}
62
63impl ReductionEncoding for RR {
64    type Output = R;
65}
66impl ReductionEncoding for R {
67    type Output = Unencoded;
68}
69impl ReductionEncoding for Unencoded {
70    type Output = RInverse;
71}
72
73/// The encoding of the result of a multiplication.
74pub trait ProductEncoding {
75    type Output: Encoding;
76}
77
78impl<E: ReductionEncoding> ProductEncoding for (Unencoded, E) {
79    type Output = E::Output;
80}
81
82impl<E: Encoding> ProductEncoding for (R, E) {
83    type Output = E;
84}
85
86impl ProductEncoding for (RR, RR) {
87    type Output = RRR;
88}
89
90impl<E: ReductionEncoding> ProductEncoding for (RInverse, E)
91where
92    E::Output: ReductionEncoding,
93{
94    type Output = <<E as ReductionEncoding>::Output as ReductionEncoding>::Output;
95}
96
97// XXX: Rust doesn't allow overlapping impls,
98// TODO (if/when Rust allows it):
99// impl<E1, E2: ReductionEncoding> ProductEncoding for
100//         (E1, E2) {
101//     type Output = <(E2, E1) as ProductEncoding>::Output;
102// }
103impl ProductEncoding for (RR, Unencoded) {
104    type Output = <(Unencoded, RR) as ProductEncoding>::Output;
105}
106impl ProductEncoding for (RR, RInverse) {
107    type Output = <(RInverse, RR) as ProductEncoding>::Output;
108}
109
110impl ProductEncoding for (RRR, RInverse) {
111    type Output = <(RInverse, RRR) as ProductEncoding>::Output;
112}
113
114#[allow(unused_imports)]
115use crate::{bssl, c, limb::Limb};
116
117#[inline(always)]
118pub(super) fn limbs_mul_mont(
119    in_out: impl AliasingSlices3<Limb>,
120    n: &[Limb],
121    n0: &N0,
122    cpu: cpu::Features,
123) -> Result<(), LimbSliceError> {
124    const MOD_FALLBACK: usize = 1; // No restriction.
125    cfg_if! {
126        if #[cfg(all(target_arch = "aarch64", target_endian = "little"))] {
127            let _: cpu::Features = cpu;
128            const MIN_4X: usize = 4;
129            const MOD_4X: usize = 4;
130            if n.len() >= MIN_4X && n.len() % MOD_4X == 0 {
131                bn_mul_mont_ffi!(in_out, n, n0, (), unsafe {
132                    (MIN_4X, MOD_4X, ()) => bn_mul4x_mont
133                })
134            } else {
135                bn_mul_mont_ffi!(in_out, n, n0, (), unsafe {
136                    (MIN_LIMBS, MOD_FALLBACK, ()) => bn_mul_mont_nohw
137                })
138            }
139        } else if #[cfg(all(target_arch = "arm", target_endian = "little"))] {
140            const MIN_8X: usize = 8;
141            const MOD_8X: usize = 8;
142            if n.len() >= MIN_8X && n.len() % MOD_8X == 0 {
143                use crate::cpu::{GetFeature as _, arm::Neon};
144                if let Some(cpu) = cpu.get_feature() {
145                    return bn_mul_mont_ffi!(in_out, n, n0, cpu, unsafe {
146                        (MIN_8X, MOD_8X, Neon) => bn_mul8x_mont_neon
147                    });
148                }
149            }
150            // The ARM version of `bn_mul_mont_nohw` has a minimum of 2.
151            const _MIN_LIMBS_AT_LEAST_2: () = assert!(MIN_LIMBS >= 2);
152            bn_mul_mont_ffi!(in_out, n, n0, (), unsafe {
153                (MIN_LIMBS, MOD_FALLBACK, ()) => bn_mul_mont_nohw
154            })
155        } else if #[cfg(target_arch = "x86")] {
156            use crate::{cpu::GetFeature as _, cpu::intel::Sse2};
157            // The X86 implementation of `bn_mul_mont` has a minimum of 4.
158            const _MIN_LIMBS_AT_LEAST_4: () = assert!(MIN_LIMBS >= 4);
159            if let Some(cpu) = cpu.get_feature() {
160                bn_mul_mont_ffi!(in_out, n, n0, cpu, unsafe {
161                    (MIN_LIMBS, MOD_FALLBACK, Sse2) => bn_mul_mont
162                })
163            } else {
164                // This isn't really an FFI call; it's defined below.
165                unsafe {
166                    super::ffi::bn_mul_mont_ffi::<(), {MIN_LIMBS}, 1>(in_out, n, n0, (),
167                    bn_mul_mont_fallback)
168                }
169            }
170        } else if #[cfg(target_arch = "x86_64")] {
171            use crate::{cpu::GetFeature as _, polyfill::slice};
172            use super::limbs::x86_64;
173            if n.len() >= x86_64::mont::MIN_4X {
174                if let (n, []) = slice::as_chunks(n) {
175                    return x86_64::mont::mul_mont5_4x(in_out, n, n0, cpu.get_feature());
176                }
177            }
178            bn_mul_mont_ffi!(in_out, n, n0, (), unsafe {
179                (MIN_LIMBS, MOD_FALLBACK, ()) => bn_mul_mont_nohw
180            })
181        } else {
182            // Use the fallback implementation implemented below through the
183            // FFI wrapper defined below, so that Rust and C code both go
184            // through `bn_mul_mont`.
185            bn_mul_mont_ffi!(in_out, n, n0, cpu, unsafe {
186                (MIN_LIMBS, MOD_FALLBACK, cpu::Features) => bn_mul_mont
187            })
188        }
189    }
190}
191
192cfg_if! {
193    if  #[cfg(not(any(
194            all(target_arch = "aarch64", target_endian = "little"),
195            all(target_arch = "arm", target_endian = "little"),
196            target_arch = "x86_64")))] {
197
198        // TODO: Stop calling this from C and un-export it.
199        #[cfg(not(target_arch = "x86"))]
200        prefixed_export! {
201            unsafe extern "C" fn bn_mul_mont(
202                r: *mut Limb,
203                a: *const Limb,
204                b: *const Limb,
205                n: *const Limb,
206                n0: &N0,
207                num_limbs: c::NonZero_size_t,
208            ) {
209                unsafe { bn_mul_mont_fallback(r, a, b, n, n0, num_limbs) }
210            }
211        }
212
213        #[cfg_attr(target_arch = "x86", cold)]
214        #[cfg_attr(target_arch = "x86", inline(never))]
215        unsafe extern "C" fn bn_mul_mont_fallback(
216            r: *mut Limb,
217            a: *const Limb,
218            b: *const Limb,
219            n: *const Limb,
220            n0: &N0,
221            num_limbs: c::NonZero_size_t,
222        ) {
223            use super::MAX_LIMBS;
224
225            let num_limbs = num_limbs.get();
226
227            // The mutable pointer `r` may alias `a` and/or `b`, so the lifetimes of
228            // any slices for `a` or `b` must not overlap with the lifetime of any
229            // mutable for `r`.
230
231            // Nothing aliases `n`
232            let n = unsafe { core::slice::from_raw_parts(n, num_limbs) };
233
234            let mut tmp = [0; 2 * MAX_LIMBS];
235            let tmp = &mut tmp[..(2 * num_limbs)];
236            {
237                let a: &[Limb] = unsafe { core::slice::from_raw_parts(a, num_limbs) };
238                let b: &[Limb] = unsafe { core::slice::from_raw_parts(b, num_limbs) };
239                limbs_mul(tmp, a, b);
240            }
241            let r: &mut [Limb] = unsafe { core::slice::from_raw_parts_mut(r, num_limbs) };
242            limbs_from_mont_in_place(r, tmp, n, n0);
243        }
244    }
245}
246
247// `bigint` needs then when the `alloc` feature is enabled. `bn_mul_mont` above needs this when
248// we are using the platforms for which we don't have `bn_mul_mont` in assembly.
249#[cfg(any(
250    feature = "alloc",
251    not(any(
252        all(target_arch = "aarch64", target_endian = "little"),
253        all(target_arch = "arm", target_endian = "little"),
254        target_arch = "x86_64"
255    ))
256))]
257pub(super) fn limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0) {
258    prefixed_extern! {
259        fn bn_from_montgomery_in_place(
260            r: *mut Limb,
261            num_r: c::size_t,
262            a: *mut Limb,
263            num_a: c::size_t,
264            n: *const Limb,
265            num_n: c::size_t,
266            n0: &N0,
267        ) -> bssl::Result;
268    }
269    Result::from(unsafe {
270        bn_from_montgomery_in_place(
271            r.as_mut_ptr(),
272            r.len(),
273            tmp.as_mut_ptr(),
274            tmp.len(),
275            m.as_ptr(),
276            m.len(),
277            n0,
278        )
279    })
280    .unwrap()
281}
282
283#[cfg(not(any(
284    all(target_arch = "aarch64", target_endian = "little"),
285    all(target_arch = "arm", target_endian = "little"),
286    target_arch = "x86_64"
287)))]
288fn limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb]) {
289    debug_assert_eq!(r.len(), 2 * a.len());
290    debug_assert_eq!(a.len(), b.len());
291    let ab_len = a.len();
292
293    r[..ab_len].fill(0);
294    for (i, &b_limb) in b.iter().enumerate() {
295        r[ab_len + i] = unsafe {
296            limbs_mul_add_limb(r[i..][..ab_len].as_mut_ptr(), a.as_ptr(), b_limb, ab_len)
297        };
298    }
299}
300
301#[cfg(any(
302    test,
303    not(any(
304        all(target_arch = "aarch64", target_endian = "little"),
305        all(target_arch = "arm", target_endian = "little"),
306        target_arch = "x86_64",
307    ))
308))]
309prefixed_extern! {
310    // `r` must not alias `a`
311    #[must_use]
312    fn limbs_mul_add_limb(r: *mut Limb, a: *const Limb, b: Limb, num_limbs: c::size_t) -> Limb;
313}
314
315/// r = r**2
316pub(super) fn limbs_square_mont(
317    r: &mut [Limb],
318    n: &[Limb],
319    n0: &N0,
320    cpu: cpu::Features,
321) -> Result<(), LimbSliceError> {
322    #[cfg(all(target_arch = "aarch64", target_endian = "little"))]
323    {
324        use super::limbs::aarch64;
325        use crate::polyfill::slice;
326        if let ((r, []), (n, [])) = (slice::as_chunks_mut(r), slice::as_chunks(n)) {
327            return aarch64::mont::sqr_mont5(r, n, n0);
328        }
329    }
330
331    #[cfg(target_arch = "x86_64")]
332    {
333        use super::limbs::x86_64;
334        use crate::{cpu::GetFeature as _, polyfill::slice};
335        if let ((r, []), (n, [])) = (slice::as_chunks_mut(r), slice::as_chunks(n)) {
336            return x86_64::mont::sqr_mont5(r, n, n0, cpu.get_feature());
337        }
338    }
339
340    limbs_mul_mont(r, n, n0, cpu)
341}
342
343#[cfg(test)]
344mod tests {
345    use super::super::MAX_LIMBS;
346    use super::*;
347    use crate::limb::Limb;
348
349    #[test]
350    // TODO: wasm
351    fn test_mul_add_words() {
352        const ZERO: Limb = 0;
353        const MAX: Limb = ZERO.wrapping_sub(1);
354        static TEST_CASES: &[(&[Limb], &[Limb], Limb, Limb, &[Limb])] = &[
355            (&[0], &[0], 0, 0, &[0]),
356            (&[MAX], &[0], MAX, 0, &[MAX]),
357            (&[0], &[MAX], MAX, MAX - 1, &[1]),
358            (&[MAX], &[MAX], MAX, MAX, &[0]),
359            (&[0, 0], &[MAX, MAX], MAX, MAX - 1, &[1, MAX]),
360            (&[1, 0], &[MAX, MAX], MAX, MAX - 1, &[2, MAX]),
361            (&[MAX, 0], &[MAX, MAX], MAX, MAX, &[0, 0]),
362            (&[0, 1], &[MAX, MAX], MAX, MAX, &[1, 0]),
363            (&[MAX, MAX], &[MAX, MAX], MAX, MAX, &[0, MAX]),
364        ];
365
366        for (i, (r_input, a, w, expected_retval, expected_r)) in TEST_CASES.iter().enumerate() {
367            let mut r = [0; MAX_LIMBS];
368            let r = {
369                let r = &mut r[..r_input.len()];
370                r.copy_from_slice(r_input);
371                r
372            };
373            assert_eq!(r.len(), a.len()); // Sanity check
374            let actual_retval =
375                unsafe { limbs_mul_add_limb(r.as_mut_ptr(), a.as_ptr(), *w, a.len()) };
376            assert_eq!(&r, expected_r, "{}: {:x?} != {:x?}", i, r, expected_r);
377            assert_eq!(
378                actual_retval, *expected_retval,
379                "{}: {:x?} != {:x?}",
380                i, actual_retval, *expected_retval
381            );
382        }
383    }
384}