openvm_ecc_guest/
weierstrass.rs

1use alloc::vec::Vec;
2use core::ops::Mul;
3
4use openvm_algebra_guest::{Field, IntMod};
5
6use super::group::Group;
7
8/// Short Weierstrass curve affine point.
9pub trait WeierstrassPoint: Clone + Sized {
10    /// The `a` coefficient in the Weierstrass curve equation `y^2 = x^3 + a x + b`.
11    const CURVE_A: Self::Coordinate;
12    /// The `b` coefficient in the Weierstrass curve equation `y^2 = x^3 + a x + b`.
13    const CURVE_B: Self::Coordinate;
14    const IDENTITY: Self;
15
16    type Coordinate: Field;
17
18    /// The concatenated `x, y` coordinates of the affine point, where
19    /// coordinates are in little endian.
20    ///
21    /// **Warning**: The memory layout of `Self` is expected to pack
22    /// `x` and `y` contiguously with no unallocated space in between.
23    fn as_le_bytes(&self) -> &[u8];
24
25    /// Raw constructor without asserting point is on the curve.
26    fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self;
27    fn into_coords(self) -> (Self::Coordinate, Self::Coordinate);
28    fn x(&self) -> &Self::Coordinate;
29    fn y(&self) -> &Self::Coordinate;
30    fn x_mut(&mut self) -> &mut Self::Coordinate;
31    fn y_mut(&mut self) -> &mut Self::Coordinate;
32
33    /// Calls any setup required for this curve. The implementation should internally use `OnceBool`
34    /// to ensure that setup is only called once.
35    fn set_up_once();
36
37    /// Add implementation that handles identity and whether points are equal or not.
38    ///
39    /// # Safety
40    /// - If `CHECK_SETUP` is true, checks if setup has been called for this curve and if not, calls
41    ///   `Self::set_up_once()`. Only set `CHECK_SETUP` to `false` if you are sure that setup has
42    ///   been called already.
43    fn add_assign_impl<const CHECK_SETUP: bool>(&mut self, p2: &Self);
44
45    /// Double implementation that handles identity.
46    ///
47    /// # Safety
48    /// - If `CHECK_SETUP` is true, checks if setup has been called for this curve and if not, calls
49    ///   `Self::set_up_once()`. Only set `CHECK_SETUP` to `false` if you are sure that setup has
50    ///   been called already.
51    fn double_assign_impl<const CHECK_SETUP: bool>(&mut self);
52
53    /// # Safety
54    /// - Assumes self != +- p2 and self != identity and p2 != identity.
55    /// - If `CHECK_SETUP` is true, checks if setup has been called for this curve and if not, calls
56    ///   `Self::set_up_once()`. Only set `CHECK_SETUP` to `false` if you are sure that setup has
57    ///   been called already.
58    unsafe fn add_ne_nonidentity<const CHECK_SETUP: bool>(&self, p2: &Self) -> Self;
59    /// # Safety
60    /// - Assumes self != +- p2 and self != identity and p2 != identity.
61    /// - If `CHECK_SETUP` is true, checks if setup has been called for this curve and if not, calls
62    ///   `Self::set_up_once()`. Only set `CHECK_SETUP` to `false` if you are sure that setup has
63    ///   been called already.
64    unsafe fn add_ne_assign_nonidentity<const CHECK_SETUP: bool>(&mut self, p2: &Self);
65    /// # Safety
66    /// - Assumes self != +- p2 and self != identity and p2 != identity.
67    /// - If `CHECK_SETUP` is true, checks if setup has been called for this curve and if not, calls
68    ///   `Self::set_up_once()`. Only set `CHECK_SETUP` to `false` if you are sure that setup has
69    ///   been called already.
70    unsafe fn sub_ne_nonidentity<const CHECK_SETUP: bool>(&self, p2: &Self) -> Self;
71    /// # Safety
72    /// - Assumes self != +- p2 and self != identity and p2 != identity.
73    /// - If `CHECK_SETUP` is true, checks if setup has been called for this curve and if not, calls
74    ///   `Self::set_up_once()`. Only set `CHECK_SETUP` to `false` if you are sure that setup has
75    ///   been called already.
76    unsafe fn sub_ne_assign_nonidentity<const CHECK_SETUP: bool>(&mut self, p2: &Self);
77    /// # Safety
78    /// - Assumes self != identity and 2 * self != identity.
79    /// - If `CHECK_SETUP` is true, checks if setup has been called for this curve and if not, calls
80    ///   `Self::set_up_once()`. Only set `CHECK_SETUP` to `false` if you are sure that setup has
81    ///   been called already.
82    unsafe fn double_nonidentity<const CHECK_SETUP: bool>(&self) -> Self;
83    /// # Safety
84    /// - Assumes self != identity and 2 * self != identity.
85    /// - If `CHECK_SETUP` is true, checks if setup has been called for this curve and if not, calls
86    ///   `Self::set_up_once()`. Only set `CHECK_SETUP` to `false` if you are sure that setup has
87    ///   been called already.
88    unsafe fn double_assign_nonidentity<const CHECK_SETUP: bool>(&mut self);
89
90    #[inline(always)]
91    fn from_xy(x: Self::Coordinate, y: Self::Coordinate) -> Option<Self>
92    where
93        for<'a> &'a Self::Coordinate: Mul<&'a Self::Coordinate, Output = Self::Coordinate>,
94    {
95        if x == Self::Coordinate::ZERO && y == Self::Coordinate::ZERO {
96            Some(Self::IDENTITY)
97        } else {
98            Self::from_xy_nonidentity(x, y)
99        }
100    }
101
102    #[inline(always)]
103    fn from_xy_nonidentity(x: Self::Coordinate, y: Self::Coordinate) -> Option<Self>
104    where
105        for<'a> &'a Self::Coordinate: Mul<&'a Self::Coordinate, Output = Self::Coordinate>,
106    {
107        let lhs = &y * &y;
108        let rhs = &x * &x * &x + &Self::CURVE_A * &x + &Self::CURVE_B;
109        if lhs != rhs {
110            return None;
111        }
112        Some(Self::from_xy_unchecked(x, y))
113    }
114}
115
116pub trait FromCompressed<Coordinate> {
117    /// Given `x`-coordinate,
118    ///
119    /// Decompresses a point from its x-coordinate and a recovery identifier which indicates
120    /// the parity of the y-coordinate. Given the x-coordinate, this function attempts to find the
121    /// corresponding y-coordinate that satisfies the elliptic curve equation. If successful, it
122    /// returns the point as an instance of Self. If the point cannot be decompressed, it returns
123    /// None.
124    fn decompress(x: Coordinate, rec_id: &u8) -> Option<Self>
125    where
126        Self: core::marker::Sized;
127}
128
129/// A trait for elliptic curves that bridges the openvm types and external types with
130/// CurveArithmetic etc. Implement this for external curves with corresponding openvm point and
131/// scalar types.
132pub trait IntrinsicCurve {
133    type Scalar: Clone;
134    type Point: Clone;
135
136    /// Multi-scalar multiplication.
137    /// The implementation may be specialized to use properties of the curve
138    /// (e.g., if the curve order is prime).
139    fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point;
140}
141
142// MSM using preprocessed table (windowed method)
143// Reference: modified from https://github.com/arkworks-rs/algebra/blob/master/ec/src/scalar_mul/mod.rs
144//
145// We specialize to Weierstrass curves and further make optimizations for when the curve order is
146// prime.
147
148/// Cached precomputations of scalar multiples of several base points.
149/// - `window_bits` is the window size used for the precomputation
150/// - `max_scalar_bits` is the maximum size of the scalars that will be multiplied
151/// - `table` is the precomputed table
152pub struct CachedMulTable<'a, C: IntrinsicCurve> {
153    /// Window bits. Must be > 0.
154    /// For alignment, we currently require this to divide 8 (bits in a byte).
155    pub window_bits: usize,
156    pub bases: &'a [C::Point],
157    /// `table[i][j] = (j + 2) * bases[i]` for `j + 2 < 2 ** window_bits`
158    table: Vec<Vec<C::Point>>,
159    /// Needed to return reference to the identity point.
160    identity: C::Point,
161}
162
163impl<'a, C: IntrinsicCurve> CachedMulTable<'a, C>
164where
165    C::Point: WeierstrassPoint + Group,
166    C::Scalar: IntMod,
167{
168    /// Constructor when each element of `bases` has prime torsion or is identity.
169    ///
170    /// Assumes that `window_bits` is less than (number of bits - 1) of the order of
171    /// subgroup generated by each non-identity `base`.
172    #[inline]
173    pub fn new_with_prime_order(bases: &'a [C::Point], window_bits: usize) -> Self {
174        C::Point::set_up_once();
175        assert!(window_bits > 0);
176        let window_size = 1 << window_bits;
177        let table = bases
178            .iter()
179            .map(|base| {
180                if base.is_identity() {
181                    vec![<C::Point as Group>::IDENTITY; window_size - 2]
182                } else {
183                    let mut multiples = Vec::with_capacity(window_size - 2);
184                    for _ in 0..window_size - 2 {
185                        // Because the order of `base` is prime, we are guaranteed that
186                        // j * base != identity,
187                        // j * base != +- base for j > 1,
188                        // j * base + base != identity
189                        let multiple = multiples
190                            .last()
191                            .map(|last| unsafe {
192                                WeierstrassPoint::add_ne_nonidentity::<false>(last, base)
193                            })
194                            .unwrap_or_else(|| unsafe { base.double_nonidentity::<false>() });
195                        multiples.push(multiple);
196                    }
197                    multiples
198                }
199            })
200            .collect();
201
202        Self {
203            window_bits,
204            bases,
205            table,
206            identity: <C::Point as Group>::IDENTITY,
207        }
208    }
209
210    #[inline(always)]
211    fn get_multiple(&self, base_idx: usize, scalar: usize) -> &C::Point {
212        if scalar == 0 {
213            &self.identity
214        } else if scalar == 1 {
215            unsafe { self.bases.get_unchecked(base_idx) }
216        } else {
217            unsafe { self.table.get_unchecked(base_idx).get_unchecked(scalar - 2) }
218        }
219    }
220
221    /// Computes `sum scalars[i] * bases[i]`.
222    ///
223    /// For implementation simplicity, currently only implemented when
224    /// `window_bits` divides 8 (number of bits in a byte).
225    #[inline]
226    pub fn windowed_mul(&self, scalars: &[C::Scalar]) -> C::Point {
227        C::Point::set_up_once();
228        assert_eq!(8 % self.window_bits, 0);
229        assert_eq!(scalars.len(), self.bases.len());
230        let windows_per_byte = 8 / self.window_bits;
231
232        let num_windows = C::Scalar::NUM_LIMBS * windows_per_byte;
233        let mask = (1u8 << self.window_bits) - 1;
234
235        // The current byte index (little endian) at the current step of the
236        // windowed method, across all scalars.
237        let mut limb_idx = C::Scalar::NUM_LIMBS;
238        // The current bit (little endian) within the current byte of the windowed
239        // method. The window will look at bits `bit_idx..bit_idx + window_bits`.
240        // bit_idx will always be in range [0, 8)
241        let mut bit_idx = 0;
242
243        let mut res = <C::Point as Group>::IDENTITY;
244        for outer in 0..num_windows {
245            if bit_idx == 0 {
246                limb_idx -= 1;
247                bit_idx = 8 - self.window_bits;
248            } else {
249                bit_idx -= self.window_bits;
250            }
251
252            if outer != 0 {
253                for _ in 0..self.window_bits {
254                    // Note: this handles identity
255                    // setup has been called above
256                    res.double_assign_impl::<false>();
257                }
258            }
259            for (base_idx, scalar) in scalars.iter().enumerate() {
260                let scalar = (scalar.as_le_bytes()[limb_idx] >> bit_idx) & mask;
261                let summand = self.get_multiple(base_idx, scalar as usize);
262                // handles identity
263                // setup has been called above
264                res.add_assign_impl::<false>(summand);
265            }
266        }
267        res
268    }
269}
270
271/// Macro to generate a newtype wrapper for [AffinePoint](crate::AffinePoint)
272/// that implements elliptic curve operations by using the underlying field operations according to
273/// the [formulas](https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html) for short Weierstrass curves.
274///
275/// The following imports are required:
276/// ```rust
277/// use core::ops::AddAssign;
278///
279/// use openvm_algebra_guest::{DivUnsafe, Field};
280/// use openvm_ecc_guest::{weierstrass::WeierstrassPoint, AffinePoint, Group};
281/// ```
282#[macro_export]
283macro_rules! impl_sw_affine {
284    // Assumes `a = 0` in curve equation. `$three` should be a constant expression for `3` of type
285    // `$field`.
286    ($struct_name:ident, $field:ty, $three:expr, $b:expr) => {
287        /// A newtype wrapper for [AffinePoint] that implements elliptic curve operations
288        /// by using the underlying field operations according to the [formulas](https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html) for short Weierstrass curves.
289        #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
290        #[repr(transparent)]
291        pub struct $struct_name(AffinePoint<$field>);
292
293        impl $struct_name {
294            pub const fn new(x: $field, y: $field) -> Self {
295                Self(AffinePoint::new(x, y))
296            }
297        }
298
299        impl WeierstrassPoint for $struct_name {
300            const CURVE_A: $field = <$field>::ZERO;
301            const CURVE_B: $field = $b;
302            const IDENTITY: Self = Self(AffinePoint::new(<$field>::ZERO, <$field>::ZERO));
303
304            type Coordinate = $field;
305
306            /// SAFETY: assumes that [$field] has internal representation in little-endian.
307            fn as_le_bytes(&self) -> &[u8] {
308                unsafe {
309                    &*core::ptr::slice_from_raw_parts(
310                        self as *const Self as *const u8,
311                        core::mem::size_of::<Self>(),
312                    )
313                }
314            }
315            fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self {
316                Self(AffinePoint::new(x, y))
317            }
318            fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) {
319                (self.0.x, self.0.y)
320            }
321            fn x(&self) -> &Self::Coordinate {
322                &self.0.x
323            }
324            fn y(&self) -> &Self::Coordinate {
325                &self.0.y
326            }
327            fn x_mut(&mut self) -> &mut Self::Coordinate {
328                &mut self.0.x
329            }
330            fn y_mut(&mut self) -> &mut Self::Coordinate {
331                &mut self.0.y
332            }
333
334            fn set_up_once() {
335                // There are no special opcodes for curve operations in this case, so no additional
336                // setup is required.
337                //
338                // Since the `Self::Coordinate` is abstract, any set up required by the field is not
339                // handled here.
340            }
341
342            fn add_assign_impl<const CHECK_SETUP: bool>(&mut self, p2: &Self) {
343                if self == &<Self as WeierstrassPoint>::IDENTITY {
344                    *self = p2.clone();
345                } else if p2 == &<Self as WeierstrassPoint>::IDENTITY {
346                    // do nothing
347                } else if self.x() == p2.x() {
348                    if self.y() + p2.y() == <Self::Coordinate as openvm_algebra_guest::Field>::ZERO
349                    {
350                        *self = <Self as WeierstrassPoint>::IDENTITY;
351                    } else {
352                        unsafe {
353                            self.double_assign_nonidentity::<CHECK_SETUP>();
354                        }
355                    }
356                } else {
357                    unsafe {
358                        self.add_ne_assign_nonidentity::<CHECK_SETUP>(p2);
359                    }
360                }
361            }
362
363            #[inline(always)]
364            fn double_assign_impl<const CHECK_SETUP: bool>(&mut self) {
365                if self != &<Self as WeierstrassPoint>::IDENTITY {
366                    unsafe {
367                        self.double_assign_nonidentity::<CHECK_SETUP>();
368                    }
369                }
370            }
371
372            unsafe fn double_nonidentity<const CHECK_SETUP: bool>(&self) -> Self {
373                use openvm_algebra_guest::DivUnsafe;
374                // lambda = (3*x1^2+a)/(2*y1)
375                // assume a = 0
376                let lambda = (&THREE * self.x() * self.x()).div_unsafe(self.y() + self.y());
377                // x3 = lambda^2-x1-x1
378                let x3 = &lambda * &lambda - self.x() - self.x();
379                // y3 = lambda * (x1-x3) - y1
380                let y3 = lambda * (self.x() - &x3) - self.y();
381                Self(AffinePoint::new(x3, y3))
382            }
383
384            #[inline(always)]
385            unsafe fn double_assign_nonidentity<const CHECK_SETUP: bool>(&mut self) {
386                *self = self.double_nonidentity::<CHECK_SETUP>();
387            }
388
389            unsafe fn add_ne_nonidentity<const CHECK_SETUP: bool>(&self, p2: &Self) -> Self {
390                use openvm_algebra_guest::DivUnsafe;
391                // lambda = (y2-y1)/(x2-x1)
392                // x3 = lambda^2-x1-x2
393                // y3 = lambda*(x1-x3)-y1
394                let lambda = (p2.y() - self.y()).div_unsafe(p2.x() - self.x());
395                let x3 = &lambda * &lambda - self.x() - p2.x();
396                let y3 = lambda * (self.x() - &x3) - self.y();
397                Self(AffinePoint::new(x3, y3))
398            }
399
400            #[inline(always)]
401            unsafe fn add_ne_assign_nonidentity<const CHECK_SETUP: bool>(&mut self, p2: &Self) {
402                *self = self.add_ne_nonidentity::<CHECK_SETUP>(p2);
403            }
404
405            unsafe fn sub_ne_nonidentity<const CHECK_SETUP: bool>(&self, p2: &Self) -> Self {
406                use openvm_algebra_guest::DivUnsafe;
407                // lambda = (y2+y1)/(x1-x2)
408                // x3 = lambda^2-x1-x2
409                // y3 = lambda*(x1-x3)-y1
410                let lambda = (p2.y() + self.y()).div_unsafe(self.x() - p2.x());
411                let x3 = &lambda * &lambda - self.x() - p2.x();
412                let y3 = lambda * (self.x() - &x3) - self.y();
413                Self(AffinePoint::new(x3, y3))
414            }
415
416            #[inline(always)]
417            unsafe fn sub_ne_assign_nonidentity<const CHECK_SETUP: bool>(&mut self, p2: &Self) {
418                *self = self.sub_ne_nonidentity::<CHECK_SETUP>(p2);
419            }
420        }
421
422        impl core::ops::Neg for $struct_name {
423            type Output = Self;
424
425            #[inline(always)]
426            fn neg(mut self) -> Self::Output {
427                self.0.y.neg_assign();
428                self
429            }
430        }
431
432        impl core::ops::Neg for &$struct_name {
433            type Output = $struct_name;
434
435            #[inline(always)]
436            fn neg(self) -> Self::Output {
437                self.clone().neg()
438            }
439        }
440
441        impl From<$struct_name> for AffinePoint<$field> {
442            fn from(value: $struct_name) -> Self {
443                value.0
444            }
445        }
446
447        impl From<AffinePoint<$field>> for $struct_name {
448            fn from(value: AffinePoint<$field>) -> Self {
449                Self(value)
450            }
451        }
452    };
453}
454
455/// Implements `Group` on `$struct_name` assuming that `$struct_name` implements `WeierstrassPoint`.
456/// Assumes that `Neg` is implemented for `&$struct_name`.
457#[macro_export]
458macro_rules! impl_sw_group_ops {
459    ($struct_name:ident, $field:ty) => {
460        impl Group for $struct_name {
461            type SelfRef<'a> = &'a Self;
462
463            const IDENTITY: Self = <Self as WeierstrassPoint>::IDENTITY;
464
465            #[inline(always)]
466            fn double(&self) -> Self {
467                if self.is_identity() {
468                    self.clone()
469                } else {
470                    unsafe { self.double_nonidentity::<true>() }
471                }
472            }
473
474            #[inline(always)]
475            fn double_assign(&mut self) {
476                self.double_assign_impl::<true>();
477            }
478
479            // This implementation is the same as the default implementation in the `Group` trait,
480            // but it was found that overriding the default implementation reduced the cycle count
481            // by 50% on the ecrecover benchmark.
482            // We hypothesize that this is due to compiler optimizations that are not possible when
483            // the `is_identity` function is defined in a different source file.
484            #[inline(always)]
485            fn is_identity(&self) -> bool {
486                self == &<Self as Group>::IDENTITY
487            }
488        }
489
490        impl core::ops::Add<&$struct_name> for $struct_name {
491            type Output = Self;
492
493            #[inline(always)]
494            fn add(mut self, p2: &$struct_name) -> Self::Output {
495                use core::ops::AddAssign;
496                self.add_assign(p2);
497                self
498            }
499        }
500
501        impl core::ops::Add for $struct_name {
502            type Output = Self;
503
504            #[inline(always)]
505            fn add(self, rhs: Self) -> Self::Output {
506                self.add(&rhs)
507            }
508        }
509
510        impl core::ops::Add<&$struct_name> for &$struct_name {
511            type Output = $struct_name;
512
513            #[inline(always)]
514            fn add(self, p2: &$struct_name) -> Self::Output {
515                if self.is_identity() {
516                    p2.clone()
517                } else if p2.is_identity() {
518                    self.clone()
519                } else if WeierstrassPoint::x(self) == WeierstrassPoint::x(p2) {
520                    if self.y() + p2.y() == <$field as openvm_algebra_guest::Field>::ZERO {
521                        <$struct_name as WeierstrassPoint>::IDENTITY
522                    } else {
523                        unsafe { self.double_nonidentity::<true>() }
524                    }
525                } else {
526                    unsafe { self.add_ne_nonidentity::<true>(p2) }
527                }
528            }
529        }
530
531        impl core::ops::AddAssign<&$struct_name> for $struct_name {
532            #[inline(always)]
533            fn add_assign(&mut self, p2: &$struct_name) {
534                self.add_assign_impl::<true>(p2);
535            }
536        }
537
538        impl core::ops::AddAssign for $struct_name {
539            #[inline(always)]
540            fn add_assign(&mut self, rhs: Self) {
541                self.add_assign(&rhs);
542            }
543        }
544
545        impl core::ops::Sub<&$struct_name> for $struct_name {
546            type Output = Self;
547
548            #[inline(always)]
549            fn sub(self, rhs: &$struct_name) -> Self::Output {
550                core::ops::Sub::sub(&self, rhs)
551            }
552        }
553
554        impl core::ops::Sub for $struct_name {
555            type Output = $struct_name;
556
557            #[inline(always)]
558            fn sub(self, rhs: Self) -> Self::Output {
559                self.sub(&rhs)
560            }
561        }
562
563        impl core::ops::Sub<&$struct_name> for &$struct_name {
564            type Output = $struct_name;
565
566            #[inline(always)]
567            fn sub(self, p2: &$struct_name) -> Self::Output {
568                if p2.is_identity() {
569                    self.clone()
570                } else if self.is_identity() {
571                    core::ops::Neg::neg(p2)
572                } else if WeierstrassPoint::x(self) == WeierstrassPoint::x(p2) {
573                    if self.y() == p2.y() {
574                        <$struct_name as WeierstrassPoint>::IDENTITY
575                    } else {
576                        unsafe { self.double_nonidentity::<true>() }
577                    }
578                } else {
579                    unsafe { self.sub_ne_nonidentity::<true>(p2) }
580                }
581            }
582        }
583
584        impl core::ops::SubAssign<&$struct_name> for $struct_name {
585            #[inline(always)]
586            fn sub_assign(&mut self, p2: &$struct_name) {
587                if p2.is_identity() {
588                    // do nothing
589                } else if self.is_identity() {
590                    *self = core::ops::Neg::neg(p2);
591                } else if WeierstrassPoint::x(self) == WeierstrassPoint::x(p2) {
592                    if self.y() == p2.y() {
593                        *self = <$struct_name as WeierstrassPoint>::IDENTITY
594                    } else {
595                        unsafe {
596                            self.double_assign_nonidentity::<true>();
597                        }
598                    }
599                } else {
600                    unsafe {
601                        self.sub_ne_assign_nonidentity::<true>(p2);
602                    }
603                }
604            }
605        }
606
607        impl core::ops::SubAssign for $struct_name {
608            #[inline(always)]
609            fn sub_assign(&mut self, rhs: Self) {
610                self.sub_assign(&rhs);
611            }
612        }
613    };
614}