k256/arithmetic/scalar/
wide64.rs

1//! Wide scalar (64-bit limbs)
2
3use super::{Scalar, MODULUS};
4use crate::ORDER;
5use elliptic_curve::{
6    bigint::{Limb, U256, U512},
7    subtle::{Choice, ConditionallySelectable},
8};
9
10/// Limbs of 2^256 minus the secp256k1 order.
11const NEG_MODULUS: [u64; 4] = [!MODULUS[0] + 1, !MODULUS[1], !MODULUS[2], !MODULUS[3]];
12
13#[derive(Clone, Copy, Debug, Default)]
14pub(crate) struct WideScalar(pub(super) U512);
15
16impl WideScalar {
17    pub const fn from_bytes(bytes: &[u8; 64]) -> Self {
18        Self(U512::from_be_slice(bytes))
19    }
20
21    /// Multiplies two scalars without modulo reduction, producing up to a 512-bit scalar.
22    #[inline(always)] // only used in Scalar::mul(), so won't cause binary bloat
23    pub fn mul_wide(a: &Scalar, b: &Scalar) -> Self {
24        let a = a.0.to_words();
25        let b = b.0.to_words();
26
27        // 160 bit accumulator.
28        let c0 = 0;
29        let c1 = 0;
30        let c2 = 0;
31
32        // l[0..7] = a[0..3] * b[0..3].
33        let (c0, c1) = muladd_fast(a[0], b[0], c0, c1);
34        let (l0, c0, c1) = (c0, c1, 0);
35        let (c0, c1, c2) = muladd(a[0], b[1], c0, c1, c2);
36        let (c0, c1, c2) = muladd(a[1], b[0], c0, c1, c2);
37        let (l1, c0, c1, c2) = (c0, c1, c2, 0);
38        let (c0, c1, c2) = muladd(a[0], b[2], c0, c1, c2);
39        let (c0, c1, c2) = muladd(a[1], b[1], c0, c1, c2);
40        let (c0, c1, c2) = muladd(a[2], b[0], c0, c1, c2);
41        let (l2, c0, c1, c2) = (c0, c1, c2, 0);
42        let (c0, c1, c2) = muladd(a[0], b[3], c0, c1, c2);
43        let (c0, c1, c2) = muladd(a[1], b[2], c0, c1, c2);
44        let (c0, c1, c2) = muladd(a[2], b[1], c0, c1, c2);
45        let (c0, c1, c2) = muladd(a[3], b[0], c0, c1, c2);
46        let (l3, c0, c1, c2) = (c0, c1, c2, 0);
47        let (c0, c1, c2) = muladd(a[1], b[3], c0, c1, c2);
48        let (c0, c1, c2) = muladd(a[2], b[2], c0, c1, c2);
49        let (c0, c1, c2) = muladd(a[3], b[1], c0, c1, c2);
50        let (l4, c0, c1, c2) = (c0, c1, c2, 0);
51        let (c0, c1, c2) = muladd(a[2], b[3], c0, c1, c2);
52        let (c0, c1, c2) = muladd(a[3], b[2], c0, c1, c2);
53        let (l5, c0, c1, _c2) = (c0, c1, c2, 0);
54        let (c0, c1) = muladd_fast(a[3], b[3], c0, c1);
55        let (l6, c0, _c1) = (c0, c1, 0);
56        let l7 = c0;
57
58        Self(U512::from_words([l0, l1, l2, l3, l4, l5, l6, l7]))
59    }
60
61    /// Multiplies `a` by `b` (without modulo reduction) divide the result by `2^shift`
62    /// (rounding to the nearest integer).
63    /// Variable time in `shift`.
64    pub(crate) fn mul_shift_vartime(a: &Scalar, b: &Scalar, shift: usize) -> Scalar {
65        debug_assert!(shift >= 256);
66
67        let l = Self::mul_wide(a, b).0.to_words();
68        let shiftlimbs = shift >> 6;
69        let shiftlow = shift & 0x3F;
70        let shifthigh = 64 - shiftlow;
71
72        let r0 = if shift < 512 {
73            let lo = l[shiftlimbs] >> shiftlow;
74            let hi = if shift < 448 && shiftlow != 0 {
75                l[1 + shiftlimbs] << shifthigh
76            } else {
77                0
78            };
79            hi | lo
80        } else {
81            0
82        };
83
84        let r1 = if shift < 448 {
85            let lo = l[1 + shiftlimbs] >> shiftlow;
86            let hi = if shift < 384 && shiftlow != 0 {
87                l[2 + shiftlimbs] << shifthigh
88            } else {
89                0
90            };
91            hi | lo
92        } else {
93            0
94        };
95
96        let r2 = if shift < 384 {
97            let lo = l[2 + shiftlimbs] >> shiftlow;
98            let hi = if shift < 320 && shiftlow != 0 {
99                l[3 + shiftlimbs] << shifthigh
100            } else {
101                0
102            };
103            hi | lo
104        } else {
105            0
106        };
107
108        let r3 = if shift < 320 {
109            l[3 + shiftlimbs] >> shiftlow
110        } else {
111            0
112        };
113
114        let res = Scalar(U256::from_words([r0, r1, r2, r3]));
115
116        // Check the highmost discarded bit and round up if it is set.
117        let c = (l[(shift - 1) >> 6] >> ((shift - 1) & 0x3f)) & 1;
118        Scalar::conditional_select(&res, &res.add(&Scalar::ONE), Choice::from(c as u8))
119    }
120
121    fn reduce_impl(&self, modulus_minus_one: bool) -> Scalar {
122        let neg_modulus0 = if modulus_minus_one {
123            NEG_MODULUS[0] + 1
124        } else {
125            NEG_MODULUS[0]
126        };
127        let modulus = if modulus_minus_one {
128            ORDER.wrapping_sub(&U256::ONE)
129        } else {
130            ORDER
131        };
132
133        let w = self.0.to_words();
134        let n0 = w[4];
135        let n1 = w[5];
136        let n2 = w[6];
137        let n3 = w[7];
138
139        // Reduce 512 bits into 385.
140        // m[0..6] = self[0..3] + n[0..3] * neg_modulus.
141        let c0 = w[0];
142        let c1 = 0;
143        let c2 = 0;
144        let (c0, c1) = muladd_fast(n0, neg_modulus0, c0, c1);
145        let (m0, c0, c1) = (c0, c1, 0);
146        let (c0, c1) = sumadd_fast(w[1], c0, c1);
147        let (c0, c1, c2) = muladd(n1, neg_modulus0, c0, c1, c2);
148        let (c0, c1, c2) = muladd(n0, NEG_MODULUS[1], c0, c1, c2);
149        let (m1, c0, c1, c2) = (c0, c1, c2, 0);
150        let (c0, c1, c2) = sumadd(w[2], c0, c1, c2);
151        let (c0, c1, c2) = muladd(n2, neg_modulus0, c0, c1, c2);
152        let (c0, c1, c2) = muladd(n1, NEG_MODULUS[1], c0, c1, c2);
153        let (c0, c1, c2) = sumadd(n0, c0, c1, c2);
154        let (m2, c0, c1, c2) = (c0, c1, c2, 0);
155        let (c0, c1, c2) = sumadd(w[3], c0, c1, c2);
156        let (c0, c1, c2) = muladd(n3, neg_modulus0, c0, c1, c2);
157        let (c0, c1, c2) = muladd(n2, NEG_MODULUS[1], c0, c1, c2);
158        let (c0, c1, c2) = sumadd(n1, c0, c1, c2);
159        let (m3, c0, c1, c2) = (c0, c1, c2, 0);
160        let (c0, c1, c2) = muladd(n3, NEG_MODULUS[1], c0, c1, c2);
161        let (c0, c1, c2) = sumadd(n2, c0, c1, c2);
162        let (m4, c0, c1, _c2) = (c0, c1, c2, 0);
163        let (c0, c1) = sumadd_fast(n3, c0, c1);
164        let (m5, c0, _c1) = (c0, c1, 0);
165        debug_assert!(c0 <= 1);
166        let m6 = c0;
167
168        // Reduce 385 bits into 258.
169        // p[0..4] = m[0..3] + m[4..6] * neg_modulus.
170        let c0 = m0;
171        let c1 = 0;
172        let c2 = 0;
173        let (c0, c1) = muladd_fast(m4, neg_modulus0, c0, c1);
174        let (p0, c0, c1) = (c0, c1, 0);
175        let (c0, c1) = sumadd_fast(m1, c0, c1);
176        let (c0, c1, c2) = muladd(m5, neg_modulus0, c0, c1, c2);
177        let (c0, c1, c2) = muladd(m4, NEG_MODULUS[1], c0, c1, c2);
178        let (p1, c0, c1) = (c0, c1, 0);
179        let (c0, c1, c2) = sumadd(m2, c0, c1, c2);
180        let (c0, c1, c2) = muladd(m6, neg_modulus0, c0, c1, c2);
181        let (c0, c1, c2) = muladd(m5, NEG_MODULUS[1], c0, c1, c2);
182        let (c0, c1, c2) = sumadd(m4, c0, c1, c2);
183        let (p2, c0, c1, _c2) = (c0, c1, c2, 0);
184        let (c0, c1) = sumadd_fast(m3, c0, c1);
185        let (c0, c1) = muladd_fast(m6, NEG_MODULUS[1], c0, c1);
186        let (c0, c1) = sumadd_fast(m5, c0, c1);
187        let (p3, c0, _c1) = (c0, c1, 0);
188        let p4 = c0 + m6;
189        debug_assert!(p4 <= 2);
190
191        // Reduce 258 bits into 256.
192        // r[0..3] = p[0..3] + p[4] * neg_modulus.
193        let mut c = (p0 as u128) + (neg_modulus0 as u128) * (p4 as u128);
194        let r0 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
195        c >>= 64;
196        c += (p1 as u128) + (NEG_MODULUS[1] as u128) * (p4 as u128);
197        let r1 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
198        c >>= 64;
199        c += (p2 as u128) + (p4 as u128);
200        let r2 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
201        c >>= 64;
202        c += p3 as u128;
203        let r3 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
204        c >>= 64;
205
206        // Final reduction of r.
207        let r = U256::from([r0, r1, r2, r3]);
208        let (r2, underflow) = r.sbb(&modulus, Limb::ZERO);
209        let high_bit = Choice::from(c as u8);
210        let underflow = Choice::from((underflow.0 >> 63) as u8);
211        Scalar(U256::conditional_select(&r, &r2, !underflow | high_bit))
212    }
213
214    #[inline(always)] // only used in Scalar::mul(), so won't cause binary bloat
215    pub(super) fn reduce(&self) -> Scalar {
216        self.reduce_impl(false)
217    }
218
219    pub(super) fn reduce_nonzero(&self) -> Scalar {
220        self.reduce_impl(true) + Scalar::ONE
221    }
222}
223
224/// Add a to the number defined by (c0,c1,c2). c2 must never overflow.
225fn sumadd(a: u64, c0: u64, c1: u64, c2: u64) -> (u64, u64, u64) {
226    let (new_c0, carry0) = c0.overflowing_add(a);
227    let (new_c1, carry1) = c1.overflowing_add(carry0 as u64);
228    let new_c2 = c2 + (carry1 as u64);
229    (new_c0, new_c1, new_c2)
230}
231
232/// Add a to the number defined by (c0,c1). c1 must never overflow.
233fn sumadd_fast(a: u64, c0: u64, c1: u64) -> (u64, u64) {
234    let (new_c0, carry0) = c0.overflowing_add(a);
235    let new_c1 = c1 + (carry0 as u64);
236    (new_c0, new_c1)
237}
238
239/// Add a*b to the number defined by (c0,c1,c2). c2 must never overflow.
240fn muladd(a: u64, b: u64, c0: u64, c1: u64, c2: u64) -> (u64, u64, u64) {
241    let t = (a as u128) * (b as u128);
242    let th = (t >> 64) as u64; // at most 0xFFFFFFFFFFFFFFFE
243    let tl = t as u64;
244
245    let (new_c0, carry0) = c0.overflowing_add(tl);
246    let new_th = th.wrapping_add(carry0 as u64); // at most 0xFFFFFFFFFFFFFFFF
247    let (new_c1, carry1) = c1.overflowing_add(new_th);
248    let new_c2 = c2 + (carry1 as u64);
249
250    (new_c0, new_c1, new_c2)
251}
252
253/// Add a*b to the number defined by (c0,c1). c1 must never overflow.
254fn muladd_fast(a: u64, b: u64, c0: u64, c1: u64) -> (u64, u64) {
255    let t = (a as u128) * (b as u128);
256    let th = (t >> 64) as u64; // at most 0xFFFFFFFFFFFFFFFE
257    let tl = t as u64;
258
259    let (new_c0, carry0) = c0.overflowing_add(tl);
260    let new_th = th.wrapping_add(carry0 as u64); // at most 0xFFFFFFFFFFFFFFFF
261    let new_c1 = c1 + new_th;
262
263    (new_c0, new_c1)
264}