openvm_native_compiler/ir/
var.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
use std::array;

use itertools::izip;
use serde::{Deserialize, Serialize};

use super::{Builder, Config, Ptr, RVar};

pub trait Variable<C: Config>: Clone {
    type Expression: From<Self>;

    fn uninit(builder: &mut Builder<C>) -> Self;

    fn assign(&self, src: Self::Expression, builder: &mut Builder<C>);

    fn assert_eq(
        lhs: impl Into<Self::Expression>,
        rhs: impl Into<Self::Expression>,
        builder: &mut Builder<C>,
    );

    fn assert_ne(
        lhs: impl Into<Self::Expression>,
        rhs: impl Into<Self::Expression>,
        builder: &mut Builder<C>,
    );

    fn eval(builder: &mut Builder<C>, expr: impl Into<Self::Expression>) -> Self {
        let dst: Self = builder.uninit();
        dst.assign(expr.into(), builder);
        dst
    }
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct MemIndex<N> {
    pub index: RVar<N>,
    pub offset: usize,
    pub size: usize,
}

pub trait MemVariable<C: Config>: Variable<C> {
    fn size_of() -> usize;
    /// Loads the variable from the heap.
    fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>);
    /// Stores the variable to the heap.
    fn store(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>);
}

pub trait FromConstant<C: Config> {
    type Constant;

    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self;
}

impl<C: Config, T: Variable<C>, const N: usize> Variable<C> for [T; N] {
    type Expression = [T; N];

    fn uninit(builder: &mut Builder<C>) -> Self {
        array::from_fn(|_| T::uninit(builder))
    }

    fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
        self.iter()
            .zip(src)
            .for_each(|(d, s)| d.assign(s.into(), builder));
    }

    fn assert_eq(
        lhs: impl Into<Self::Expression>,
        rhs: impl Into<Self::Expression>,
        builder: &mut Builder<C>,
    ) {
        izip!(lhs.into(), rhs.into()).for_each(|(l, r)| T::assert_eq(l, r, builder));
    }

    fn assert_ne(
        _lhs: impl Into<Self::Expression>,
        _rhs: impl Into<Self::Expression>,
        _builder: &mut Builder<C>,
    ) {
        unimplemented!("assert_ne cannot be implemented for slices")
    }
}

impl<C: Config, T: MemVariable<C>, const N: usize> MemVariable<C> for [T; N] {
    fn size_of() -> usize {
        N * T::size_of()
    }

    fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
        for (i, v) in self.iter().enumerate() {
            let mut v_idx = index;
            v_idx.offset += i * T::size_of();
            v.load(ptr, v_idx, builder);
        }
    }

    fn store(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
        for (i, v) in self.iter().enumerate() {
            let mut v_idx = index;
            v_idx.offset += i * T::size_of();
            v.store(ptr, v_idx, builder);
        }
    }
}