p3_poseidon2/
external.rs

1use alloc::vec::Vec;
2
3use p3_field::FieldAlgebra;
4use p3_mds::MdsPermutation;
5use p3_symmetric::Permutation;
6use rand::distributions::{Distribution, Standard};
7use rand::Rng;
8
9/// Multiply a 4-element vector x by
10/// [ 5 7 1 3 ]
11/// [ 4 6 1 1 ]
12/// [ 1 3 5 7 ]
13/// [ 1 1 4 6 ].
14/// This uses the formula from the start of Appendix B in the Poseidon2 paper, with multiplications unrolled into additions.
15/// It is also the matrix used by the Horizon Labs implementation.
16#[inline(always)]
17fn apply_hl_mat4<FA>(x: &mut [FA; 4])
18where
19    FA: FieldAlgebra,
20{
21    let t0 = x[0].clone() + x[1].clone();
22    let t1 = x[2].clone() + x[3].clone();
23    let t2 = x[1].clone() + x[1].clone() + t1.clone();
24    let t3 = x[3].clone() + x[3].clone() + t0.clone();
25    let t4 = t1.double().double() + t3.clone();
26    let t5 = t0.double().double() + t2.clone();
27    let t6 = t3 + t5.clone();
28    let t7 = t2 + t4.clone();
29    x[0] = t6;
30    x[1] = t5;
31    x[2] = t7;
32    x[3] = t4;
33}
34
35// It turns out we can find a 4x4 matrix which is more efficient than the above.
36
37/// Multiply a 4-element vector x by:
38/// [ 2 3 1 1 ]
39/// [ 1 2 3 1 ]
40/// [ 1 1 2 3 ]
41/// [ 3 1 1 2 ].
42#[inline(always)]
43fn apply_mat4<FA>(x: &mut [FA; 4])
44where
45    FA: FieldAlgebra,
46{
47    let t01 = x[0].clone() + x[1].clone();
48    let t23 = x[2].clone() + x[3].clone();
49    let t0123 = t01.clone() + t23.clone();
50    let t01123 = t0123.clone() + x[1].clone();
51    let t01233 = t0123.clone() + x[3].clone();
52    // The order here is important. Need to overwrite x[0] and x[2] after x[1] and x[3].
53    x[3] = t01233.clone() + x[0].double(); // 3*x[0] + x[1] + x[2] + 2*x[3]
54    x[1] = t01123.clone() + x[2].double(); // x[0] + 2*x[1] + 3*x[2] + x[3]
55    x[0] = t01123 + t01; // 2*x[0] + 3*x[1] + x[2] + x[3]
56    x[2] = t01233 + t23; // x[0] + x[1] + 2*x[2] + 3*x[3]
57}
58
59/// The 4x4 MDS matrix used by the Horizon Labs implementation of Poseidon2.
60///
61/// This requires 10 additions and 4 doubles to compute.
62#[derive(Clone, Default)]
63pub struct HLMDSMat4;
64
65impl<FA: FieldAlgebra> Permutation<[FA; 4]> for HLMDSMat4 {
66    #[inline(always)]
67    fn permute_mut(&self, input: &mut [FA; 4]) {
68        apply_hl_mat4(input)
69    }
70}
71impl<FA: FieldAlgebra> MdsPermutation<FA, 4> for HLMDSMat4 {}
72
73/// The fastest 4x4 MDS matrix.
74///
75/// This requires 7 additions and 2 doubles to compute.
76#[derive(Clone, Default)]
77pub struct MDSMat4;
78
79impl<FA: FieldAlgebra> Permutation<[FA; 4]> for MDSMat4 {
80    #[inline(always)]
81    fn permute_mut(&self, input: &mut [FA; 4]) {
82        apply_mat4(input)
83    }
84}
85impl<FA: FieldAlgebra> MdsPermutation<FA, 4> for MDSMat4 {}
86
87/// Implement the matrix multiplication used by the external layer.
88///
89/// Given a 4x4 MDS matrix M, we multiply by the `4N x 4N` matrix
90/// `[[2M M  ... M], [M  2M ... M], ..., [M  M ... 2M]]`.
91#[inline(always)]
92pub fn mds_light_permutation<
93    FA: FieldAlgebra,
94    MdsPerm4: MdsPermutation<FA, 4>,
95    const WIDTH: usize,
96>(
97    state: &mut [FA; WIDTH],
98    mdsmat: &MdsPerm4,
99) {
100    match WIDTH {
101        2 => {
102            let sum = state[0].clone() + state[1].clone();
103            state[0] += sum.clone();
104            state[1] += sum;
105        }
106
107        3 => {
108            let sum = state[0].clone() + state[1].clone() + state[2].clone();
109            state[0] += sum.clone();
110            state[1] += sum.clone();
111            state[2] += sum;
112        }
113
114        4 | 8 | 12 | 16 | 20 | 24 => {
115            // First, we apply M_4 to each consecutive four elements of the state.
116            // In Appendix B's terminology, this replaces each x_i with x_i'.
117            for chunk in state.chunks_exact_mut(4) {
118                mdsmat.permute_mut(chunk.try_into().unwrap());
119            }
120            // Now, we apply the outer circulant matrix (to compute the y_i values).
121
122            // We first precompute the four sums of every four elements.
123            let sums: [FA; 4] = core::array::from_fn(|k| {
124                (0..WIDTH)
125                    .step_by(4)
126                    .map(|j| state[j + k].clone())
127                    .sum::<FA>()
128            });
129
130            // The formula for each y_i involves 2x_i' term and x_j' terms for each j that equals i mod 4.
131            // In other words, we can add a single copy of x_i' to the appropriate one of our precomputed sums
132            state
133                .iter_mut()
134                .enumerate()
135                .for_each(|(i, elem)| *elem += sums[i % 4].clone());
136        }
137
138        _ => {
139            panic!("Unsupported width");
140        }
141    }
142}
143
144/// A struct which holds the constants for the external layer.
145#[derive(Debug, Clone)]
146pub struct ExternalLayerConstants<T, const WIDTH: usize> {
147    // Once initialised, these constants should be immutable.
148    initial: Vec<[T; WIDTH]>,
149    terminal: Vec<[T; WIDTH]>, // We use terminal instead of final as final is a reserved keyword.
150}
151
152impl<T, const WIDTH: usize> ExternalLayerConstants<T, WIDTH> {
153    pub fn new(initial: Vec<[T; WIDTH]>, terminal: Vec<[T; WIDTH]>) -> Self {
154        assert_eq!(
155            initial.len(),
156            terminal.len(),
157            "The number of initial and terminal external rounds should be equal."
158        );
159        Self { initial, terminal }
160    }
161
162    pub fn new_from_rng<R: Rng>(external_round_number: usize, rng: &mut R) -> Self
163    where
164        Standard: Distribution<[T; WIDTH]>,
165    {
166        let half_f = external_round_number / 2;
167        assert_eq!(
168            2 * half_f,
169            external_round_number,
170            "The total number of external rounds should be even"
171        );
172        let initial_constants = rng.sample_iter(Standard).take(half_f).collect();
173        let terminal_constants = rng.sample_iter(Standard).take(half_f).collect();
174
175        Self::new(initial_constants, terminal_constants)
176    }
177
178    pub fn new_from_saved_array<U, const N: usize>(
179        [initial, terminal]: [[[U; WIDTH]; N]; 2],
180        conversion_fn: fn([U; WIDTH]) -> [T; WIDTH],
181    ) -> Self
182    where
183        T: Clone,
184    {
185        let initial_consts = initial.map(conversion_fn).to_vec();
186        let terminal_consts = terminal.map(conversion_fn).to_vec();
187        Self::new(initial_consts, terminal_consts)
188    }
189
190    pub fn get_initial_constants(&self) -> &Vec<[T; WIDTH]> {
191        &self.initial
192    }
193
194    pub fn get_terminal_constants(&self) -> &Vec<[T; WIDTH]> {
195        &self.terminal
196    }
197}
198
199/// Initialize an external layer from a set of constants.
200pub trait ExternalLayerConstructor<FA, const WIDTH: usize>
201where
202    FA: FieldAlgebra,
203{
204    /// A constructor which internally will convert the supplied
205    /// constants into the appropriate form for the implementation.
206    fn new_from_constants(external_constants: ExternalLayerConstants<FA::F, WIDTH>) -> Self;
207}
208
209/// A trait containing all data needed to implement the external layers of Poseidon2.
210pub trait ExternalLayer<FA, const WIDTH: usize, const D: u64>: Sync + Clone
211where
212    FA: FieldAlgebra,
213{
214    // permute_state_initial, permute_state_terminal are split as the Poseidon2 specifications are slightly different
215    // with the initial rounds involving an extra matrix multiplication.
216
217    /// Perform the initial external layers of the Poseidon2 permutation on the given state.
218    fn permute_state_initial(&self, state: &mut [FA; WIDTH]);
219
220    /// Perform the terminal external layers of the Poseidon2 permutation on the given state.
221    fn permute_state_terminal(&self, state: &mut [FA; WIDTH]);
222}
223
224/// A helper method which allow any field to easily implement the terminal External Layer.
225#[inline]
226pub fn external_terminal_permute_state<
227    FA: FieldAlgebra,
228    CT: Copy, // Whatever type the constants are stored as.
229    MdsPerm4: MdsPermutation<FA, 4>,
230    const WIDTH: usize,
231>(
232    state: &mut [FA; WIDTH],
233    terminal_external_constants: &[[CT; WIDTH]],
234    add_rc_and_sbox: fn(&mut FA, CT),
235    mat4: &MdsPerm4,
236) {
237    for elem in terminal_external_constants.iter() {
238        state
239            .iter_mut()
240            .zip(elem.iter())
241            .for_each(|(s, &rc)| add_rc_and_sbox(s, rc));
242        mds_light_permutation(state, mat4);
243    }
244}
245
246/// A helper method which allow any field to easily implement the initial External Layer.
247#[inline]
248pub fn external_initial_permute_state<
249    FA: FieldAlgebra,
250    CT: Copy, // Whatever type the constants are stored as.
251    MdsPerm4: MdsPermutation<FA, 4>,
252    const WIDTH: usize,
253>(
254    state: &mut [FA; WIDTH],
255    initial_external_constants: &[[CT; WIDTH]],
256    add_rc_and_sbox: fn(&mut FA, CT),
257    mat4: &MdsPerm4,
258) {
259    mds_light_permutation(state, mat4);
260    // After the initial mds_light_permutation, the remaining layers are identical
261    // to the terminal permutation simply with different constants.
262    external_terminal_permute_state(state, initial_external_constants, add_rc_and_sbox, mat4)
263}