openvm_native_compiler/ir/
var.rs
1use std::array;
2
3use itertools::izip;
4use serde::{Deserialize, Serialize};
5
6use super::{Builder, Config, Ptr, RVar};
7
8pub trait Variable<C: Config>: Clone {
9 type Expression: From<Self>;
10
11 fn uninit(builder: &mut Builder<C>) -> Self;
12
13 fn assign(&self, src: Self::Expression, builder: &mut Builder<C>);
14
15 fn assert_eq(
16 lhs: impl Into<Self::Expression>,
17 rhs: impl Into<Self::Expression>,
18 builder: &mut Builder<C>,
19 );
20
21 fn eval(builder: &mut Builder<C>, expr: impl Into<Self::Expression>) -> Self {
22 let dst: Self = builder.uninit();
23 dst.assign(expr.into(), builder);
24 dst
25 }
26}
27
28#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
29pub struct MemIndex<N> {
30 pub index: RVar<N>,
31 pub offset: usize,
32 pub size: usize,
33}
34
35pub trait MemVariable<C: Config>: Variable<C> {
36 fn size_of() -> usize;
37 fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>);
39 fn store(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>);
41}
42
43pub trait FromConstant<C: Config> {
44 type Constant;
45
46 fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self;
47}
48
49impl<C: Config, T: Variable<C>, const N: usize> Variable<C> for [T; N] {
50 type Expression = [T; N];
51
52 fn uninit(builder: &mut Builder<C>) -> Self {
53 array::from_fn(|_| T::uninit(builder))
54 }
55
56 fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
57 self.iter()
58 .zip(src)
59 .for_each(|(d, s)| d.assign(s.into(), builder));
60 }
61
62 fn assert_eq(
63 lhs: impl Into<Self::Expression>,
64 rhs: impl Into<Self::Expression>,
65 builder: &mut Builder<C>,
66 ) {
67 izip!(lhs.into(), rhs.into()).for_each(|(l, r)| T::assert_eq(l, r, builder));
68 }
69}
70
71impl<C: Config, T: MemVariable<C>, const N: usize> MemVariable<C> for [T; N] {
72 fn size_of() -> usize {
73 N * T::size_of()
74 }
75
76 fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
77 for (i, v) in self.iter().enumerate() {
78 let mut v_idx = index;
79 v_idx.offset += i * T::size_of();
80 v.load(ptr, v_idx, builder);
81 }
82 }
83
84 fn store(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
85 for (i, v) in self.iter().enumerate() {
86 let mut v_idx = index;
87 v_idx.offset += i * T::size_of();
88 v.store(ptr, v_idx, builder);
89 }
90 }
91}