openvm_native_compiler/ir/
ptr.rs

1use core::ops::{Add, Sub};
2
3use openvm_stark_backend::p3_field::{Field, PrimeField};
4use serde::{Deserialize, Serialize};
5
6use super::{Builder, Config, DslIr, MemIndex, MemVariable, RVar, SymbolicVar, Var, Variable};
7
8/// A point to a location in memory.
9#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
10pub struct Ptr<N> {
11    pub address: Var<N>,
12}
13
14pub struct SymbolicPtr<N: Field> {
15    pub address: SymbolicVar<N>,
16}
17
18impl<C: Config> Builder<C> {
19    /// Allocates an array on the heap.
20    pub fn alloc(&mut self, len: impl Into<RVar<C::N>>, size: usize) -> Ptr<C::N> {
21        assert!(
22            !self.flags.static_only,
23            "Cannot allocate memory in static mode"
24        );
25        let ptr = Ptr::uninit(self);
26        self.push(DslIr::Alloc(ptr, len.into(), size));
27        ptr
28    }
29
30    /// Loads a value from memory.
31    pub fn load<V: MemVariable<C>>(&mut self, var: V, ptr: Ptr<C::N>, index: MemIndex<C::N>) {
32        var.load(ptr, index, self);
33    }
34
35    /// Stores a value to memory.
36    pub fn store<V: MemVariable<C>>(&mut self, ptr: Ptr<C::N>, index: MemIndex<C::N>, value: V) {
37        value.store(ptr, index, self);
38    }
39
40    pub fn load_heap_ptr(&mut self) -> Ptr<C::N> {
41        assert!(
42            !self.flags.static_only,
43            "heap pointer only exists in dynamic mode"
44        );
45        let ptr = Ptr::uninit(self);
46        self.push(DslIr::LoadHeapPtr(ptr));
47        ptr
48    }
49
50    pub fn store_heap_ptr(&mut self, ptr: Ptr<C::N>) {
51        assert!(
52            !self.flags.static_only,
53            "heap pointer only exists in dynamic mode"
54        );
55        self.push(DslIr::StoreHeapPtr(ptr));
56    }
57}
58
59impl<C: Config> Variable<C> for Ptr<C::N> {
60    type Expression = SymbolicPtr<C::N>;
61
62    fn uninit(builder: &mut Builder<C>) -> Self {
63        Ptr {
64            address: Var::uninit(builder),
65        }
66    }
67
68    fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
69        self.address.assign(src.address, builder);
70    }
71
72    fn assert_eq(
73        lhs: impl Into<Self::Expression>,
74        rhs: impl Into<Self::Expression>,
75        builder: &mut Builder<C>,
76    ) {
77        Var::assert_eq(lhs.into().address, rhs.into().address, builder);
78    }
79}
80
81impl<C: Config> MemVariable<C> for Ptr<C::N> {
82    fn size_of() -> usize {
83        1
84    }
85
86    fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
87        self.address.load(ptr, index, builder);
88    }
89
90    fn store(&self, ptr: Ptr<<C as Config>::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
91        self.address.store(ptr, index, builder);
92    }
93}
94
95impl<N: Field> From<Ptr<N>> for SymbolicPtr<N> {
96    fn from(ptr: Ptr<N>) -> Self {
97        SymbolicPtr {
98            address: SymbolicVar::from(ptr.address),
99        }
100    }
101}
102
103impl<N: Field> Add for Ptr<N> {
104    type Output = SymbolicPtr<N>;
105
106    fn add(self, rhs: Self) -> Self::Output {
107        SymbolicPtr {
108            address: self.address + rhs.address,
109        }
110    }
111}
112
113impl<N: Field> Sub for Ptr<N> {
114    type Output = SymbolicPtr<N>;
115
116    fn sub(self, rhs: Self) -> Self::Output {
117        SymbolicPtr {
118            address: self.address - rhs.address,
119        }
120    }
121}
122
123impl<N: Field> Add for SymbolicPtr<N> {
124    type Output = Self;
125
126    fn add(self, rhs: Self) -> Self {
127        Self {
128            address: self.address + rhs.address,
129        }
130    }
131}
132
133impl<N: Field> Sub for SymbolicPtr<N> {
134    type Output = Self;
135
136    fn sub(self, rhs: Self) -> Self {
137        Self {
138            address: self.address - rhs.address,
139        }
140    }
141}
142
143impl<N: Field> Add<Ptr<N>> for SymbolicPtr<N> {
144    type Output = Self;
145
146    fn add(self, rhs: Ptr<N>) -> Self {
147        Self {
148            address: self.address + rhs.address,
149        }
150    }
151}
152
153impl<N: Field> Sub<Ptr<N>> for SymbolicPtr<N> {
154    type Output = Self;
155
156    fn sub(self, rhs: Ptr<N>) -> Self {
157        Self {
158            address: self.address - rhs.address,
159        }
160    }
161}
162
163impl<N: Field> Add<SymbolicPtr<N>> for Ptr<N> {
164    type Output = SymbolicPtr<N>;
165
166    fn add(self, rhs: SymbolicPtr<N>) -> SymbolicPtr<N> {
167        SymbolicPtr {
168            address: self.address + rhs.address,
169        }
170    }
171}
172
173impl<N: Field> Sub<SymbolicPtr<N>> for Ptr<N> {
174    type Output = SymbolicPtr<N>;
175
176    fn sub(self, rhs: SymbolicPtr<N>) -> SymbolicPtr<N> {
177        SymbolicPtr {
178            address: self.address - rhs.address,
179        }
180    }
181}
182
183impl<N: Field, RHS: Into<SymbolicVar<N>>> Add<RHS> for Ptr<N> {
184    type Output = SymbolicPtr<N>;
185
186    fn add(self, rhs: RHS) -> SymbolicPtr<N> {
187        SymbolicPtr::from(self) + rhs.into()
188    }
189}
190
191impl<N: Field, RHS: Into<SymbolicVar<N>>> Add<RHS> for SymbolicPtr<N> {
192    type Output = SymbolicPtr<N>;
193
194    fn add(self, rhs: RHS) -> SymbolicPtr<N> {
195        SymbolicPtr {
196            address: self.address + rhs.into(),
197        }
198    }
199}
200
201impl<N: PrimeField, RHS: Into<SymbolicVar<N>>> Sub<RHS> for Ptr<N> {
202    type Output = SymbolicPtr<N>;
203
204    fn sub(self, rhs: RHS) -> SymbolicPtr<N> {
205        SymbolicPtr::from(self) - rhs.into()
206    }
207}
208
209impl<N: PrimeField, RHS: Into<SymbolicVar<N>>> Sub<RHS> for SymbolicPtr<N> {
210    type Output = SymbolicPtr<N>;
211
212    fn sub(self, rhs: RHS) -> SymbolicPtr<N> {
213        SymbolicPtr {
214            address: self.address - rhs.into(),
215        }
216    }
217}