openvm_native_compiler/ir/
var.rs

1use std::array;
2
3use itertools::izip;
4use serde::{Deserialize, Serialize};
5
6use super::{Builder, Config, Ptr, RVar};
7
8pub trait Variable<C: Config>: Clone {
9    type Expression: From<Self>;
10
11    fn uninit(builder: &mut Builder<C>) -> Self;
12
13    fn assign(&self, src: Self::Expression, builder: &mut Builder<C>);
14
15    fn assert_eq(
16        lhs: impl Into<Self::Expression>,
17        rhs: impl Into<Self::Expression>,
18        builder: &mut Builder<C>,
19    );
20
21    fn eval(builder: &mut Builder<C>, expr: impl Into<Self::Expression>) -> Self {
22        let dst: Self = builder.uninit();
23        dst.assign(expr.into(), builder);
24        dst
25    }
26}
27
28#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
29pub struct MemIndex<N> {
30    pub index: RVar<N>,
31    pub offset: usize,
32    pub size: usize,
33}
34
35pub trait MemVariable<C: Config>: Variable<C> {
36    fn size_of() -> usize;
37    /// Loads the variable from the heap.
38    fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>);
39    /// Stores the variable to the heap.
40    fn store(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>);
41}
42
43pub trait FromConstant<C: Config> {
44    type Constant;
45
46    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self;
47}
48
49impl<C: Config, T: Variable<C>, const N: usize> Variable<C> for [T; N] {
50    type Expression = [T; N];
51
52    fn uninit(builder: &mut Builder<C>) -> Self {
53        array::from_fn(|_| T::uninit(builder))
54    }
55
56    fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
57        self.iter()
58            .zip(src)
59            .for_each(|(d, s)| d.assign(s.into(), builder));
60    }
61
62    fn assert_eq(
63        lhs: impl Into<Self::Expression>,
64        rhs: impl Into<Self::Expression>,
65        builder: &mut Builder<C>,
66    ) {
67        izip!(lhs.into(), rhs.into()).for_each(|(l, r)| T::assert_eq(l, r, builder));
68    }
69}
70
71impl<C: Config, T: MemVariable<C>, const N: usize> MemVariable<C> for [T; N] {
72    fn size_of() -> usize {
73        N * T::size_of()
74    }
75
76    fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
77        for (i, v) in self.iter().enumerate() {
78            let mut v_idx = index;
79            v_idx.offset += i * T::size_of();
80            v.load(ptr, v_idx, builder);
81        }
82    }
83
84    fn store(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
85        for (i, v) in self.iter().enumerate() {
86            let mut v_idx = index;
87            v_idx.offset += i * T::size_of();
88            v.store(ptr, v_idx, builder);
89        }
90    }
91}