p3_bn254_fr/
poseidon2.rs
1use std::sync::OnceLock;
6
7use p3_field::FieldAlgebra;
8use p3_poseidon2::{
9 add_rc_and_sbox_generic, external_initial_permute_state, external_terminal_permute_state,
10 internal_permute_state, matmul_internal, ExternalLayer, ExternalLayerConstants,
11 ExternalLayerConstructor, HLMDSMat4, InternalLayer, InternalLayerConstructor, Poseidon2,
12};
13
14use crate::Bn254Fr;
15
16const BN254_S_BOX_DEGREE: u64 = 5;
20
21pub type Poseidon2Bn254<const WIDTH: usize> = Poseidon2<
25 Bn254Fr,
26 Poseidon2ExternalLayerBn254<WIDTH>,
27 Poseidon2InternalLayerBn254,
28 WIDTH,
29 BN254_S_BOX_DEGREE,
30>;
31
32const BN254_WIDTH: usize = 3;
34
35#[inline]
36fn get_diffusion_matrix_3() -> &'static [Bn254Fr; 3] {
37 static MAT_DIAG3_M_1: OnceLock<[Bn254Fr; 3]> = OnceLock::new();
38 MAT_DIAG3_M_1.get_or_init(|| [Bn254Fr::ONE, Bn254Fr::ONE, Bn254Fr::TWO])
39}
40
41#[derive(Debug, Clone, Default)]
42pub struct Poseidon2InternalLayerBn254 {
43 internal_constants: Vec<Bn254Fr>,
44}
45
46impl InternalLayerConstructor<Bn254Fr> for Poseidon2InternalLayerBn254 {
47 fn new_from_constants(internal_constants: Vec<Bn254Fr>) -> Self {
48 Self { internal_constants }
49 }
50}
51
52impl InternalLayer<Bn254Fr, BN254_WIDTH, BN254_S_BOX_DEGREE> for Poseidon2InternalLayerBn254 {
53 fn permute_state(&self, state: &mut [Bn254Fr; BN254_WIDTH]) {
55 internal_permute_state::<Bn254Fr, BN254_WIDTH, BN254_S_BOX_DEGREE>(
56 state,
57 |x| matmul_internal(x, *get_diffusion_matrix_3()),
58 &self.internal_constants,
59 )
60 }
61}
62
63pub type Poseidon2ExternalLayerBn254<const WIDTH: usize> = ExternalLayerConstants<Bn254Fr, WIDTH>;
64
65impl<const WIDTH: usize> ExternalLayerConstructor<Bn254Fr, WIDTH>
66 for Poseidon2ExternalLayerBn254<WIDTH>
67{
68 fn new_from_constants(external_constants: ExternalLayerConstants<Bn254Fr, WIDTH>) -> Self {
69 external_constants
70 }
71}
72
73impl<const WIDTH: usize> ExternalLayer<Bn254Fr, WIDTH, BN254_S_BOX_DEGREE>
74 for Poseidon2ExternalLayerBn254<WIDTH>
75{
76 fn permute_state_initial(&self, state: &mut [Bn254Fr; WIDTH]) {
78 external_initial_permute_state(
79 state,
80 self.get_initial_constants(),
81 add_rc_and_sbox_generic::<_, BN254_S_BOX_DEGREE>,
82 &HLMDSMat4,
83 );
84 }
85
86 fn permute_state_terminal(&self, state: &mut [Bn254Fr; WIDTH]) {
88 external_terminal_permute_state(
89 state,
90 self.get_terminal_constants(),
91 add_rc_and_sbox_generic::<_, BN254_S_BOX_DEGREE>,
92 &HLMDSMat4,
93 );
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use ff::PrimeField;
100 use p3_poseidon2::ExternalLayerConstants;
101 use p3_symmetric::Permutation;
102 use rand::Rng;
103 use zkhash::ark_ff::{BigInteger, PrimeField as ark_PrimeField};
104 use zkhash::fields::bn256::FpBN256 as ark_FpBN256;
105 use zkhash::poseidon2::poseidon2::Poseidon2 as Poseidon2Ref;
106 use zkhash::poseidon2::poseidon2_instance_bn256::{POSEIDON2_BN256_PARAMS, RC3};
107
108 use super::*;
109 use crate::FFBn254Fr;
110
111 fn bn254_from_ark_ff(input: ark_FpBN256) -> Bn254Fr {
112 let bytes = input.into_bigint().to_bytes_le();
113
114 let mut res = <FFBn254Fr as PrimeField>::Repr::default();
115
116 for (i, digit) in res.as_mut().iter_mut().enumerate() {
117 *digit = bytes[i];
118 }
119
120 let value = FFBn254Fr::from_repr(res);
121
122 if value.is_some().into() {
123 Bn254Fr {
124 value: value.unwrap(),
125 }
126 } else {
127 panic!("Invalid field element")
128 }
129 }
130
131 #[test]
132 fn test_poseidon2_bn254() {
133 const WIDTH: usize = 3;
134 const ROUNDS_F: usize = 8;
135 const ROUNDS_P: usize = 56;
136
137 type F = Bn254Fr;
138
139 let mut rng = rand::thread_rng();
140
141 let poseidon2_ref = Poseidon2Ref::new(&POSEIDON2_BN256_PARAMS);
143
144 let mut round_constants: Vec<[F; WIDTH]> = RC3
146 .iter()
147 .map(|vec| {
148 vec.iter()
149 .cloned()
150 .map(bn254_from_ark_ff)
151 .collect::<Vec<_>>()
152 .try_into()
153 .unwrap()
154 })
155 .collect();
156
157 let internal_start = ROUNDS_F / 2;
158 let internal_end = (ROUNDS_F / 2) + ROUNDS_P;
159 let internal_round_constants = round_constants
160 .drain(internal_start..internal_end)
161 .map(|vec| vec[0])
162 .collect::<Vec<_>>();
163 let external_round_constants = ExternalLayerConstants::new(
164 round_constants[..(ROUNDS_F / 2)].to_vec(),
165 round_constants[(ROUNDS_F / 2)..].to_vec(),
166 );
167 let poseidon2 = Poseidon2Bn254::new(external_round_constants, internal_round_constants);
169
170 let input_ark_ff = rng.gen::<[ark_FpBN256; WIDTH]>();
172 let input: [Bn254Fr; 3] = input_ark_ff
173 .iter()
174 .cloned()
175 .map(bn254_from_ark_ff)
176 .collect::<Vec<_>>()
177 .try_into()
178 .unwrap();
179
180 let output_ref = poseidon2_ref.permutation(&input_ark_ff);
182
183 let expected: [F; WIDTH] = output_ref
184 .iter()
185 .cloned()
186 .map(bn254_from_ark_ff)
187 .collect::<Vec<_>>()
188 .try_into()
189 .unwrap();
190
191 let mut output = input;
193 poseidon2.permute_mut(&mut output);
194
195 assert_eq!(output, expected);
196 }
197}