k256/arithmetic/
mul.rs

1//! From libsecp256k1:
2//!
3//! The Secp256k1 curve has an endomorphism, where lambda * (x, y) = (beta * x, y), where
4//! lambda is {0x53,0x63,0xad,0x4c,0xc0,0x5c,0x30,0xe0,0xa5,0x26,0x1c,0x02,0x88,0x12,0x64,0x5a,
5//!         0x12,0x2e,0x22,0xea,0x20,0x81,0x66,0x78,0xdf,0x02,0x96,0x7c,0x1b,0x23,0xbd,0x72}
6//!
7//! "Guide to Elliptic Curve Cryptography" (Hankerson, Menezes, Vanstone) gives an algorithm
8//! (algorithm 3.74) to find k1 and k2 given k, such that k1 + k2 * lambda == k mod n, and k1
9//! and k2 have a small size.
10//! It relies on constants a1, b1, a2, b2. These constants for the value of lambda above are:
11//!
12//! - a1 =      {0x30,0x86,0xd2,0x21,0xa7,0xd4,0x6b,0xcd,0xe8,0x6c,0x90,0xe4,0x92,0x84,0xeb,0x15}
13//! - b1 =     -{0xe4,0x43,0x7e,0xd6,0x01,0x0e,0x88,0x28,0x6f,0x54,0x7f,0xa9,0x0a,0xbf,0xe4,0xc3}
14//! - a2 = {0x01,0x14,0xca,0x50,0xf7,0xa8,0xe2,0xf3,0xf6,0x57,0xc1,0x10,0x8d,0x9d,0x44,0xcf,0xd8}
15//! - b2 =      {0x30,0x86,0xd2,0x21,0xa7,0xd4,0x6b,0xcd,0xe8,0x6c,0x90,0xe4,0x92,0x84,0xeb,0x15}
16//!
17//! The algorithm then computes c1 = round(b1 * k / n) and c2 = round(b2 * k / n), and gives
18//! k1 = k - (c1*a1 + c2*a2) and k2 = -(c1*b1 + c2*b2). Instead, we use modular arithmetic, and
19//! compute k1 as k - k2 * lambda, avoiding the need for constants a1 and a2.
20//!
21//! g1, g2 are precomputed constants used to replace division with a rounded multiplication
22//! when decomposing the scalar for an endomorphism-based point multiplication.
23//!
24//! The possibility of using precomputed estimates is mentioned in "Guide to Elliptic Curve
25//! Cryptography" (Hankerson, Menezes, Vanstone) in section 3.5.
26//!
27//! The derivation is described in the paper "Efficient Software Implementation of Public-Key
28//! Cryptography on Sensor Networks Using the MSP430X Microcontroller" (Gouvea, Oliveira, Lopez),
29//! Section 4.3 (here we use a somewhat higher-precision estimate):
30//! d = a1*b2 - b1*a2
31//! g1 = round((2^384)*b2/d)
32//! g2 = round((2^384)*(-b1)/d)
33//!
34//! (Note that 'd' is also equal to the curve order here because `[a1,b1]` and `[a2,b2]` are found
35//! as outputs of the Extended Euclidean Algorithm on inputs 'order' and 'lambda').
36
37#[cfg(all(
38    feature = "precomputed-tables",
39    not(any(feature = "critical-section", feature = "std"))
40))]
41compile_error!("`precomputed-tables` feature requires either `critical-section` or `std`");
42
43use crate::arithmetic::{
44    scalar::{Scalar, WideScalar},
45    ProjectivePoint,
46};
47
48use core::ops::{Mul, MulAssign};
49use elliptic_curve::ops::LinearCombinationExt as LinearCombination;
50use elliptic_curve::{
51    ops::MulByGenerator,
52    scalar::IsHigh,
53    subtle::{Choice, ConditionallySelectable, ConstantTimeEq},
54};
55
56#[cfg(feature = "precomputed-tables")]
57use once_cell::sync::Lazy;
58
59/// Lookup table containing precomputed values `[p, 2p, 3p, ..., 8p]`
60#[derive(Copy, Clone, Default)]
61struct LookupTable([ProjectivePoint; 8]);
62
63impl From<&ProjectivePoint> for LookupTable {
64    fn from(p: &ProjectivePoint) -> Self {
65        let mut points = [*p; 8];
66        for j in 0..7 {
67            points[j + 1] = p + &points[j];
68        }
69        LookupTable(points)
70    }
71}
72
73impl LookupTable {
74    /// Given -8 <= x <= 8, returns x * p in constant time.
75    fn select(&self, x: i8) -> ProjectivePoint {
76        debug_assert!(x >= -8);
77        debug_assert!(x <= 8);
78
79        // Compute xabs = |x|
80        let xmask = x >> 7;
81        let xabs = (x + xmask) ^ xmask;
82
83        // Get an array element in constant time
84        let mut t = ProjectivePoint::IDENTITY;
85        for j in 1..9 {
86            let c = (xabs as u8).ct_eq(&(j as u8));
87            t.conditional_assign(&self.0[j - 1], c);
88        }
89        // Now t == |x| * p.
90
91        let neg_mask = Choice::from((xmask & 1) as u8);
92        t.conditional_assign(&-t, neg_mask);
93        // Now t == x * p.
94
95        t
96    }
97}
98
99const MINUS_LAMBDA: Scalar = Scalar::from_bytes_unchecked(&[
100    0xac, 0x9c, 0x52, 0xb3, 0x3f, 0xa3, 0xcf, 0x1f, 0x5a, 0xd9, 0xe3, 0xfd, 0x77, 0xed, 0x9b, 0xa4,
101    0xa8, 0x80, 0xb9, 0xfc, 0x8e, 0xc7, 0x39, 0xc2, 0xe0, 0xcf, 0xc8, 0x10, 0xb5, 0x12, 0x83, 0xcf,
102]);
103
104const MINUS_B1: Scalar = Scalar::from_bytes_unchecked(&[
105    0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
106    0xe4, 0x43, 0x7e, 0xd6, 0x01, 0x0e, 0x88, 0x28, 0x6f, 0x54, 0x7f, 0xa9, 0x0a, 0xbf, 0xe4, 0xc3,
107]);
108
109const MINUS_B2: Scalar = Scalar::from_bytes_unchecked(&[
110    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
111    0x8a, 0x28, 0x0a, 0xc5, 0x07, 0x74, 0x34, 0x6d, 0xd7, 0x65, 0xcd, 0xa8, 0x3d, 0xb1, 0x56, 0x2c,
112]);
113
114const G1: Scalar = Scalar::from_bytes_unchecked(&[
115    0x30, 0x86, 0xd2, 0x21, 0xa7, 0xd4, 0x6b, 0xcd, 0xe8, 0x6c, 0x90, 0xe4, 0x92, 0x84, 0xeb, 0x15,
116    0x3d, 0xaa, 0x8a, 0x14, 0x71, 0xe8, 0xca, 0x7f, 0xe8, 0x93, 0x20, 0x9a, 0x45, 0xdb, 0xb0, 0x31,
117]);
118
119const G2: Scalar = Scalar::from_bytes_unchecked(&[
120    0xe4, 0x43, 0x7e, 0xd6, 0x01, 0x0e, 0x88, 0x28, 0x6f, 0x54, 0x7f, 0xa9, 0x0a, 0xbf, 0xe4, 0xc4,
121    0x22, 0x12, 0x08, 0xac, 0x9d, 0xf5, 0x06, 0xc6, 0x15, 0x71, 0xb4, 0xae, 0x8a, 0xc4, 0x7f, 0x71,
122]);
123
124/*
125 * Proof for decompose_scalar's bounds.
126 *
127 * Let
128 *  - epsilon1 = 2^256 * |g1/2^384 - b2/d|
129 *  - epsilon2 = 2^256 * |g2/2^384 - (-b1)/d|
130 *  - c1 = round(k*g1/2^384)
131 *  - c2 = round(k*g2/2^384)
132 *
133 * Lemma 1: |c1 - k*b2/d| < 2^-1 + epsilon1
134 *
135 *    |c1 - k*b2/d|
136 *  =
137 *    |c1 - k*g1/2^384 + k*g1/2^384 - k*b2/d|
138 * <=   {triangle inequality}
139 *    |c1 - k*g1/2^384| + |k*g1/2^384 - k*b2/d|
140 *  =
141 *    |c1 - k*g1/2^384| + k*|g1/2^384 - b2/d|
142 * <    {rounding in c1 and 0 <= k < 2^256}
143 *    2^-1 + 2^256 * |g1/2^384 - b2/d|
144 *  =   {definition of epsilon1}
145 *    2^-1 + epsilon1
146 *
147 * Lemma 2: |c2 - k*(-b1)/d| < 2^-1 + epsilon2
148 *
149 *    |c2 - k*(-b1)/d|
150 *  =
151 *    |c2 - k*g2/2^384 + k*g2/2^384 - k*(-b1)/d|
152 * <=   {triangle inequality}
153 *    |c2 - k*g2/2^384| + |k*g2/2^384 - k*(-b1)/d|
154 *  =
155 *    |c2 - k*g2/2^384| + k*|g2/2^384 - (-b1)/d|
156 * <    {rounding in c2 and 0 <= k < 2^256}
157 *    2^-1 + 2^256 * |g2/2^384 - (-b1)/d|
158 *  =   {definition of epsilon2}
159 *    2^-1 + epsilon2
160 *
161 * Let
162 *  - k1 = k - c1*a1 - c2*a2
163 *  - k2 = - c1*b1 - c2*b2
164 *
165 * Lemma 3: |k1| < (a1 + a2 + 1)/2 < 2^128
166 *
167 *    |k1|
168 *  =   {definition of k1}
169 *    |k - c1*a1 - c2*a2|
170 *  =   {(a1*b2 - b1*a2)/n = 1}
171 *    |k*(a1*b2 - b1*a2)/n - c1*a1 - c2*a2|
172 *  =
173 *    |a1*(k*b2/n - c1) + a2*(k*(-b1)/n - c2)|
174 * <=   {triangle inequality}
175 *    a1*|k*b2/n - c1| + a2*|k*(-b1)/n - c2|
176 * <    {Lemma 1 and Lemma 2}
177 *    a1*(2^-1 + epslion1) + a2*(2^-1 + epsilon2)
178 * <    {rounding up to an integer}
179 *    (a1 + a2 + 1)/2
180 * <    {rounding up to a power of 2}
181 *    2^128
182 *
183 * Lemma 4: |k2| < (-b1 + b2)/2 + 1 < 2^128
184 *
185 *    |k2|
186 *  =   {definition of k2}
187 *    |- c1*a1 - c2*a2|
188 *  =   {(b1*b2 - b1*b2)/n = 0}
189 *    |k*(b1*b2 - b1*b2)/n - c1*b1 - c2*b2|
190 *  =
191 *    |b1*(k*b2/n - c1) + b2*(k*(-b1)/n - c2)|
192 * <=   {triangle inequality}
193 *    (-b1)*|k*b2/n - c1| + b2*|k*(-b1)/n - c2|
194 * <    {Lemma 1 and Lemma 2}
195 *    (-b1)*(2^-1 + epslion1) + b2*(2^-1 + epsilon2)
196 * <    {rounding up to an integer}
197 *    (-b1 + b2)/2 + 1
198 * <    {rounding up to a power of 2}
199 *    2^128
200 *
201 * Let
202 *  - r2 = k2 mod n
203 *  - r1 = k - r2*lambda mod n.
204 *
205 * Notice that r1 is defined such that r1 + r2 * lambda == k (mod n).
206 *
207 * Lemma 5: r1 == k1 mod n.
208 *
209 *    r1
210 * ==   {definition of r1 and r2}
211 *    k - k2*lambda
212 * ==   {definition of k2}
213 *    k - (- c1*b1 - c2*b2)*lambda
214 * ==
215 *    k + c1*b1*lambda + c2*b2*lambda
216 * ==  {a1 + b1*lambda == 0 mod n and a2 + b2*lambda == 0 mod n}
217 *    k - c1*a1 - c2*a2
218 * ==  {definition of k1}
219 *    k1
220 *
221 * From Lemma 3, Lemma 4, Lemma 5 and the definition of r2, we can conclude that
222 *
223 *  - either r1 < 2^128 or -r1 mod n < 2^128
224 *  - either r2 < 2^128 or -r2 mod n < 2^128.
225 *
226 * Q.E.D.
227 */
228
229/// Find r1 and r2 given k, such that r1 + r2 * lambda == k mod n.
230fn decompose_scalar(k: &Scalar) -> (Scalar, Scalar) {
231    // these _vartime calls are constant time since the shift amount is constant
232    let c1 = WideScalar::mul_shift_vartime(k, &G1, 384) * MINUS_B1;
233    let c2 = WideScalar::mul_shift_vartime(k, &G2, 384) * MINUS_B2;
234    let r2 = c1 + c2;
235    let r1 = k + r2 * MINUS_LAMBDA;
236
237    (r1, r2)
238}
239
240// This needs to be an object to have Default implemented for it
241// (required because it's used in static_map later)
242// Otherwise we could just have a function returning an array.
243#[derive(Copy, Clone)]
244struct Radix16Decomposition<const D: usize>([i8; D]);
245
246impl<const D: usize> Radix16Decomposition<D> {
247    /// Returns an object containing a decomposition
248    /// `[a_0, ..., a_D]` such that `sum(a_j * 2^(j * 4)) == x`,
249    /// and `-8 <= a_j <= 7`.
250    /// Assumes `x < 2^(4*(D-1))`.
251    fn new(x: &Scalar) -> Self {
252        debug_assert!((x >> (4 * (D - 1))).is_zero().unwrap_u8() == 1);
253
254        // The resulting decomposition can be negative, so, despite the limit on `x`,
255        // we need an additional byte to store the carry.
256        let mut output = [0i8; D];
257
258        // Step 1: change radix.
259        // Convert from radix 256 (bytes) to radix 16 (nibbles)
260        let bytes = x.to_bytes();
261        for i in 0..(D - 1) / 2 {
262            output[2 * i] = (bytes[31 - i] & 0xf) as i8;
263            output[2 * i + 1] = ((bytes[31 - i] >> 4) & 0xf) as i8;
264        }
265
266        // Step 2: recenter coefficients from [0,16) to [-8,8)
267        for i in 0..(D - 1) {
268            let carry = (output[i] + 8) >> 4;
269            output[i] -= carry << 4;
270            output[i + 1] += carry;
271        }
272
273        Self(output)
274    }
275}
276
277impl<const D: usize> Default for Radix16Decomposition<D> {
278    fn default() -> Self {
279        Self([0i8; D])
280    }
281}
282
283impl<const N: usize> LinearCombination<[(ProjectivePoint, Scalar); N]> for ProjectivePoint {
284    fn lincomb_ext(points_and_scalars: &[(ProjectivePoint, Scalar); N]) -> Self {
285        let mut tables = [(LookupTable::default(), LookupTable::default()); N];
286        let mut digits = [(
287            Radix16Decomposition::<33>::default(),
288            Radix16Decomposition::<33>::default(),
289        ); N];
290
291        lincomb(points_and_scalars, &mut tables, &mut digits)
292    }
293}
294
295#[cfg(feature = "alloc")]
296impl LinearCombination<[(ProjectivePoint, Scalar)]> for ProjectivePoint {
297    fn lincomb_ext(points_and_scalars: &[(ProjectivePoint, Scalar)]) -> Self {
298        let mut tables =
299            vec![(LookupTable::default(), LookupTable::default()); points_and_scalars.len()];
300        let mut digits = vec![
301            (
302                Radix16Decomposition::<33>::default(),
303                Radix16Decomposition::<33>::default(),
304            );
305            points_and_scalars.len()
306        ];
307
308        lincomb(points_and_scalars, &mut tables, &mut digits)
309    }
310}
311
312fn lincomb(
313    xks: &[(ProjectivePoint, Scalar)],
314    tables: &mut [(LookupTable, LookupTable)],
315    digits: &mut [(Radix16Decomposition<33>, Radix16Decomposition<33>)],
316) -> ProjectivePoint {
317    xks.iter().enumerate().for_each(|(i, (x, k))| {
318        let (r1, r2) = decompose_scalar(k);
319        let x_beta = x.endomorphism();
320        let (r1_sign, r2_sign) = (r1.is_high(), r2.is_high());
321
322        let (r1_c, r2_c) = (
323            Scalar::conditional_select(&r1, &-r1, r1_sign),
324            Scalar::conditional_select(&r2, &-r2, r2_sign),
325        );
326
327        tables[i] = (
328            LookupTable::from(&ProjectivePoint::conditional_select(x, &-*x, r1_sign)),
329            LookupTable::from(&ProjectivePoint::conditional_select(
330                &x_beta, &-x_beta, r2_sign,
331            )),
332        );
333
334        digits[i] = (
335            Radix16Decomposition::<33>::new(&r1_c),
336            Radix16Decomposition::<33>::new(&r2_c),
337        )
338    });
339
340    let mut acc = ProjectivePoint::IDENTITY;
341    for component in 0..xks.len() {
342        let (digit1, digit2) = digits[component];
343        let (table1, table2) = tables[component];
344
345        acc += &table1.select(digit1.0[32]);
346        acc += &table2.select(digit2.0[32]);
347    }
348
349    for i in (0..32).rev() {
350        for _j in 0..4 {
351            acc = acc.double();
352        }
353
354        for component in 0..xks.len() {
355            let (digit1, digit2) = digits[component];
356            let (table1, table2) = tables[component];
357
358            acc += &table1.select(digit1.0[i]);
359            acc += &table2.select(digit2.0[i]);
360        }
361    }
362    acc
363}
364
365/// Lazily computed basepoint table.
366#[cfg(feature = "precomputed-tables")]
367static GEN_LOOKUP_TABLE: Lazy<[LookupTable; 33]> = Lazy::new(precompute_gen_lookup_table);
368
369#[cfg(feature = "precomputed-tables")]
370fn precompute_gen_lookup_table() -> [LookupTable; 33] {
371    let mut gen = ProjectivePoint::GENERATOR;
372    let mut res = [LookupTable::default(); 33];
373
374    for i in 0..33 {
375        res[i] = LookupTable::from(&gen);
376        // We are storing tables spaced by two radix steps,
377        // to decrease the size of the precomputed data.
378        for _ in 0..8 {
379            gen = gen.double();
380        }
381    }
382    res
383}
384
385impl MulByGenerator for ProjectivePoint {
386    /// Calculates `k * G`, where `G` is the generator.
387    #[cfg(not(feature = "precomputed-tables"))]
388    fn mul_by_generator(k: &Scalar) -> ProjectivePoint {
389        ProjectivePoint::GENERATOR * k
390    }
391
392    /// Calculates `k * G`, where `G` is the generator.
393    #[cfg(feature = "precomputed-tables")]
394    fn mul_by_generator(k: &Scalar) -> ProjectivePoint {
395        let digits = Radix16Decomposition::<65>::new(k);
396        let table = *GEN_LOOKUP_TABLE;
397        let mut acc = table[32].select(digits.0[64]);
398        let mut acc2 = ProjectivePoint::IDENTITY;
399        for i in (0..32).rev() {
400            acc2 += &table[i].select(digits.0[i * 2 + 1]);
401            acc += &table[i].select(digits.0[i * 2]);
402        }
403        // This is the price of halving the precomputed table size (from 60kb to 30kb)
404        // The performance hit is minor, about 3%.
405        for _ in 0..4 {
406            acc2 = acc2.double();
407        }
408        acc + acc2
409    }
410}
411
412#[inline(always)]
413fn mul(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint {
414    ProjectivePoint::lincomb_ext(&[(*x, *k)])
415}
416
417impl Mul<Scalar> for ProjectivePoint {
418    type Output = ProjectivePoint;
419
420    fn mul(self, other: Scalar) -> ProjectivePoint {
421        mul(&self, &other)
422    }
423}
424
425impl Mul<&Scalar> for &ProjectivePoint {
426    type Output = ProjectivePoint;
427
428    fn mul(self, other: &Scalar) -> ProjectivePoint {
429        mul(self, other)
430    }
431}
432
433impl Mul<&Scalar> for ProjectivePoint {
434    type Output = ProjectivePoint;
435
436    fn mul(self, other: &Scalar) -> ProjectivePoint {
437        mul(&self, other)
438    }
439}
440
441impl MulAssign<Scalar> for ProjectivePoint {
442    fn mul_assign(&mut self, rhs: Scalar) {
443        *self = mul(self, &rhs);
444    }
445}
446
447impl MulAssign<&Scalar> for ProjectivePoint {
448    fn mul_assign(&mut self, rhs: &Scalar) {
449        *self = mul(self, rhs);
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use crate::arithmetic::{ProjectivePoint, Scalar};
457    use elliptic_curve::{
458        ops::{LinearCombination as _, MulByGenerator},
459        rand_core::OsRng,
460        Field, Group,
461    };
462
463    #[test]
464    fn test_lincomb() {
465        let x = ProjectivePoint::random(&mut OsRng);
466        let y = ProjectivePoint::random(&mut OsRng);
467        let k = Scalar::random(&mut OsRng);
468        let l = Scalar::random(&mut OsRng);
469
470        let reference = &x * &k + &y * &l;
471        let test = ProjectivePoint::lincomb(&x, &k, &y, &l);
472        assert_eq!(reference, test);
473    }
474
475    #[test]
476    fn test_mul_by_generator() {
477        let k = Scalar::random(&mut OsRng);
478        let reference = &ProjectivePoint::GENERATOR * &k;
479        let test = ProjectivePoint::mul_by_generator(&k);
480        assert_eq!(reference, test);
481    }
482
483    #[cfg(feature = "alloc")]
484    #[test]
485    fn test_lincomb_slice() {
486        let x = ProjectivePoint::random(&mut OsRng);
487        let y = ProjectivePoint::random(&mut OsRng);
488        let k = Scalar::random(&mut OsRng);
489        let l = Scalar::random(&mut OsRng);
490
491        let reference = &x * &k + &y * &l;
492        let points_and_scalars = vec![(x, k), (y, l)];
493
494        let test = ProjectivePoint::lincomb_ext(points_and_scalars.as_slice());
495        assert_eq!(reference, test);
496    }
497}