#![allow(clippy::op_ref)]
use ff::{Field, FromUniformBytes, PrimeField};
use pasta_curves::arithmetic::CurveExt;
use static_assertions::const_assert;
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
use crate::ff_ext::Legendre;
fn hash_to_field<F: FromUniformBytes<64>>(
method: &str,
curve_id: &str,
domain_prefix: &str,
message: &[u8],
buf: &mut [F; 2],
) {
assert!(domain_prefix.len() < 256);
assert!((18 + method.len() + curve_id.len() + domain_prefix.len()) < 256);
const CHUNKLEN: usize = 64;
const_assert!(CHUNKLEN * 2 < 256);
const R_IN_BYTES: usize = 128;
let personal = [0u8; 16];
let empty_hasher = blake2b_simd::Params::new()
.hash_length(CHUNKLEN)
.personal(&personal)
.to_state();
let b_0 = empty_hasher
.clone()
.update(&[0; R_IN_BYTES])
.update(message)
.update(&[0, (CHUNKLEN * 2) as u8, 0])
.update(domain_prefix.as_bytes())
.update(b"-")
.update(curve_id.as_bytes())
.update(b"_XMD:BLAKE2b_")
.update(method.as_bytes())
.update(b"_RO_")
.update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
.finalize();
let b_1 = empty_hasher
.clone()
.update(b_0.as_array())
.update(&[1])
.update(domain_prefix.as_bytes())
.update(b"-")
.update(curve_id.as_bytes())
.update(b"_XMD:BLAKE2b_")
.update(method.as_bytes())
.update(b"_RO_")
.update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
.finalize();
let b_2 = {
let mut empty_hasher = empty_hasher;
for (l, r) in b_0.as_array().iter().zip(b_1.as_array().iter()) {
empty_hasher.update(&[*l ^ *r]);
}
empty_hasher
.update(&[2])
.update(domain_prefix.as_bytes())
.update(b"-")
.update(curve_id.as_bytes())
.update(b"_XMD:BLAKE2b_")
.update(method.as_bytes())
.update(b"_RO_")
.update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
.finalize()
};
for (big, buf) in [b_1, b_2].iter().zip(buf.iter_mut()) {
let mut little = [0u8; CHUNKLEN];
little.copy_from_slice(big.as_array());
little.reverse();
*buf = F::from_uniform_bytes(&little);
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn simple_svdw_map_to_curve<C>(u: C::Base, z: C::Base) -> C
where
C: CurveExt,
{
let zero = C::Base::ZERO;
let one = C::Base::ONE;
let a = C::a();
let b = C::b();
let tv1 = u.square();
let tv1 = z * tv1;
let tv2 = tv1.square();
let tv2 = tv2 + tv1;
let tv3 = tv2 + one;
let tv3 = b * tv3;
let tv2_is_not_zero = !tv2.ct_eq(&zero);
let tv4 = C::Base::conditional_select(&z, &-tv2, tv2_is_not_zero);
let tv4 = a * tv4;
let tv2 = tv3.square();
let tv6 = tv4.square();
let tv5 = a * tv6;
let tv2 = tv2 + tv5;
let tv2 = tv2 * tv3;
let tv6 = tv6 * tv4;
let tv5 = b * tv6;
let tv2 = tv2 + tv5;
let x = tv1 * tv3;
let (is_gx1_square, y1) = sqrt_ratio(&tv2, &tv6, &z);
let y = tv1 * u;
let y = y * y1;
let x = C::Base::conditional_select(&x, &tv3, is_gx1_square);
let y = C::Base::conditional_select(&y, &y1, is_gx1_square);
let e1 = u.is_odd().ct_eq(&y.is_odd());
let y = C::Base::conditional_select(&-y, &y, e1);
let x = x * tv4.invert().unwrap();
C::new_jacobian(x, y, one).unwrap()
}
#[allow(clippy::type_complexity)]
pub(crate) fn simple_svdw_hash_to_curve<'a, C>(
curve_id: &'static str,
domain_prefix: &'a str,
z: C::Base,
) -> Box<dyn Fn(&[u8]) -> C + 'a>
where
C: CurveExt,
C::Base: FromUniformBytes<64>,
{
Box::new(move |message| {
let mut us = [C::Base::ZERO; 2];
hash_to_field("SSWU", curve_id, domain_prefix, message, &mut us);
let [q0, q1]: [C; 2] = us.map(|u| simple_svdw_map_to_curve(u, z));
let r = q0 + &q1;
debug_assert!(bool::from(r.is_on_curve()));
r
})
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn svdw_map_to_curve<C>(
u: C::Base,
c1: C::Base,
c2: C::Base,
c3: C::Base,
c4: C::Base,
z: C::Base,
) -> C
where
C: CurveExt,
C::Base: Legendre,
{
let one = C::Base::ONE;
let a = C::a();
let b = C::b();
let tv1 = u.square();
let tv1 = tv1 * c1;
let tv2 = one + tv1;
let tv1 = one - tv1;
let tv3 = tv1 * tv2;
let tv3 = tv3.invert().unwrap_or(C::Base::ZERO);
let tv4 = u * tv1;
let tv4 = tv4 * tv3;
let tv4 = tv4 * c3;
let x1 = c2 - tv4;
let gx1 = x1.square();
let gx1 = gx1 + a;
let gx1 = gx1 * x1;
let gx1 = gx1 + b;
let e1 = !gx1.ct_quadratic_non_residue();
let x2 = c2 + tv4;
let gx2 = x2.square();
let gx2 = gx2 + a;
let gx2 = gx2 * x2;
let gx2 = gx2 + b;
let e2 = !gx2.ct_quadratic_non_residue() & (!e1);
let x3 = tv2.square();
let x3 = x3 * tv3;
let x3 = x3.square();
let x3 = x3 * c4;
let x3 = x3 + z;
let x = C::Base::conditional_select(&x3, &x1, e1);
let x = C::Base::conditional_select(&x, &x2, e2);
let gx = x.square();
let gx = gx + a;
let gx = gx * x;
let gx = gx + b;
let y = gx.sqrt().unwrap();
let e3 = u.is_odd().ct_eq(&y.is_odd());
let y = C::Base::conditional_select(&-y, &y, e3);
C::new_jacobian(x, y, one).unwrap()
}
fn sqrt_ratio<F: PrimeField>(num: &F, div: &F, z: &F) -> (Choice, F) {
let a = div.invert().unwrap_or(F::ZERO) * num;
let b = a * z;
let sqrt_a = a.sqrt();
let sqrt_b = b.sqrt();
let num_is_zero = num.is_zero();
let div_is_zero = div.is_zero();
let is_square = sqrt_a.is_some();
let is_nonsquare = sqrt_b.is_some();
assert!(bool::from(
num_is_zero | div_is_zero | (is_square ^ is_nonsquare)
));
(
is_square & (num_is_zero | !div_is_zero),
CtOption::conditional_select(&sqrt_b, &sqrt_a, is_square).unwrap(),
)
}
#[allow(clippy::type_complexity)]
pub(crate) fn svdw_hash_to_curve<'a, C>(
curve_id: &'static str,
domain_prefix: &'a str,
z: C::Base,
) -> Box<dyn Fn(&[u8]) -> C + 'a>
where
C: CurveExt,
C::Base: FromUniformBytes<64> + Legendre,
{
let [c1, c2, c3, c4] = svdw_precomputed_constants::<C>(z);
Box::new(move |message| {
let mut us = [C::Base::ZERO; 2];
hash_to_field("SVDW", curve_id, domain_prefix, message, &mut us);
let [q0, q1]: [C; 2] = us.map(|u| svdw_map_to_curve(u, c1, c2, c3, c4, z));
let r = q0 + &q1;
debug_assert!(bool::from(r.is_on_curve()));
r
})
}
pub(crate) fn svdw_precomputed_constants<C: CurveExt>(z: C::Base) -> [C::Base; 4] {
let a = C::a();
let b = C::b();
let one = C::Base::ONE;
let three = one + one + one;
let four = three + one;
let tmp = three * z.square() + four * a;
let c1 = (z.square() + a) * z + b;
let c2 = -z * C::Base::TWO_INV;
let c3 = {
let c3 = (-c1 * tmp).sqrt().unwrap();
C::Base::conditional_select(&c3, &-c3, c3.is_odd())
};
let c4 = -four * c1 * tmp.invert().unwrap();
[c1, c2, c3, c4]
}