openvm_ecc_guest/msm.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
use alloc::{vec, vec::Vec};
use core::ops::{Add, Neg};
use openvm_algebra_guest::IntMod;
use super::Group;
/// Multi-scalar multiplication via Pippenger's algorithm
// Reference: https://github.com/privacy-scaling-explorations/halo2curves/blob/8771fe5a5d54fc03e74dbc8915db5dad3ab46a83/src/msm.rs#L335
// FIXME[jpw]: there are many memcpy in this function
pub fn msm<EcPoint: Group, Scalar: IntMod>(coeffs: &[Scalar], bases: &[EcPoint]) -> EcPoint
where
for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
{
let coeffs: Vec<_> = coeffs.iter().map(|c| c.as_le_bytes()).collect();
let mut acc = EcPoint::IDENTITY;
// c: window size. Will group scalars into c-bit windows
let c = if bases.len() < 4 {
1
} else if bases.len() < 32 {
3
} else {
// TODO: finetune this if needed
bases.len().ilog2() as usize
};
let field_byte_size = Scalar::NUM_LIMBS;
// OR all coefficients in order to make a mask to figure out the maximum number of bytes used
// among all coefficients.
let mut acc_or = vec![0; field_byte_size];
for coeff in &coeffs {
for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
*acc_limb |= *limb;
}
}
let max_byte_size = field_byte_size
- acc_or
.iter()
.rev()
.position(|v| *v != 0)
.unwrap_or(field_byte_size);
if max_byte_size == 0 {
return EcPoint::IDENTITY;
}
let number_of_windows = max_byte_size * 8_usize / c + 1;
for current_window in (0..number_of_windows).rev() {
for _ in 0..c {
acc.double_assign();
}
#[derive(Clone)]
enum Bucket<EcPoint: Group> {
None,
Affine(EcPoint),
}
impl<EcPoint: Group> Bucket<EcPoint>
where
for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
{
fn add_assign(&mut self, other: &EcPoint) {
match self {
Bucket::None => {
*self = Bucket::Affine(other.clone());
}
Bucket::Affine(a) => {
a.add_assign(other);
}
}
}
fn sub_assign(&mut self, other: &EcPoint) {
match self {
Bucket::None => {
*self = Bucket::Affine(other.clone().neg());
}
Bucket::Affine(a) => {
a.sub_assign(other);
}
}
}
fn add(self, mut other: EcPoint) -> EcPoint {
match self {
Bucket::None => other.clone(),
Bucket::Affine(a) => {
other += a;
other
}
}
}
}
let mut buckets: Vec<Bucket<EcPoint>> = vec![Bucket::None; 1 << (c - 1)];
for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let coeff = get_booth_index(current_window, c, coeff);
if coeff.is_positive() {
buckets[coeff as usize - 1].add_assign(base);
}
if coeff.is_negative() {
buckets[coeff.unsigned_abs() as usize - 1].sub_assign(base);
}
}
// Summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut running_sum = EcPoint::IDENTITY;
for exp in buckets.into_iter().rev() {
running_sum = exp.add(running_sum);
acc = acc.add(&running_sum);
}
}
acc
}
// TODO: benchmark to see if this is faster.
fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
// Booth encoding:
// * step by `window` size
// * slice by size of `window + 1``
// * each window overlap by 1 bit * append a zero bit to the least significant end
// Indexing rule for example window size 3 where we slice by 4 bits:
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
// So we can reduce the bucket size without preprocessing scalars
// and remembering them as in classic signed digit encoding
let skip_bits = (window_index * window_size).saturating_sub(1);
let skip_bytes = skip_bits / 8;
// fill into a u32
let mut v: [u8; 4] = [0; 4];
for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
*dst = *src
}
let mut tmp = u32::from_le_bytes(v);
// pad with one 0 if slicing the least significant window
if window_index == 0 {
tmp <<= 1;
}
// remove further bits
tmp >>= skip_bits - (skip_bytes * 8);
// apply the booth window
tmp &= (1 << (window_size + 1)) - 1;
let sign = tmp & (1 << window_size) == 0;
// div ceil by 2
tmp = (tmp + 1) >> 1;
// find the booth action index
if sign {
tmp as i32
} else {
((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
}
}