openvm_ecc_guest/
weierstrass.rs

1use alloc::vec::Vec;
2use core::ops::{AddAssign, Mul};
3
4use openvm_algebra_guest::{Field, IntMod};
5
6use super::group::Group;
7
8/// Short Weierstrass curve affine point.
9pub trait WeierstrassPoint: Sized {
10    /// The `a` coefficient in the Weierstrass curve equation `y^2 = x^3 + a x + b`.
11    const CURVE_A: Self::Coordinate;
12    /// The `b` coefficient in the Weierstrass curve equation `y^2 = x^3 + a x + b`.
13    const CURVE_B: Self::Coordinate;
14    const IDENTITY: Self;
15
16    type Coordinate: Field;
17
18    /// The concatenated `x, y` coordinates of the affine point, where
19    /// coordinates are in little endian.
20    ///
21    /// **Warning**: The memory layout of `Self` is expected to pack
22    /// `x` and `y` contigously with no unallocated space in between.
23    fn as_le_bytes(&self) -> &[u8];
24
25    /// Raw constructor without asserting point is on the curve.
26    fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self;
27    fn into_coords(self) -> (Self::Coordinate, Self::Coordinate);
28    fn x(&self) -> &Self::Coordinate;
29    fn y(&self) -> &Self::Coordinate;
30    fn x_mut(&mut self) -> &mut Self::Coordinate;
31    fn y_mut(&mut self) -> &mut Self::Coordinate;
32
33    /// Hazmat: Assumes self != +- p2 and self != identity and p2 != identity.
34    fn add_ne_nonidentity(&self, p2: &Self) -> Self;
35    /// Hazmat: Assumes self != +- p2 and self != identity and p2 != identity.
36    fn add_ne_assign_nonidentity(&mut self, p2: &Self);
37    /// Hazmat: Assumes self != +- p2 and self != identity and p2 != identity.
38    fn sub_ne_nonidentity(&self, p2: &Self) -> Self;
39    /// Hazmat: Assumes self != +- p2 and self != identity and p2 != identity.
40    fn sub_ne_assign_nonidentity(&mut self, p2: &Self);
41    /// Hazmat: Assumes self != identity and 2 * self != identity.
42    fn double_nonidentity(&self) -> Self;
43    /// Hazmat: Assumes self != identity and 2 * self != identity.
44    fn double_assign_nonidentity(&mut self);
45
46    fn from_xy(x: Self::Coordinate, y: Self::Coordinate) -> Option<Self>
47    where
48        for<'a> &'a Self::Coordinate: Mul<&'a Self::Coordinate, Output = Self::Coordinate>,
49    {
50        if x == Self::Coordinate::ZERO && y == Self::Coordinate::ZERO {
51            Some(Self::IDENTITY)
52        } else {
53            Self::from_xy_nonidentity(x, y)
54        }
55    }
56
57    fn from_xy_nonidentity(x: Self::Coordinate, y: Self::Coordinate) -> Option<Self>
58    where
59        for<'a> &'a Self::Coordinate: Mul<&'a Self::Coordinate, Output = Self::Coordinate>,
60    {
61        let lhs = &y * &y;
62        let rhs = &x * &x * &x + &Self::CURVE_A * &x + &Self::CURVE_B;
63        if lhs != rhs {
64            return None;
65        }
66        Some(Self::from_xy_unchecked(x, y))
67    }
68}
69
70// Hint for a decompression
71// if possible is true, then `sqrt` is the decompressed y-coordinate
72// if possible is false, then `sqrt` is such that
73// `sqrt^2 = rhs * non_qr` where `rhs` is the rhs of the curve equation
74pub struct DecompressionHint<T> {
75    pub possible: bool,
76    pub sqrt: T,
77}
78
79pub trait FromCompressed<Coordinate> {
80    /// Given `x`-coordinate,
81    ///
82    /// Decompresses a point from its x-coordinate and a recovery identifier which indicates
83    /// the parity of the y-coordinate. Given the x-coordinate, this function attempts to find the
84    /// corresponding y-coordinate that satisfies the elliptic curve equation. If successful, it
85    /// returns the point as an instance of Self. If the point cannot be decompressed, it returns
86    /// None.
87    fn decompress(x: Coordinate, rec_id: &u8) -> Option<Self>
88    where
89        Self: core::marker::Sized;
90
91    /// If it exists, hints the unique `y` coordinate that is less than `Coordinate::MODULUS`
92    /// such that `(x, y)` is a point on the curve and `y` has parity equal to `rec_id`.
93    /// If such `y` does not exist, hints a coordinate `sqrt` such that `sqrt^2 = rhs * non_qr`
94    /// where `rhs` is the rhs of the curve equation and `non_qr` is the non-quadratic residue
95    /// for this curve that was initialized in the setup function.
96    ///
97    /// This is only a hint, and the returned value does not guarantee any of the above properties.
98    /// They must be checked separately. Normal users should use `decompress` directly.
99    ///
100    /// Returns None if the `DecompressionHint::possible` flag in the hint stream is not a boolean.
101    fn hint_decompress(x: &Coordinate, rec_id: &u8) -> Option<DecompressionHint<Coordinate>>;
102}
103
104/// A trait for elliptic curves that bridges the openvm types and external types with
105/// CurveArithmetic etc. Implement this for external curves with corresponding openvm point and
106/// scalar types.
107pub trait IntrinsicCurve {
108    type Scalar: Clone;
109    type Point: Clone;
110
111    /// Multi-scalar multiplication.
112    /// The implementation may be specialized to use properties of the curve
113    /// (e.g., if the curve order is prime).
114    fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point;
115}
116
117// MSM using preprocessed table (windowed method)
118// Reference: modified from https://github.com/arkworks-rs/algebra/blob/master/ec/src/scalar_mul/mod.rs
119//
120// We specialize to Weierstrass curves and further make optimizations for when the curve order is
121// prime.
122
123/// Cached precomputations of scalar multiples of several base points.
124/// - `window_bits` is the window size used for the precomputation
125/// - `max_scalar_bits` is the maximum size of the scalars that will be multiplied
126/// - `table` is the precomputed table
127pub struct CachedMulTable<'a, C: IntrinsicCurve> {
128    /// Window bits. Must be > 0.
129    /// For alignment, we currently require this to divide 8 (bits in a byte).
130    pub window_bits: usize,
131    pub bases: &'a [C::Point],
132    /// `table[i][j] = (j + 2) * bases[i]` for `j + 2 < 2 ** window_bits`
133    table: Vec<Vec<C::Point>>,
134    /// Needed to return reference to the identity point.
135    identity: C::Point,
136}
137
138impl<'a, C: IntrinsicCurve> CachedMulTable<'a, C>
139where
140    C::Point: WeierstrassPoint + Group,
141    C::Scalar: IntMod,
142{
143    /// Constructor when each element of `bases` has prime torsion or is identity.
144    ///
145    /// Assumes that `window_bits` is less than (number of bits - 1) of the order of
146    /// subgroup generated by each non-identity `base`.
147    pub fn new_with_prime_order(bases: &'a [C::Point], window_bits: usize) -> Self {
148        assert!(window_bits > 0);
149        let window_size = 1 << window_bits;
150        let table = bases
151            .iter()
152            .map(|base| {
153                if base.is_identity() {
154                    vec![<C::Point as Group>::IDENTITY; window_size - 2]
155                } else {
156                    let mut multiples = Vec::with_capacity(window_size - 2);
157                    for _ in 0..window_size - 2 {
158                        // Because the order of `base` is prime, we are guaranteed that
159                        // j * base != identity,
160                        // j * base != +- base for j > 1,
161                        // j * base + base != identity
162                        let multiple = multiples
163                            .last()
164                            .map(|last| WeierstrassPoint::add_ne_nonidentity(last, base))
165                            .unwrap_or_else(|| base.double_nonidentity());
166                        multiples.push(multiple);
167                    }
168                    multiples
169                }
170            })
171            .collect();
172
173        Self {
174            window_bits,
175            bases,
176            table,
177            identity: <C::Point as Group>::IDENTITY,
178        }
179    }
180
181    fn get_multiple(&self, base_idx: usize, scalar: usize) -> &C::Point {
182        if scalar == 0 {
183            &self.identity
184        } else if scalar == 1 {
185            unsafe { self.bases.get_unchecked(base_idx) }
186        } else {
187            unsafe { self.table.get_unchecked(base_idx).get_unchecked(scalar - 2) }
188        }
189    }
190
191    /// Computes `sum scalars[i] * bases[i]`.
192    ///
193    /// For implementation simplicity, currently only implemented when
194    /// `window_bits` divides 8 (number of bits in a byte).
195    pub fn windowed_mul(&self, scalars: &[C::Scalar]) -> C::Point {
196        assert_eq!(8 % self.window_bits, 0);
197        assert_eq!(scalars.len(), self.bases.len());
198        let windows_per_byte = 8 / self.window_bits;
199
200        let num_windows = C::Scalar::NUM_LIMBS * windows_per_byte;
201        let mask = (1u8 << self.window_bits) - 1;
202
203        // The current byte index (little endian) at the current step of the
204        // windowed method, across all scalars.
205        let mut limb_idx = C::Scalar::NUM_LIMBS;
206        // The current bit (little endian) within the current byte of the windowed
207        // method. The window will look at bits `bit_idx..bit_idx + window_bits`.
208        // bit_idx will always be in range [0, 8)
209        let mut bit_idx = 0;
210
211        let mut res = <C::Point as Group>::IDENTITY;
212        for outer in 0..num_windows {
213            if bit_idx == 0 {
214                limb_idx -= 1;
215                bit_idx = 8 - self.window_bits;
216            } else {
217                bit_idx -= self.window_bits;
218            }
219
220            if outer != 0 {
221                for _ in 0..self.window_bits {
222                    // Note: this handles identity
223                    res.double_assign();
224                }
225            }
226            for (base_idx, scalar) in scalars.iter().enumerate() {
227                let scalar = (scalar.as_le_bytes()[limb_idx] >> bit_idx) & mask;
228                let summand = self.get_multiple(base_idx, scalar as usize);
229                // handles identity
230                res.add_assign(summand);
231            }
232        }
233        res
234    }
235}
236
237/// Macro to generate a newtype wrapper for [AffinePoint](crate::AffinePoint)
238/// that implements elliptic curve operations by using the underlying field operations according to
239/// the [formulas](https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html) for short Weierstrass curves.
240///
241/// The following imports are required:
242/// ```rust
243/// use core::ops::AddAssign;
244///
245/// use openvm_algebra_guest::{DivUnsafe, Field};
246/// use openvm_ecc_guest::{weierstrass::WeierstrassPoint, AffinePoint, Group};
247/// ```
248#[macro_export]
249macro_rules! impl_sw_affine {
250    // Assumes `a = 0` in curve equation. `$three` should be a constant expression for `3` of type
251    // `$field`.
252    ($struct_name:ident, $field:ty, $three:expr, $b:expr) => {
253        /// A newtype wrapper for [AffinePoint] that implements elliptic curve operations
254        /// by using the underlying field operations according to the [formulas](https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html) for short Weierstrass curves.
255        #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
256        #[repr(transparent)]
257        pub struct $struct_name(AffinePoint<$field>);
258
259        impl $struct_name {
260            pub const fn new(x: $field, y: $field) -> Self {
261                Self(AffinePoint::new(x, y))
262            }
263        }
264
265        impl WeierstrassPoint for $struct_name {
266            const CURVE_A: $field = <$field>::ZERO;
267            const CURVE_B: $field = $b;
268            const IDENTITY: Self = Self(AffinePoint::new(<$field>::ZERO, <$field>::ZERO));
269
270            type Coordinate = $field;
271
272            /// SAFETY: assumes that [$field] has internal representation in little-endian.
273            fn as_le_bytes(&self) -> &[u8] {
274                unsafe {
275                    &*core::ptr::slice_from_raw_parts(
276                        self as *const Self as *const u8,
277                        core::mem::size_of::<Self>(),
278                    )
279                }
280            }
281            fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self {
282                Self(AffinePoint::new(x, y))
283            }
284            fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) {
285                (self.0.x, self.0.y)
286            }
287            fn x(&self) -> &Self::Coordinate {
288                &self.0.x
289            }
290            fn y(&self) -> &Self::Coordinate {
291                &self.0.y
292            }
293            fn x_mut(&mut self) -> &mut Self::Coordinate {
294                &mut self.0.x
295            }
296            fn y_mut(&mut self) -> &mut Self::Coordinate {
297                &mut self.0.y
298            }
299
300            fn double_nonidentity(&self) -> Self {
301                use ::openvm_algebra_guest::DivUnsafe;
302                // lambda = (3*x1^2+a)/(2*y1)
303                // assume a = 0
304                let lambda = (&THREE * self.x() * self.x()).div_unsafe(self.y() + self.y());
305                // x3 = lambda^2-x1-x1
306                let x3 = &lambda * &lambda - self.x() - self.x();
307                // y3 = lambda * (x1-x3) - y1
308                let y3 = lambda * (self.x() - &x3) - self.y();
309                Self(AffinePoint::new(x3, y3))
310            }
311
312            fn double_assign_nonidentity(&mut self) {
313                *self = self.double_nonidentity();
314            }
315
316            fn add_ne_nonidentity(&self, p2: &Self) -> Self {
317                use ::openvm_algebra_guest::DivUnsafe;
318                // lambda = (y2-y1)/(x2-x1)
319                // x3 = lambda^2-x1-x2
320                // y3 = lambda*(x1-x3)-y1
321                let lambda = (p2.y() - self.y()).div_unsafe(p2.x() - self.x());
322                let x3 = &lambda * &lambda - self.x() - p2.x();
323                let y3 = lambda * (self.x() - &x3) - self.y();
324                Self(AffinePoint::new(x3, y3))
325            }
326
327            fn add_ne_assign_nonidentity(&mut self, p2: &Self) {
328                *self = self.add_ne_nonidentity(p2);
329            }
330
331            fn sub_ne_nonidentity(&self, p2: &Self) -> Self {
332                use ::openvm_algebra_guest::DivUnsafe;
333                // lambda = (y2+y1)/(x1-x2)
334                // x3 = lambda^2-x1-x2
335                // y3 = lambda*(x1-x3)-y1
336                let lambda = (p2.y() + self.y()).div_unsafe(self.x() - p2.x());
337                let x3 = &lambda * &lambda - self.x() - p2.x();
338                let y3 = lambda * (self.x() - &x3) - self.y();
339                Self(AffinePoint::new(x3, y3))
340            }
341
342            fn sub_ne_assign_nonidentity(&mut self, p2: &Self) {
343                *self = self.sub_ne_nonidentity(p2);
344            }
345        }
346
347        impl core::ops::Neg for $struct_name {
348            type Output = Self;
349
350            fn neg(mut self) -> Self::Output {
351                self.0.y.neg_assign();
352                self
353            }
354        }
355
356        impl core::ops::Neg for &$struct_name {
357            type Output = $struct_name;
358
359            fn neg(self) -> Self::Output {
360                self.clone().neg()
361            }
362        }
363
364        impl From<$struct_name> for AffinePoint<$field> {
365            fn from(value: $struct_name) -> Self {
366                value.0
367            }
368        }
369
370        impl From<AffinePoint<$field>> for $struct_name {
371            fn from(value: AffinePoint<$field>) -> Self {
372                Self(value)
373            }
374        }
375    };
376}
377
378/// Implements `Group` on `$struct_name` assuming that `$struct_name` implements `WeierstrassPoint`.
379/// Assumes that `Neg` is implemented for `&$struct_name`.
380#[macro_export]
381macro_rules! impl_sw_group_ops {
382    ($struct_name:ident, $field:ty) => {
383        impl Group for $struct_name {
384            type SelfRef<'a> = &'a Self;
385
386            const IDENTITY: Self = <Self as WeierstrassPoint>::IDENTITY;
387
388            fn double(&self) -> Self {
389                if self.is_identity() {
390                    self.clone()
391                } else {
392                    self.double_nonidentity()
393                }
394            }
395
396            fn double_assign(&mut self) {
397                if !self.is_identity() {
398                    self.double_assign_nonidentity();
399                }
400            }
401        }
402
403        impl core::ops::Add<&$struct_name> for $struct_name {
404            type Output = Self;
405
406            fn add(mut self, p2: &$struct_name) -> Self::Output {
407                use core::ops::AddAssign;
408                self.add_assign(p2);
409                self
410            }
411        }
412
413        impl core::ops::Add for $struct_name {
414            type Output = Self;
415
416            fn add(self, rhs: Self) -> Self::Output {
417                self.add(&rhs)
418            }
419        }
420
421        impl core::ops::Add<&$struct_name> for &$struct_name {
422            type Output = $struct_name;
423
424            fn add(self, p2: &$struct_name) -> Self::Output {
425                if self.is_identity() {
426                    p2.clone()
427                } else if p2.is_identity() {
428                    self.clone()
429                } else if self.x() == p2.x() {
430                    if self.y() + p2.y() == <$field as openvm_algebra_guest::Field>::ZERO {
431                        <$struct_name as WeierstrassPoint>::IDENTITY
432                    } else {
433                        self.double_nonidentity()
434                    }
435                } else {
436                    self.add_ne_nonidentity(p2)
437                }
438            }
439        }
440
441        impl core::ops::AddAssign<&$struct_name> for $struct_name {
442            fn add_assign(&mut self, p2: &$struct_name) {
443                if self.is_identity() {
444                    *self = p2.clone();
445                } else if p2.is_identity() {
446                    // do nothing
447                } else if self.x() == p2.x() {
448                    if self.y() + p2.y() == <$field as openvm_algebra_guest::Field>::ZERO {
449                        *self = <$struct_name as WeierstrassPoint>::IDENTITY;
450                    } else {
451                        self.double_assign_nonidentity();
452                    }
453                } else {
454                    self.add_ne_assign_nonidentity(p2);
455                }
456            }
457        }
458
459        impl core::ops::AddAssign for $struct_name {
460            fn add_assign(&mut self, rhs: Self) {
461                self.add_assign(&rhs);
462            }
463        }
464
465        impl core::ops::Sub<&$struct_name> for $struct_name {
466            type Output = Self;
467
468            fn sub(self, rhs: &$struct_name) -> Self::Output {
469                core::ops::Sub::sub(&self, rhs)
470            }
471        }
472
473        impl core::ops::Sub for $struct_name {
474            type Output = $struct_name;
475
476            fn sub(self, rhs: Self) -> Self::Output {
477                self.sub(&rhs)
478            }
479        }
480
481        impl core::ops::Sub<&$struct_name> for &$struct_name {
482            type Output = $struct_name;
483
484            fn sub(self, p2: &$struct_name) -> Self::Output {
485                if p2.is_identity() {
486                    self.clone()
487                } else if self.is_identity() {
488                    core::ops::Neg::neg(p2)
489                } else if self.x() == p2.x() {
490                    if self.y() == p2.y() {
491                        <$struct_name as WeierstrassPoint>::IDENTITY
492                    } else {
493                        self.double_nonidentity()
494                    }
495                } else {
496                    self.sub_ne_nonidentity(p2)
497                }
498            }
499        }
500
501        impl core::ops::SubAssign<&$struct_name> for $struct_name {
502            fn sub_assign(&mut self, p2: &$struct_name) {
503                if p2.is_identity() {
504                    // do nothing
505                } else if self.is_identity() {
506                    *self = core::ops::Neg::neg(p2);
507                } else if self.x() == p2.x() {
508                    if self.y() == p2.y() {
509                        *self = <$struct_name as WeierstrassPoint>::IDENTITY
510                    } else {
511                        self.double_assign_nonidentity();
512                    }
513                } else {
514                    self.sub_ne_assign_nonidentity(p2);
515                }
516            }
517        }
518
519        impl core::ops::SubAssign for $struct_name {
520            fn sub_assign(&mut self, rhs: Self) {
521                self.sub_assign(&rhs);
522            }
523        }
524    };
525}