openvm_native_compiler/ir/
poseidon.rs

1use openvm_native_compiler_derive::iter_zip;
2use openvm_stark_backend::p3_field::FieldAlgebra;
3
4use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var};
5
6pub const DIGEST_SIZE: usize = 8;
7pub const HASH_RATE: usize = 8;
8pub const PERMUTATION_WIDTH: usize = 16;
9
10impl<C: Config> Builder<C> {
11    /// Applies the Poseidon2 permutation to the given array.
12    ///
13    /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html)
14    pub fn poseidon2_permute(&mut self, array: &Array<C, Felt<C::F>>) -> Array<C, Felt<C::F>> {
15        let output = match array {
16            Array::Fixed(values) => {
17                assert_eq!(values.borrow().len(), PERMUTATION_WIDTH);
18                self.dyn_array::<Felt<C::F>>(Usize::from(PERMUTATION_WIDTH))
19            }
20            Array::Dyn(_, len) => self.dyn_array::<Felt<C::F>>(len.clone()),
21        };
22        self.operations.push(DslIr::Poseidon2PermuteBabyBear(
23            output.clone(),
24            array.clone(),
25        ));
26        output
27    }
28
29    /// Applies the Poseidon2 permutation to the given array.
30    ///
31    /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html)
32    pub fn poseidon2_permute_mut(&mut self, array: &Array<C, Felt<C::F>>) {
33        if let Array::Fixed(_) = array {
34            panic!("Poseidon2 permutation is not allowed on fixed arrays");
35        }
36        self.operations.push(DslIr::Poseidon2PermuteBabyBear(
37            array.clone(),
38            array.clone(),
39        ));
40    }
41
42    /// Applies the Poseidon2 compression function to the given array.
43    ///
44    /// [Reference](https://docs.rs/p3-symmetric/latest/p3_symmetric/struct.TruncatedPermutation.html)
45    pub fn poseidon2_compress(
46        &mut self,
47        left: &Array<C, Felt<C::F>>,
48        right: &Array<C, Felt<C::F>>,
49    ) -> Array<C, Felt<C::F>> {
50        let perm_width = PERMUTATION_WIDTH;
51        let input = self.dyn_array(perm_width);
52        for i in 0..DIGEST_SIZE {
53            let a = self.get(left, i);
54            let b = self.get(right, i);
55            self.set(&input, i, a);
56            self.set(&input, i + DIGEST_SIZE, b);
57        }
58        self.poseidon2_permute_mut(&input);
59        input
60    }
61
62    /// Applies the Poseidon2 compression to the given array.
63    ///
64    /// [Reference](https://docs.rs/p3-symmetric/latest/p3_symmetric/struct.TruncatedPermutation.html)
65    pub fn poseidon2_compress_x(
66        &mut self,
67        result: &Array<C, Felt<C::F>>,
68        left: &Array<C, Felt<C::F>>,
69        right: &Array<C, Felt<C::F>>,
70    ) {
71        self.operations.push(DslIr::Poseidon2CompressBabyBear(
72            result.clone(),
73            left.clone(),
74            right.clone(),
75        ));
76    }
77
78    pub fn poseidon2_hash_x(
79        &mut self,
80        array: &Array<C, Array<C, Felt<C::F>>>,
81    ) -> Array<C, Felt<C::F>> {
82        self.cycle_tracker_start("poseidon2-hash");
83        let perm_width = PERMUTATION_WIDTH;
84        let state: Array<C, Felt<C::F>> = self.dyn_array(perm_width);
85        self.range(0, perm_width).for_each(|idx_vec, builder| {
86            builder.set(&state, idx_vec[0], C::F::ZERO);
87        });
88
89        let address = self.eval(state.ptr().address);
90        let start: Var<_> = self.eval(address);
91        let end: Var<_> = self.eval(address + C::N::from_canonical_usize(HASH_RATE));
92        iter_zip!(self, array).for_each(|idx_vec, builder| {
93            let subarray = builder.iter_ptr_get(array, idx_vec[0]);
94            iter_zip!(builder, subarray).for_each(|ptr_vec, builder| {
95                let element = builder.iter_ptr_get(&subarray, ptr_vec[0]);
96                builder.cycle_tracker_start("poseidon2-hash-setup");
97                builder.store(
98                    Ptr { address },
99                    MemIndex {
100                        index: 0.into(),
101                        offset: 0,
102                        size: 1,
103                    },
104                    element,
105                );
106                builder.assign(&address, address + C::N::ONE);
107                builder.cycle_tracker_end("poseidon2-hash-setup");
108                builder.if_eq(address, end).then(|builder| {
109                    builder.poseidon2_permute_mut(&state);
110                    builder.assign(&address, start);
111                });
112            });
113        });
114
115        self.if_ne(address, start).then(|builder| {
116            builder.poseidon2_permute_mut(&state);
117        });
118
119        state.truncate(self, Usize::from(DIGEST_SIZE));
120        self.cycle_tracker_end("poseidon2-hash");
121        state
122    }
123
124    pub fn poseidon2_hash_ext(
125        &mut self,
126        array: &Array<C, Array<C, Ext<C::F, C::EF>>>,
127    ) -> Array<C, Felt<C::F>> {
128        self.cycle_tracker_start("poseidon2-hash-ext");
129        let hash_rate = HASH_RATE;
130        let perm_width = PERMUTATION_WIDTH;
131        let state: Array<C, Felt<C::F>> = self.dyn_array(perm_width);
132        self.range(hash_rate, perm_width)
133            .for_each(|i_vec, builder| {
134                builder.set(&state, i_vec[0], C::F::ZERO);
135            });
136
137        let idx: Var<_> = self.eval(C::N::ZERO);
138        self.range(0, array.len()).for_each(|i_vec, builder| {
139            let subarray = builder.get(array, i_vec[0]);
140            builder.range(0, subarray.len()).for_each(|j_vec, builder| {
141                let element = builder.get(&subarray, j_vec[0]);
142                let felts = builder.ext2felt(element);
143                for i in 0..4 {
144                    let felt = builder.get(&felts, i);
145                    builder.set_value(&state, idx, felt);
146                    builder.assign(&idx, idx + C::N::ONE);
147                    builder
148                        .if_eq(idx, C::N::from_canonical_usize(HASH_RATE))
149                        .then(|builder| {
150                            builder.poseidon2_permute_mut(&state);
151                            builder.assign(&idx, C::N::ZERO);
152                        });
153                }
154            });
155        });
156
157        self.if_ne(idx, C::N::ZERO).then(|builder| {
158            builder.poseidon2_permute_mut(&state);
159        });
160
161        state.truncate(self, Usize::from(DIGEST_SIZE));
162        self.cycle_tracker_end("poseidon2-hash-ext");
163        state
164    }
165}