halo2curves_axiom/
hash_to_curve.rs

1#![allow(clippy::op_ref)]
2
3use ff::{Field, FromUniformBytes, PrimeField};
4use pasta_curves::arithmetic::CurveExt;
5use static_assertions::const_assert;
6use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
7
8use crate::ff_ext::Legendre;
9
10/// Hashes over a message and writes the output to all of `buf`.
11/// Modified from https://github.com/zcash/pasta_curves/blob/7e3fc6a4919f6462a32b79dd226cb2587b7961eb/src/hashtocurve.rs#L11.
12fn hash_to_field<F: FromUniformBytes<64>>(
13    method: &str,
14    curve_id: &str,
15    domain_prefix: &str,
16    message: &[u8],
17    buf: &mut [F; 2],
18) {
19    assert!(domain_prefix.len() < 256);
20    assert!((18 + method.len() + curve_id.len() + domain_prefix.len()) < 256);
21
22    // Assume that the field size is 32 bytes and k is 256, where k is defined in
23    // <https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.html#name-security-considerations-3>.
24    const CHUNKLEN: usize = 64;
25    const_assert!(CHUNKLEN * 2 < 256);
26
27    // Input block size of BLAKE2b.
28    const R_IN_BYTES: usize = 128;
29
30    let personal = [0u8; 16];
31    let empty_hasher = blake2b_simd::Params::new()
32        .hash_length(CHUNKLEN)
33        .personal(&personal)
34        .to_state();
35
36    let b_0 = empty_hasher
37        .clone()
38        .update(&[0; R_IN_BYTES])
39        .update(message)
40        .update(&[0, (CHUNKLEN * 2) as u8, 0])
41        .update(domain_prefix.as_bytes())
42        .update(b"-")
43        .update(curve_id.as_bytes())
44        .update(b"_XMD:BLAKE2b_")
45        .update(method.as_bytes())
46        .update(b"_RO_")
47        .update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
48        .finalize();
49
50    let b_1 = empty_hasher
51        .clone()
52        .update(b_0.as_array())
53        .update(&[1])
54        .update(domain_prefix.as_bytes())
55        .update(b"-")
56        .update(curve_id.as_bytes())
57        .update(b"_XMD:BLAKE2b_")
58        .update(method.as_bytes())
59        .update(b"_RO_")
60        .update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
61        .finalize();
62
63    let b_2 = {
64        let mut empty_hasher = empty_hasher;
65        for (l, r) in b_0.as_array().iter().zip(b_1.as_array().iter()) {
66            empty_hasher.update(&[*l ^ *r]);
67        }
68        empty_hasher
69            .update(&[2])
70            .update(domain_prefix.as_bytes())
71            .update(b"-")
72            .update(curve_id.as_bytes())
73            .update(b"_XMD:BLAKE2b_")
74            .update(method.as_bytes())
75            .update(b"_RO_")
76            .update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
77            .finalize()
78    };
79
80    for (big, buf) in [b_1, b_2].iter().zip(buf.iter_mut()) {
81        let mut little = [0u8; CHUNKLEN];
82        little.copy_from_slice(big.as_array());
83        little.reverse();
84        *buf = F::from_uniform_bytes(&little);
85    }
86}
87
88// Implementation of <https://datatracker.ietf.org/doc/html/rfc9380#name-simplified-swu-method>
89#[allow(clippy::too_many_arguments)]
90pub(crate) fn simple_svdw_map_to_curve<C>(u: C::Base, z: C::Base) -> C
91where
92    C: CurveExt,
93{
94    let zero = C::Base::ZERO;
95    let one = C::Base::ONE;
96    let a = C::a();
97    let b = C::b();
98
99    //1.  tv1 = u^2
100    let tv1 = u.square();
101    //2.  tv1 = Z * tv1
102    let tv1 = z * tv1;
103    //3.  tv2 = tv1^2
104    let tv2 = tv1.square();
105    //4.  tv2 = tv2 + tv1
106    let tv2 = tv2 + tv1;
107    //5.  tv3 = tv2 + 1
108    let tv3 = tv2 + one;
109    //6.  tv3 = B * tv3
110    let tv3 = b * tv3;
111    //7.  tv4 = CMOV(Z, -tv2, tv2 != 0) # tv4 = z if tv2 is 0 else tv4 = -tv2
112    let tv2_is_not_zero = !tv2.ct_eq(&zero);
113    let tv4 = C::Base::conditional_select(&z, &-tv2, tv2_is_not_zero);
114    //8.  tv4 = A * tv4
115    let tv4 = a * tv4;
116    //9.  tv2 = tv3^2
117    let tv2 = tv3.square();
118    //10. tv6 = tv4^2
119    let tv6 = tv4.square();
120    //11. tv5 = A * tv6
121    let tv5 = a * tv6;
122    //12. tv2 = tv2 + tv5
123    let tv2 = tv2 + tv5;
124    //13. tv2 = tv2 * tv3
125    let tv2 = tv2 * tv3;
126    //14. tv6 = tv6 * tv4
127    let tv6 = tv6 * tv4;
128    //15. tv5 = B * tv6
129    let tv5 = b * tv6;
130    //16. tv2 = tv2 + tv5
131    let tv2 = tv2 + tv5;
132    //17.   x = tv1 * tv3
133    let x = tv1 * tv3;
134    //18. (is_gx1_square, y1) = sqrt_ratio(tv2, tv6)
135    let (is_gx1_square, y1) = sqrt_ratio(&tv2, &tv6, &z);
136    //19.   y = tv1 * u
137    let y = tv1 * u;
138    //20.   y = y * y1
139    let y = y * y1;
140    //21.   x = CMOV(x, tv3, is_gx1_square)
141    let x = C::Base::conditional_select(&x, &tv3, is_gx1_square);
142    //22.   y = CMOV(y, y1, is_gx1_square)
143    let y = C::Base::conditional_select(&y, &y1, is_gx1_square);
144    //23.  e1 = sgn0(u) == sgn0(y)
145    let e1 = u.is_odd().ct_eq(&y.is_odd());
146    //24.   y = CMOV(-y, y, e1) # Select correct sign of y
147    let y = C::Base::conditional_select(&-y, &y, e1);
148    //25.   x = x / tv4
149    let x = x * tv4.invert().unwrap();
150    //26. return (x, y)
151    C::new_jacobian(x, y, one).unwrap()
152}
153
154#[allow(clippy::type_complexity)]
155pub(crate) fn simple_svdw_hash_to_curve<'a, C>(
156    curve_id: &'static str,
157    domain_prefix: &'a str,
158    z: C::Base,
159) -> Box<dyn Fn(&[u8]) -> C + 'a>
160where
161    C: CurveExt,
162    C::Base: FromUniformBytes<64>,
163{
164    Box::new(move |message| {
165        let mut us = [C::Base::ZERO; 2];
166        hash_to_field("SSWU", curve_id, domain_prefix, message, &mut us);
167
168        let [q0, q1]: [C; 2] = us.map(|u| simple_svdw_map_to_curve(u, z));
169
170        let r = q0 + &q1;
171        debug_assert!(bool::from(r.is_on_curve()));
172        r
173    })
174}
175
176#[allow(clippy::too_many_arguments)]
177pub(crate) fn svdw_map_to_curve<C>(
178    u: C::Base,
179    c1: C::Base,
180    c2: C::Base,
181    c3: C::Base,
182    c4: C::Base,
183    z: C::Base,
184) -> C
185where
186    C: CurveExt,
187    C::Base: Legendre,
188{
189    let one = C::Base::ONE;
190    let a = C::a();
191    let b = C::b();
192
193    // 1. tv1 = u^2
194    let tv1 = u.square();
195    // 2. tv1 = tv1 * c1
196    let tv1 = tv1 * c1;
197    // 3. tv2 = 1 + tv1
198    let tv2 = one + tv1;
199    // 4. tv1 = 1 - tv1
200    let tv1 = one - tv1;
201    // 5. tv3 = tv1 * tv2
202    let tv3 = tv1 * tv2;
203    // 6. tv3 = inv0(tv3)
204    let tv3 = tv3.invert().unwrap_or(C::Base::ZERO);
205    // 7. tv4 = u * tv1
206    let tv4 = u * tv1;
207    // 8. tv4 = tv4 * tv3
208    let tv4 = tv4 * tv3;
209    // 9. tv4 = tv4 * c3
210    let tv4 = tv4 * c3;
211    // 10. x1 = c2 - tv4
212    let x1 = c2 - tv4;
213    // 11. gx1 = x1^2
214    let gx1 = x1.square();
215    // 12. gx1 = gx1 + A
216    let gx1 = gx1 + a;
217    // 13. gx1 = gx1 * x1
218    let gx1 = gx1 * x1;
219    // 14. gx1 = gx1 + B
220    let gx1 = gx1 + b;
221    // 15. e1 = is_square(gx1)
222    let e1 = !gx1.ct_quadratic_non_residue();
223    // 16. x2 = c2 + tv4
224    let x2 = c2 + tv4;
225    // 17. gx2 = x2^2
226    let gx2 = x2.square();
227    // 18. gx2 = gx2 + A
228    let gx2 = gx2 + a;
229    // 19. gx2 = gx2 * x2
230    let gx2 = gx2 * x2;
231    // 20. gx2 = gx2 + B
232    let gx2 = gx2 + b;
233    // 21. e2 = is_square(gx2) AND NOT e1    # Avoid short-circuit logic ops
234    let e2 = !gx2.ct_quadratic_non_residue() & (!e1);
235    // 22. x3 = tv2^2
236    let x3 = tv2.square();
237    // 23. x3 = x3 * tv3
238    let x3 = x3 * tv3;
239    // 24. x3 = x3^2
240    let x3 = x3.square();
241    // 25. x3 = x3 * c4
242    let x3 = x3 * c4;
243    // 26. x3 = x3 + Z
244    let x3 = x3 + z;
245    // 27. x = CMOV(x3, x1, e1)    # x = x1 if gx1 is square, else x = x3
246    let x = C::Base::conditional_select(&x3, &x1, e1);
247    // 28. x = CMOV(x, x2, e2)    # x = x2 if gx2 is square and gx1 is not
248    let x = C::Base::conditional_select(&x, &x2, e2);
249    // 29. gx = x^2
250    let gx = x.square();
251    // 30. gx = gx + A
252    let gx = gx + a;
253    // 31. gx = gx * x
254    let gx = gx * x;
255    // 32. gx = gx + B
256    let gx = gx + b;
257    // 33. y = sqrt(gx)
258    let y = gx.sqrt().unwrap();
259    // 34. e3 = sgn0(u) == sgn0(y)
260    let e3 = u.is_odd().ct_eq(&y.is_odd());
261    // 35. y = CMOV(-y, y, e3)    # Select correct sign of y
262    let y = C::Base::conditional_select(&-y, &y, e3);
263    // 36. return (x, y)
264    C::new_jacobian(x, y, one).unwrap()
265}
266
267// Implement https://datatracker.ietf.org/doc/html/rfc9380#name-sqrt_ratio-for-any-field
268// Copied from ff sqrt_ratio_generic substituting F::ROOT_OF_UNITY for input Z
269fn sqrt_ratio<F: PrimeField>(num: &F, div: &F, z: &F) -> (Choice, F) {
270    // General implementation:
271    //
272    // a = num * inv0(div)
273    //   = {    0    if div is zero
274    //     { num/div otherwise
275    //
276    // b = z * a
277    //   = {      0      if div is zero
278    //     { z*num/div otherwise
279
280    // Since z is non-square, a and b are either both zero (and both square), or
281    // only one of them is square. We can therefore choose the square root to return
282    // based on whether a is square, but for the boolean output we need to handle the
283    // num != 0 && div == 0 case specifically.
284
285    let a = div.invert().unwrap_or(F::ZERO) * num;
286    let b = a * z;
287    let sqrt_a = a.sqrt();
288    let sqrt_b = b.sqrt();
289
290    let num_is_zero = num.is_zero();
291    let div_is_zero = div.is_zero();
292    let is_square = sqrt_a.is_some();
293    let is_nonsquare = sqrt_b.is_some();
294    assert!(bool::from(
295        num_is_zero | div_is_zero | (is_square ^ is_nonsquare)
296    ));
297
298    (
299        is_square & (num_is_zero | !div_is_zero),
300        CtOption::conditional_select(&sqrt_b, &sqrt_a, is_square).unwrap(),
301    )
302}
303
304/// Implementation of https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.html#section-6.6.1
305#[allow(clippy::type_complexity)]
306pub(crate) fn svdw_hash_to_curve<'a, C>(
307    curve_id: &'static str,
308    domain_prefix: &'a str,
309    z: C::Base,
310) -> Box<dyn Fn(&[u8]) -> C + 'a>
311where
312    C: CurveExt,
313    C::Base: FromUniformBytes<64> + Legendre,
314{
315    let [c1, c2, c3, c4] = svdw_precomputed_constants::<C>(z);
316
317    Box::new(move |message| {
318        let mut us = [C::Base::ZERO; 2];
319        hash_to_field("SVDW", curve_id, domain_prefix, message, &mut us);
320
321        let [q0, q1]: [C; 2] = us.map(|u| svdw_map_to_curve(u, c1, c2, c3, c4, z));
322
323        let r = q0 + &q1;
324        debug_assert!(bool::from(r.is_on_curve()));
325        r
326    })
327}
328
329pub(crate) fn svdw_precomputed_constants<C: CurveExt>(z: C::Base) -> [C::Base; 4] {
330    let a = C::a();
331    let b = C::b();
332    let one = C::Base::ONE;
333    let three = one + one + one;
334    let four = three + one;
335    let tmp = three * z.square() + four * a;
336
337    // 1. c1 = g(Z)
338    let c1 = (z.square() + a) * z + b;
339    // 2. c2 = -Z / 2
340    let c2 = -z * C::Base::TWO_INV;
341    // 3. c3 = sqrt(-g(Z) * (3 * Z^2 + 4 * A))    # sgn0(c3) MUST equal 0
342    let c3 = {
343        let c3 = (-c1 * tmp).sqrt().unwrap();
344        C::Base::conditional_select(&c3, &-c3, c3.is_odd())
345    };
346    // 4. c4 = -4 * g(Z) / (3 * Z^2 + 4 * A)
347    let c4 = -four * c1 * tmp.invert().unwrap();
348
349    [c1, c2, c3, c4]
350}