openvm_native_compiler/ir/
ptr.rs
1use core::ops::{Add, Sub};
2
3use openvm_stark_backend::p3_field::{Field, PrimeField};
4use serde::{Deserialize, Serialize};
5
6use super::{Builder, Config, DslIr, MemIndex, MemVariable, RVar, SymbolicVar, Var, Variable};
7
8#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
10pub struct Ptr<N> {
11 pub address: Var<N>,
12}
13
14pub struct SymbolicPtr<N: Field> {
15 pub address: SymbolicVar<N>,
16}
17
18impl<C: Config> Builder<C> {
19 pub fn alloc(&mut self, len: impl Into<RVar<C::N>>, size: usize) -> Ptr<C::N> {
21 assert!(
22 !self.flags.static_only,
23 "Cannot allocate memory in static mode"
24 );
25 let ptr = Ptr::uninit(self);
26 self.push(DslIr::Alloc(ptr, len.into(), size));
27 ptr
28 }
29
30 pub fn load<V: MemVariable<C>>(&mut self, var: V, ptr: Ptr<C::N>, index: MemIndex<C::N>) {
32 var.load(ptr, index, self);
33 }
34
35 pub fn store<V: MemVariable<C>>(&mut self, ptr: Ptr<C::N>, index: MemIndex<C::N>, value: V) {
37 value.store(ptr, index, self);
38 }
39
40 pub fn load_heap_ptr(&mut self) -> Ptr<C::N> {
41 assert!(
42 !self.flags.static_only,
43 "heap pointer only exists in dynamic mode"
44 );
45 let ptr = Ptr::uninit(self);
46 self.push(DslIr::LoadHeapPtr(ptr));
47 ptr
48 }
49
50 pub fn store_heap_ptr(&mut self, ptr: Ptr<C::N>) {
51 assert!(
52 !self.flags.static_only,
53 "heap pointer only exists in dynamic mode"
54 );
55 self.push(DslIr::StoreHeapPtr(ptr));
56 }
57}
58
59impl<C: Config> Variable<C> for Ptr<C::N> {
60 type Expression = SymbolicPtr<C::N>;
61
62 fn uninit(builder: &mut Builder<C>) -> Self {
63 Ptr {
64 address: Var::uninit(builder),
65 }
66 }
67
68 fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
69 self.address.assign(src.address, builder);
70 }
71
72 fn assert_eq(
73 lhs: impl Into<Self::Expression>,
74 rhs: impl Into<Self::Expression>,
75 builder: &mut Builder<C>,
76 ) {
77 Var::assert_eq(lhs.into().address, rhs.into().address, builder);
78 }
79}
80
81impl<C: Config> MemVariable<C> for Ptr<C::N> {
82 fn size_of() -> usize {
83 1
84 }
85
86 fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
87 self.address.load(ptr, index, builder);
88 }
89
90 fn store(&self, ptr: Ptr<<C as Config>::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
91 self.address.store(ptr, index, builder);
92 }
93}
94
95impl<N: Field> From<Ptr<N>> for SymbolicPtr<N> {
96 fn from(ptr: Ptr<N>) -> Self {
97 SymbolicPtr {
98 address: SymbolicVar::from(ptr.address),
99 }
100 }
101}
102
103impl<N: Field> Add for Ptr<N> {
104 type Output = SymbolicPtr<N>;
105
106 fn add(self, rhs: Self) -> Self::Output {
107 SymbolicPtr {
108 address: self.address + rhs.address,
109 }
110 }
111}
112
113impl<N: Field> Sub for Ptr<N> {
114 type Output = SymbolicPtr<N>;
115
116 fn sub(self, rhs: Self) -> Self::Output {
117 SymbolicPtr {
118 address: self.address - rhs.address,
119 }
120 }
121}
122
123impl<N: Field> Add for SymbolicPtr<N> {
124 type Output = Self;
125
126 fn add(self, rhs: Self) -> Self {
127 Self {
128 address: self.address + rhs.address,
129 }
130 }
131}
132
133impl<N: Field> Sub for SymbolicPtr<N> {
134 type Output = Self;
135
136 fn sub(self, rhs: Self) -> Self {
137 Self {
138 address: self.address - rhs.address,
139 }
140 }
141}
142
143impl<N: Field> Add<Ptr<N>> for SymbolicPtr<N> {
144 type Output = Self;
145
146 fn add(self, rhs: Ptr<N>) -> Self {
147 Self {
148 address: self.address + rhs.address,
149 }
150 }
151}
152
153impl<N: Field> Sub<Ptr<N>> for SymbolicPtr<N> {
154 type Output = Self;
155
156 fn sub(self, rhs: Ptr<N>) -> Self {
157 Self {
158 address: self.address - rhs.address,
159 }
160 }
161}
162
163impl<N: Field> Add<SymbolicPtr<N>> for Ptr<N> {
164 type Output = SymbolicPtr<N>;
165
166 fn add(self, rhs: SymbolicPtr<N>) -> SymbolicPtr<N> {
167 SymbolicPtr {
168 address: self.address + rhs.address,
169 }
170 }
171}
172
173impl<N: Field> Sub<SymbolicPtr<N>> for Ptr<N> {
174 type Output = SymbolicPtr<N>;
175
176 fn sub(self, rhs: SymbolicPtr<N>) -> SymbolicPtr<N> {
177 SymbolicPtr {
178 address: self.address - rhs.address,
179 }
180 }
181}
182
183impl<N: Field, RHS: Into<SymbolicVar<N>>> Add<RHS> for Ptr<N> {
184 type Output = SymbolicPtr<N>;
185
186 fn add(self, rhs: RHS) -> SymbolicPtr<N> {
187 SymbolicPtr::from(self) + rhs.into()
188 }
189}
190
191impl<N: Field, RHS: Into<SymbolicVar<N>>> Add<RHS> for SymbolicPtr<N> {
192 type Output = SymbolicPtr<N>;
193
194 fn add(self, rhs: RHS) -> SymbolicPtr<N> {
195 SymbolicPtr {
196 address: self.address + rhs.into(),
197 }
198 }
199}
200
201impl<N: PrimeField, RHS: Into<SymbolicVar<N>>> Sub<RHS> for Ptr<N> {
202 type Output = SymbolicPtr<N>;
203
204 fn sub(self, rhs: RHS) -> SymbolicPtr<N> {
205 SymbolicPtr::from(self) - rhs.into()
206 }
207}
208
209impl<N: PrimeField, RHS: Into<SymbolicVar<N>>> Sub<RHS> for SymbolicPtr<N> {
210 type Output = SymbolicPtr<N>;
211
212 fn sub(self, rhs: RHS) -> SymbolicPtr<N> {
213 SymbolicPtr {
214 address: self.address - rhs.into(),
215 }
216 }
217}