openvm_native_recursion/
digest.rs

1use openvm_native_compiler::{
2    ir::{Array, Builder, Config, Ext, Felt, FromConstant, MemIndex, Ptr, Usize, Var, Variable},
3    prelude::MemVariable,
4};
5
6use crate::{outer_poseidon2::Poseidon2CircuitBuilder, vars::OuterDigestVariable};
7
8#[derive(Clone)]
9pub enum DigestVal<C: Config> {
10    F(Vec<C::F>),
11    N(Vec<C::N>),
12}
13
14impl<C: Config> DigestVal<C> {
15    pub fn len(&self) -> usize {
16        match self {
17            DigestVal::F(v) => v.len(),
18            DigestVal::N(v) => v.len(),
19        }
20    }
21    pub fn is_empty(&self) -> bool {
22        match self {
23            DigestVal::F(v) => v.is_empty(),
24            DigestVal::N(v) => v.is_empty(),
25        }
26    }
27}
28
29#[derive(Clone)]
30pub enum DigestVariable<C: Config> {
31    Felt(Array<C, Felt<C::F>>),
32    Var(Array<C, Var<C::N>>),
33}
34
35impl<C: Config> Variable<C> for DigestVariable<C> {
36    type Expression = Self;
37
38    fn uninit(builder: &mut Builder<C>) -> Self {
39        Self::Felt(builder.uninit())
40    }
41
42    fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
43        match (self, src) {
44            (Self::Felt(lhs), Self::Felt(rhs)) => builder.assign(lhs, rhs),
45            (Self::Var(lhs), Self::Var(rhs)) => builder.assign(lhs, rhs),
46            _ => panic!("Assignment types mismatch"),
47        }
48    }
49
50    fn assert_eq(
51        lhs: impl Into<Self::Expression>,
52        rhs: impl Into<Self::Expression>,
53        builder: &mut Builder<C>,
54    ) {
55        match (lhs.into(), rhs.into()) {
56            (Self::Felt(lhs), Self::Felt(rhs)) => builder.assert_eq::<Array<C, _>>(lhs, rhs),
57            (Self::Var(lhs), Self::Var(rhs)) => builder.assert_eq::<Array<C, _>>(lhs, rhs),
58            _ => panic!("Assertion types mismatch"),
59        }
60    }
61}
62
63impl<C: Config> MemVariable<C> for DigestVariable<C> {
64    fn size_of() -> usize {
65        Array::<C, Felt<C::F>>::size_of()
66    }
67
68    fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
69        match self {
70            DigestVariable::Felt(array) => array.load(ptr, index, builder),
71            DigestVariable::Var(array) => array.load(ptr, index, builder),
72        }
73    }
74
75    fn store(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
76        match self {
77            DigestVariable::Felt(array) => array.store(ptr, index, builder),
78            DigestVariable::Var(array) => array.store(ptr, index, builder),
79        }
80    }
81}
82
83impl<C: Config> FromConstant<C> for DigestVariable<C> {
84    type Constant = DigestVal<C>;
85
86    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
87        match value {
88            DigestVal::F(value) => {
89                let array = builder.array(value.len());
90                for (i, val) in value.into_iter().enumerate() {
91                    let val = Felt::constant(val, builder);
92                    builder.set(&array, i, val);
93                }
94                Self::Felt(array)
95            }
96            DigestVal::N(value) => {
97                let array = builder.array(value.len());
98                for (i, val) in value.into_iter().enumerate() {
99                    let val = Var::constant(val, builder);
100                    builder.set(&array, i, val);
101                }
102                Self::Var(array)
103            }
104        }
105    }
106}
107
108impl<C: Config> DigestVariable<C> {
109    pub fn len(&self) -> Usize<C::N> {
110        match self {
111            DigestVariable::Felt(array) => array.len(),
112            DigestVariable::Var(array) => array.len(),
113        }
114    }
115    /// Cast to OuterDigestVariable. This should only be used in static mode.
116    pub fn into_outer_digest(self) -> OuterDigestVariable<C> {
117        match self {
118            DigestVariable::Var(array) => array.vec().try_into().unwrap(),
119            DigestVariable::Felt(_) => panic!("Trying to get Var array from Felt array"),
120        }
121    }
122    /// Cast to an inner digest. This should only be used in dynamic mode.
123    pub fn into_inner_digest(self) -> Array<C, Felt<C::F>> {
124        match self {
125            DigestVariable::Felt(array) => array,
126            DigestVariable::Var(_) => panic!("Trying to get Felt array from Var array"),
127        }
128    }
129}
130
131impl<C: Config> From<Array<C, Felt<C::F>>> for DigestVariable<C> {
132    fn from(value: Array<C, Felt<C::F>>) -> Self {
133        Self::Felt(value)
134    }
135}
136
137impl<C: Config> From<Array<C, Var<C::N>>> for DigestVariable<C> {
138    fn from(value: Array<C, Var<C::N>>) -> Self {
139        Self::Var(value)
140    }
141}
142
143pub trait CanPoseidon2Digest<C: Config> {
144    fn p2_digest(&self, builder: &mut Builder<C>) -> DigestVariable<C>;
145}
146
147impl<C: Config> CanPoseidon2Digest<C> for Array<C, Array<C, Felt<C::F>>> {
148    fn p2_digest(&self, builder: &mut Builder<C>) -> DigestVariable<C> {
149        if builder.flags.static_only {
150            let digest_vec = builder.p2_hash(&flatten_fixed(self));
151            DigestVariable::Var(builder.vec(digest_vec.to_vec()))
152        } else {
153            DigestVariable::Felt(builder.poseidon2_hash_x(self))
154        }
155    }
156}
157
158impl<C: Config> CanPoseidon2Digest<C> for Array<C, Array<C, Ext<C::F, C::EF>>> {
159    fn p2_digest(&self, builder: &mut Builder<C>) -> DigestVariable<C> {
160        if builder.flags.static_only {
161            let flat_felts: Vec<_> = flatten_fixed(self)
162                .into_iter()
163                .flat_map(|ext| builder.ext2felt_circuit(ext).to_vec())
164                .collect();
165            let digest_vec = builder.p2_hash(&flat_felts);
166            DigestVariable::Var(builder.vec(digest_vec.to_vec()))
167        } else {
168            DigestVariable::Felt(builder.poseidon2_hash_ext(self))
169        }
170    }
171}
172
173fn flatten_fixed<C: Config, V: MemVariable<C>>(arr: &Array<C, Array<C, V>>) -> Vec<V> {
174    arr.vec()
175        .into_iter()
176        .flat_map(|felt_arr| felt_arr.vec())
177        .collect()
178}