ecdsa/
recovery.rs

1//! Public key recovery support.
2
3use crate::{Error, Result};
4
5#[cfg(feature = "signing")]
6use {
7    crate::{hazmat::SignPrimitive, SigningKey},
8    elliptic_curve::subtle::CtOption,
9    signature::{hazmat::PrehashSigner, DigestSigner, Signer},
10};
11
12#[cfg(feature = "verifying")]
13use {
14    crate::{hazmat::VerifyPrimitive, VerifyingKey},
15    elliptic_curve::{
16        bigint::CheckedAdd,
17        ops::{LinearCombination, Reduce},
18        point::DecompressPoint,
19        sec1::{self, FromEncodedPoint, ToEncodedPoint},
20        AffinePoint, FieldBytesEncoding, FieldBytesSize, Group, PrimeField, ProjectivePoint,
21    },
22    signature::hazmat::PrehashVerifier,
23};
24
25#[cfg(any(feature = "signing", feature = "verifying"))]
26use {
27    crate::{
28        hazmat::{bits2field, DigestPrimitive},
29        Signature, SignatureSize,
30    },
31    elliptic_curve::{
32        generic_array::ArrayLength, ops::Invert, CurveArithmetic, PrimeCurve, Scalar,
33    },
34    signature::digest::Digest,
35};
36
37/// Recovery IDs, a.k.a. "recid".
38///
39/// This is an integer value `0`, `1`, `2`, or `3` included along with a
40/// signature which is used during the recovery process to select the correct
41/// public key from the signature.
42///
43/// It consists of two bits of information:
44///
45/// - low bit (0/1): was the y-coordinate of the affine point resulting from
46///   the fixed-base multiplication 𝑘×𝑮 odd? This part of the algorithm
47///   functions similar to point decompression.
48/// - hi bit (3/4): did the affine x-coordinate of 𝑘×𝑮 overflow the order of
49///   the scalar field, requiring a reduction when computing `r`?
50#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
51pub struct RecoveryId(u8);
52
53impl RecoveryId {
54    /// Maximum supported value for the recovery ID (inclusive).
55    pub const MAX: u8 = 3;
56
57    /// Create a new [`RecoveryId`] from the following 1-bit arguments:
58    ///
59    /// - `is_y_odd`: is the affine y-coordinate of 𝑘×𝑮 odd?
60    /// - `is_x_reduced`: did the affine x-coordinate of 𝑘×𝑮 overflow the curve order?
61    pub const fn new(is_y_odd: bool, is_x_reduced: bool) -> Self {
62        Self((is_x_reduced as u8) << 1 | (is_y_odd as u8))
63    }
64
65    /// Did the affine x-coordinate of 𝑘×𝑮 overflow the curve order?
66    pub const fn is_x_reduced(self) -> bool {
67        (self.0 & 0b10) != 0
68    }
69
70    /// Is the affine y-coordinate of 𝑘×𝑮 odd?
71    pub const fn is_y_odd(self) -> bool {
72        (self.0 & 1) != 0
73    }
74
75    /// Convert a `u8` into a [`RecoveryId`].
76    pub const fn from_byte(byte: u8) -> Option<Self> {
77        if byte <= Self::MAX {
78            Some(Self(byte))
79        } else {
80            None
81        }
82    }
83
84    /// Convert this [`RecoveryId`] into a `u8`.
85    pub const fn to_byte(self) -> u8 {
86        self.0
87    }
88}
89
90#[cfg(feature = "verifying")]
91impl RecoveryId {
92    /// Given a public key, message, and signature, use trial recovery
93    /// to determine if a suitable recovery ID exists, or return an error
94    /// otherwise.
95    pub fn trial_recovery_from_msg<C>(
96        verifying_key: &VerifyingKey<C>,
97        msg: &[u8],
98        signature: &Signature<C>,
99    ) -> Result<Self>
100    where
101        C: DigestPrimitive + PrimeCurve + CurveArithmetic,
102        AffinePoint<C>:
103            DecompressPoint<C> + FromEncodedPoint<C> + ToEncodedPoint<C> + VerifyPrimitive<C>,
104        FieldBytesSize<C>: sec1::ModulusSize,
105        SignatureSize<C>: ArrayLength<u8>,
106    {
107        Self::trial_recovery_from_digest(verifying_key, C::Digest::new_with_prefix(msg), signature)
108    }
109
110    /// Given a public key, message digest, and signature, use trial recovery
111    /// to determine if a suitable recovery ID exists, or return an error
112    /// otherwise.
113    pub fn trial_recovery_from_digest<C, D>(
114        verifying_key: &VerifyingKey<C>,
115        digest: D,
116        signature: &Signature<C>,
117    ) -> Result<Self>
118    where
119        C: PrimeCurve + CurveArithmetic,
120        D: Digest,
121        AffinePoint<C>:
122            DecompressPoint<C> + FromEncodedPoint<C> + ToEncodedPoint<C> + VerifyPrimitive<C>,
123        FieldBytesSize<C>: sec1::ModulusSize,
124        SignatureSize<C>: ArrayLength<u8>,
125    {
126        Self::trial_recovery_from_prehash(verifying_key, &digest.finalize(), signature)
127    }
128
129    /// Given a public key, message digest, and signature, use trial recovery
130    /// to determine if a suitable recovery ID exists, or return an error
131    /// otherwise.
132    pub fn trial_recovery_from_prehash<C>(
133        verifying_key: &VerifyingKey<C>,
134        prehash: &[u8],
135        signature: &Signature<C>,
136    ) -> Result<Self>
137    where
138        C: PrimeCurve + CurveArithmetic,
139        AffinePoint<C>:
140            DecompressPoint<C> + FromEncodedPoint<C> + ToEncodedPoint<C> + VerifyPrimitive<C>,
141        FieldBytesSize<C>: sec1::ModulusSize,
142        SignatureSize<C>: ArrayLength<u8>,
143    {
144        for id in 0..=Self::MAX {
145            let recovery_id = RecoveryId(id);
146
147            if let Ok(vk) = VerifyingKey::recover_from_prehash(prehash, signature, recovery_id) {
148                if verifying_key == &vk {
149                    return Ok(recovery_id);
150                }
151            }
152        }
153
154        Err(Error::new())
155    }
156}
157
158impl TryFrom<u8> for RecoveryId {
159    type Error = Error;
160
161    fn try_from(byte: u8) -> Result<Self> {
162        Self::from_byte(byte).ok_or_else(Error::new)
163    }
164}
165
166impl From<RecoveryId> for u8 {
167    fn from(id: RecoveryId) -> u8 {
168        id.0
169    }
170}
171
172#[cfg(feature = "signing")]
173impl<C> SigningKey<C>
174where
175    C: PrimeCurve + CurveArithmetic + DigestPrimitive,
176    Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
177    SignatureSize<C>: ArrayLength<u8>,
178{
179    /// Sign the given message prehash, returning a signature and recovery ID.
180    pub fn sign_prehash_recoverable(&self, prehash: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
181        let z = bits2field::<C>(prehash)?;
182        let (sig, recid) = self
183            .as_nonzero_scalar()
184            .try_sign_prehashed_rfc6979::<C::Digest>(&z, &[])?;
185
186        Ok((sig, recid.ok_or_else(Error::new)?))
187    }
188
189    /// Sign the given message digest, returning a signature and recovery ID.
190    pub fn sign_digest_recoverable<D>(&self, msg_digest: D) -> Result<(Signature<C>, RecoveryId)>
191    where
192        D: Digest,
193    {
194        self.sign_prehash_recoverable(&msg_digest.finalize())
195    }
196
197    /// Sign the given message, hashing it with the curve's default digest
198    /// function, and returning a signature and recovery ID.
199    pub fn sign_recoverable(&self, msg: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
200        self.sign_digest_recoverable(C::Digest::new_with_prefix(msg))
201    }
202}
203
204#[cfg(feature = "signing")]
205impl<C, D> DigestSigner<D, (Signature<C>, RecoveryId)> for SigningKey<C>
206where
207    C: PrimeCurve + CurveArithmetic + DigestPrimitive,
208    D: Digest,
209    Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
210    SignatureSize<C>: ArrayLength<u8>,
211{
212    fn try_sign_digest(&self, msg_digest: D) -> Result<(Signature<C>, RecoveryId)> {
213        self.sign_digest_recoverable(msg_digest)
214    }
215}
216
217#[cfg(feature = "signing")]
218impl<C> PrehashSigner<(Signature<C>, RecoveryId)> for SigningKey<C>
219where
220    C: PrimeCurve + CurveArithmetic + DigestPrimitive,
221    Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
222    SignatureSize<C>: ArrayLength<u8>,
223{
224    fn sign_prehash(&self, prehash: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
225        self.sign_prehash_recoverable(prehash)
226    }
227}
228
229#[cfg(feature = "signing")]
230impl<C> Signer<(Signature<C>, RecoveryId)> for SigningKey<C>
231where
232    C: PrimeCurve + CurveArithmetic + DigestPrimitive,
233    Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
234    SignatureSize<C>: ArrayLength<u8>,
235{
236    fn try_sign(&self, msg: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
237        self.sign_recoverable(msg)
238    }
239}
240
241#[cfg(feature = "verifying")]
242impl<C> VerifyingKey<C>
243where
244    C: PrimeCurve + CurveArithmetic,
245    AffinePoint<C>:
246        DecompressPoint<C> + FromEncodedPoint<C> + ToEncodedPoint<C> + VerifyPrimitive<C>,
247    FieldBytesSize<C>: sec1::ModulusSize,
248    SignatureSize<C>: ArrayLength<u8>,
249{
250    /// Recover a [`VerifyingKey`] from the given message, signature, and
251    /// [`RecoveryId`].
252    ///
253    /// The message is first hashed using this curve's [`DigestPrimitive`].
254    pub fn recover_from_msg(
255        msg: &[u8],
256        signature: &Signature<C>,
257        recovery_id: RecoveryId,
258    ) -> Result<Self>
259    where
260        C: DigestPrimitive,
261    {
262        Self::recover_from_digest(C::Digest::new_with_prefix(msg), signature, recovery_id)
263    }
264
265    /// Recover a [`VerifyingKey`] from the given message [`Digest`],
266    /// signature, and [`RecoveryId`].
267    pub fn recover_from_digest<D>(
268        msg_digest: D,
269        signature: &Signature<C>,
270        recovery_id: RecoveryId,
271    ) -> Result<Self>
272    where
273        D: Digest,
274    {
275        Self::recover_from_prehash(&msg_digest.finalize(), signature, recovery_id)
276    }
277
278    /// Recover a [`VerifyingKey`] from the given `prehash` of a message, the
279    /// signature over that prehashed message, and a [`RecoveryId`].
280    #[allow(non_snake_case)]
281    pub fn recover_from_prehash(
282        prehash: &[u8],
283        signature: &Signature<C>,
284        recovery_id: RecoveryId,
285    ) -> Result<Self> {
286        let (r, s) = signature.split_scalars();
287        let z = <Scalar<C> as Reduce<C::Uint>>::reduce_bytes(&bits2field::<C>(prehash)?);
288
289        let mut r_bytes = r.to_repr();
290        if recovery_id.is_x_reduced() {
291            match Option::<C::Uint>::from(
292                C::Uint::decode_field_bytes(&r_bytes).checked_add(&C::ORDER),
293            ) {
294                Some(restored) => r_bytes = restored.encode_field_bytes(),
295                // No reduction should happen here if r was reduced
296                None => return Err(Error::new()),
297            };
298        }
299        let R = AffinePoint::<C>::decompress(&r_bytes, u8::from(recovery_id.is_y_odd()).into());
300
301        if R.is_none().into() {
302            return Err(Error::new());
303        }
304
305        let R = ProjectivePoint::<C>::from(R.unwrap());
306        let r_inv = *r.invert();
307        let u1 = -(r_inv * z);
308        let u2 = r_inv * *s;
309        let pk = ProjectivePoint::<C>::lincomb(&ProjectivePoint::<C>::generator(), &u1, &R, &u2);
310        let vk = Self::from_affine(pk.into())?;
311
312        // Ensure signature verifies with the recovered key
313        vk.verify_prehash(prehash, signature)?;
314
315        Ok(vk)
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::RecoveryId;
322
323    #[test]
324    fn new() {
325        assert_eq!(RecoveryId::new(false, false).to_byte(), 0);
326        assert_eq!(RecoveryId::new(true, false).to_byte(), 1);
327        assert_eq!(RecoveryId::new(false, true).to_byte(), 2);
328        assert_eq!(RecoveryId::new(true, true).to_byte(), 3);
329    }
330
331    #[test]
332    fn try_from() {
333        for n in 0u8..=3 {
334            assert_eq!(RecoveryId::try_from(n).unwrap().to_byte(), n);
335        }
336
337        for n in 4u8..=255 {
338            assert!(RecoveryId::try_from(n).is_err());
339        }
340    }
341
342    #[test]
343    fn is_x_reduced() {
344        assert_eq!(RecoveryId::try_from(0).unwrap().is_x_reduced(), false);
345        assert_eq!(RecoveryId::try_from(1).unwrap().is_x_reduced(), false);
346        assert_eq!(RecoveryId::try_from(2).unwrap().is_x_reduced(), true);
347        assert_eq!(RecoveryId::try_from(3).unwrap().is_x_reduced(), true);
348    }
349
350    #[test]
351    fn is_y_odd() {
352        assert_eq!(RecoveryId::try_from(0).unwrap().is_y_odd(), false);
353        assert_eq!(RecoveryId::try_from(1).unwrap().is_y_odd(), true);
354        assert_eq!(RecoveryId::try_from(2).unwrap().is_y_odd(), false);
355        assert_eq!(RecoveryId::try_from(3).unwrap().is_y_odd(), true);
356    }
357}