pasta_curves/arithmetic/
fields.rs

1//! This module contains the `Field` abstraction that allows us to write
2//! code that generalizes over a pair of fields.
3
4use core::mem::size_of;
5
6use static_assertions::const_assert;
7
8#[cfg(feature = "sqrt-table")]
9use alloc::{boxed::Box, vec::Vec};
10#[cfg(feature = "sqrt-table")]
11use core::marker::PhantomData;
12
13#[cfg(feature = "sqrt-table")]
14use subtle::Choice;
15
16const_assert!(size_of::<usize>() >= 4);
17
18/// An internal trait that exposes additional operations related to calculating square roots of
19/// prime-order finite fields.
20pub(crate) trait SqrtTableHelpers: ff::PrimeField {
21    /// Raise this field element to the power $(t-1)/2$.
22    ///
23    /// Field implementations may override this to use an efficient addition chain.
24    fn pow_by_t_minus1_over2(&self) -> Self;
25
26    /// Gets the lower 32 bits of this field element when expressed
27    /// canonically.
28    fn get_lower_32(&self) -> u32;
29}
30
31/// Parameters for a perfect hash function used in square root computation.
32#[cfg(feature = "sqrt-table")]
33#[cfg_attr(docsrs, doc(cfg(feature = "sqrt-table")))]
34#[derive(Debug)]
35struct SqrtHasher<F: SqrtTableHelpers> {
36    hash_xor: u32,
37    hash_mod: usize,
38    marker: PhantomData<F>,
39}
40
41#[cfg(feature = "sqrt-table")]
42impl<F: SqrtTableHelpers> SqrtHasher<F> {
43    /// Returns a perfect hash of x for use with SqrtTables::inv.
44    fn hash(&self, x: &F) -> usize {
45        // This is just the simplest constant-time perfect hash construction that could
46        // possibly work. The 32 low-order bits are unique within the 2^S order subgroup,
47        // then the xor acts as "salt" to injectively randomize the output when taken modulo
48        // `hash_mod`. Since the table is small, we do not need anything more complicated.
49        ((x.get_lower_32() ^ self.hash_xor) as usize) % self.hash_mod
50    }
51}
52
53/// Tables used for square root computation.
54#[cfg(feature = "sqrt-table")]
55#[cfg_attr(docsrs, doc(cfg(feature = "sqrt-table")))]
56#[derive(Debug)]
57pub(crate) struct SqrtTables<F: SqrtTableHelpers> {
58    hasher: SqrtHasher<F>,
59    inv: Vec<u8>,
60    g0: Box<[F; 256]>,
61    g1: Box<[F; 256]>,
62    g2: Box<[F; 256]>,
63    g3: Box<[F; 129]>,
64}
65
66#[cfg(feature = "sqrt-table")]
67impl<F: SqrtTableHelpers> SqrtTables<F> {
68    /// Build tables given parameters for the perfect hash.
69    pub fn new(hash_xor: u32, hash_mod: usize) -> Self {
70        use alloc::vec;
71
72        let hasher = SqrtHasher {
73            hash_xor,
74            hash_mod,
75            marker: PhantomData,
76        };
77
78        let mut gtab = (0..4).scan(F::ROOT_OF_UNITY, |gi, _| {
79            // gi == ROOT_OF_UNITY^(256^i)
80            let gtab_i: Vec<F> = (0..256)
81                .scan(F::ONE, |acc, _| {
82                    let res = *acc;
83                    *acc *= *gi;
84                    Some(res)
85                })
86                .collect();
87            *gi = gtab_i[255] * *gi;
88            Some(gtab_i)
89        });
90        let gtab_0 = gtab.next().unwrap();
91        let gtab_1 = gtab.next().unwrap();
92        let gtab_2 = gtab.next().unwrap();
93        let mut gtab_3 = gtab.next().unwrap();
94        assert_eq!(gtab.next(), None);
95
96        // Now invert gtab[3].
97        let mut inv: Vec<u8> = vec![1; hash_mod];
98        for (j, gtab_3_j) in gtab_3.iter().enumerate() {
99            let hash = hasher.hash(gtab_3_j);
100            // 1 is the last value to be assigned, so this ensures there are no collisions.
101            assert!(inv[hash] == 1);
102            inv[hash] = ((256 - j) & 0xFF) as u8;
103        }
104
105        gtab_3.truncate(129);
106
107        SqrtTables::<F> {
108            hasher,
109            inv,
110            g0: gtab_0.into_boxed_slice().try_into().unwrap(),
111            g1: gtab_1.into_boxed_slice().try_into().unwrap(),
112            g2: gtab_2.into_boxed_slice().try_into().unwrap(),
113            g3: gtab_3.into_boxed_slice().try_into().unwrap(),
114        }
115    }
116
117    /// Computes:
118    ///
119    /// * (true,  sqrt(num/div)),                 if num and div are nonzero and num/div is a square in the field;
120    /// * (true,  0),                             if num is zero;
121    /// * (false, 0),                             if num is nonzero and div is zero;
122    /// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field;
123    ///
124    /// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare).
125    ///
126    /// The choice of root from sqrt is unspecified.
127    pub fn sqrt_ratio(&self, num: &F, div: &F) -> (Choice, F) {
128        // Based on:
129        // * [Sarkar2020](https://eprint.iacr.org/2020/1407)
130        // * [BDLSY2012](https://cr.yp.to/papers.html#ed25519)
131        //
132        // We need to calculate uv and v, where v = u^((T-1)/2), u = num/div, and p-1 = T * 2^S.
133        // We can rewrite as follows:
134        //
135        //      v = (num/div)^((T-1)/2)
136        //        = num^((T-1)/2) * div^(p-1 - (T-1)/2)    [Fermat's Little Theorem]
137        //        =       "       * div^(T * 2^S - (T-1)/2)
138        //        =       "       * div^((2^(S+1) - 1)*(T-1)/2 + 2^S)
139        //        = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S)
140        //
141        // Let  w = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S - 1).
142        // Then v = w * div, and uv = num * v / div = num * w.
143        //
144        // We calculate:
145        //
146        //      s = div^(2^S - 1) using an addition chain
147        //      t = div^(2^(S+1) - 1) = s^2 * div
148        //      w = (num * t)^((T-1)/2) * s using another addition chain
149        //
150        // then u and uv as above. The addition chains are given in
151        // https://github.com/zcash/pasta/blob/master/addchain_sqrt.py .
152        // The overall cost of this part is similar to a single full-width exponentiation,
153        // regardless of S.
154
155        let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
156
157        // s = div^(2^S - 1)
158        let s = (0..5).fold(*div, |d: F, i| sqr(d, 1 << i) * d);
159
160        // t == div^(2^(S+1) - 1)
161        let t = s.square() * div;
162
163        // w = (num * t)^((T-1)/2) * s
164        let w = (t * num).pow_by_t_minus1_over2() * s;
165
166        // v == u^((T-1)/2)
167        let v = w * div;
168
169        // uv = u * v
170        let uv = w * num;
171
172        let res = self.sqrt_common(&uv, &v);
173
174        let sqdiv = res.square() * div;
175        let is_square = (sqdiv - num).is_zero();
176        let is_nonsquare = (sqdiv - F::ROOT_OF_UNITY * num).is_zero();
177        assert!(bool::from(
178            num.is_zero() | div.is_zero() | (is_square ^ is_nonsquare)
179        ));
180
181        (is_square, res)
182    }
183
184    /// Same as sqrt_ratio(u, one()) but more efficient.
185    pub fn sqrt_alt(&self, u: &F) -> (Choice, F) {
186        let v = u.pow_by_t_minus1_over2();
187        let uv = *u * v;
188
189        let res = self.sqrt_common(&uv, &v);
190
191        let sq = res.square();
192        let is_square = (sq - u).is_zero();
193        let is_nonsquare = (sq - F::ROOT_OF_UNITY * u).is_zero();
194        assert!(bool::from(u.is_zero() | (is_square ^ is_nonsquare)));
195
196        (is_square, res)
197    }
198
199    /// Common part of sqrt_ratio and sqrt_alt: return their result given v = u^((T-1)/2) and uv = u * v.
200    fn sqrt_common(&self, uv: &F, v: &F) -> F {
201        let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
202        let inv = |x: F| self.inv[self.hasher.hash(&x)] as usize;
203
204        let x3 = *uv * v;
205        let x2 = sqr(x3, 8);
206        let x1 = sqr(x2, 8);
207        let x0 = sqr(x1, 8);
208
209        // i = 0, 1
210        let mut t_ = inv(x0); // = t >> 16
211                              // 1 == x0 * ROOT_OF_UNITY^(t_ << 24)
212        assert!(t_ < 0x100);
213        let alpha = x1 * self.g2[t_];
214
215        // i = 2
216        t_ += inv(alpha) << 8; // = t >> 8
217                               // 1 == x1 * ROOT_OF_UNITY^(t_ << 16)
218        assert!(t_ < 0x10000);
219        let alpha = x2 * self.g1[t_ & 0xFF] * self.g2[t_ >> 8];
220
221        // i = 3
222        t_ += inv(alpha) << 16; // = t
223                                // 1 == x2 * ROOT_OF_UNITY^(t_ << 8)
224        assert!(t_ < 0x1000000);
225        let alpha = x3 * self.g0[t_ & 0xFF] * self.g1[(t_ >> 8) & 0xFF] * self.g2[t_ >> 16];
226
227        t_ += inv(alpha) << 24; // = t << 1
228                                // 1 == x3 * ROOT_OF_UNITY^t_
229        t_ = (((t_ as u64) + 1) >> 1) as usize;
230        assert!(t_ <= 0x80000000);
231
232        *uv * self.g0[t_ & 0xFF]
233            * self.g1[(t_ >> 8) & 0xFF]
234            * self.g2[(t_ >> 16) & 0xFF]
235            * self.g3[t_ >> 24]
236    }
237}
238
239/// Compute a + b + carry, returning the result and the new carry over.
240#[inline(always)]
241pub(crate) const fn adc(a: u64, b: u64, carry: u64) -> (u64, u64) {
242    let ret = (a as u128) + (b as u128) + (carry as u128);
243    (ret as u64, (ret >> 64) as u64)
244}
245
246/// Compute a - (b + borrow), returning the result and the new borrow.
247#[inline(always)]
248pub(crate) const fn sbb(a: u64, b: u64, borrow: u64) -> (u64, u64) {
249    let ret = (a as u128).wrapping_sub((b as u128) + ((borrow >> 63) as u128));
250    (ret as u64, (ret >> 64) as u64)
251}
252
253/// Compute a + (b * c) + carry, returning the result and the new carry over.
254#[inline(always)]
255pub(crate) const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) {
256    let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128);
257    (ret as u64, (ret >> 64) as u64)
258}