1#![allow(clippy::op_ref)]
2
3use ff::{Field, FromUniformBytes, PrimeField};
4use pasta_curves::arithmetic::CurveExt;
5use static_assertions::const_assert;
6use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
7
8use crate::ff_ext::Legendre;
9
10fn hash_to_field<F: FromUniformBytes<64>>(
13 method: &str,
14 curve_id: &str,
15 domain_prefix: &str,
16 message: &[u8],
17 buf: &mut [F; 2],
18) {
19 assert!(domain_prefix.len() < 256);
20 assert!((18 + method.len() + curve_id.len() + domain_prefix.len()) < 256);
21
22 const CHUNKLEN: usize = 64;
25 const_assert!(CHUNKLEN * 2 < 256);
26
27 const R_IN_BYTES: usize = 128;
29
30 let personal = [0u8; 16];
31 let empty_hasher = blake2b_simd::Params::new()
32 .hash_length(CHUNKLEN)
33 .personal(&personal)
34 .to_state();
35
36 let b_0 = empty_hasher
37 .clone()
38 .update(&[0; R_IN_BYTES])
39 .update(message)
40 .update(&[0, (CHUNKLEN * 2) as u8, 0])
41 .update(domain_prefix.as_bytes())
42 .update(b"-")
43 .update(curve_id.as_bytes())
44 .update(b"_XMD:BLAKE2b_")
45 .update(method.as_bytes())
46 .update(b"_RO_")
47 .update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
48 .finalize();
49
50 let b_1 = empty_hasher
51 .clone()
52 .update(b_0.as_array())
53 .update(&[1])
54 .update(domain_prefix.as_bytes())
55 .update(b"-")
56 .update(curve_id.as_bytes())
57 .update(b"_XMD:BLAKE2b_")
58 .update(method.as_bytes())
59 .update(b"_RO_")
60 .update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
61 .finalize();
62
63 let b_2 = {
64 let mut empty_hasher = empty_hasher;
65 for (l, r) in b_0.as_array().iter().zip(b_1.as_array().iter()) {
66 empty_hasher.update(&[*l ^ *r]);
67 }
68 empty_hasher
69 .update(&[2])
70 .update(domain_prefix.as_bytes())
71 .update(b"-")
72 .update(curve_id.as_bytes())
73 .update(b"_XMD:BLAKE2b_")
74 .update(method.as_bytes())
75 .update(b"_RO_")
76 .update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
77 .finalize()
78 };
79
80 for (big, buf) in [b_1, b_2].iter().zip(buf.iter_mut()) {
81 let mut little = [0u8; CHUNKLEN];
82 little.copy_from_slice(big.as_array());
83 little.reverse();
84 *buf = F::from_uniform_bytes(&little);
85 }
86}
87
88#[allow(clippy::too_many_arguments)]
90pub(crate) fn simple_svdw_map_to_curve<C>(u: C::Base, z: C::Base) -> C
91where
92 C: CurveExt,
93{
94 let zero = C::Base::ZERO;
95 let one = C::Base::ONE;
96 let a = C::a();
97 let b = C::b();
98
99 let tv1 = u.square();
101 let tv1 = z * tv1;
103 let tv2 = tv1.square();
105 let tv2 = tv2 + tv1;
107 let tv3 = tv2 + one;
109 let tv3 = b * tv3;
111 let tv2_is_not_zero = !tv2.ct_eq(&zero);
113 let tv4 = C::Base::conditional_select(&z, &-tv2, tv2_is_not_zero);
114 let tv4 = a * tv4;
116 let tv2 = tv3.square();
118 let tv6 = tv4.square();
120 let tv5 = a * tv6;
122 let tv2 = tv2 + tv5;
124 let tv2 = tv2 * tv3;
126 let tv6 = tv6 * tv4;
128 let tv5 = b * tv6;
130 let tv2 = tv2 + tv5;
132 let x = tv1 * tv3;
134 let (is_gx1_square, y1) = sqrt_ratio(&tv2, &tv6, &z);
136 let y = tv1 * u;
138 let y = y * y1;
140 let x = C::Base::conditional_select(&x, &tv3, is_gx1_square);
142 let y = C::Base::conditional_select(&y, &y1, is_gx1_square);
144 let e1 = u.is_odd().ct_eq(&y.is_odd());
146 let y = C::Base::conditional_select(&-y, &y, e1);
148 let x = x * tv4.invert().unwrap();
150 C::new_jacobian(x, y, one).unwrap()
152}
153
154#[allow(clippy::type_complexity)]
155pub(crate) fn simple_svdw_hash_to_curve<'a, C>(
156 curve_id: &'static str,
157 domain_prefix: &'a str,
158 z: C::Base,
159) -> Box<dyn Fn(&[u8]) -> C + 'a>
160where
161 C: CurveExt,
162 C::Base: FromUniformBytes<64>,
163{
164 Box::new(move |message| {
165 let mut us = [C::Base::ZERO; 2];
166 hash_to_field("SSWU", curve_id, domain_prefix, message, &mut us);
167
168 let [q0, q1]: [C; 2] = us.map(|u| simple_svdw_map_to_curve(u, z));
169
170 let r = q0 + &q1;
171 debug_assert!(bool::from(r.is_on_curve()));
172 r
173 })
174}
175
176#[allow(clippy::too_many_arguments)]
177pub(crate) fn svdw_map_to_curve<C>(
178 u: C::Base,
179 c1: C::Base,
180 c2: C::Base,
181 c3: C::Base,
182 c4: C::Base,
183 z: C::Base,
184) -> C
185where
186 C: CurveExt,
187 C::Base: Legendre,
188{
189 let one = C::Base::ONE;
190 let a = C::a();
191 let b = C::b();
192
193 let tv1 = u.square();
195 let tv1 = tv1 * c1;
197 let tv2 = one + tv1;
199 let tv1 = one - tv1;
201 let tv3 = tv1 * tv2;
203 let tv3 = tv3.invert().unwrap_or(C::Base::ZERO);
205 let tv4 = u * tv1;
207 let tv4 = tv4 * tv3;
209 let tv4 = tv4 * c3;
211 let x1 = c2 - tv4;
213 let gx1 = x1.square();
215 let gx1 = gx1 + a;
217 let gx1 = gx1 * x1;
219 let gx1 = gx1 + b;
221 let e1 = !gx1.ct_quadratic_non_residue();
223 let x2 = c2 + tv4;
225 let gx2 = x2.square();
227 let gx2 = gx2 + a;
229 let gx2 = gx2 * x2;
231 let gx2 = gx2 + b;
233 let e2 = !gx2.ct_quadratic_non_residue() & (!e1);
235 let x3 = tv2.square();
237 let x3 = x3 * tv3;
239 let x3 = x3.square();
241 let x3 = x3 * c4;
243 let x3 = x3 + z;
245 let x = C::Base::conditional_select(&x3, &x1, e1);
247 let x = C::Base::conditional_select(&x, &x2, e2);
249 let gx = x.square();
251 let gx = gx + a;
253 let gx = gx * x;
255 let gx = gx + b;
257 let y = gx.sqrt().unwrap();
259 let e3 = u.is_odd().ct_eq(&y.is_odd());
261 let y = C::Base::conditional_select(&-y, &y, e3);
263 C::new_jacobian(x, y, one).unwrap()
265}
266
267fn sqrt_ratio<F: PrimeField>(num: &F, div: &F, z: &F) -> (Choice, F) {
270 let a = div.invert().unwrap_or(F::ZERO) * num;
286 let b = a * z;
287 let sqrt_a = a.sqrt();
288 let sqrt_b = b.sqrt();
289
290 let num_is_zero = num.is_zero();
291 let div_is_zero = div.is_zero();
292 let is_square = sqrt_a.is_some();
293 let is_nonsquare = sqrt_b.is_some();
294 assert!(bool::from(
295 num_is_zero | div_is_zero | (is_square ^ is_nonsquare)
296 ));
297
298 (
299 is_square & (num_is_zero | !div_is_zero),
300 CtOption::conditional_select(&sqrt_b, &sqrt_a, is_square).unwrap(),
301 )
302}
303
304#[allow(clippy::type_complexity)]
306pub(crate) fn svdw_hash_to_curve<'a, C>(
307 curve_id: &'static str,
308 domain_prefix: &'a str,
309 z: C::Base,
310) -> Box<dyn Fn(&[u8]) -> C + 'a>
311where
312 C: CurveExt,
313 C::Base: FromUniformBytes<64> + Legendre,
314{
315 let [c1, c2, c3, c4] = svdw_precomputed_constants::<C>(z);
316
317 Box::new(move |message| {
318 let mut us = [C::Base::ZERO; 2];
319 hash_to_field("SVDW", curve_id, domain_prefix, message, &mut us);
320
321 let [q0, q1]: [C; 2] = us.map(|u| svdw_map_to_curve(u, c1, c2, c3, c4, z));
322
323 let r = q0 + &q1;
324 debug_assert!(bool::from(r.is_on_curve()));
325 r
326 })
327}
328
329pub(crate) fn svdw_precomputed_constants<C: CurveExt>(z: C::Base) -> [C::Base; 4] {
330 let a = C::a();
331 let b = C::b();
332 let one = C::Base::ONE;
333 let three = one + one + one;
334 let four = three + one;
335 let tmp = three * z.square() + four * a;
336
337 let c1 = (z.square() + a) * z + b;
339 let c2 = -z * C::Base::TWO_INV;
341 let c3 = {
343 let c3 = (-c1 * tmp).sqrt().unwrap();
344 C::Base::conditional_select(&c3, &-c3, c3.is_odd())
345 };
346 let c4 = -four * c1 * tmp.invert().unwrap();
348
349 [c1, c2, c3, c4]
350}