poseidon_primitives/poseidon/primitives/
grain.rs

1//! The Grain LFSR in self-shrinking mode, as used by Poseidon.
2
3use std::marker::PhantomData;
4
5use bitvec::prelude::*;
6use ff::{FromUniformBytes, PrimeField};
7
8const STATE: usize = 80;
9
10#[derive(Debug, Clone, Copy)]
11pub(super) enum FieldType {
12    /// GF(2^n)
13    #[allow(dead_code)]
14    Binary,
15    /// GF(p)
16    PrimeOrder,
17}
18
19impl FieldType {
20    fn tag(&self) -> u8 {
21        match self {
22            FieldType::Binary => 0,
23            FieldType::PrimeOrder => 1,
24        }
25    }
26}
27
28#[derive(Debug, Clone, Copy)]
29pub(super) enum SboxType {
30    /// x^alpha
31    Pow,
32    /// x^(-1)
33    #[allow(dead_code)]
34    Inv,
35}
36
37impl SboxType {
38    fn tag(&self) -> u8 {
39        match self {
40            SboxType::Pow => 0,
41            SboxType::Inv => 1,
42        }
43    }
44}
45
46pub(super) struct Grain<F> {
47    state: BitArr!(for 80, in u8, Msb0),
48    next_bit: usize,
49    _field: PhantomData<F>,
50}
51
52impl<F: PrimeField> Grain<F> {
53    pub(super) fn new(sbox: SboxType, t: u16, r_f: u16, r_p: u16) -> Self {
54        // Initialize the LFSR state.
55        let mut state = bitarr![u8, Msb0; 1; STATE];
56        let mut set_bits = |offset: usize, len, value| {
57            // Poseidon reference impl sets initial state bits in MSB order.
58            for i in 0..len {
59                *state.get_mut(offset + len - 1 - i).unwrap() = (value >> i) & 1 != 0;
60            }
61        };
62        set_bits(0, 2, FieldType::PrimeOrder.tag() as u16);
63        set_bits(2, 4, sbox.tag() as u16);
64        set_bits(6, 12, F::NUM_BITS as u16);
65        set_bits(18, 12, t);
66        set_bits(30, 10, r_f);
67        set_bits(40, 10, r_p);
68
69        let mut grain = Grain {
70            state,
71            next_bit: STATE,
72            _field: PhantomData,
73        };
74
75        // Discard the first 160 bits.
76        for _ in 0..20 {
77            grain.load_next_8_bits();
78            grain.next_bit = STATE;
79        }
80
81        grain
82    }
83
84    fn load_next_8_bits(&mut self) {
85        let mut new_bits = 0u8;
86        for i in 0..8 {
87            new_bits |= ((self.state[i + 62]
88                ^ self.state[i + 51]
89                ^ self.state[i + 38]
90                ^ self.state[i + 23]
91                ^ self.state[i + 13]
92                ^ self.state[i]) as u8)
93                << i;
94        }
95        self.state.rotate_left(8);
96        self.next_bit -= 8;
97        for i in 0..8 {
98            *self.state.get_mut(self.next_bit + i).unwrap() = (new_bits >> i) & 1 != 0;
99        }
100    }
101
102    fn get_next_bit(&mut self) -> bool {
103        if self.next_bit == STATE {
104            self.load_next_8_bits();
105        }
106        let ret = self.state[self.next_bit];
107        self.next_bit += 1;
108        ret
109    }
110
111    /// Returns the next field element from this Grain instantiation.
112    pub(super) fn next_field_element(&mut self) -> F {
113        // Loop until we get an element in the field.
114        loop {
115            let mut bytes = F::Repr::default();
116
117            // Poseidon reference impl interprets the bits as a repr in MSB order, because
118            // it's easy to do that in Python. Meanwhile, our field elements all use LSB
119            // order. There's little motivation to diverge from the reference impl; these
120            // are all constants, so we aren't introducing big-endianness into the rest of
121            // the circuit (assuming unkeyed Poseidon, but we probably wouldn't want to
122            // implement Grain inside a circuit, so we'd use a different round constant
123            // derivation function there).
124            let view = bytes.as_mut();
125            for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() {
126                // If we diverged from the reference impl and interpreted the bits in LSB
127                // order, we would remove this line.
128                let i = F::NUM_BITS as usize - 1 - i;
129
130                view[i / 8] |= if bit { 1 << (i % 8) } else { 0 };
131            }
132
133            if let Some(f) = F::from_repr_vartime(bytes) {
134                break f;
135            }
136        }
137    }
138
139    /// Returns the next field element from this Grain instantiation, without using
140    /// rejection sampling.
141    pub(super) fn next_field_element_without_rejection(&mut self) -> F
142    where
143        F: FromUniformBytes<64>,
144    {
145        let mut bytes = [0u8; 64];
146
147        // Poseidon reference impl interprets the bits as a repr in MSB order, because
148        // it's easy to do that in Python. Additionally, it does not use rejection
149        // sampling in cases where the constants don't specifically need to be uniformly
150        // random for security. We do not provide APIs that take a field-element-sized
151        // array and reduce it modulo the field order, because those are unsafe APIs to
152        // offer generally (accidentally using them can lead to divergence in consensus
153        // systems due to not rejecting canonical forms).
154        //
155        // Given that we don't want to diverge from the reference implementation, we hack
156        // around this restriction by serializing the bits into a 64-byte array and then
157        // calling F::from_bytes_wide. PLEASE DO NOT COPY THIS INTO YOUR OWN CODE!
158        let view = bytes.as_mut();
159        for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() {
160            // If we diverged from the reference impl and interpreted the bits in LSB
161            // order, we would remove this line.
162            let i = F::NUM_BITS as usize - 1 - i;
163
164            view[i / 8] |= if bit { 1 << (i % 8) } else { 0 };
165        }
166
167        F::from_uniform_bytes(&bytes)
168    }
169}
170
171impl<F: PrimeField> Iterator for Grain<F> {
172    type Item = bool;
173
174    fn next(&mut self) -> Option<Self::Item> {
175        // Evaluate bits in pairs:
176        // - If the first bit is a 1, output the second bit.
177        // - If the first bit is a 0, discard the second bit.
178        while !self.get_next_bit() {
179            self.get_next_bit();
180        }
181        Some(self.get_next_bit())
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::super::pasta::Fp;
188    use super::{Grain, SboxType};
189
190    #[test]
191    fn grain() {
192        let mut grain = Grain::<Fp>::new(SboxType::Pow, 3, 8, 56);
193        let _f = grain.next_field_element();
194    }
195}