openvm_ecc_guest/
weierstrass.rs

1use alloc::vec::Vec;
2use core::ops::{AddAssign, Mul};
3
4use openvm_algebra_guest::{Field, IntMod};
5
6use super::group::Group;
7
8/// Short Weierstrass curve affine point.
9pub trait WeierstrassPoint: Sized {
10    /// The `a` coefficient in the Weierstrass curve equation `y^2 = x^3 + a x + b`.
11    const CURVE_A: Self::Coordinate;
12    /// The `b` coefficient in the Weierstrass curve equation `y^2 = x^3 + a x + b`.
13    const CURVE_B: Self::Coordinate;
14    const IDENTITY: Self;
15
16    type Coordinate: Field;
17
18    /// The concatenated `x, y` coordinates of the affine point, where
19    /// coordinates are in little endian.
20    ///
21    /// **Warning**: The memory layout of `Self` is expected to pack
22    /// `x` and `y` contigously with no unallocated space in between.
23    fn as_le_bytes(&self) -> &[u8];
24
25    /// Raw constructor without asserting point is on the curve.
26    fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self;
27    fn into_coords(self) -> (Self::Coordinate, Self::Coordinate);
28    fn x(&self) -> &Self::Coordinate;
29    fn y(&self) -> &Self::Coordinate;
30    fn x_mut(&mut self) -> &mut Self::Coordinate;
31    fn y_mut(&mut self) -> &mut Self::Coordinate;
32
33    /// Hazmat: Assumes self != +- p2 and self != identity and p2 != identity.
34    fn add_ne_nonidentity(&self, p2: &Self) -> Self;
35    /// Hazmat: Assumes self != +- p2 and self != identity and p2 != identity.
36    fn add_ne_assign_nonidentity(&mut self, p2: &Self);
37    /// Hazmat: Assumes self != +- p2 and self != identity and p2 != identity.
38    fn sub_ne_nonidentity(&self, p2: &Self) -> Self;
39    /// Hazmat: Assumes self != +- p2 and self != identity and p2 != identity.
40    fn sub_ne_assign_nonidentity(&mut self, p2: &Self);
41    /// Hazmat: Assumes self != identity and 2 * self != identity.
42    fn double_nonidentity(&self) -> Self;
43    /// Hazmat: Assumes self != identity and 2 * self != identity.
44    fn double_assign_nonidentity(&mut self);
45
46    fn from_xy(x: Self::Coordinate, y: Self::Coordinate) -> Option<Self>
47    where
48        for<'a> &'a Self::Coordinate: Mul<&'a Self::Coordinate, Output = Self::Coordinate>,
49    {
50        if x == Self::Coordinate::ZERO && y == Self::Coordinate::ZERO {
51            Some(Self::IDENTITY)
52        } else {
53            Self::from_xy_nonidentity(x, y)
54        }
55    }
56
57    fn from_xy_nonidentity(x: Self::Coordinate, y: Self::Coordinate) -> Option<Self>
58    where
59        for<'a> &'a Self::Coordinate: Mul<&'a Self::Coordinate, Output = Self::Coordinate>,
60    {
61        let lhs = &y * &y;
62        let rhs = &x * &x * &x + &Self::CURVE_A * &x + &Self::CURVE_B;
63        if lhs != rhs {
64            return None;
65        }
66        Some(Self::from_xy_unchecked(x, y))
67    }
68}
69
70// Hint for a decompression
71// if possible is true, then `sqrt` is the decompressed y-coordinate
72// if possible is false, then `sqrt` is such that
73// `sqrt^2 = rhs * non_qr` where `rhs` is the rhs of the curve equation
74pub struct DecompressionHint<T> {
75    pub possible: bool,
76    pub sqrt: T,
77}
78
79pub trait FromCompressed<Coordinate> {
80    /// Given `x`-coordinate,
81    ///
82    /// Decompresses a point from its x-coordinate and a recovery identifier which indicates
83    /// the parity of the y-coordinate. Given the x-coordinate, this function attempts to find the
84    /// corresponding y-coordinate that satisfies the elliptic curve equation. If successful, it
85    /// returns the point as an instance of Self. If the point cannot be decompressed, it returns None.
86    fn decompress(x: Coordinate, rec_id: &u8) -> Option<Self>
87    where
88        Self: core::marker::Sized;
89
90    /// If it exists, hints the unique `y` coordinate that is less than `Coordinate::MODULUS`
91    /// such that `(x, y)` is a point on the curve and `y` has parity equal to `rec_id`.
92    /// If such `y` does not exist, hints a coordinate `sqrt` such that `sqrt^2 = rhs * non_qr`
93    /// where `rhs` is the rhs of the curve equation and `non_qr` is the non-quadratic residue
94    /// for this curve that was initialized in the setup function.
95    ///
96    /// This is only a hint, and the returned value does not guarantee any of the above properties.
97    /// They must be checked separately. Normal users should use `decompress` directly.
98    ///
99    /// Returns None if the `DecompressionHint::possible` flag in the hint stream is not a boolean.
100    fn hint_decompress(x: &Coordinate, rec_id: &u8) -> Option<DecompressionHint<Coordinate>>;
101}
102
103/// A trait for elliptic curves that bridges the openvm types and external types with CurveArithmetic etc.
104/// Implement this for external curves with corresponding openvm point and scalar types.
105pub trait IntrinsicCurve {
106    type Scalar: Clone;
107    type Point: Clone;
108
109    /// Multi-scalar multiplication.
110    /// The implementation may be specialized to use properties of the curve
111    /// (e.g., if the curve order is prime).
112    fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point;
113}
114
115// MSM using preprocessed table (windowed method)
116// Reference: modified from https://github.com/arkworks-rs/algebra/blob/master/ec/src/scalar_mul/mod.rs
117//
118// We specialize to Weierstrass curves and further make optimizations for when the curve order is prime.
119
120/// Cached precomputations of scalar multiples of several base points.
121/// - `window_bits` is the window size used for the precomputation
122/// - `max_scalar_bits` is the maximum size of the scalars that will be multiplied
123/// - `table` is the precomputed table
124pub struct CachedMulTable<'a, C: IntrinsicCurve> {
125    /// Window bits. Must be > 0.
126    /// For alignment, we currently require this to divide 8 (bits in a byte).
127    pub window_bits: usize,
128    pub bases: &'a [C::Point],
129    /// `table[i][j] = (j + 2) * bases[i]` for `j + 2 < 2 ** window_bits`
130    table: Vec<Vec<C::Point>>,
131    /// Needed to return reference to the identity point.
132    identity: C::Point,
133}
134
135impl<'a, C: IntrinsicCurve> CachedMulTable<'a, C>
136where
137    C::Point: WeierstrassPoint + Group,
138    C::Scalar: IntMod,
139{
140    /// Constructor when each element of `bases` has prime torsion or is identity.
141    ///
142    /// Assumes that `window_bits` is less than (number of bits - 1) of the order of
143    /// subgroup generated by each non-identity `base`.
144    pub fn new_with_prime_order(bases: &'a [C::Point], window_bits: usize) -> Self {
145        assert!(window_bits > 0);
146        let window_size = 1 << window_bits;
147        let table = bases
148            .iter()
149            .map(|base| {
150                if base.is_identity() {
151                    vec![<C::Point as Group>::IDENTITY; window_size - 2]
152                } else {
153                    let mut multiples = Vec::with_capacity(window_size - 2);
154                    for _ in 0..window_size - 2 {
155                        // Because the order of `base` is prime, we are guaranteed that
156                        // j * base != identity,
157                        // j * base != +- base for j > 1,
158                        // j * base + base != identity
159                        let multiple = multiples
160                            .last()
161                            .map(|last| WeierstrassPoint::add_ne_nonidentity(last, base))
162                            .unwrap_or_else(|| base.double_nonidentity());
163                        multiples.push(multiple);
164                    }
165                    multiples
166                }
167            })
168            .collect();
169
170        Self {
171            window_bits,
172            bases,
173            table,
174            identity: <C::Point as Group>::IDENTITY,
175        }
176    }
177
178    fn get_multiple(&self, base_idx: usize, scalar: usize) -> &C::Point {
179        if scalar == 0 {
180            &self.identity
181        } else if scalar == 1 {
182            unsafe { self.bases.get_unchecked(base_idx) }
183        } else {
184            unsafe { self.table.get_unchecked(base_idx).get_unchecked(scalar - 2) }
185        }
186    }
187
188    /// Computes `sum scalars[i] * bases[i]`.
189    ///
190    /// For implementation simplicity, currently only implemented when
191    /// `window_bits` divides 8 (number of bits in a byte).
192    pub fn windowed_mul(&self, scalars: &[C::Scalar]) -> C::Point {
193        assert_eq!(8 % self.window_bits, 0);
194        assert_eq!(scalars.len(), self.bases.len());
195        let windows_per_byte = 8 / self.window_bits;
196
197        let num_windows = C::Scalar::NUM_LIMBS * windows_per_byte;
198        let mask = (1u8 << self.window_bits) - 1;
199
200        // The current byte index (little endian) at the current step of the
201        // windowed method, across all scalars.
202        let mut limb_idx = C::Scalar::NUM_LIMBS;
203        // The current bit (little endian) within the current byte of the windowed
204        // method. The window will look at bits `bit_idx..bit_idx + window_bits`.
205        // bit_idx will always be in range [0, 8)
206        let mut bit_idx = 0;
207
208        let mut res = <C::Point as Group>::IDENTITY;
209        for outer in 0..num_windows {
210            if bit_idx == 0 {
211                limb_idx -= 1;
212                bit_idx = 8 - self.window_bits;
213            } else {
214                bit_idx -= self.window_bits;
215            }
216
217            if outer != 0 {
218                for _ in 0..self.window_bits {
219                    // Note: this handles identity
220                    res.double_assign();
221                }
222            }
223            for (base_idx, scalar) in scalars.iter().enumerate() {
224                let scalar = (scalar.as_le_bytes()[limb_idx] >> bit_idx) & mask;
225                let summand = self.get_multiple(base_idx, scalar as usize);
226                // handles identity
227                res.add_assign(summand);
228            }
229        }
230        res
231    }
232}
233
234/// Macro to generate a newtype wrapper for [AffinePoint](crate::AffinePoint)
235/// that implements elliptic curve operations by using the underlying field operations according to the
236/// [formulas](https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html) for short Weierstrass curves.
237///
238/// The following imports are required:
239/// ```rust
240/// use core::ops::AddAssign;
241///
242/// use openvm_algebra_guest::{DivUnsafe, Field};
243/// use openvm_ecc_guest::{AffinePoint, Group, weierstrass::WeierstrassPoint};
244/// ```
245#[macro_export]
246macro_rules! impl_sw_affine {
247    // Assumes `a = 0` in curve equation. `$three` should be a constant expression for `3` of type `$field`.
248    ($struct_name:ident, $field:ty, $three:expr, $b:expr) => {
249        /// A newtype wrapper for [AffinePoint] that implements elliptic curve operations
250        /// by using the underlying field operations according to the [formulas](https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html) for short Weierstrass curves.
251        #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
252        #[repr(transparent)]
253        pub struct $struct_name(AffinePoint<$field>);
254
255        impl $struct_name {
256            pub const fn new(x: $field, y: $field) -> Self {
257                Self(AffinePoint::new(x, y))
258            }
259        }
260
261        impl WeierstrassPoint for $struct_name {
262            const CURVE_A: $field = <$field>::ZERO;
263            const CURVE_B: $field = $b;
264            const IDENTITY: Self = Self(AffinePoint::new(<$field>::ZERO, <$field>::ZERO));
265
266            type Coordinate = $field;
267
268            /// SAFETY: assumes that [$field] has internal representation in little-endian.
269            fn as_le_bytes(&self) -> &[u8] {
270                unsafe {
271                    &*core::ptr::slice_from_raw_parts(
272                        self as *const Self as *const u8,
273                        core::mem::size_of::<Self>(),
274                    )
275                }
276            }
277            fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self {
278                Self(AffinePoint::new(x, y))
279            }
280            fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) {
281                (self.0.x, self.0.y)
282            }
283            fn x(&self) -> &Self::Coordinate {
284                &self.0.x
285            }
286            fn y(&self) -> &Self::Coordinate {
287                &self.0.y
288            }
289            fn x_mut(&mut self) -> &mut Self::Coordinate {
290                &mut self.0.x
291            }
292            fn y_mut(&mut self) -> &mut Self::Coordinate {
293                &mut self.0.y
294            }
295
296            fn double_nonidentity(&self) -> Self {
297                use ::openvm_algebra_guest::DivUnsafe;
298                // lambda = (3*x1^2+a)/(2*y1)
299                // assume a = 0
300                let lambda = (&THREE * self.x() * self.x()).div_unsafe(self.y() + self.y());
301                // x3 = lambda^2-x1-x1
302                let x3 = &lambda * &lambda - self.x() - self.x();
303                // y3 = lambda * (x1-x3) - y1
304                let y3 = lambda * (self.x() - &x3) - self.y();
305                Self(AffinePoint::new(x3, y3))
306            }
307
308            fn double_assign_nonidentity(&mut self) {
309                *self = self.double_nonidentity();
310            }
311
312            fn add_ne_nonidentity(&self, p2: &Self) -> Self {
313                use ::openvm_algebra_guest::DivUnsafe;
314                // lambda = (y2-y1)/(x2-x1)
315                // x3 = lambda^2-x1-x2
316                // y3 = lambda*(x1-x3)-y1
317                let lambda = (p2.y() - self.y()).div_unsafe(p2.x() - self.x());
318                let x3 = &lambda * &lambda - self.x() - p2.x();
319                let y3 = lambda * (self.x() - &x3) - self.y();
320                Self(AffinePoint::new(x3, y3))
321            }
322
323            fn add_ne_assign_nonidentity(&mut self, p2: &Self) {
324                *self = self.add_ne_nonidentity(p2);
325            }
326
327            fn sub_ne_nonidentity(&self, p2: &Self) -> Self {
328                use ::openvm_algebra_guest::DivUnsafe;
329                // lambda = (y2+y1)/(x1-x2)
330                // x3 = lambda^2-x1-x2
331                // y3 = lambda*(x1-x3)-y1
332                let lambda = (p2.y() + self.y()).div_unsafe(self.x() - p2.x());
333                let x3 = &lambda * &lambda - self.x() - p2.x();
334                let y3 = lambda * (self.x() - &x3) - self.y();
335                Self(AffinePoint::new(x3, y3))
336            }
337
338            fn sub_ne_assign_nonidentity(&mut self, p2: &Self) {
339                *self = self.sub_ne_nonidentity(p2);
340            }
341        }
342
343        impl core::ops::Neg for $struct_name {
344            type Output = Self;
345
346            fn neg(mut self) -> Self::Output {
347                self.0.y.neg_assign();
348                self
349            }
350        }
351
352        impl core::ops::Neg for &$struct_name {
353            type Output = $struct_name;
354
355            fn neg(self) -> Self::Output {
356                self.clone().neg()
357            }
358        }
359
360        impl From<$struct_name> for AffinePoint<$field> {
361            fn from(value: $struct_name) -> Self {
362                value.0
363            }
364        }
365
366        impl From<AffinePoint<$field>> for $struct_name {
367            fn from(value: AffinePoint<$field>) -> Self {
368                Self(value)
369            }
370        }
371    };
372}
373
374/// Implements `Group` on `$struct_name` assuming that `$struct_name` implements `WeierstrassPoint`.
375/// Assumes that `Neg` is implemented for `&$struct_name`.
376#[macro_export]
377macro_rules! impl_sw_group_ops {
378    ($struct_name:ident, $field:ty) => {
379        impl Group for $struct_name {
380            type SelfRef<'a> = &'a Self;
381
382            const IDENTITY: Self = <Self as WeierstrassPoint>::IDENTITY;
383
384            fn double(&self) -> Self {
385                if self.is_identity() {
386                    self.clone()
387                } else {
388                    self.double_nonidentity()
389                }
390            }
391
392            fn double_assign(&mut self) {
393                if !self.is_identity() {
394                    self.double_assign_nonidentity();
395                }
396            }
397        }
398
399        impl core::ops::Add<&$struct_name> for $struct_name {
400            type Output = Self;
401
402            fn add(mut self, p2: &$struct_name) -> Self::Output {
403                use core::ops::AddAssign;
404                self.add_assign(p2);
405                self
406            }
407        }
408
409        impl core::ops::Add for $struct_name {
410            type Output = Self;
411
412            fn add(self, rhs: Self) -> Self::Output {
413                self.add(&rhs)
414            }
415        }
416
417        impl core::ops::Add<&$struct_name> for &$struct_name {
418            type Output = $struct_name;
419
420            fn add(self, p2: &$struct_name) -> Self::Output {
421                if self.is_identity() {
422                    p2.clone()
423                } else if p2.is_identity() {
424                    self.clone()
425                } else if self.x() == p2.x() {
426                    if self.y() + p2.y() == <$field as openvm_algebra_guest::Field>::ZERO {
427                        <$struct_name as WeierstrassPoint>::IDENTITY
428                    } else {
429                        self.double_nonidentity()
430                    }
431                } else {
432                    self.add_ne_nonidentity(p2)
433                }
434            }
435        }
436
437        impl core::ops::AddAssign<&$struct_name> for $struct_name {
438            fn add_assign(&mut self, p2: &$struct_name) {
439                if self.is_identity() {
440                    *self = p2.clone();
441                } else if p2.is_identity() {
442                    // do nothing
443                } else if self.x() == p2.x() {
444                    if self.y() + p2.y() == <$field as openvm_algebra_guest::Field>::ZERO {
445                        *self = <$struct_name as WeierstrassPoint>::IDENTITY;
446                    } else {
447                        self.double_assign_nonidentity();
448                    }
449                } else {
450                    self.add_ne_assign_nonidentity(p2);
451                }
452            }
453        }
454
455        impl core::ops::AddAssign for $struct_name {
456            fn add_assign(&mut self, rhs: Self) {
457                self.add_assign(&rhs);
458            }
459        }
460
461        impl core::ops::Sub<&$struct_name> for $struct_name {
462            type Output = Self;
463
464            fn sub(self, rhs: &$struct_name) -> Self::Output {
465                core::ops::Sub::sub(&self, rhs)
466            }
467        }
468
469        impl core::ops::Sub for $struct_name {
470            type Output = $struct_name;
471
472            fn sub(self, rhs: Self) -> Self::Output {
473                self.sub(&rhs)
474            }
475        }
476
477        impl core::ops::Sub<&$struct_name> for &$struct_name {
478            type Output = $struct_name;
479
480            fn sub(self, p2: &$struct_name) -> Self::Output {
481                if p2.is_identity() {
482                    self.clone()
483                } else if self.is_identity() {
484                    core::ops::Neg::neg(p2)
485                } else if self.x() == p2.x() {
486                    if self.y() == p2.y() {
487                        <$struct_name as WeierstrassPoint>::IDENTITY
488                    } else {
489                        self.double_nonidentity()
490                    }
491                } else {
492                    self.sub_ne_nonidentity(p2)
493                }
494            }
495        }
496
497        impl core::ops::SubAssign<&$struct_name> for $struct_name {
498            fn sub_assign(&mut self, p2: &$struct_name) {
499                if p2.is_identity() {
500                    // do nothing
501                } else if self.is_identity() {
502                    *self = core::ops::Neg::neg(p2);
503                } else if self.x() == p2.x() {
504                    if self.y() == p2.y() {
505                        *self = <$struct_name as WeierstrassPoint>::IDENTITY
506                    } else {
507                        self.double_assign_nonidentity();
508                    }
509                } else {
510                    self.sub_ne_assign_nonidentity(p2);
511                }
512            }
513        }
514
515        impl core::ops::SubAssign for $struct_name {
516            fn sub_assign(&mut self, rhs: Self) {
517                self.sub_assign(&rhs);
518            }
519        }
520    };
521}