openvm_native_compiler/ir/
collections.rs

1use alloc::rc::Rc;
2use std::cell::RefCell;
3
4use itertools::Itertools;
5use serde::{Deserialize, Serialize};
6
7use super::{
8    Builder, Config, FromConstant, MemIndex, MemVariable, Ptr, RVar, Usize, Var, Variable,
9};
10
11/// A logical array.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum Array<C: Config, T> {
14    /// Array of some local variables or constants, which can only be manipulated statically. It
15    /// only exists in the DSL syntax and isn't backed by memory.
16    Fixed(Rc<RefCell<Vec<Option<T>>>>),
17    /// Array on heap. Index access can use variables. Length could be determined on runtime but
18    /// cannot change after initialization.
19    Dyn(Ptr<C::N>, Usize<C::N>),
20}
21
22impl<C: Config, V: MemVariable<C>> Array<C, V> {
23    /// Gets a right value of the array.
24    pub fn vec(&self) -> Vec<V> {
25        match self {
26            Self::Fixed(vec) => vec.borrow().iter().map(|x| x.clone().unwrap()).collect(),
27            _ => panic!("array is dynamic, not fixed"),
28        }
29    }
30
31    pub fn ptr(&self) -> Ptr<C::N> {
32        match *self {
33            Array::Dyn(ptr, _) => ptr,
34            Array::Fixed(_) => panic!("cannot retrieve pointer for a compile-time array"),
35        }
36    }
37
38    /// Gets the length of the array as a variable inside the DSL.
39    pub fn len(&self) -> Usize<C::N> {
40        match self {
41            Self::Fixed(vec) => Usize::from(vec.borrow().len()),
42            Self::Dyn(_, len) => len.clone(),
43        }
44    }
45
46    /// Shifts the array by `shift` elements.
47    /// !Attention!: the behavior of `Fixed` and `Dyn` is different. For Dyn, the shift is a view
48    /// and shares memory with the original. For `Fixed`, `set`/`set_value` on slices won't impact
49    /// the original array.
50    pub fn shift(&self, builder: &mut Builder<C>, shift: impl Into<RVar<C::N>>) -> Array<C, V> {
51        match self {
52            Self::Fixed(v) => {
53                let shift = shift.into();
54                if let RVar::Const(_) = shift {
55                    let shift = shift.value();
56                    Array::Fixed(Rc::new(RefCell::new(v.borrow()[shift..].to_vec())))
57                } else {
58                    panic!("Cannot shift a fixed array with a variable shift");
59                }
60            }
61            Self::Dyn(ptr, len) => {
62                let len = RVar::from(len.clone());
63                let shift = shift.into();
64                let new_ptr = builder.eval(*ptr + shift * RVar::from(V::size_of()));
65                let new_length = builder.eval(len - shift);
66                Array::Dyn(new_ptr, Usize::Var(new_length))
67            }
68        }
69    }
70
71    /// Truncates the array to `len` elements.
72    pub fn truncate(&self, builder: &mut Builder<C>, len: Usize<C::N>) {
73        match self {
74            Self::Fixed(v) => {
75                let len = len.value();
76                v.borrow_mut().truncate(len);
77            }
78            Self::Dyn(_, old_len) => {
79                builder.assign(old_len, len);
80            }
81        };
82    }
83
84    /// Slices the array from `start` to `end`.
85    /// !Attention!: the behavior of `Fixed` and `Dyn` is different. For Dyn, the shift is a view
86    /// and shares memory with the original. For `Fixed`, `set`/`set_value` on slices won't impact
87    /// the original array.
88    pub fn slice(
89        &self,
90        builder: &mut Builder<C>,
91        start: impl Into<RVar<C::N>>,
92        end: impl Into<RVar<C::N>>,
93    ) -> Array<C, V> {
94        let start = start.into();
95        let end = end.into();
96        match self {
97            Self::Fixed(v) => {
98                if let (RVar::Const(_), RVar::Const(_)) = (&start, &end) {
99                    Array::Fixed(Rc::new(RefCell::new(
100                        v.borrow()[start.value()..end.value()].to_vec(),
101                    )))
102                } else {
103                    panic!("Cannot slice a fixed array with a variable start or end");
104                }
105            }
106            Self::Dyn(ptr, _) => {
107                let slice_len = builder.eval(end - start);
108                let address = builder.eval(ptr.address + start * RVar::from(V::size_of()));
109                let ptr = Ptr { address };
110                Array::Dyn(ptr, Usize::Var(slice_len))
111            }
112        }
113    }
114}
115
116impl<C: Config> Builder<C> {
117    /// Initialize an array of fixed length `len`. The entries will be uninitialized.
118    pub fn array<V: MemVariable<C>>(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, V> {
119        let len = len.into();
120        if self.flags.static_only {
121            self.uninit_fixed_array(len.value())
122        } else {
123            self.dyn_array(len)
124        }
125    }
126
127    /// Creates an array from a vector.
128    pub fn vec<V: MemVariable<C>>(&mut self, v: Vec<V>) -> Array<C, V> {
129        Array::Fixed(Rc::new(RefCell::new(
130            v.into_iter().map(|x| Some(x)).collect(),
131        )))
132    }
133
134    /// Create an uninitialized Array::Fixed.
135    pub fn uninit_fixed_array<V: Variable<C>>(&mut self, len: usize) -> Array<C, V> {
136        Array::Fixed(Rc::new(RefCell::new(vec![None::<V>; len])))
137    }
138
139    /// Creates a dynamic array for a length.
140    pub fn dyn_array<V: MemVariable<C>>(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, V> {
141        let len: Var<_> = self.eval(len.into());
142        let ptr = self.alloc(len, V::size_of());
143        Array::Dyn(ptr, Usize::Var(len))
144    }
145
146    pub fn get<V: MemVariable<C>, I: Into<RVar<C::N>>>(
147        &mut self,
148        slice: &Array<C, V>,
149        index: I,
150    ) -> V {
151        let index = index.into();
152
153        match slice {
154            Array::Fixed(slice) => {
155                if let RVar::Const(_) = index {
156                    let idx = index.value();
157                    if let Some(ele) = &slice.borrow()[idx] {
158                        ele.clone()
159                    } else {
160                        panic!("Cannot get an uninitialized element in a fixed slice");
161                    }
162                } else {
163                    panic!("Cannot index into a fixed slice with a variable size")
164                }
165            }
166            Array::Dyn(ptr, _) => {
167                let index = MemIndex {
168                    index,
169                    offset: 0,
170                    size: V::size_of(),
171                };
172                let var: V = self.uninit();
173                self.load(var.clone(), *ptr, index);
174                var
175            }
176        }
177    }
178
179    /// Intended to be used with `ptr` from `zip`. Assumes that:
180    /// - if `slice` is `Array::Fixed`, then `ptr` is a constant index in [0, slice.len()).
181    /// - if `slice` is `Array::Dyn`, then `ptr` is a variable iterator over the entries of `slice`.
182    ///
183    /// In both cases, loads and returns the corresponding element of `slice`.
184    pub fn iter_ptr_get<V: MemVariable<C>>(&mut self, slice: &Array<C, V>, ptr: RVar<C::N>) -> V {
185        match slice {
186            Array::Fixed(v) => {
187                if let RVar::Const(_) = ptr {
188                    let idx = ptr.value();
189                    v.borrow()[idx].clone().unwrap()
190                } else {
191                    panic!("Cannot index into a fixed slice with a variable index")
192                }
193            }
194            Array::Dyn(_, _) => {
195                let index = MemIndex {
196                    index: 0.into(),
197                    offset: 0,
198                    size: V::size_of(),
199                };
200                let var: V = self.uninit();
201                self.load(
202                    var.clone(),
203                    Ptr {
204                        address: match ptr {
205                            RVar::Const(_) => panic!(
206                                "iter_ptr_get on dynamic array not supported for constant ptr"
207                            ),
208                            RVar::Val(v) => v,
209                        },
210                    },
211                    index,
212                );
213                var
214            }
215        }
216    }
217
218    /// Intended to be used with `ptr` from `zip`. Assumes that:
219    /// - if `slice` is `Array::Fixed`, then `ptr` is a constant index in [0, slice.len()).
220    /// - if `slice` is `Array::Dyn`, then `ptr` is a variable iterator over the entries of `slice`.
221    ///
222    /// In both cases, stores the given `value` at the corresponding element of `slice`.
223    pub fn iter_ptr_set<V: MemVariable<C>, Expr: Into<V::Expression>>(
224        &mut self,
225        slice: &Array<C, V>,
226        ptr: RVar<C::N>,
227        value: Expr,
228    ) {
229        match slice {
230            Array::Fixed(v) => {
231                if let RVar::Const(_) = ptr {
232                    let idx = ptr.value();
233                    let value = self.eval(value);
234                    v.borrow_mut()[idx] = Some(value);
235                } else {
236                    panic!("Cannot index into a fixed slice with a variable index")
237                }
238            }
239            Array::Dyn(_, _) => {
240                let value: V = self.eval(value);
241                self.store(
242                    Ptr {
243                        address: match ptr {
244                            RVar::Const(_) => panic!(
245                                "iter_ptr_set on dynamic array not supported for constant ptr"
246                            ),
247                            RVar::Val(v) => v,
248                        },
249                    },
250                    MemIndex {
251                        index: 0.into(),
252                        offset: 0,
253                        size: V::size_of(),
254                    },
255                    value,
256                );
257            }
258        }
259    }
260
261    pub fn set<V: MemVariable<C>, I: Into<RVar<C::N>>, Expr: Into<V::Expression>>(
262        &mut self,
263        slice: &Array<C, V>,
264        index: I,
265        value: Expr,
266    ) {
267        let index = index.into();
268
269        match slice {
270            Array::Fixed(v) => {
271                if let RVar::Const(_) = index {
272                    let idx = index.value();
273                    let value = self.eval(value);
274                    v.borrow_mut()[idx] = Some(value);
275                } else {
276                    panic!("Cannot index into a fixed slice with a variable index")
277                }
278            }
279            Array::Dyn(ptr, _) => {
280                let index = MemIndex {
281                    index,
282                    offset: 0,
283                    size: V::size_of(),
284                };
285                let value: V = self.eval(value);
286                self.store(*ptr, index, value);
287            }
288        }
289    }
290
291    pub fn set_value<V: MemVariable<C>, I: Into<RVar<C::N>>>(
292        &mut self,
293        slice: &Array<C, V>,
294        index: I,
295        value: V,
296    ) {
297        let index = index.into();
298
299        match slice {
300            Array::Fixed(v) => {
301                if let RVar::Const(_) = index {
302                    let idx = index.value();
303                    v.borrow_mut()[idx] = Some(value);
304                } else {
305                    panic!("Cannot index into a fixed slice with a variable size")
306                }
307            }
308            Array::Dyn(ptr, _) => {
309                let index = MemIndex {
310                    index,
311                    offset: 0,
312                    size: V::size_of(),
313                };
314                self.store(*ptr, index, value);
315            }
316        }
317    }
318}
319
320impl<C: Config, T: MemVariable<C>> Variable<C> for Array<C, T> {
321    type Expression = Self;
322
323    fn uninit(builder: &mut Builder<C>) -> Self {
324        Array::Dyn(builder.uninit(), builder.uninit())
325    }
326
327    fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
328        match (self, src.clone()) {
329            (Array::Dyn(lhs_ptr, lhs_len), Array::Dyn(rhs_ptr, rhs_len)) => {
330                builder.assign(lhs_ptr, rhs_ptr);
331                builder.assign(lhs_len, rhs_len);
332            }
333            (Array::Fixed(lhs_rc), Array::Fixed(rhs_rc)) => {
334                *lhs_rc.borrow_mut() = rhs_rc.borrow().clone();
335            }
336            _ => unreachable!(),
337        }
338    }
339
340    fn assert_eq(
341        lhs: impl Into<Self::Expression>,
342        rhs: impl Into<Self::Expression>,
343        builder: &mut Builder<C>,
344    ) {
345        let lhs = lhs.into();
346        let rhs = rhs.into();
347
348        match (lhs.clone(), rhs.clone()) {
349            (Array::Fixed(lhs), Array::Fixed(rhs)) => {
350                // No need to compare if they are the same reference. The same reference will
351                // also cause borrow errors in the following loop.
352                if Rc::ptr_eq(&lhs, &rhs) {
353                    return;
354                }
355                for (l, r) in lhs.borrow().iter().zip_eq(rhs.borrow().iter()) {
356                    assert!(l.is_some(), "lhs array is not fully initialized");
357                    assert!(r.is_some(), "rhs array is not fully initialized");
358                    T::assert_eq(
359                        T::Expression::from(l.as_ref().unwrap().clone()),
360                        T::Expression::from(r.as_ref().unwrap().clone()),
361                        builder,
362                    );
363                }
364            }
365            (Array::Dyn(_, lhs_len), Array::Dyn(_, rhs_len)) => {
366                builder.assert_eq::<Usize<_>>(lhs_len.clone(), rhs_len);
367
368                builder.range(0, lhs_len).for_each(|idx_vec, builder| {
369                    let a = builder.get(&lhs, idx_vec[0]);
370                    let b = builder.get(&rhs, idx_vec[0]);
371                    builder.assert_eq::<T>(a, b);
372                });
373            }
374            _ => panic!("cannot compare arrays of different types"),
375        }
376    }
377
378    // The default version calls `uninit`. If `expr` is `Fixed`, it will be converted into `Dyn`.
379    fn eval(_builder: &mut Builder<C>, expr: impl Into<Self::Expression>) -> Self {
380        expr.into()
381    }
382}
383
384impl<C: Config, T: MemVariable<C>> MemVariable<C> for Array<C, T> {
385    fn size_of() -> usize {
386        2
387    }
388
389    fn load(&self, src: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
390        match self {
391            Array::Dyn(dst, Usize::Var(len)) => {
392                let mut index = index;
393                dst.load(src, index, builder);
394                index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
395                len.load(src, index, builder);
396            }
397            _ => unreachable!(),
398        }
399    }
400
401    fn store(&self, dst: Ptr<<C as Config>::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
402        match self {
403            Array::Dyn(src, Usize::Var(len)) => {
404                let mut index = index;
405                src.store(dst, index, builder);
406                index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
407                len.store(dst, index, builder);
408            }
409            _ => unreachable!(),
410        }
411    }
412}
413
414impl<C: Config, V: FromConstant<C> + MemVariable<C>> FromConstant<C> for Array<C, V> {
415    type Constant = Vec<V::Constant>;
416
417    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
418        let array = builder.dyn_array(value.len());
419        for (i, val) in value.into_iter().enumerate() {
420            let val = V::constant(val, builder);
421            builder.set(&array, i, val);
422        }
423        array
424    }
425}
426
427/// Unsafe transmute from array of one type to another.
428///
429/// SAFETY: only use this if the memory layout of types `S` and `T` align.
430/// Only usable for `Array::Dyn`, will panic otherwise.
431pub fn unsafe_array_transmute<C: Config, S, T>(arr: Array<C, S>) -> Array<C, T> {
432    if let Array::Dyn(ptr, len) = arr {
433        Array::Dyn(ptr, len)
434    } else {
435        unreachable!()
436    }
437}
438
439#[allow(clippy::len_without_is_empty)]
440pub trait ArrayLike<C: Config> {
441    fn len(&self) -> Usize<C::N>;
442
443    fn ptr(&self) -> Ptr<C::N>;
444
445    fn is_fixed(&self) -> bool;
446
447    fn element_size_of(&self) -> usize;
448}
449
450impl<C: Config, T: MemVariable<C>> ArrayLike<C> for Array<C, T> {
451    fn len(&self) -> Usize<C::N> {
452        self.len()
453    }
454
455    fn ptr(&self) -> Ptr<C::N> {
456        self.ptr()
457    }
458
459    fn is_fixed(&self) -> bool {
460        match self {
461            Array::Fixed(_) => true,
462            Array::Dyn(_, _) => false,
463        }
464    }
465
466    fn element_size_of(&self) -> usize {
467        T::size_of()
468    }
469}