1use openvm_native_compiler::{
2 ir::{Array, Builder, Config, Ext, Felt, FromConstant, MemIndex, Ptr, Usize, Var, Variable},
3 prelude::MemVariable,
4};
5
6use crate::{outer_poseidon2::Poseidon2CircuitBuilder, vars::OuterDigestVariable};
7
8#[derive(Clone)]
9pub enum DigestVal<C: Config> {
10 F(Vec<C::F>),
11 N(Vec<C::N>),
12}
13
14impl<C: Config> DigestVal<C> {
15 pub fn len(&self) -> usize {
16 match self {
17 DigestVal::F(v) => v.len(),
18 DigestVal::N(v) => v.len(),
19 }
20 }
21 pub fn is_empty(&self) -> bool {
22 match self {
23 DigestVal::F(v) => v.is_empty(),
24 DigestVal::N(v) => v.is_empty(),
25 }
26 }
27}
28
29#[derive(Clone)]
30pub enum DigestVariable<C: Config> {
31 Felt(Array<C, Felt<C::F>>),
32 Var(Array<C, Var<C::N>>),
33}
34
35impl<C: Config> Variable<C> for DigestVariable<C> {
36 type Expression = Self;
37
38 fn uninit(builder: &mut Builder<C>) -> Self {
39 Self::Felt(builder.uninit())
40 }
41
42 fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
43 match (self, src) {
44 (Self::Felt(lhs), Self::Felt(rhs)) => builder.assign(lhs, rhs),
45 (Self::Var(lhs), Self::Var(rhs)) => builder.assign(lhs, rhs),
46 _ => panic!("Assignment types mismatch"),
47 }
48 }
49
50 fn assert_eq(
51 lhs: impl Into<Self::Expression>,
52 rhs: impl Into<Self::Expression>,
53 builder: &mut Builder<C>,
54 ) {
55 match (lhs.into(), rhs.into()) {
56 (Self::Felt(lhs), Self::Felt(rhs)) => builder.assert_eq::<Array<C, _>>(lhs, rhs),
57 (Self::Var(lhs), Self::Var(rhs)) => builder.assert_eq::<Array<C, _>>(lhs, rhs),
58 _ => panic!("Assertion types mismatch"),
59 }
60 }
61}
62
63impl<C: Config> MemVariable<C> for DigestVariable<C> {
64 fn size_of() -> usize {
65 Array::<C, Felt<C::F>>::size_of()
66 }
67
68 fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
69 match self {
70 DigestVariable::Felt(array) => array.load(ptr, index, builder),
71 DigestVariable::Var(array) => array.load(ptr, index, builder),
72 }
73 }
74
75 fn store(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
76 match self {
77 DigestVariable::Felt(array) => array.store(ptr, index, builder),
78 DigestVariable::Var(array) => array.store(ptr, index, builder),
79 }
80 }
81}
82
83impl<C: Config> FromConstant<C> for DigestVariable<C> {
84 type Constant = DigestVal<C>;
85
86 fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
87 match value {
88 DigestVal::F(value) => {
89 let array = builder.array(value.len());
90 for (i, val) in value.into_iter().enumerate() {
91 let val = Felt::constant(val, builder);
92 builder.set(&array, i, val);
93 }
94 Self::Felt(array)
95 }
96 DigestVal::N(value) => {
97 let array = builder.array(value.len());
98 for (i, val) in value.into_iter().enumerate() {
99 let val = Var::constant(val, builder);
100 builder.set(&array, i, val);
101 }
102 Self::Var(array)
103 }
104 }
105 }
106}
107
108impl<C: Config> DigestVariable<C> {
109 pub fn len(&self) -> Usize<C::N> {
110 match self {
111 DigestVariable::Felt(array) => array.len(),
112 DigestVariable::Var(array) => array.len(),
113 }
114 }
115 pub fn into_outer_digest(self) -> OuterDigestVariable<C> {
117 match self {
118 DigestVariable::Var(array) => array.vec().try_into().unwrap(),
119 DigestVariable::Felt(_) => panic!("Trying to get Var array from Felt array"),
120 }
121 }
122 pub fn into_inner_digest(self) -> Array<C, Felt<C::F>> {
124 match self {
125 DigestVariable::Felt(array) => array,
126 DigestVariable::Var(_) => panic!("Trying to get Felt array from Var array"),
127 }
128 }
129}
130
131impl<C: Config> From<Array<C, Felt<C::F>>> for DigestVariable<C> {
132 fn from(value: Array<C, Felt<C::F>>) -> Self {
133 Self::Felt(value)
134 }
135}
136
137impl<C: Config> From<Array<C, Var<C::N>>> for DigestVariable<C> {
138 fn from(value: Array<C, Var<C::N>>) -> Self {
139 Self::Var(value)
140 }
141}
142
143pub trait CanPoseidon2Digest<C: Config> {
144 fn p2_digest(&self, builder: &mut Builder<C>) -> DigestVariable<C>;
145}
146
147impl<C: Config> CanPoseidon2Digest<C> for Array<C, Array<C, Felt<C::F>>> {
148 fn p2_digest(&self, builder: &mut Builder<C>) -> DigestVariable<C> {
149 if builder.flags.static_only {
150 let digest_vec = builder.p2_hash(&flatten_fixed(self));
151 DigestVariable::Var(builder.vec(digest_vec.to_vec()))
152 } else {
153 DigestVariable::Felt(builder.poseidon2_hash_x(self))
154 }
155 }
156}
157
158impl<C: Config> CanPoseidon2Digest<C> for Array<C, Array<C, Ext<C::F, C::EF>>> {
159 fn p2_digest(&self, builder: &mut Builder<C>) -> DigestVariable<C> {
160 if builder.flags.static_only {
161 let flat_felts: Vec<_> = flatten_fixed(self)
162 .into_iter()
163 .flat_map(|ext| builder.ext2felt_circuit(ext).to_vec())
164 .collect();
165 let digest_vec = builder.p2_hash(&flat_felts);
166 DigestVariable::Var(builder.vec(digest_vec.to_vec()))
167 } else {
168 DigestVariable::Felt(builder.poseidon2_hash_ext(self))
169 }
170 }
171}
172
173fn flatten_fixed<C: Config, V: MemVariable<C>>(arr: &Array<C, Array<C, V>>) -> Vec<V> {
174 arr.vec()
175 .into_iter()
176 .flat_map(|felt_arr| felt_arr.vec())
177 .collect()
178}