pasta_curves/arithmetic/fields.rs
1//! This module contains the `Field` abstraction that allows us to write
2//! code that generalizes over a pair of fields.
3
4use core::mem::size_of;
5
6use static_assertions::const_assert;
7
8#[cfg(feature = "sqrt-table")]
9use alloc::{boxed::Box, vec::Vec};
10#[cfg(feature = "sqrt-table")]
11use core::marker::PhantomData;
12
13#[cfg(feature = "sqrt-table")]
14use subtle::Choice;
15
16const_assert!(size_of::<usize>() >= 4);
17
18/// An internal trait that exposes additional operations related to calculating square roots of
19/// prime-order finite fields.
20pub(crate) trait SqrtTableHelpers: ff::PrimeField {
21 /// Raise this field element to the power $(t-1)/2$.
22 ///
23 /// Field implementations may override this to use an efficient addition chain.
24 fn pow_by_t_minus1_over2(&self) -> Self;
25
26 /// Gets the lower 32 bits of this field element when expressed
27 /// canonically.
28 fn get_lower_32(&self) -> u32;
29}
30
31/// Parameters for a perfect hash function used in square root computation.
32#[cfg(feature = "sqrt-table")]
33#[cfg_attr(docsrs, doc(cfg(feature = "sqrt-table")))]
34#[derive(Debug)]
35struct SqrtHasher<F: SqrtTableHelpers> {
36 hash_xor: u32,
37 hash_mod: usize,
38 marker: PhantomData<F>,
39}
40
41#[cfg(feature = "sqrt-table")]
42impl<F: SqrtTableHelpers> SqrtHasher<F> {
43 /// Returns a perfect hash of x for use with SqrtTables::inv.
44 fn hash(&self, x: &F) -> usize {
45 // This is just the simplest constant-time perfect hash construction that could
46 // possibly work. The 32 low-order bits are unique within the 2^S order subgroup,
47 // then the xor acts as "salt" to injectively randomize the output when taken modulo
48 // `hash_mod`. Since the table is small, we do not need anything more complicated.
49 ((x.get_lower_32() ^ self.hash_xor) as usize) % self.hash_mod
50 }
51}
52
53/// Tables used for square root computation.
54#[cfg(feature = "sqrt-table")]
55#[cfg_attr(docsrs, doc(cfg(feature = "sqrt-table")))]
56#[derive(Debug)]
57pub(crate) struct SqrtTables<F: SqrtTableHelpers> {
58 hasher: SqrtHasher<F>,
59 inv: Vec<u8>,
60 g0: Box<[F; 256]>,
61 g1: Box<[F; 256]>,
62 g2: Box<[F; 256]>,
63 g3: Box<[F; 129]>,
64}
65
66#[cfg(feature = "sqrt-table")]
67impl<F: SqrtTableHelpers> SqrtTables<F> {
68 /// Build tables given parameters for the perfect hash.
69 pub fn new(hash_xor: u32, hash_mod: usize) -> Self {
70 use alloc::vec;
71
72 let hasher = SqrtHasher {
73 hash_xor,
74 hash_mod,
75 marker: PhantomData,
76 };
77
78 let mut gtab = (0..4).scan(F::ROOT_OF_UNITY, |gi, _| {
79 // gi == ROOT_OF_UNITY^(256^i)
80 let gtab_i: Vec<F> = (0..256)
81 .scan(F::ONE, |acc, _| {
82 let res = *acc;
83 *acc *= *gi;
84 Some(res)
85 })
86 .collect();
87 *gi = gtab_i[255] * *gi;
88 Some(gtab_i)
89 });
90 let gtab_0 = gtab.next().unwrap();
91 let gtab_1 = gtab.next().unwrap();
92 let gtab_2 = gtab.next().unwrap();
93 let mut gtab_3 = gtab.next().unwrap();
94 assert_eq!(gtab.next(), None);
95
96 // Now invert gtab[3].
97 let mut inv: Vec<u8> = vec![1; hash_mod];
98 for (j, gtab_3_j) in gtab_3.iter().enumerate() {
99 let hash = hasher.hash(gtab_3_j);
100 // 1 is the last value to be assigned, so this ensures there are no collisions.
101 assert!(inv[hash] == 1);
102 inv[hash] = ((256 - j) & 0xFF) as u8;
103 }
104
105 gtab_3.truncate(129);
106
107 SqrtTables::<F> {
108 hasher,
109 inv,
110 g0: gtab_0.into_boxed_slice().try_into().unwrap(),
111 g1: gtab_1.into_boxed_slice().try_into().unwrap(),
112 g2: gtab_2.into_boxed_slice().try_into().unwrap(),
113 g3: gtab_3.into_boxed_slice().try_into().unwrap(),
114 }
115 }
116
117 /// Computes:
118 ///
119 /// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field;
120 /// * (true, 0), if num is zero;
121 /// * (false, 0), if num is nonzero and div is zero;
122 /// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field;
123 ///
124 /// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare).
125 ///
126 /// The choice of root from sqrt is unspecified.
127 pub fn sqrt_ratio(&self, num: &F, div: &F) -> (Choice, F) {
128 // Based on:
129 // * [Sarkar2020](https://eprint.iacr.org/2020/1407)
130 // * [BDLSY2012](https://cr.yp.to/papers.html#ed25519)
131 //
132 // We need to calculate uv and v, where v = u^((T-1)/2), u = num/div, and p-1 = T * 2^S.
133 // We can rewrite as follows:
134 //
135 // v = (num/div)^((T-1)/2)
136 // = num^((T-1)/2) * div^(p-1 - (T-1)/2) [Fermat's Little Theorem]
137 // = " * div^(T * 2^S - (T-1)/2)
138 // = " * div^((2^(S+1) - 1)*(T-1)/2 + 2^S)
139 // = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S)
140 //
141 // Let w = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S - 1).
142 // Then v = w * div, and uv = num * v / div = num * w.
143 //
144 // We calculate:
145 //
146 // s = div^(2^S - 1) using an addition chain
147 // t = div^(2^(S+1) - 1) = s^2 * div
148 // w = (num * t)^((T-1)/2) * s using another addition chain
149 //
150 // then u and uv as above. The addition chains are given in
151 // https://github.com/zcash/pasta/blob/master/addchain_sqrt.py .
152 // The overall cost of this part is similar to a single full-width exponentiation,
153 // regardless of S.
154
155 let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
156
157 // s = div^(2^S - 1)
158 let s = (0..5).fold(*div, |d: F, i| sqr(d, 1 << i) * d);
159
160 // t == div^(2^(S+1) - 1)
161 let t = s.square() * div;
162
163 // w = (num * t)^((T-1)/2) * s
164 let w = (t * num).pow_by_t_minus1_over2() * s;
165
166 // v == u^((T-1)/2)
167 let v = w * div;
168
169 // uv = u * v
170 let uv = w * num;
171
172 let res = self.sqrt_common(&uv, &v);
173
174 let sqdiv = res.square() * div;
175 let is_square = (sqdiv - num).is_zero();
176 let is_nonsquare = (sqdiv - F::ROOT_OF_UNITY * num).is_zero();
177 assert!(bool::from(
178 num.is_zero() | div.is_zero() | (is_square ^ is_nonsquare)
179 ));
180
181 (is_square, res)
182 }
183
184 /// Same as sqrt_ratio(u, one()) but more efficient.
185 pub fn sqrt_alt(&self, u: &F) -> (Choice, F) {
186 let v = u.pow_by_t_minus1_over2();
187 let uv = *u * v;
188
189 let res = self.sqrt_common(&uv, &v);
190
191 let sq = res.square();
192 let is_square = (sq - u).is_zero();
193 let is_nonsquare = (sq - F::ROOT_OF_UNITY * u).is_zero();
194 assert!(bool::from(u.is_zero() | (is_square ^ is_nonsquare)));
195
196 (is_square, res)
197 }
198
199 /// Common part of sqrt_ratio and sqrt_alt: return their result given v = u^((T-1)/2) and uv = u * v.
200 fn sqrt_common(&self, uv: &F, v: &F) -> F {
201 let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
202 let inv = |x: F| self.inv[self.hasher.hash(&x)] as usize;
203
204 let x3 = *uv * v;
205 let x2 = sqr(x3, 8);
206 let x1 = sqr(x2, 8);
207 let x0 = sqr(x1, 8);
208
209 // i = 0, 1
210 let mut t_ = inv(x0); // = t >> 16
211 // 1 == x0 * ROOT_OF_UNITY^(t_ << 24)
212 assert!(t_ < 0x100);
213 let alpha = x1 * self.g2[t_];
214
215 // i = 2
216 t_ += inv(alpha) << 8; // = t >> 8
217 // 1 == x1 * ROOT_OF_UNITY^(t_ << 16)
218 assert!(t_ < 0x10000);
219 let alpha = x2 * self.g1[t_ & 0xFF] * self.g2[t_ >> 8];
220
221 // i = 3
222 t_ += inv(alpha) << 16; // = t
223 // 1 == x2 * ROOT_OF_UNITY^(t_ << 8)
224 assert!(t_ < 0x1000000);
225 let alpha = x3 * self.g0[t_ & 0xFF] * self.g1[(t_ >> 8) & 0xFF] * self.g2[t_ >> 16];
226
227 t_ += inv(alpha) << 24; // = t << 1
228 // 1 == x3 * ROOT_OF_UNITY^t_
229 t_ = (((t_ as u64) + 1) >> 1) as usize;
230 assert!(t_ <= 0x80000000);
231
232 *uv * self.g0[t_ & 0xFF]
233 * self.g1[(t_ >> 8) & 0xFF]
234 * self.g2[(t_ >> 16) & 0xFF]
235 * self.g3[t_ >> 24]
236 }
237}
238
239/// Compute a + b + carry, returning the result and the new carry over.
240#[inline(always)]
241pub(crate) const fn adc(a: u64, b: u64, carry: u64) -> (u64, u64) {
242 let ret = (a as u128) + (b as u128) + (carry as u128);
243 (ret as u64, (ret >> 64) as u64)
244}
245
246/// Compute a - (b + borrow), returning the result and the new borrow.
247#[inline(always)]
248pub(crate) const fn sbb(a: u64, b: u64, borrow: u64) -> (u64, u64) {
249 let ret = (a as u128).wrapping_sub((b as u128) + ((borrow >> 63) as u128));
250 (ret as u64, (ret >> 64) as u64)
251}
252
253/// Compute a + (b * c) + carry, returning the result and the new carry over.
254#[inline(always)]
255pub(crate) const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) {
256 let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128);
257 (ret as u64, (ret >> 64) as u64)
258}