openvm_native_compiler/ir/
collections.rs

1use alloc::rc::Rc;
2use std::cell::RefCell;
3
4use itertools::Itertools;
5use serde::{Deserialize, Serialize};
6
7use super::{
8    Builder, Config, FromConstant, MemIndex, MemVariable, Ptr, RVar, Usize, Var, Variable,
9};
10
11/// A logical array.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum Array<C: Config, T> {
14    /// Array of some local variables or constants, which can only be manipulated statically. It
15    /// only exists in the DSL syntax and isn't backed by memory.
16    Fixed(Rc<RefCell<Vec<Option<T>>>>),
17    /// Array on heap. Index access can use variables. Length could be determined on runtime but
18    /// cannot change after initialization.
19    Dyn(Ptr<C::N>, Usize<C::N>),
20}
21
22impl<C: Config, V: MemVariable<C>> Array<C, V> {
23    /// Gets a right value of the array.
24    pub fn vec(&self) -> Vec<V> {
25        match self {
26            Self::Fixed(vec) => vec.borrow().iter().map(|x| x.clone().unwrap()).collect(),
27            _ => panic!("array is dynamic, not fixed"),
28        }
29    }
30
31    pub fn ptr(&self) -> Ptr<C::N> {
32        match *self {
33            Array::Dyn(ptr, _) => ptr,
34            Array::Fixed(_) => panic!("cannot retrieve pointer for a compile-time array"),
35        }
36    }
37
38    /// Gets the length of the array as a variable inside the DSL.
39    pub fn len(&self) -> Usize<C::N> {
40        match self {
41            Self::Fixed(vec) => Usize::from(vec.borrow().len()),
42            Self::Dyn(_, len) => len.clone(),
43        }
44    }
45
46    /// Shifts the array by `shift` elements.
47    /// !Attention!: the behavior of `Fixed` and `Dyn` is different. For Dyn, the shift is a view
48    /// and shares memory with the original. For `Fixed`, `set`/`set_value` on slices won't impact
49    /// the original array.
50    pub fn shift(&self, builder: &mut Builder<C>, shift: impl Into<RVar<C::N>>) -> Array<C, V> {
51        match self {
52            Self::Fixed(v) => {
53                let shift = shift.into();
54                if let RVar::Const(_) = shift {
55                    let shift = shift.value();
56                    Array::Fixed(Rc::new(RefCell::new(v.borrow()[shift..].to_vec())))
57                } else {
58                    panic!("Cannot shift a fixed array with a variable shift");
59                }
60            }
61            Self::Dyn(ptr, len) => {
62                let len = RVar::from(len.clone());
63                let shift = shift.into();
64                let new_ptr = builder.eval(*ptr + shift * RVar::from(V::size_of()));
65                let new_length = builder.eval(len - shift);
66                Array::Dyn(new_ptr, Usize::Var(new_length))
67            }
68        }
69    }
70
71    /// Truncates the array to `len` elements.
72    pub fn truncate(&self, builder: &mut Builder<C>, len: Usize<C::N>) {
73        match self {
74            Self::Fixed(v) => {
75                let len = len.value();
76                v.borrow_mut().truncate(len);
77            }
78            Self::Dyn(_, old_len) => {
79                builder.assign(old_len, len);
80            }
81        };
82    }
83
84    /// Slices the array from `start` to `end`.
85    /// !Attention!: the behavior of `Fixed` and `Dyn` is different. For Dyn, the shift is a view
86    /// and shares memory with the original. For `Fixed`, `set`/`set_value` on slices won't impact
87    /// the original array.
88    pub fn slice(
89        &self,
90        builder: &mut Builder<C>,
91        start: impl Into<RVar<C::N>>,
92        end: impl Into<RVar<C::N>>,
93    ) -> Array<C, V> {
94        let start = start.into();
95        let end = end.into();
96        match self {
97            Self::Fixed(v) => {
98                if let (RVar::Const(_), RVar::Const(_)) = (&start, &end) {
99                    Array::Fixed(Rc::new(RefCell::new(
100                        v.borrow()[start.value()..end.value()].to_vec(),
101                    )))
102                } else {
103                    panic!("Cannot slice a fixed array with a variable start or end");
104                }
105            }
106            Self::Dyn(ptr, _) => {
107                let slice_len = builder.eval(end - start);
108                let address = builder.eval(ptr.address + start * RVar::from(V::size_of()));
109                let ptr = Ptr { address };
110                Array::Dyn(ptr, Usize::Var(slice_len))
111            }
112        }
113    }
114}
115
116impl<C: Config> Builder<C> {
117    /// Initialize an array of fixed length `len`. The entries will be uninitialized.
118    pub fn array<V: MemVariable<C>>(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, V> {
119        let len = len.into();
120        if self.flags.static_only {
121            self.uninit_fixed_array(len.value())
122        } else {
123            self.dyn_array(len)
124        }
125    }
126
127    /// Creates an array from a vector.
128    pub fn vec<V: MemVariable<C>>(&mut self, v: Vec<V>) -> Array<C, V> {
129        Array::Fixed(Rc::new(RefCell::new(
130            v.into_iter().map(|x| Some(x)).collect(),
131        )))
132    }
133
134    /// Create an uninitialized Array::Fixed.
135    pub fn uninit_fixed_array<V: Variable<C>>(&mut self, len: usize) -> Array<C, V> {
136        Array::Fixed(Rc::new(RefCell::new(vec![None::<V>; len])))
137    }
138
139    /// Creates a dynamic array for a length.
140    pub fn dyn_array<V: MemVariable<C>>(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, V> {
141        let len: Var<_> = self.eval(len.into());
142        let ptr = self.alloc(len, V::size_of());
143        Array::Dyn(ptr, Usize::Var(len))
144    }
145
146    /// Reads an element from an array.
147    ///
148    /// For `Array::Dyn`, this emits an unchecked load at `index`.
149    ///
150    /// # Safety
151    ///
152    /// When `slice` is `Array::Dyn`, the caller must have already established that
153    /// `index < slice.len()`. Otherwise this may read invalid memory and compromise soundness.
154    ///
155    /// When `slice` is `Array::Fixed`, `index` must be a constant in bounds, and the element must
156    /// have been initialized.
157    pub fn get<V: MemVariable<C>, I: Into<RVar<C::N>>>(
158        &mut self,
159        slice: &Array<C, V>,
160        index: I,
161    ) -> V {
162        let index = index.into();
163
164        match slice {
165            Array::Fixed(slice) => {
166                if let RVar::Const(_) = index {
167                    let idx = index.value();
168                    if let Some(ele) = &slice.borrow()[idx] {
169                        ele.clone()
170                    } else {
171                        panic!("Cannot get an uninitialized element in a fixed slice");
172                    }
173                } else {
174                    panic!("Cannot index into a fixed slice with a variable size")
175                }
176            }
177            Array::Dyn(ptr, _) => {
178                let index = MemIndex {
179                    index,
180                    offset: 0,
181                    size: V::size_of(),
182                };
183                let var: V = self.uninit();
184                self.load(var.clone(), *ptr, index);
185                var
186            }
187        }
188    }
189
190    /// Intended to be used with `ptr` from `zip`. Assumes that:
191    /// - if `slice` is `Array::Fixed`, then `ptr` is a constant index in [0, slice.len()).
192    /// - if `slice` is `Array::Dyn`, then `ptr` is a variable iterator over the entries of `slice`.
193    ///
194    /// In both cases, loads and returns the corresponding element of `slice`.
195    pub fn iter_ptr_get<V: MemVariable<C>>(&mut self, slice: &Array<C, V>, ptr: RVar<C::N>) -> V {
196        match slice {
197            Array::Fixed(v) => {
198                if let RVar::Const(_) = ptr {
199                    let idx = ptr.value();
200                    v.borrow()[idx].clone().unwrap()
201                } else {
202                    panic!("Cannot index into a fixed slice with a variable index")
203                }
204            }
205            Array::Dyn(_, _) => {
206                let index = MemIndex {
207                    index: 0.into(),
208                    offset: 0,
209                    size: V::size_of(),
210                };
211                let var: V = self.uninit();
212                self.load(
213                    var.clone(),
214                    Ptr {
215                        address: match ptr {
216                            RVar::Const(_) => panic!(
217                                "iter_ptr_get on dynamic array not supported for constant ptr"
218                            ),
219                            RVar::Val(v) => v,
220                        },
221                    },
222                    index,
223                );
224                var
225            }
226        }
227    }
228
229    /// Intended to be used with `ptr` from `zip`. Assumes that:
230    /// - if `slice` is `Array::Fixed`, then `ptr` is a constant index in [0, slice.len()).
231    /// - if `slice` is `Array::Dyn`, then `ptr` is a variable iterator over the entries of `slice`.
232    ///
233    /// In both cases, stores the given `value` at the corresponding element of `slice`.
234    pub fn iter_ptr_set<V: MemVariable<C>, Expr: Into<V::Expression>>(
235        &mut self,
236        slice: &Array<C, V>,
237        ptr: RVar<C::N>,
238        value: Expr,
239    ) {
240        match slice {
241            Array::Fixed(v) => {
242                if let RVar::Const(_) = ptr {
243                    let idx = ptr.value();
244                    let value = self.eval(value);
245                    v.borrow_mut()[idx] = Some(value);
246                } else {
247                    panic!("Cannot index into a fixed slice with a variable index")
248                }
249            }
250            Array::Dyn(_, _) => {
251                let value: V = self.eval(value);
252                self.store(
253                    Ptr {
254                        address: match ptr {
255                            RVar::Const(_) => panic!(
256                                "iter_ptr_set on dynamic array not supported for constant ptr"
257                            ),
258                            RVar::Val(v) => v,
259                        },
260                    },
261                    MemIndex {
262                        index: 0.into(),
263                        offset: 0,
264                        size: V::size_of(),
265                    },
266                    value,
267                );
268            }
269        }
270    }
271
272    pub fn set<V: MemVariable<C>, I: Into<RVar<C::N>>, Expr: Into<V::Expression>>(
273        &mut self,
274        slice: &Array<C, V>,
275        index: I,
276        value: Expr,
277    ) {
278        let index = index.into();
279
280        match slice {
281            Array::Fixed(v) => {
282                if let RVar::Const(_) = index {
283                    let idx = index.value();
284                    let value = self.eval(value);
285                    v.borrow_mut()[idx] = Some(value);
286                } else {
287                    panic!("Cannot index into a fixed slice with a variable index")
288                }
289            }
290            Array::Dyn(ptr, _) => {
291                let index = MemIndex {
292                    index,
293                    offset: 0,
294                    size: V::size_of(),
295                };
296                let value: V = self.eval(value);
297                self.store(*ptr, index, value);
298            }
299        }
300    }
301
302    pub fn set_value<V: MemVariable<C>, I: Into<RVar<C::N>>>(
303        &mut self,
304        slice: &Array<C, V>,
305        index: I,
306        value: V,
307    ) {
308        let index = index.into();
309
310        match slice {
311            Array::Fixed(v) => {
312                if let RVar::Const(_) = index {
313                    let idx = index.value();
314                    v.borrow_mut()[idx] = Some(value);
315                } else {
316                    panic!("Cannot index into a fixed slice with a variable size")
317                }
318            }
319            Array::Dyn(ptr, _) => {
320                let index = MemIndex {
321                    index,
322                    offset: 0,
323                    size: V::size_of(),
324                };
325                self.store(*ptr, index, value);
326            }
327        }
328    }
329}
330
331impl<C: Config, T: MemVariable<C>> Variable<C> for Array<C, T> {
332    type Expression = Self;
333
334    fn uninit(builder: &mut Builder<C>) -> Self {
335        Array::Dyn(builder.uninit(), builder.uninit())
336    }
337
338    fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
339        match (self, src.clone()) {
340            (Array::Dyn(lhs_ptr, lhs_len), Array::Dyn(rhs_ptr, rhs_len)) => {
341                builder.assign(lhs_ptr, rhs_ptr);
342                builder.assign(lhs_len, rhs_len);
343            }
344            (Array::Fixed(lhs_rc), Array::Fixed(rhs_rc)) => {
345                *lhs_rc.borrow_mut() = rhs_rc.borrow().clone();
346            }
347            _ => unreachable!(),
348        }
349    }
350
351    fn assert_eq(
352        lhs: impl Into<Self::Expression>,
353        rhs: impl Into<Self::Expression>,
354        builder: &mut Builder<C>,
355    ) {
356        let lhs = lhs.into();
357        let rhs = rhs.into();
358
359        match (lhs.clone(), rhs.clone()) {
360            (Array::Fixed(lhs), Array::Fixed(rhs)) => {
361                // No need to compare if they are the same reference. The same reference will
362                // also cause borrow errors in the following loop.
363                if Rc::ptr_eq(&lhs, &rhs) {
364                    return;
365                }
366                for (l, r) in lhs.borrow().iter().zip_eq(rhs.borrow().iter()) {
367                    assert!(l.is_some(), "lhs array is not fully initialized");
368                    assert!(r.is_some(), "rhs array is not fully initialized");
369                    T::assert_eq(
370                        T::Expression::from(l.as_ref().unwrap().clone()),
371                        T::Expression::from(r.as_ref().unwrap().clone()),
372                        builder,
373                    );
374                }
375            }
376            (Array::Dyn(_, lhs_len), Array::Dyn(_, rhs_len)) => {
377                builder.assert_eq::<Usize<_>>(lhs_len.clone(), rhs_len);
378
379                builder.range(0, lhs_len).for_each(|idx_vec, builder| {
380                    let a = builder.get(&lhs, idx_vec[0]);
381                    let b = builder.get(&rhs, idx_vec[0]);
382                    builder.assert_eq::<T>(a, b);
383                });
384            }
385            _ => panic!("cannot compare arrays of different types"),
386        }
387    }
388
389    // The default version calls `uninit`. If `expr` is `Fixed`, it will be converted into `Dyn`.
390    fn eval(_builder: &mut Builder<C>, expr: impl Into<Self::Expression>) -> Self {
391        expr.into()
392    }
393}
394
395impl<C: Config, T: MemVariable<C>> MemVariable<C> for Array<C, T> {
396    fn size_of() -> usize {
397        2
398    }
399
400    fn load(&self, src: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
401        match self {
402            Array::Dyn(dst, Usize::Var(len)) => {
403                let mut index = index;
404                dst.load(src, index, builder);
405                index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
406                len.load(src, index, builder);
407            }
408            _ => unreachable!(),
409        }
410    }
411
412    fn store(&self, dst: Ptr<<C as Config>::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
413        match self {
414            Array::Dyn(src, Usize::Var(len)) => {
415                let mut index = index;
416                src.store(dst, index, builder);
417                index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
418                len.store(dst, index, builder);
419            }
420            _ => unreachable!(),
421        }
422    }
423}
424
425impl<C: Config, V: FromConstant<C> + MemVariable<C>> FromConstant<C> for Array<C, V> {
426    type Constant = Vec<V::Constant>;
427
428    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
429        let array = builder.dyn_array(value.len());
430        for (i, val) in value.into_iter().enumerate() {
431            let val = V::constant(val, builder);
432            builder.set(&array, i, val);
433        }
434        array
435    }
436}
437
438/// Unsafe transmute from array of one type to another.
439///
440/// SAFETY: only use this if the memory layout of types `S` and `T` align.
441/// Only usable for `Array::Dyn`, will panic otherwise.
442pub fn unsafe_array_transmute<C: Config, S, T>(arr: Array<C, S>) -> Array<C, T> {
443    if let Array::Dyn(ptr, len) = arr {
444        Array::Dyn(ptr, len)
445    } else {
446        unreachable!()
447    }
448}
449
450#[allow(clippy::len_without_is_empty)]
451pub trait ArrayLike<C: Config> {
452    fn len(&self) -> Usize<C::N>;
453
454    fn ptr(&self) -> Ptr<C::N>;
455
456    fn is_fixed(&self) -> bool;
457
458    fn element_size_of(&self) -> usize;
459}
460
461impl<C: Config, T: MemVariable<C>> ArrayLike<C> for Array<C, T> {
462    fn len(&self) -> Usize<C::N> {
463        self.len()
464    }
465
466    fn ptr(&self) -> Ptr<C::N> {
467        self.ptr()
468    }
469
470    fn is_fixed(&self) -> bool {
471        match self {
472            Array::Fixed(_) => true,
473            Array::Dyn(_, _) => false,
474        }
475    }
476
477    fn element_size_of(&self) -> usize {
478        T::size_of()
479    }
480}