openvm_native_recursion/
outer_poseidon2.rs

1use itertools::Itertools;
2use openvm_native_compiler::ir::{Builder, Config, DslIr, Felt, Var};
3use openvm_stark_backend::p3_field::{Field, FieldAlgebra};
4
5use crate::{utils::reduce_32, vars::OuterDigestVariable, OUTER_DIGEST_SIZE};
6
7pub const SPONGE_SIZE: usize = 3;
8pub const RATE: usize = 2;
9const POSEIDON_CELL_TRACKER_NAME: &str = "PoseidonCell";
10
11pub trait Poseidon2CircuitBuilder<C: Config> {
12    fn p2_permute_mut(&mut self, state: [Var<C::N>; SPONGE_SIZE]);
13    #[allow(dead_code)]
14    fn p2_hash(&mut self, input: &[Felt<C::F>]) -> OuterDigestVariable<C>;
15    #[allow(dead_code)]
16    fn p2_compress(&mut self, input: [OuterDigestVariable<C>; RATE]) -> OuterDigestVariable<C>;
17}
18
19impl<C: Config> Poseidon2CircuitBuilder<C> for Builder<C> {
20    fn p2_permute_mut(&mut self, state: [Var<C::N>; SPONGE_SIZE]) {
21        self.cycle_tracker_start(POSEIDON_CELL_TRACKER_NAME);
22        p2_permute_mut_impl(self, state);
23        self.cycle_tracker_end(POSEIDON_CELL_TRACKER_NAME);
24    }
25
26    fn p2_hash(&mut self, input: &[Felt<C::F>]) -> OuterDigestVariable<C> {
27        self.cycle_tracker_start(POSEIDON_CELL_TRACKER_NAME);
28        assert_eq!(C::N::bits(), openvm_stark_sdk::p3_bn254_fr::Bn254Fr::bits());
29        assert_eq!(
30            C::F::bits(),
31            openvm_stark_sdk::p3_baby_bear::BabyBear::bits()
32        );
33        let num_f_elms = C::N::bits() / C::F::bits();
34        let mut state: [Var<C::N>; SPONGE_SIZE] = [
35            self.eval(C::N::ZERO),
36            self.eval(C::N::ZERO),
37            self.eval(C::N::ZERO),
38        ];
39        // <Poseidon2 RATE> * <Felt per Var>
40        let felt_per_chunk = RATE * num_f_elms;
41        for block_chunk in &input.iter().chunks(felt_per_chunk) {
42            for (chunk_id, chunk) in (&block_chunk.chunks(num_f_elms)).into_iter().enumerate() {
43                let chunk = chunk.collect_vec().into_iter().copied().collect::<Vec<_>>();
44                state[chunk_id] = reduce_32(self, chunk.as_slice());
45            }
46            p2_permute_mut_impl(self, state);
47        }
48        self.cycle_tracker_end(POSEIDON_CELL_TRACKER_NAME);
49
50        [state[0]]
51    }
52
53    fn p2_compress(&mut self, input: [OuterDigestVariable<C>; 2]) -> OuterDigestVariable<C> {
54        self.cycle_tracker_start(POSEIDON_CELL_TRACKER_NAME);
55        let state: [Var<C::N>; SPONGE_SIZE] = [
56            self.eval(input[0][0]),
57            self.eval(input[1][0]),
58            self.eval(C::N::ZERO),
59        ];
60        p2_permute_mut_impl(self, state);
61        self.cycle_tracker_end(POSEIDON_CELL_TRACKER_NAME);
62        [state[0]; OUTER_DIGEST_SIZE]
63    }
64}
65
66fn p2_permute_mut_impl<C: Config>(builder: &mut Builder<C>, state: [Var<C::N>; SPONGE_SIZE]) {
67    builder.push(DslIr::CircuitPoseidon2Permute(state))
68}