snark_verifier/loader/evm/
loader.rs

1use crate::{
2    loader::{
3        evm::{
4            code::{Precompiled, SolidityAssemblyCode},
5            fe_to_u256, modulus, u256_to_fe, U256, U512,
6        },
7        EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader,
8    },
9    util::{
10        arithmetic::{CurveAffine, FieldOps, PrimeField},
11        Itertools,
12    },
13};
14use hex;
15use std::{
16    cell::RefCell,
17    collections::HashMap,
18    fmt::{self, Debug},
19    iter,
20    ops::{Add, AddAssign, DerefMut, Mul, MulAssign, Neg, Sub, SubAssign},
21    rc::Rc,
22};
23
24/// Memory pointer starts at 0x80, which is the end of the Solidity memory layout scratch space.
25pub const MEM_PTR_START: usize = 0x80;
26
27#[derive(Clone, Debug)]
28pub enum Value<T> {
29    Constant(T),
30    Memory(usize),
31    Negated(Box<Value<T>>),
32    Sum(Box<Value<T>>, Box<Value<T>>),
33    Product(Box<Value<T>>, Box<Value<T>>),
34}
35
36impl<T: Debug> PartialEq for Value<T> {
37    fn eq(&self, other: &Self) -> bool {
38        self.identifier() == other.identifier()
39    }
40}
41
42impl<T: Debug> Value<T> {
43    fn identifier(&self) -> String {
44        match self {
45            Value::Constant(_) | Value::Memory(_) => format!("{self:?}"),
46            Value::Negated(value) => format!("-({value:?})"),
47            Value::Sum(lhs, rhs) => format!("({lhs:?} + {rhs:?})"),
48            Value::Product(lhs, rhs) => format!("({lhs:?} * {rhs:?})"),
49        }
50    }
51}
52
53/// `Loader` implementation for generating yul code as EVM verifier.
54#[derive(Clone, Debug)]
55pub struct EvmLoader {
56    base_modulus: U256,
57    scalar_modulus: U256,
58    code: RefCell<SolidityAssemblyCode>,
59    ptr: RefCell<usize>,
60    cache: RefCell<HashMap<String, usize>>,
61}
62
63fn hex_encode_u256(value: &U256) -> String {
64    format!("0x{}", hex::encode(value.to_be_bytes::<32>()))
65}
66
67impl EvmLoader {
68    /// Initialize a [`EvmLoader`] with base and scalar field.
69    pub fn new<Base, Scalar>() -> Rc<Self>
70    where
71        Base: PrimeField<Repr = [u8; 0x20]>,
72        Scalar: PrimeField<Repr = [u8; 32]>,
73    {
74        let base_modulus = modulus::<Base>();
75        let scalar_modulus = modulus::<Scalar>();
76        let code = SolidityAssemblyCode::new();
77
78        Rc::new(Self {
79            base_modulus,
80            scalar_modulus,
81            code: RefCell::new(code),
82            ptr: RefCell::new(MEM_PTR_START),
83            cache: Default::default(),
84        })
85    }
86
87    /// Returns generated Solidity code. This is "Solidity" code that is wrapped in an assembly block.
88    /// In other words, it's basically just assembly (equivalently, Yul).
89    pub fn solidity_code(self: &Rc<Self>) -> String {
90        let code = "
91            // Revert if anything fails
92            if iszero(success) { revert(0, 0) }
93
94            // Return empty bytes on success
95            return(0, 0)"
96            .to_string();
97        self.code.borrow_mut().runtime_append(code);
98        self.code
99            .borrow()
100            .code(hex_encode_u256(&self.base_modulus), hex_encode_u256(&self.scalar_modulus))
101    }
102
103    /// Allocates memory chunk with given `size` and returns pointer.
104    pub fn allocate(self: &Rc<Self>, size: usize) -> usize {
105        let ptr = *self.ptr.borrow();
106        *self.ptr.borrow_mut() += size;
107        ptr
108    }
109
110    pub(crate) fn ptr(&self) -> usize {
111        *self.ptr.borrow()
112    }
113
114    pub(crate) fn code_mut(&self) -> impl DerefMut<Target = SolidityAssemblyCode> + '_ {
115        self.code.borrow_mut()
116    }
117
118    fn push(self: &Rc<Self>, scalar: &Scalar) -> String {
119        match scalar.value.clone() {
120            Value::Constant(constant) => {
121                format!("{constant}")
122            }
123            Value::Memory(ptr) => {
124                format!("mload({ptr:#x})")
125            }
126            Value::Negated(value) => {
127                let v = self.push(&self.scalar(*value));
128                format!("sub(f_q, {v})")
129            }
130            Value::Sum(lhs, rhs) => {
131                let lhs = self.push(&self.scalar(*lhs));
132                let rhs = self.push(&self.scalar(*rhs));
133                format!("addmod({lhs}, {rhs}, f_q)")
134            }
135            Value::Product(lhs, rhs) => {
136                let lhs = self.push(&self.scalar(*lhs));
137                let rhs = self.push(&self.scalar(*rhs));
138                format!("mulmod({lhs}, {rhs}, f_q)")
139            }
140        }
141    }
142
143    /// Calldata load a field element.
144    pub fn calldataload_scalar(self: &Rc<Self>, offset: usize) -> Scalar {
145        let ptr = self.allocate(0x20);
146        let code = format!("mstore({ptr:#x}, mod(calldataload({offset:#x}), f_q))");
147        self.code.borrow_mut().runtime_append(code);
148        self.scalar(Value::Memory(ptr))
149    }
150
151    /// Calldata load an elliptic curve point and validate it's on affine plane.
152    /// Note that identity will cause the verification to fail.
153    pub fn calldataload_ec_point(self: &Rc<Self>, offset: usize) -> EcPoint {
154        let x_ptr = self.allocate(0x40);
155        let y_ptr = x_ptr + 0x20;
156        let x_cd_ptr = offset;
157        let y_cd_ptr = offset + 0x20;
158        let validate_code = self.validate_ec_point();
159        let code = format!(
160            "
161        {{
162            let x := calldataload({x_cd_ptr:#x})
163            mstore({x_ptr:#x}, x)
164            let y := calldataload({y_cd_ptr:#x})
165            mstore({y_ptr:#x}, y)
166            {validate_code}
167        }}"
168        );
169        self.code.borrow_mut().runtime_append(code);
170        self.ec_point(Value::Memory(x_ptr))
171    }
172
173    /// Decode an elliptic curve point from limbs.
174    pub fn ec_point_from_limbs<const LIMBS: usize, const BITS: usize>(
175        self: &Rc<Self>,
176        x_limbs: [&Scalar; LIMBS],
177        y_limbs: [&Scalar; LIMBS],
178    ) -> EcPoint {
179        let ptr = self.allocate(0x40);
180        let mut code = String::new();
181        for (idx, limb) in x_limbs.iter().enumerate() {
182            let limb_i = self.push(limb);
183            let shift = idx * BITS;
184            if idx == 0 {
185                code.push_str(format!("let x := {limb_i}\n").as_str());
186            } else {
187                code.push_str(format!("x := add(x, shl({shift}, {limb_i}))\n").as_str());
188            }
189        }
190        let x_ptr = ptr;
191        code.push_str(format!("mstore({x_ptr}, x)\n").as_str());
192        for (idx, limb) in y_limbs.iter().enumerate() {
193            let limb_i = self.push(limb);
194            let shift = idx * BITS;
195            if idx == 0 {
196                code.push_str(format!("let y := {limb_i}\n").as_str());
197            } else {
198                code.push_str(format!("y := add(y, shl({shift}, {limb_i}))\n").as_str());
199            }
200        }
201        let y_ptr = ptr + 0x20;
202        code.push_str(format!("mstore({y_ptr}, y)\n").as_str());
203        let validate_code = self.validate_ec_point();
204        let code = format!(
205            "{{
206            {code}
207            {validate_code}
208        }}"
209        );
210        self.code.borrow_mut().runtime_append(code);
211        self.ec_point(Value::Memory(ptr))
212    }
213
214    fn validate_ec_point(self: &Rc<Self>) -> String {
215        "success := and(validate_ec_point(x, y), success)".to_string()
216    }
217
218    pub(crate) fn scalar(self: &Rc<Self>, value: Value<U256>) -> Scalar {
219        let value = if matches!(value, Value::Constant(_) | Value::Memory(_) | Value::Negated(_)) {
220            value
221        } else {
222            let identifier = value.identifier();
223            let some_ptr = self.cache.borrow().get(&identifier).cloned();
224            let ptr = if let Some(ptr) = some_ptr {
225                ptr
226            } else {
227                let v = self.push(&Scalar { loader: self.clone(), value });
228                let ptr = self.allocate(0x20);
229                self.code.borrow_mut().runtime_append(format!("mstore({ptr:#x}, {v})"));
230                self.cache.borrow_mut().insert(identifier, ptr);
231                ptr
232            };
233            Value::Memory(ptr)
234        };
235        Scalar { loader: self.clone(), value }
236    }
237
238    fn ec_point(self: &Rc<Self>, value: Value<(U256, U256)>) -> EcPoint {
239        EcPoint { loader: self.clone(), value }
240    }
241
242    /// Performs `KECCAK256` on `memory[ptr..ptr+len]` and returns pointer of
243    /// hash.
244    pub fn keccak256(self: &Rc<Self>, ptr: usize, len: usize) -> usize {
245        let hash_ptr = self.allocate(0x20);
246        let code = format!("mstore({hash_ptr:#x}, keccak256({ptr:#x}, {len}))");
247        self.code.borrow_mut().runtime_append(code);
248        hash_ptr
249    }
250
251    /// Copies a field element into given `ptr`.
252    pub fn copy_scalar(self: &Rc<Self>, scalar: &Scalar, ptr: usize) {
253        let scalar = self.push(scalar);
254        self.code.borrow_mut().runtime_append(format!("mstore({ptr:#x}, {scalar})"));
255    }
256
257    /// Allocates a new field element and copies the given value into it.
258    pub fn dup_scalar(self: &Rc<Self>, scalar: &Scalar) -> Scalar {
259        let ptr = self.allocate(0x20);
260        self.copy_scalar(scalar, ptr);
261        self.scalar(Value::Memory(ptr))
262    }
263
264    /// Allocates a new elliptic curve point and copies the given value into it.
265    pub fn dup_ec_point(self: &Rc<Self>, value: &EcPoint) -> EcPoint {
266        let ptr = self.allocate(0x40);
267        match value.value {
268            Value::Constant((x, y)) => {
269                let x_ptr = ptr;
270                let y_ptr = ptr + 0x20;
271                let x = hex_encode_u256(&x);
272                let y = hex_encode_u256(&y);
273                let code = format!(
274                    "mstore({x_ptr:#x}, {x})
275                    mstore({y_ptr:#x}, {y})"
276                );
277                self.code.borrow_mut().runtime_append(code);
278            }
279            Value::Memory(src_ptr) => {
280                let x_ptr = ptr;
281                let y_ptr = ptr + 0x20;
282                let src_x = src_ptr;
283                let src_y = src_ptr + 0x20;
284                let code = format!(
285                    "mstore({x_ptr:#x}, mload({src_x:#x}))
286                    mstore({y_ptr:#x}, mload({src_y:#x}))"
287                );
288                self.code.borrow_mut().runtime_append(code);
289            }
290            Value::Negated(_) | Value::Sum(_, _) | Value::Product(_, _) => {
291                unreachable!()
292            }
293        }
294        self.ec_point(Value::Memory(ptr))
295    }
296
297    fn staticcall(self: &Rc<Self>, precompile: Precompiled, cd_ptr: usize, rd_ptr: usize) {
298        let (cd_len, rd_len) = match precompile {
299            Precompiled::BigModExp => (0xc0, 0x20),
300            Precompiled::Bn254Add => (0x80, 0x40),
301            Precompiled::Bn254ScalarMul => (0x60, 0x40),
302            Precompiled::Bn254Pairing => (0x180, 0x20),
303        };
304        let a = precompile as usize;
305        let code = format!("success := and(eq(staticcall(gas(), {a:#x}, {cd_ptr:#x}, {cd_len:#x}, {rd_ptr:#x}, {rd_len:#x}), 1), success)");
306        self.code.borrow_mut().runtime_append(code);
307    }
308
309    fn invert(self: &Rc<Self>, scalar: &Scalar) -> Scalar {
310        let rd_ptr = self.allocate(0x20);
311        let [cd_ptr, ..] = [
312            &self.scalar(Value::Constant(U256::from(0x20))),
313            &self.scalar(Value::Constant(U256::from(0x20))),
314            &self.scalar(Value::Constant(U256::from(0x20))),
315            scalar,
316            &self.scalar(Value::Constant(self.scalar_modulus - U256::from(2))),
317            &self.scalar(Value::Constant(self.scalar_modulus)),
318        ]
319        .map(|value| self.dup_scalar(value).ptr());
320        self.staticcall(Precompiled::BigModExp, cd_ptr, rd_ptr);
321        self.scalar(Value::Memory(rd_ptr))
322    }
323
324    fn ec_point_add(self: &Rc<Self>, lhs: &EcPoint, rhs: &EcPoint) -> EcPoint {
325        let rd_ptr = self.dup_ec_point(lhs).ptr();
326        self.dup_ec_point(rhs);
327        self.staticcall(Precompiled::Bn254Add, rd_ptr, rd_ptr);
328        self.ec_point(Value::Memory(rd_ptr))
329    }
330
331    fn ec_point_scalar_mul(self: &Rc<Self>, ec_point: &EcPoint, scalar: &Scalar) -> EcPoint {
332        let rd_ptr = self.dup_ec_point(ec_point).ptr();
333        self.dup_scalar(scalar);
334        self.staticcall(Precompiled::Bn254ScalarMul, rd_ptr, rd_ptr);
335        self.ec_point(Value::Memory(rd_ptr))
336    }
337
338    /// Performs pairing.
339    pub fn pairing(
340        self: &Rc<Self>,
341        lhs: &EcPoint,
342        g2: (U256, U256, U256, U256),
343        rhs: &EcPoint,
344        minus_s_g2: (U256, U256, U256, U256),
345    ) {
346        let rd_ptr = self.dup_ec_point(lhs).ptr();
347        self.allocate(0x80);
348        let g2_0 = hex_encode_u256(&g2.0);
349        let g2_0_ptr = rd_ptr + 0x40;
350        let g2_1 = hex_encode_u256(&g2.1);
351        let g2_1_ptr = rd_ptr + 0x60;
352        let g2_2 = hex_encode_u256(&g2.2);
353        let g2_2_ptr = rd_ptr + 0x80;
354        let g2_3 = hex_encode_u256(&g2.3);
355        let g2_3_ptr = rd_ptr + 0xa0;
356        let code = format!(
357            "mstore({g2_0_ptr:#x}, {g2_0})
358            mstore({g2_1_ptr:#x}, {g2_1})
359            mstore({g2_2_ptr:#x}, {g2_2})
360            mstore({g2_3_ptr:#x}, {g2_3})"
361        );
362        self.code.borrow_mut().runtime_append(code);
363        self.dup_ec_point(rhs);
364        self.allocate(0x80);
365        let minus_s_g2_0 = hex_encode_u256(&minus_s_g2.0);
366        let minus_s_g2_0_ptr = rd_ptr + 0x100;
367        let minus_s_g2_1 = hex_encode_u256(&minus_s_g2.1);
368        let minus_s_g2_1_ptr = rd_ptr + 0x120;
369        let minus_s_g2_2 = hex_encode_u256(&minus_s_g2.2);
370        let minus_s_g2_2_ptr = rd_ptr + 0x140;
371        let minus_s_g2_3 = hex_encode_u256(&minus_s_g2.3);
372        let minus_s_g2_3_ptr = rd_ptr + 0x160;
373        let code = format!(
374            "mstore({minus_s_g2_0_ptr:#x}, {minus_s_g2_0})
375            mstore({minus_s_g2_1_ptr:#x}, {minus_s_g2_1})
376            mstore({minus_s_g2_2_ptr:#x}, {minus_s_g2_2})
377            mstore({minus_s_g2_3_ptr:#x}, {minus_s_g2_3})"
378        );
379        self.code.borrow_mut().runtime_append(code);
380        self.staticcall(Precompiled::Bn254Pairing, rd_ptr, rd_ptr);
381        let code = format!("success := and(eq(mload({rd_ptr:#x}), 1), success)");
382        self.code.borrow_mut().runtime_append(code);
383    }
384
385    fn add(self: &Rc<Self>, lhs: &Scalar, rhs: &Scalar) -> Scalar {
386        if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) {
387            let out = (U512::from(*lhs) + U512::from(*rhs)) % U512::from(self.scalar_modulus);
388            return self.scalar(Value::Constant(U256::from(out)));
389        }
390
391        self.scalar(Value::Sum(Box::new(lhs.value.clone()), Box::new(rhs.value.clone())))
392    }
393
394    fn sub(self: &Rc<Self>, lhs: &Scalar, rhs: &Scalar) -> Scalar {
395        if rhs.is_const() {
396            return self.add(lhs, &self.neg(rhs));
397        }
398
399        self.scalar(Value::Sum(
400            Box::new(lhs.value.clone()),
401            Box::new(Value::Negated(Box::new(rhs.value.clone()))),
402        ))
403    }
404
405    fn mul(self: &Rc<Self>, lhs: &Scalar, rhs: &Scalar) -> Scalar {
406        if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) {
407            let out = (U512::from(*lhs) * U512::from(*rhs)) % U512::from(self.scalar_modulus);
408            return self.scalar(Value::Constant(U256::from(out)));
409        }
410
411        self.scalar(Value::Product(Box::new(lhs.value.clone()), Box::new(rhs.value.clone())))
412    }
413
414    fn neg(self: &Rc<Self>, scalar: &Scalar) -> Scalar {
415        if let Value::Constant(constant) = scalar.value {
416            return self.scalar(Value::Constant(self.scalar_modulus - constant));
417        }
418
419        self.scalar(Value::Negated(Box::new(scalar.value.clone())))
420    }
421}
422
423#[cfg(test)]
424impl EvmLoader {
425    fn start_gas_metering(self: &Rc<Self>, _: &str) {
426        //  unimplemented
427    }
428
429    fn end_gas_metering(self: &Rc<Self>) {
430        //  unimplemented
431    }
432
433    #[allow(dead_code)]
434    fn print_gas_metering(self: &Rc<Self>, _: Vec<u64>) {
435        //  unimplemented
436    }
437}
438
439/// Elliptic curve point.
440#[derive(Clone)]
441pub struct EcPoint {
442    loader: Rc<EvmLoader>,
443    value: Value<(U256, U256)>,
444}
445
446impl EcPoint {
447    pub(crate) fn loader(&self) -> &Rc<EvmLoader> {
448        &self.loader
449    }
450
451    pub(crate) fn value(&self) -> Value<(U256, U256)> {
452        self.value.clone()
453    }
454
455    pub(crate) fn ptr(&self) -> usize {
456        match self.value {
457            Value::Memory(ptr) => ptr,
458            _ => unreachable!(),
459        }
460    }
461}
462
463impl Debug for EcPoint {
464    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465        f.debug_struct("EcPoint").field("value", &self.value).finish()
466    }
467}
468
469impl PartialEq for EcPoint {
470    fn eq(&self, other: &Self) -> bool {
471        self.value == other.value
472    }
473}
474
475impl<C> LoadedEcPoint<C> for EcPoint
476where
477    C: CurveAffine,
478    C::ScalarExt: PrimeField<Repr = [u8; 0x20]>,
479{
480    type Loader = Rc<EvmLoader>;
481
482    fn loader(&self) -> &Rc<EvmLoader> {
483        &self.loader
484    }
485}
486
487/// Field element.
488#[derive(Clone)]
489pub struct Scalar {
490    loader: Rc<EvmLoader>,
491    value: Value<U256>,
492}
493
494impl Scalar {
495    pub(crate) fn loader(&self) -> &Rc<EvmLoader> {
496        &self.loader
497    }
498
499    pub(crate) fn value(&self) -> Value<U256> {
500        self.value.clone()
501    }
502
503    pub(crate) fn is_const(&self) -> bool {
504        matches!(self.value, Value::Constant(_))
505    }
506
507    pub(crate) fn ptr(&self) -> usize {
508        match self.value {
509            Value::Memory(ptr) => ptr,
510            _ => *self.loader.cache.borrow().get(&self.value.identifier()).unwrap(),
511        }
512    }
513}
514
515impl Debug for Scalar {
516    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
517        f.debug_struct("Scalar").field("value", &self.value).finish()
518    }
519}
520
521impl Add for Scalar {
522    type Output = Self;
523
524    fn add(self, rhs: Self) -> Self {
525        self.loader.add(&self, &rhs)
526    }
527}
528
529impl Sub for Scalar {
530    type Output = Self;
531
532    fn sub(self, rhs: Self) -> Self {
533        self.loader.sub(&self, &rhs)
534    }
535}
536
537impl Mul for Scalar {
538    type Output = Self;
539
540    fn mul(self, rhs: Self) -> Self {
541        self.loader.mul(&self, &rhs)
542    }
543}
544
545impl Neg for Scalar {
546    type Output = Self;
547
548    fn neg(self) -> Self {
549        self.loader.neg(&self)
550    }
551}
552
553impl<'a> Add<&'a Self> for Scalar {
554    type Output = Self;
555
556    fn add(self, rhs: &'a Self) -> Self {
557        self.loader.add(&self, rhs)
558    }
559}
560
561impl<'a> Sub<&'a Self> for Scalar {
562    type Output = Self;
563
564    fn sub(self, rhs: &'a Self) -> Self {
565        self.loader.sub(&self, rhs)
566    }
567}
568
569impl<'a> Mul<&'a Self> for Scalar {
570    type Output = Self;
571
572    fn mul(self, rhs: &'a Self) -> Self {
573        self.loader.mul(&self, rhs)
574    }
575}
576
577impl AddAssign for Scalar {
578    fn add_assign(&mut self, rhs: Self) {
579        *self = self.loader.add(self, &rhs);
580    }
581}
582
583impl SubAssign for Scalar {
584    fn sub_assign(&mut self, rhs: Self) {
585        *self = self.loader.sub(self, &rhs);
586    }
587}
588
589impl MulAssign for Scalar {
590    fn mul_assign(&mut self, rhs: Self) {
591        *self = self.loader.mul(self, &rhs);
592    }
593}
594
595impl<'a> AddAssign<&'a Self> for Scalar {
596    fn add_assign(&mut self, rhs: &'a Self) {
597        *self = self.loader.add(self, rhs);
598    }
599}
600
601impl<'a> SubAssign<&'a Self> for Scalar {
602    fn sub_assign(&mut self, rhs: &'a Self) {
603        *self = self.loader.sub(self, rhs);
604    }
605}
606
607impl<'a> MulAssign<&'a Self> for Scalar {
608    fn mul_assign(&mut self, rhs: &'a Self) {
609        *self = self.loader.mul(self, rhs);
610    }
611}
612
613impl FieldOps for Scalar {
614    fn invert(&self) -> Option<Scalar> {
615        Some(self.loader.invert(self))
616    }
617}
618
619impl PartialEq for Scalar {
620    fn eq(&self, other: &Self) -> bool {
621        self.value == other.value
622    }
623}
624
625impl<F: PrimeField<Repr = [u8; 0x20]>> LoadedScalar<F> for Scalar {
626    type Loader = Rc<EvmLoader>;
627
628    fn loader(&self) -> &Self::Loader {
629        &self.loader
630    }
631
632    fn pow_var(&self, _exp: &Self, _exp_max_bits: usize) -> Self {
633        todo!()
634    }
635}
636
637impl<C> EcPointLoader<C> for Rc<EvmLoader>
638where
639    C: CurveAffine,
640    C::Scalar: PrimeField<Repr = [u8; 0x20]>,
641{
642    type LoadedEcPoint = EcPoint;
643
644    fn ec_point_load_const(&self, value: &C) -> EcPoint {
645        let coordinates = value.coordinates().unwrap();
646        let [x, y] = [coordinates.x(), coordinates.y()]
647            .map(|coordinate| U256::try_from_le_slice(coordinate.to_repr().as_ref()).unwrap());
648        self.ec_point(Value::Constant((x, y)))
649    }
650
651    fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) {
652        unimplemented!()
653    }
654
655    fn multi_scalar_multiplication(
656        pairs: &[(&<Self as ScalarLoader<C::Scalar>>::LoadedScalar, &EcPoint)],
657    ) -> EcPoint {
658        pairs
659            .iter()
660            .cloned()
661            .map(|(scalar, ec_point)| match scalar.value {
662                Value::Constant(constant) if U256::from(1) == constant => ec_point.clone(),
663                _ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar),
664            })
665            .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point))
666            .expect("pairs should not be empty")
667    }
668}
669
670impl<F: PrimeField<Repr = [u8; 0x20]>> ScalarLoader<F> for Rc<EvmLoader> {
671    type LoadedScalar = Scalar;
672
673    fn load_const(&self, value: &F) -> Scalar {
674        self.scalar(Value::Constant(fe_to_u256(*value)))
675    }
676
677    fn assert_eq(&self, _: &str, _: &Scalar, _: &Scalar) {
678        unimplemented!()
679    }
680
681    fn sum_with_coeff_and_const(&self, values: &[(F, &Scalar)], constant: F) -> Scalar {
682        if values.is_empty() {
683            return self.load_const(&constant);
684        }
685
686        let push_addend = |(coeff, value): &(F, &Scalar)| {
687            assert_ne!(*coeff, F::ZERO);
688            match (*coeff == F::ONE, &value.value) {
689                (true, _) => self.push(value),
690                (false, Value::Constant(value)) => self.push(
691                    &self.scalar(Value::Constant(fe_to_u256(*coeff * u256_to_fe::<F>(*value)))),
692                ),
693                (false, _) => {
694                    let value = self.push(value);
695                    let coeff = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff))));
696                    format!("mulmod({value}, {coeff}, f_q)")
697                }
698            }
699        };
700
701        let mut values = values.iter();
702        let initial_value = if constant == F::ZERO {
703            push_addend(values.next().unwrap())
704        } else {
705            self.push(&self.scalar(Value::Constant(fe_to_u256(constant))))
706        };
707
708        let mut code = format!("let result := {initial_value}\n");
709        for value in values {
710            let v = push_addend(value);
711            let addend = format!("result := addmod({v}, result, f_q)\n");
712            code.push_str(addend.as_str());
713        }
714
715        let ptr = self.allocate(0x20);
716        code.push_str(format!("mstore({ptr}, result)").as_str());
717        self.code.borrow_mut().runtime_append(format!(
718            "{{
719            {code}
720        }}"
721        ));
722
723        self.scalar(Value::Memory(ptr))
724    }
725
726    fn sum_products_with_coeff_and_const(
727        &self,
728        values: &[(F, &Scalar, &Scalar)],
729        constant: F,
730    ) -> Scalar {
731        if values.is_empty() {
732            return self.load_const(&constant);
733        }
734
735        let push_addend = |(coeff, lhs, rhs): &(F, &Scalar, &Scalar)| {
736            assert_ne!(*coeff, F::ZERO);
737            match (*coeff == F::ONE, &lhs.value, &rhs.value) {
738                (_, Value::Constant(lhs), Value::Constant(rhs)) => {
739                    self.push(&self.scalar(Value::Constant(fe_to_u256(
740                        *coeff * u256_to_fe::<F>(*lhs) * u256_to_fe::<F>(*rhs),
741                    ))))
742                }
743                (_, value @ Value::Memory(_), Value::Constant(constant))
744                | (_, Value::Constant(constant), value @ Value::Memory(_)) => {
745                    let v1 = self.push(&self.scalar(value.clone()));
746                    let v2 =
747                        self.push(&self.scalar(Value::Constant(fe_to_u256(
748                            *coeff * u256_to_fe::<F>(*constant),
749                        ))));
750                    format!("mulmod({v1}, {v2}, f_q)")
751                }
752                (true, _, _) => {
753                    let rhs = self.push(rhs);
754                    let lhs = self.push(lhs);
755                    format!("mulmod({rhs}, {lhs}, f_q)")
756                }
757                (false, _, _) => {
758                    let rhs = self.push(rhs);
759                    let lhs = self.push(lhs);
760                    let value = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff))));
761                    format!("mulmod({rhs}, mulmod({lhs}, {value}, f_q), f_q)")
762                }
763            }
764        };
765
766        let mut values = values.iter();
767        let initial_value = if constant == F::ZERO {
768            push_addend(values.next().unwrap())
769        } else {
770            self.push(&self.scalar(Value::Constant(fe_to_u256(constant))))
771        };
772
773        let mut code = format!("let result := {initial_value}\n");
774        for value in values {
775            let v = push_addend(value);
776            let addend = format!("result := addmod({v}, result, f_q)\n");
777            code.push_str(addend.as_str());
778        }
779
780        let ptr = self.allocate(0x20);
781        code.push_str(format!("mstore({ptr}, result)").as_str());
782        self.code.borrow_mut().runtime_append(format!(
783            "{{
784            {code}
785        }}"
786        ));
787
788        self.scalar(Value::Memory(ptr))
789    }
790
791    // batch_invert algorithm
792    // n := values.len() - 1
793    // input : values[0], ..., values[n]
794    // output : values[0]^{-1}, ..., values[n]^{-1}
795    // 1. products[i] <- values[0] * ... * values[i], i = 1, ..., n
796    // 2. inv <- (products[n])^{-1}
797    // 3. v_n <- values[n]
798    // 4. values[n] <- products[n - 1] * inv (values[n]^{-1})
799    // 5. inv <- v_n * inv
800    fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Scalar>) {
801        let values = values.into_iter().collect_vec();
802        let loader = &values.first().unwrap().loader;
803        let products = iter::once(values[0].clone())
804            .chain(
805                iter::repeat_with(|| loader.allocate(0x20))
806                    .map(|ptr| loader.scalar(Value::Memory(ptr)))
807                    .take(values.len() - 1),
808            )
809            .collect_vec();
810
811        let initial_value = loader.push(products.first().unwrap());
812        let mut code = format!("let prod := {initial_value}\n");
813        for (value, product) in values.iter().zip(products.iter()).skip(1) {
814            let v = loader.push(value);
815            let ptr = product.ptr();
816            code.push_str(
817                format!(
818                    "
819                prod := mulmod({v}, prod, f_q)
820                mstore({ptr:#x}, prod)
821            "
822                )
823                .as_str(),
824            );
825        }
826        loader.code.borrow_mut().runtime_append(format!(
827            "{{
828            {code}
829        }}"
830        ));
831
832        let inv = loader.push(&loader.invert(products.last().unwrap()));
833
834        let mut code = format!(
835            "
836            let inv := {inv}
837            let v
838        "
839        );
840        for (value, product) in
841            values.iter().rev().zip(products.iter().rev().skip(1).map(Some).chain(iter::once(None)))
842        {
843            if let Some(product) = product {
844                let val_ptr = value.ptr();
845                let prod_ptr = product.ptr();
846                let v = loader.push(value);
847                code.push_str(
848                    format!(
849                        "
850                    v := {v}
851                    mstore({val_ptr}, mulmod(mload({prod_ptr:#x}), inv, f_q))
852                    inv := mulmod(v, inv, f_q)
853                "
854                    )
855                    .as_str(),
856                );
857            } else {
858                let ptr = value.ptr();
859                code.push_str(format!("mstore({ptr:#x}, inv)\n").as_str());
860            }
861        }
862        loader.code.borrow_mut().runtime_append(format!(
863            "{{
864            {code}
865        }}"
866        ));
867    }
868}
869
870impl<C> Loader<C> for Rc<EvmLoader>
871where
872    C: CurveAffine,
873    C::Scalar: PrimeField<Repr = [u8; 0x20]>,
874{
875    #[cfg(test)]
876    fn start_cost_metering(&self, identifier: &str) {
877        self.start_gas_metering(identifier)
878    }
879
880    #[cfg(test)]
881    fn end_cost_metering(&self) {
882        self.end_gas_metering()
883    }
884}