1use crate::{
2 loader::{
3 evm::{
4 code::{Precompiled, SolidityAssemblyCode},
5 fe_to_u256, modulus, u256_to_fe, U256, U512,
6 },
7 EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader,
8 },
9 util::{
10 arithmetic::{CurveAffine, FieldOps, PrimeField},
11 Itertools,
12 },
13};
14use hex;
15use std::{
16 cell::RefCell,
17 collections::HashMap,
18 fmt::{self, Debug},
19 iter,
20 ops::{Add, AddAssign, DerefMut, Mul, MulAssign, Neg, Sub, SubAssign},
21 rc::Rc,
22};
23
24pub const MEM_PTR_START: usize = 0x80;
26
27#[derive(Clone, Debug)]
28pub enum Value<T> {
29 Constant(T),
30 Memory(usize),
31 Negated(Box<Value<T>>),
32 Sum(Box<Value<T>>, Box<Value<T>>),
33 Product(Box<Value<T>>, Box<Value<T>>),
34}
35
36impl<T: Debug> PartialEq for Value<T> {
37 fn eq(&self, other: &Self) -> bool {
38 self.identifier() == other.identifier()
39 }
40}
41
42impl<T: Debug> Value<T> {
43 fn identifier(&self) -> String {
44 match self {
45 Value::Constant(_) | Value::Memory(_) => format!("{self:?}"),
46 Value::Negated(value) => format!("-({value:?})"),
47 Value::Sum(lhs, rhs) => format!("({lhs:?} + {rhs:?})"),
48 Value::Product(lhs, rhs) => format!("({lhs:?} * {rhs:?})"),
49 }
50 }
51}
52
53#[derive(Clone, Debug)]
55pub struct EvmLoader {
56 base_modulus: U256,
57 scalar_modulus: U256,
58 code: RefCell<SolidityAssemblyCode>,
59 ptr: RefCell<usize>,
60 cache: RefCell<HashMap<String, usize>>,
61}
62
63fn hex_encode_u256(value: &U256) -> String {
64 format!("0x{}", hex::encode(value.to_be_bytes::<32>()))
65}
66
67impl EvmLoader {
68 pub fn new<Base, Scalar>() -> Rc<Self>
70 where
71 Base: PrimeField<Repr = [u8; 0x20]>,
72 Scalar: PrimeField<Repr = [u8; 32]>,
73 {
74 let base_modulus = modulus::<Base>();
75 let scalar_modulus = modulus::<Scalar>();
76 let code = SolidityAssemblyCode::new();
77
78 Rc::new(Self {
79 base_modulus,
80 scalar_modulus,
81 code: RefCell::new(code),
82 ptr: RefCell::new(MEM_PTR_START),
83 cache: Default::default(),
84 })
85 }
86
87 pub fn solidity_code(self: &Rc<Self>) -> String {
90 let code = "
91 // Revert if anything fails
92 if iszero(success) { revert(0, 0) }
93
94 // Return empty bytes on success
95 return(0, 0)"
96 .to_string();
97 self.code.borrow_mut().runtime_append(code);
98 self.code
99 .borrow()
100 .code(hex_encode_u256(&self.base_modulus), hex_encode_u256(&self.scalar_modulus))
101 }
102
103 pub fn allocate(self: &Rc<Self>, size: usize) -> usize {
105 let ptr = *self.ptr.borrow();
106 *self.ptr.borrow_mut() += size;
107 ptr
108 }
109
110 pub(crate) fn ptr(&self) -> usize {
111 *self.ptr.borrow()
112 }
113
114 pub(crate) fn code_mut(&self) -> impl DerefMut<Target = SolidityAssemblyCode> + '_ {
115 self.code.borrow_mut()
116 }
117
118 fn push(self: &Rc<Self>, scalar: &Scalar) -> String {
119 match scalar.value.clone() {
120 Value::Constant(constant) => {
121 format!("{constant}")
122 }
123 Value::Memory(ptr) => {
124 format!("mload({ptr:#x})")
125 }
126 Value::Negated(value) => {
127 let v = self.push(&self.scalar(*value));
128 format!("sub(f_q, {v})")
129 }
130 Value::Sum(lhs, rhs) => {
131 let lhs = self.push(&self.scalar(*lhs));
132 let rhs = self.push(&self.scalar(*rhs));
133 format!("addmod({lhs}, {rhs}, f_q)")
134 }
135 Value::Product(lhs, rhs) => {
136 let lhs = self.push(&self.scalar(*lhs));
137 let rhs = self.push(&self.scalar(*rhs));
138 format!("mulmod({lhs}, {rhs}, f_q)")
139 }
140 }
141 }
142
143 pub fn calldataload_scalar(self: &Rc<Self>, offset: usize) -> Scalar {
145 let ptr = self.allocate(0x20);
146 let code = format!("mstore({ptr:#x}, mod(calldataload({offset:#x}), f_q))");
147 self.code.borrow_mut().runtime_append(code);
148 self.scalar(Value::Memory(ptr))
149 }
150
151 pub fn calldataload_ec_point(self: &Rc<Self>, offset: usize) -> EcPoint {
154 let x_ptr = self.allocate(0x40);
155 let y_ptr = x_ptr + 0x20;
156 let x_cd_ptr = offset;
157 let y_cd_ptr = offset + 0x20;
158 let validate_code = self.validate_ec_point();
159 let code = format!(
160 "
161 {{
162 let x := calldataload({x_cd_ptr:#x})
163 mstore({x_ptr:#x}, x)
164 let y := calldataload({y_cd_ptr:#x})
165 mstore({y_ptr:#x}, y)
166 {validate_code}
167 }}"
168 );
169 self.code.borrow_mut().runtime_append(code);
170 self.ec_point(Value::Memory(x_ptr))
171 }
172
173 pub fn ec_point_from_limbs<const LIMBS: usize, const BITS: usize>(
175 self: &Rc<Self>,
176 x_limbs: [&Scalar; LIMBS],
177 y_limbs: [&Scalar; LIMBS],
178 ) -> EcPoint {
179 let ptr = self.allocate(0x40);
180 let mut code = String::new();
181 for (idx, limb) in x_limbs.iter().enumerate() {
182 let limb_i = self.push(limb);
183 let shift = idx * BITS;
184 if idx == 0 {
185 code.push_str(format!("let x := {limb_i}\n").as_str());
186 } else {
187 code.push_str(format!("x := add(x, shl({shift}, {limb_i}))\n").as_str());
188 }
189 }
190 let x_ptr = ptr;
191 code.push_str(format!("mstore({x_ptr}, x)\n").as_str());
192 for (idx, limb) in y_limbs.iter().enumerate() {
193 let limb_i = self.push(limb);
194 let shift = idx * BITS;
195 if idx == 0 {
196 code.push_str(format!("let y := {limb_i}\n").as_str());
197 } else {
198 code.push_str(format!("y := add(y, shl({shift}, {limb_i}))\n").as_str());
199 }
200 }
201 let y_ptr = ptr + 0x20;
202 code.push_str(format!("mstore({y_ptr}, y)\n").as_str());
203 let validate_code = self.validate_ec_point();
204 let code = format!(
205 "{{
206 {code}
207 {validate_code}
208 }}"
209 );
210 self.code.borrow_mut().runtime_append(code);
211 self.ec_point(Value::Memory(ptr))
212 }
213
214 fn validate_ec_point(self: &Rc<Self>) -> String {
215 "success := and(validate_ec_point(x, y), success)".to_string()
216 }
217
218 pub(crate) fn scalar(self: &Rc<Self>, value: Value<U256>) -> Scalar {
219 let value = if matches!(value, Value::Constant(_) | Value::Memory(_) | Value::Negated(_)) {
220 value
221 } else {
222 let identifier = value.identifier();
223 let some_ptr = self.cache.borrow().get(&identifier).cloned();
224 let ptr = if let Some(ptr) = some_ptr {
225 ptr
226 } else {
227 let v = self.push(&Scalar { loader: self.clone(), value });
228 let ptr = self.allocate(0x20);
229 self.code.borrow_mut().runtime_append(format!("mstore({ptr:#x}, {v})"));
230 self.cache.borrow_mut().insert(identifier, ptr);
231 ptr
232 };
233 Value::Memory(ptr)
234 };
235 Scalar { loader: self.clone(), value }
236 }
237
238 fn ec_point(self: &Rc<Self>, value: Value<(U256, U256)>) -> EcPoint {
239 EcPoint { loader: self.clone(), value }
240 }
241
242 pub fn keccak256(self: &Rc<Self>, ptr: usize, len: usize) -> usize {
245 let hash_ptr = self.allocate(0x20);
246 let code = format!("mstore({hash_ptr:#x}, keccak256({ptr:#x}, {len}))");
247 self.code.borrow_mut().runtime_append(code);
248 hash_ptr
249 }
250
251 pub fn copy_scalar(self: &Rc<Self>, scalar: &Scalar, ptr: usize) {
253 let scalar = self.push(scalar);
254 self.code.borrow_mut().runtime_append(format!("mstore({ptr:#x}, {scalar})"));
255 }
256
257 pub fn dup_scalar(self: &Rc<Self>, scalar: &Scalar) -> Scalar {
259 let ptr = self.allocate(0x20);
260 self.copy_scalar(scalar, ptr);
261 self.scalar(Value::Memory(ptr))
262 }
263
264 pub fn dup_ec_point(self: &Rc<Self>, value: &EcPoint) -> EcPoint {
266 let ptr = self.allocate(0x40);
267 match value.value {
268 Value::Constant((x, y)) => {
269 let x_ptr = ptr;
270 let y_ptr = ptr + 0x20;
271 let x = hex_encode_u256(&x);
272 let y = hex_encode_u256(&y);
273 let code = format!(
274 "mstore({x_ptr:#x}, {x})
275 mstore({y_ptr:#x}, {y})"
276 );
277 self.code.borrow_mut().runtime_append(code);
278 }
279 Value::Memory(src_ptr) => {
280 let x_ptr = ptr;
281 let y_ptr = ptr + 0x20;
282 let src_x = src_ptr;
283 let src_y = src_ptr + 0x20;
284 let code = format!(
285 "mstore({x_ptr:#x}, mload({src_x:#x}))
286 mstore({y_ptr:#x}, mload({src_y:#x}))"
287 );
288 self.code.borrow_mut().runtime_append(code);
289 }
290 Value::Negated(_) | Value::Sum(_, _) | Value::Product(_, _) => {
291 unreachable!()
292 }
293 }
294 self.ec_point(Value::Memory(ptr))
295 }
296
297 fn staticcall(self: &Rc<Self>, precompile: Precompiled, cd_ptr: usize, rd_ptr: usize) {
298 let (cd_len, rd_len) = match precompile {
299 Precompiled::BigModExp => (0xc0, 0x20),
300 Precompiled::Bn254Add => (0x80, 0x40),
301 Precompiled::Bn254ScalarMul => (0x60, 0x40),
302 Precompiled::Bn254Pairing => (0x180, 0x20),
303 };
304 let a = precompile as usize;
305 let code = format!("success := and(eq(staticcall(gas(), {a:#x}, {cd_ptr:#x}, {cd_len:#x}, {rd_ptr:#x}, {rd_len:#x}), 1), success)");
306 self.code.borrow_mut().runtime_append(code);
307 }
308
309 fn invert(self: &Rc<Self>, scalar: &Scalar) -> Scalar {
310 let rd_ptr = self.allocate(0x20);
311 let [cd_ptr, ..] = [
312 &self.scalar(Value::Constant(U256::from(0x20))),
313 &self.scalar(Value::Constant(U256::from(0x20))),
314 &self.scalar(Value::Constant(U256::from(0x20))),
315 scalar,
316 &self.scalar(Value::Constant(self.scalar_modulus - U256::from(2))),
317 &self.scalar(Value::Constant(self.scalar_modulus)),
318 ]
319 .map(|value| self.dup_scalar(value).ptr());
320 self.staticcall(Precompiled::BigModExp, cd_ptr, rd_ptr);
321 self.scalar(Value::Memory(rd_ptr))
322 }
323
324 fn ec_point_add(self: &Rc<Self>, lhs: &EcPoint, rhs: &EcPoint) -> EcPoint {
325 let rd_ptr = self.dup_ec_point(lhs).ptr();
326 self.dup_ec_point(rhs);
327 self.staticcall(Precompiled::Bn254Add, rd_ptr, rd_ptr);
328 self.ec_point(Value::Memory(rd_ptr))
329 }
330
331 fn ec_point_scalar_mul(self: &Rc<Self>, ec_point: &EcPoint, scalar: &Scalar) -> EcPoint {
332 let rd_ptr = self.dup_ec_point(ec_point).ptr();
333 self.dup_scalar(scalar);
334 self.staticcall(Precompiled::Bn254ScalarMul, rd_ptr, rd_ptr);
335 self.ec_point(Value::Memory(rd_ptr))
336 }
337
338 pub fn pairing(
340 self: &Rc<Self>,
341 lhs: &EcPoint,
342 g2: (U256, U256, U256, U256),
343 rhs: &EcPoint,
344 minus_s_g2: (U256, U256, U256, U256),
345 ) {
346 let rd_ptr = self.dup_ec_point(lhs).ptr();
347 self.allocate(0x80);
348 let g2_0 = hex_encode_u256(&g2.0);
349 let g2_0_ptr = rd_ptr + 0x40;
350 let g2_1 = hex_encode_u256(&g2.1);
351 let g2_1_ptr = rd_ptr + 0x60;
352 let g2_2 = hex_encode_u256(&g2.2);
353 let g2_2_ptr = rd_ptr + 0x80;
354 let g2_3 = hex_encode_u256(&g2.3);
355 let g2_3_ptr = rd_ptr + 0xa0;
356 let code = format!(
357 "mstore({g2_0_ptr:#x}, {g2_0})
358 mstore({g2_1_ptr:#x}, {g2_1})
359 mstore({g2_2_ptr:#x}, {g2_2})
360 mstore({g2_3_ptr:#x}, {g2_3})"
361 );
362 self.code.borrow_mut().runtime_append(code);
363 self.dup_ec_point(rhs);
364 self.allocate(0x80);
365 let minus_s_g2_0 = hex_encode_u256(&minus_s_g2.0);
366 let minus_s_g2_0_ptr = rd_ptr + 0x100;
367 let minus_s_g2_1 = hex_encode_u256(&minus_s_g2.1);
368 let minus_s_g2_1_ptr = rd_ptr + 0x120;
369 let minus_s_g2_2 = hex_encode_u256(&minus_s_g2.2);
370 let minus_s_g2_2_ptr = rd_ptr + 0x140;
371 let minus_s_g2_3 = hex_encode_u256(&minus_s_g2.3);
372 let minus_s_g2_3_ptr = rd_ptr + 0x160;
373 let code = format!(
374 "mstore({minus_s_g2_0_ptr:#x}, {minus_s_g2_0})
375 mstore({minus_s_g2_1_ptr:#x}, {minus_s_g2_1})
376 mstore({minus_s_g2_2_ptr:#x}, {minus_s_g2_2})
377 mstore({minus_s_g2_3_ptr:#x}, {minus_s_g2_3})"
378 );
379 self.code.borrow_mut().runtime_append(code);
380 self.staticcall(Precompiled::Bn254Pairing, rd_ptr, rd_ptr);
381 let code = format!("success := and(eq(mload({rd_ptr:#x}), 1), success)");
382 self.code.borrow_mut().runtime_append(code);
383 }
384
385 fn add(self: &Rc<Self>, lhs: &Scalar, rhs: &Scalar) -> Scalar {
386 if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) {
387 let out = (U512::from(*lhs) + U512::from(*rhs)) % U512::from(self.scalar_modulus);
388 return self.scalar(Value::Constant(U256::from(out)));
389 }
390
391 self.scalar(Value::Sum(Box::new(lhs.value.clone()), Box::new(rhs.value.clone())))
392 }
393
394 fn sub(self: &Rc<Self>, lhs: &Scalar, rhs: &Scalar) -> Scalar {
395 if rhs.is_const() {
396 return self.add(lhs, &self.neg(rhs));
397 }
398
399 self.scalar(Value::Sum(
400 Box::new(lhs.value.clone()),
401 Box::new(Value::Negated(Box::new(rhs.value.clone()))),
402 ))
403 }
404
405 fn mul(self: &Rc<Self>, lhs: &Scalar, rhs: &Scalar) -> Scalar {
406 if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) {
407 let out = (U512::from(*lhs) * U512::from(*rhs)) % U512::from(self.scalar_modulus);
408 return self.scalar(Value::Constant(U256::from(out)));
409 }
410
411 self.scalar(Value::Product(Box::new(lhs.value.clone()), Box::new(rhs.value.clone())))
412 }
413
414 fn neg(self: &Rc<Self>, scalar: &Scalar) -> Scalar {
415 if let Value::Constant(constant) = scalar.value {
416 return self.scalar(Value::Constant(self.scalar_modulus - constant));
417 }
418
419 self.scalar(Value::Negated(Box::new(scalar.value.clone())))
420 }
421}
422
423#[cfg(test)]
424impl EvmLoader {
425 fn start_gas_metering(self: &Rc<Self>, _: &str) {
426 }
428
429 fn end_gas_metering(self: &Rc<Self>) {
430 }
432
433 #[allow(dead_code)]
434 fn print_gas_metering(self: &Rc<Self>, _: Vec<u64>) {
435 }
437}
438
439#[derive(Clone)]
441pub struct EcPoint {
442 loader: Rc<EvmLoader>,
443 value: Value<(U256, U256)>,
444}
445
446impl EcPoint {
447 pub(crate) fn loader(&self) -> &Rc<EvmLoader> {
448 &self.loader
449 }
450
451 pub(crate) fn value(&self) -> Value<(U256, U256)> {
452 self.value.clone()
453 }
454
455 pub(crate) fn ptr(&self) -> usize {
456 match self.value {
457 Value::Memory(ptr) => ptr,
458 _ => unreachable!(),
459 }
460 }
461}
462
463impl Debug for EcPoint {
464 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465 f.debug_struct("EcPoint").field("value", &self.value).finish()
466 }
467}
468
469impl PartialEq for EcPoint {
470 fn eq(&self, other: &Self) -> bool {
471 self.value == other.value
472 }
473}
474
475impl<C> LoadedEcPoint<C> for EcPoint
476where
477 C: CurveAffine,
478 C::ScalarExt: PrimeField<Repr = [u8; 0x20]>,
479{
480 type Loader = Rc<EvmLoader>;
481
482 fn loader(&self) -> &Rc<EvmLoader> {
483 &self.loader
484 }
485}
486
487#[derive(Clone)]
489pub struct Scalar {
490 loader: Rc<EvmLoader>,
491 value: Value<U256>,
492}
493
494impl Scalar {
495 pub(crate) fn loader(&self) -> &Rc<EvmLoader> {
496 &self.loader
497 }
498
499 pub(crate) fn value(&self) -> Value<U256> {
500 self.value.clone()
501 }
502
503 pub(crate) fn is_const(&self) -> bool {
504 matches!(self.value, Value::Constant(_))
505 }
506
507 pub(crate) fn ptr(&self) -> usize {
508 match self.value {
509 Value::Memory(ptr) => ptr,
510 _ => *self.loader.cache.borrow().get(&self.value.identifier()).unwrap(),
511 }
512 }
513}
514
515impl Debug for Scalar {
516 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
517 f.debug_struct("Scalar").field("value", &self.value).finish()
518 }
519}
520
521impl Add for Scalar {
522 type Output = Self;
523
524 fn add(self, rhs: Self) -> Self {
525 self.loader.add(&self, &rhs)
526 }
527}
528
529impl Sub for Scalar {
530 type Output = Self;
531
532 fn sub(self, rhs: Self) -> Self {
533 self.loader.sub(&self, &rhs)
534 }
535}
536
537impl Mul for Scalar {
538 type Output = Self;
539
540 fn mul(self, rhs: Self) -> Self {
541 self.loader.mul(&self, &rhs)
542 }
543}
544
545impl Neg for Scalar {
546 type Output = Self;
547
548 fn neg(self) -> Self {
549 self.loader.neg(&self)
550 }
551}
552
553impl<'a> Add<&'a Self> for Scalar {
554 type Output = Self;
555
556 fn add(self, rhs: &'a Self) -> Self {
557 self.loader.add(&self, rhs)
558 }
559}
560
561impl<'a> Sub<&'a Self> for Scalar {
562 type Output = Self;
563
564 fn sub(self, rhs: &'a Self) -> Self {
565 self.loader.sub(&self, rhs)
566 }
567}
568
569impl<'a> Mul<&'a Self> for Scalar {
570 type Output = Self;
571
572 fn mul(self, rhs: &'a Self) -> Self {
573 self.loader.mul(&self, rhs)
574 }
575}
576
577impl AddAssign for Scalar {
578 fn add_assign(&mut self, rhs: Self) {
579 *self = self.loader.add(self, &rhs);
580 }
581}
582
583impl SubAssign for Scalar {
584 fn sub_assign(&mut self, rhs: Self) {
585 *self = self.loader.sub(self, &rhs);
586 }
587}
588
589impl MulAssign for Scalar {
590 fn mul_assign(&mut self, rhs: Self) {
591 *self = self.loader.mul(self, &rhs);
592 }
593}
594
595impl<'a> AddAssign<&'a Self> for Scalar {
596 fn add_assign(&mut self, rhs: &'a Self) {
597 *self = self.loader.add(self, rhs);
598 }
599}
600
601impl<'a> SubAssign<&'a Self> for Scalar {
602 fn sub_assign(&mut self, rhs: &'a Self) {
603 *self = self.loader.sub(self, rhs);
604 }
605}
606
607impl<'a> MulAssign<&'a Self> for Scalar {
608 fn mul_assign(&mut self, rhs: &'a Self) {
609 *self = self.loader.mul(self, rhs);
610 }
611}
612
613impl FieldOps for Scalar {
614 fn invert(&self) -> Option<Scalar> {
615 Some(self.loader.invert(self))
616 }
617}
618
619impl PartialEq for Scalar {
620 fn eq(&self, other: &Self) -> bool {
621 self.value == other.value
622 }
623}
624
625impl<F: PrimeField<Repr = [u8; 0x20]>> LoadedScalar<F> for Scalar {
626 type Loader = Rc<EvmLoader>;
627
628 fn loader(&self) -> &Self::Loader {
629 &self.loader
630 }
631
632 fn pow_var(&self, _exp: &Self, _exp_max_bits: usize) -> Self {
633 todo!()
634 }
635}
636
637impl<C> EcPointLoader<C> for Rc<EvmLoader>
638where
639 C: CurveAffine,
640 C::Scalar: PrimeField<Repr = [u8; 0x20]>,
641{
642 type LoadedEcPoint = EcPoint;
643
644 fn ec_point_load_const(&self, value: &C) -> EcPoint {
645 let coordinates = value.coordinates().unwrap();
646 let [x, y] = [coordinates.x(), coordinates.y()]
647 .map(|coordinate| U256::try_from_le_slice(coordinate.to_repr().as_ref()).unwrap());
648 self.ec_point(Value::Constant((x, y)))
649 }
650
651 fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) {
652 unimplemented!()
653 }
654
655 fn multi_scalar_multiplication(
656 pairs: &[(&<Self as ScalarLoader<C::Scalar>>::LoadedScalar, &EcPoint)],
657 ) -> EcPoint {
658 pairs
659 .iter()
660 .cloned()
661 .map(|(scalar, ec_point)| match scalar.value {
662 Value::Constant(constant) if U256::from(1) == constant => ec_point.clone(),
663 _ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar),
664 })
665 .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point))
666 .expect("pairs should not be empty")
667 }
668}
669
670impl<F: PrimeField<Repr = [u8; 0x20]>> ScalarLoader<F> for Rc<EvmLoader> {
671 type LoadedScalar = Scalar;
672
673 fn load_const(&self, value: &F) -> Scalar {
674 self.scalar(Value::Constant(fe_to_u256(*value)))
675 }
676
677 fn assert_eq(&self, _: &str, _: &Scalar, _: &Scalar) {
678 unimplemented!()
679 }
680
681 fn sum_with_coeff_and_const(&self, values: &[(F, &Scalar)], constant: F) -> Scalar {
682 if values.is_empty() {
683 return self.load_const(&constant);
684 }
685
686 let push_addend = |(coeff, value): &(F, &Scalar)| {
687 assert_ne!(*coeff, F::ZERO);
688 match (*coeff == F::ONE, &value.value) {
689 (true, _) => self.push(value),
690 (false, Value::Constant(value)) => self.push(
691 &self.scalar(Value::Constant(fe_to_u256(*coeff * u256_to_fe::<F>(*value)))),
692 ),
693 (false, _) => {
694 let value = self.push(value);
695 let coeff = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff))));
696 format!("mulmod({value}, {coeff}, f_q)")
697 }
698 }
699 };
700
701 let mut values = values.iter();
702 let initial_value = if constant == F::ZERO {
703 push_addend(values.next().unwrap())
704 } else {
705 self.push(&self.scalar(Value::Constant(fe_to_u256(constant))))
706 };
707
708 let mut code = format!("let result := {initial_value}\n");
709 for value in values {
710 let v = push_addend(value);
711 let addend = format!("result := addmod({v}, result, f_q)\n");
712 code.push_str(addend.as_str());
713 }
714
715 let ptr = self.allocate(0x20);
716 code.push_str(format!("mstore({ptr}, result)").as_str());
717 self.code.borrow_mut().runtime_append(format!(
718 "{{
719 {code}
720 }}"
721 ));
722
723 self.scalar(Value::Memory(ptr))
724 }
725
726 fn sum_products_with_coeff_and_const(
727 &self,
728 values: &[(F, &Scalar, &Scalar)],
729 constant: F,
730 ) -> Scalar {
731 if values.is_empty() {
732 return self.load_const(&constant);
733 }
734
735 let push_addend = |(coeff, lhs, rhs): &(F, &Scalar, &Scalar)| {
736 assert_ne!(*coeff, F::ZERO);
737 match (*coeff == F::ONE, &lhs.value, &rhs.value) {
738 (_, Value::Constant(lhs), Value::Constant(rhs)) => {
739 self.push(&self.scalar(Value::Constant(fe_to_u256(
740 *coeff * u256_to_fe::<F>(*lhs) * u256_to_fe::<F>(*rhs),
741 ))))
742 }
743 (_, value @ Value::Memory(_), Value::Constant(constant))
744 | (_, Value::Constant(constant), value @ Value::Memory(_)) => {
745 let v1 = self.push(&self.scalar(value.clone()));
746 let v2 =
747 self.push(&self.scalar(Value::Constant(fe_to_u256(
748 *coeff * u256_to_fe::<F>(*constant),
749 ))));
750 format!("mulmod({v1}, {v2}, f_q)")
751 }
752 (true, _, _) => {
753 let rhs = self.push(rhs);
754 let lhs = self.push(lhs);
755 format!("mulmod({rhs}, {lhs}, f_q)")
756 }
757 (false, _, _) => {
758 let rhs = self.push(rhs);
759 let lhs = self.push(lhs);
760 let value = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff))));
761 format!("mulmod({rhs}, mulmod({lhs}, {value}, f_q), f_q)")
762 }
763 }
764 };
765
766 let mut values = values.iter();
767 let initial_value = if constant == F::ZERO {
768 push_addend(values.next().unwrap())
769 } else {
770 self.push(&self.scalar(Value::Constant(fe_to_u256(constant))))
771 };
772
773 let mut code = format!("let result := {initial_value}\n");
774 for value in values {
775 let v = push_addend(value);
776 let addend = format!("result := addmod({v}, result, f_q)\n");
777 code.push_str(addend.as_str());
778 }
779
780 let ptr = self.allocate(0x20);
781 code.push_str(format!("mstore({ptr}, result)").as_str());
782 self.code.borrow_mut().runtime_append(format!(
783 "{{
784 {code}
785 }}"
786 ));
787
788 self.scalar(Value::Memory(ptr))
789 }
790
791 fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Scalar>) {
801 let values = values.into_iter().collect_vec();
802 let loader = &values.first().unwrap().loader;
803 let products = iter::once(values[0].clone())
804 .chain(
805 iter::repeat_with(|| loader.allocate(0x20))
806 .map(|ptr| loader.scalar(Value::Memory(ptr)))
807 .take(values.len() - 1),
808 )
809 .collect_vec();
810
811 let initial_value = loader.push(products.first().unwrap());
812 let mut code = format!("let prod := {initial_value}\n");
813 for (value, product) in values.iter().zip(products.iter()).skip(1) {
814 let v = loader.push(value);
815 let ptr = product.ptr();
816 code.push_str(
817 format!(
818 "
819 prod := mulmod({v}, prod, f_q)
820 mstore({ptr:#x}, prod)
821 "
822 )
823 .as_str(),
824 );
825 }
826 loader.code.borrow_mut().runtime_append(format!(
827 "{{
828 {code}
829 }}"
830 ));
831
832 let inv = loader.push(&loader.invert(products.last().unwrap()));
833
834 let mut code = format!(
835 "
836 let inv := {inv}
837 let v
838 "
839 );
840 for (value, product) in
841 values.iter().rev().zip(products.iter().rev().skip(1).map(Some).chain(iter::once(None)))
842 {
843 if let Some(product) = product {
844 let val_ptr = value.ptr();
845 let prod_ptr = product.ptr();
846 let v = loader.push(value);
847 code.push_str(
848 format!(
849 "
850 v := {v}
851 mstore({val_ptr}, mulmod(mload({prod_ptr:#x}), inv, f_q))
852 inv := mulmod(v, inv, f_q)
853 "
854 )
855 .as_str(),
856 );
857 } else {
858 let ptr = value.ptr();
859 code.push_str(format!("mstore({ptr:#x}, inv)\n").as_str());
860 }
861 }
862 loader.code.borrow_mut().runtime_append(format!(
863 "{{
864 {code}
865 }}"
866 ));
867 }
868}
869
870impl<C> Loader<C> for Rc<EvmLoader>
871where
872 C: CurveAffine,
873 C::Scalar: PrimeField<Repr = [u8; 0x20]>,
874{
875 #[cfg(test)]
876 fn start_cost_metering(&self, identifier: &str) {
877 self.start_gas_metering(identifier)
878 }
879
880 #[cfg(test)]
881 fn end_cost_metering(&self) {
882 self.end_gas_metering()
883 }
884}