1use super::{Scalar, MODULUS};
4use crate::ORDER;
5use elliptic_curve::{
6 bigint::{Limb, U256, U512},
7 subtle::{Choice, ConditionallySelectable},
8};
9
10const NEG_MODULUS: [u64; 4] = [!MODULUS[0] + 1, !MODULUS[1], !MODULUS[2], !MODULUS[3]];
12
13#[derive(Clone, Copy, Debug, Default)]
14pub(crate) struct WideScalar(pub(super) U512);
15
16impl WideScalar {
17 pub const fn from_bytes(bytes: &[u8; 64]) -> Self {
18 Self(U512::from_be_slice(bytes))
19 }
20
21 #[inline(always)] pub fn mul_wide(a: &Scalar, b: &Scalar) -> Self {
24 let a = a.0.to_words();
25 let b = b.0.to_words();
26
27 let c0 = 0;
29 let c1 = 0;
30 let c2 = 0;
31
32 let (c0, c1) = muladd_fast(a[0], b[0], c0, c1);
34 let (l0, c0, c1) = (c0, c1, 0);
35 let (c0, c1, c2) = muladd(a[0], b[1], c0, c1, c2);
36 let (c0, c1, c2) = muladd(a[1], b[0], c0, c1, c2);
37 let (l1, c0, c1, c2) = (c0, c1, c2, 0);
38 let (c0, c1, c2) = muladd(a[0], b[2], c0, c1, c2);
39 let (c0, c1, c2) = muladd(a[1], b[1], c0, c1, c2);
40 let (c0, c1, c2) = muladd(a[2], b[0], c0, c1, c2);
41 let (l2, c0, c1, c2) = (c0, c1, c2, 0);
42 let (c0, c1, c2) = muladd(a[0], b[3], c0, c1, c2);
43 let (c0, c1, c2) = muladd(a[1], b[2], c0, c1, c2);
44 let (c0, c1, c2) = muladd(a[2], b[1], c0, c1, c2);
45 let (c0, c1, c2) = muladd(a[3], b[0], c0, c1, c2);
46 let (l3, c0, c1, c2) = (c0, c1, c2, 0);
47 let (c0, c1, c2) = muladd(a[1], b[3], c0, c1, c2);
48 let (c0, c1, c2) = muladd(a[2], b[2], c0, c1, c2);
49 let (c0, c1, c2) = muladd(a[3], b[1], c0, c1, c2);
50 let (l4, c0, c1, c2) = (c0, c1, c2, 0);
51 let (c0, c1, c2) = muladd(a[2], b[3], c0, c1, c2);
52 let (c0, c1, c2) = muladd(a[3], b[2], c0, c1, c2);
53 let (l5, c0, c1, _c2) = (c0, c1, c2, 0);
54 let (c0, c1) = muladd_fast(a[3], b[3], c0, c1);
55 let (l6, c0, _c1) = (c0, c1, 0);
56 let l7 = c0;
57
58 Self(U512::from_words([l0, l1, l2, l3, l4, l5, l6, l7]))
59 }
60
61 pub(crate) fn mul_shift_vartime(a: &Scalar, b: &Scalar, shift: usize) -> Scalar {
65 debug_assert!(shift >= 256);
66
67 let l = Self::mul_wide(a, b).0.to_words();
68 let shiftlimbs = shift >> 6;
69 let shiftlow = shift & 0x3F;
70 let shifthigh = 64 - shiftlow;
71
72 let r0 = if shift < 512 {
73 let lo = l[shiftlimbs] >> shiftlow;
74 let hi = if shift < 448 && shiftlow != 0 {
75 l[1 + shiftlimbs] << shifthigh
76 } else {
77 0
78 };
79 hi | lo
80 } else {
81 0
82 };
83
84 let r1 = if shift < 448 {
85 let lo = l[1 + shiftlimbs] >> shiftlow;
86 let hi = if shift < 384 && shiftlow != 0 {
87 l[2 + shiftlimbs] << shifthigh
88 } else {
89 0
90 };
91 hi | lo
92 } else {
93 0
94 };
95
96 let r2 = if shift < 384 {
97 let lo = l[2 + shiftlimbs] >> shiftlow;
98 let hi = if shift < 320 && shiftlow != 0 {
99 l[3 + shiftlimbs] << shifthigh
100 } else {
101 0
102 };
103 hi | lo
104 } else {
105 0
106 };
107
108 let r3 = if shift < 320 {
109 l[3 + shiftlimbs] >> shiftlow
110 } else {
111 0
112 };
113
114 let res = Scalar(U256::from_words([r0, r1, r2, r3]));
115
116 let c = (l[(shift - 1) >> 6] >> ((shift - 1) & 0x3f)) & 1;
118 Scalar::conditional_select(&res, &res.add(&Scalar::ONE), Choice::from(c as u8))
119 }
120
121 fn reduce_impl(&self, modulus_minus_one: bool) -> Scalar {
122 let neg_modulus0 = if modulus_minus_one {
123 NEG_MODULUS[0] + 1
124 } else {
125 NEG_MODULUS[0]
126 };
127 let modulus = if modulus_minus_one {
128 ORDER.wrapping_sub(&U256::ONE)
129 } else {
130 ORDER
131 };
132
133 let w = self.0.to_words();
134 let n0 = w[4];
135 let n1 = w[5];
136 let n2 = w[6];
137 let n3 = w[7];
138
139 let c0 = w[0];
142 let c1 = 0;
143 let c2 = 0;
144 let (c0, c1) = muladd_fast(n0, neg_modulus0, c0, c1);
145 let (m0, c0, c1) = (c0, c1, 0);
146 let (c0, c1) = sumadd_fast(w[1], c0, c1);
147 let (c0, c1, c2) = muladd(n1, neg_modulus0, c0, c1, c2);
148 let (c0, c1, c2) = muladd(n0, NEG_MODULUS[1], c0, c1, c2);
149 let (m1, c0, c1, c2) = (c0, c1, c2, 0);
150 let (c0, c1, c2) = sumadd(w[2], c0, c1, c2);
151 let (c0, c1, c2) = muladd(n2, neg_modulus0, c0, c1, c2);
152 let (c0, c1, c2) = muladd(n1, NEG_MODULUS[1], c0, c1, c2);
153 let (c0, c1, c2) = sumadd(n0, c0, c1, c2);
154 let (m2, c0, c1, c2) = (c0, c1, c2, 0);
155 let (c0, c1, c2) = sumadd(w[3], c0, c1, c2);
156 let (c0, c1, c2) = muladd(n3, neg_modulus0, c0, c1, c2);
157 let (c0, c1, c2) = muladd(n2, NEG_MODULUS[1], c0, c1, c2);
158 let (c0, c1, c2) = sumadd(n1, c0, c1, c2);
159 let (m3, c0, c1, c2) = (c0, c1, c2, 0);
160 let (c0, c1, c2) = muladd(n3, NEG_MODULUS[1], c0, c1, c2);
161 let (c0, c1, c2) = sumadd(n2, c0, c1, c2);
162 let (m4, c0, c1, _c2) = (c0, c1, c2, 0);
163 let (c0, c1) = sumadd_fast(n3, c0, c1);
164 let (m5, c0, _c1) = (c0, c1, 0);
165 debug_assert!(c0 <= 1);
166 let m6 = c0;
167
168 let c0 = m0;
171 let c1 = 0;
172 let c2 = 0;
173 let (c0, c1) = muladd_fast(m4, neg_modulus0, c0, c1);
174 let (p0, c0, c1) = (c0, c1, 0);
175 let (c0, c1) = sumadd_fast(m1, c0, c1);
176 let (c0, c1, c2) = muladd(m5, neg_modulus0, c0, c1, c2);
177 let (c0, c1, c2) = muladd(m4, NEG_MODULUS[1], c0, c1, c2);
178 let (p1, c0, c1) = (c0, c1, 0);
179 let (c0, c1, c2) = sumadd(m2, c0, c1, c2);
180 let (c0, c1, c2) = muladd(m6, neg_modulus0, c0, c1, c2);
181 let (c0, c1, c2) = muladd(m5, NEG_MODULUS[1], c0, c1, c2);
182 let (c0, c1, c2) = sumadd(m4, c0, c1, c2);
183 let (p2, c0, c1, _c2) = (c0, c1, c2, 0);
184 let (c0, c1) = sumadd_fast(m3, c0, c1);
185 let (c0, c1) = muladd_fast(m6, NEG_MODULUS[1], c0, c1);
186 let (c0, c1) = sumadd_fast(m5, c0, c1);
187 let (p3, c0, _c1) = (c0, c1, 0);
188 let p4 = c0 + m6;
189 debug_assert!(p4 <= 2);
190
191 let mut c = (p0 as u128) + (neg_modulus0 as u128) * (p4 as u128);
194 let r0 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
195 c >>= 64;
196 c += (p1 as u128) + (NEG_MODULUS[1] as u128) * (p4 as u128);
197 let r1 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
198 c >>= 64;
199 c += (p2 as u128) + (p4 as u128);
200 let r2 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
201 c >>= 64;
202 c += p3 as u128;
203 let r3 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
204 c >>= 64;
205
206 let r = U256::from([r0, r1, r2, r3]);
208 let (r2, underflow) = r.sbb(&modulus, Limb::ZERO);
209 let high_bit = Choice::from(c as u8);
210 let underflow = Choice::from((underflow.0 >> 63) as u8);
211 Scalar(U256::conditional_select(&r, &r2, !underflow | high_bit))
212 }
213
214 #[inline(always)] pub(super) fn reduce(&self) -> Scalar {
216 self.reduce_impl(false)
217 }
218
219 pub(super) fn reduce_nonzero(&self) -> Scalar {
220 self.reduce_impl(true) + Scalar::ONE
221 }
222}
223
224fn sumadd(a: u64, c0: u64, c1: u64, c2: u64) -> (u64, u64, u64) {
226 let (new_c0, carry0) = c0.overflowing_add(a);
227 let (new_c1, carry1) = c1.overflowing_add(carry0 as u64);
228 let new_c2 = c2 + (carry1 as u64);
229 (new_c0, new_c1, new_c2)
230}
231
232fn sumadd_fast(a: u64, c0: u64, c1: u64) -> (u64, u64) {
234 let (new_c0, carry0) = c0.overflowing_add(a);
235 let new_c1 = c1 + (carry0 as u64);
236 (new_c0, new_c1)
237}
238
239fn muladd(a: u64, b: u64, c0: u64, c1: u64, c2: u64) -> (u64, u64, u64) {
241 let t = (a as u128) * (b as u128);
242 let th = (t >> 64) as u64; let tl = t as u64;
244
245 let (new_c0, carry0) = c0.overflowing_add(tl);
246 let new_th = th.wrapping_add(carry0 as u64); let (new_c1, carry1) = c1.overflowing_add(new_th);
248 let new_c2 = c2 + (carry1 as u64);
249
250 (new_c0, new_c1, new_c2)
251}
252
253fn muladd_fast(a: u64, b: u64, c0: u64, c1: u64) -> (u64, u64) {
255 let t = (a as u128) * (b as u128);
256 let th = (t >> 64) as u64; let tl = t as u64;
258
259 let (new_c0, carry0) = c0.overflowing_add(tl);
260 let new_th = th.wrapping_add(carry0 as u64); let new_c1 = c1 + new_th;
262
263 (new_c0, new_c1)
264}