poseidon_primitives/poseidon/primitives/grain.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
//! The Grain LFSR in self-shrinking mode, as used by Poseidon.
use std::marker::PhantomData;
use bitvec::prelude::*;
use ff::{FromUniformBytes, PrimeField};
const STATE: usize = 80;
#[derive(Debug, Clone, Copy)]
pub(super) enum FieldType {
/// GF(2^n)
#[allow(dead_code)]
Binary,
/// GF(p)
PrimeOrder,
}
impl FieldType {
fn tag(&self) -> u8 {
match self {
FieldType::Binary => 0,
FieldType::PrimeOrder => 1,
}
}
}
#[derive(Debug, Clone, Copy)]
pub(super) enum SboxType {
/// x^alpha
Pow,
/// x^(-1)
#[allow(dead_code)]
Inv,
}
impl SboxType {
fn tag(&self) -> u8 {
match self {
SboxType::Pow => 0,
SboxType::Inv => 1,
}
}
}
pub(super) struct Grain<F> {
state: BitArr!(for 80, in u8, Msb0),
next_bit: usize,
_field: PhantomData<F>,
}
impl<F: PrimeField> Grain<F> {
pub(super) fn new(sbox: SboxType, t: u16, r_f: u16, r_p: u16) -> Self {
// Initialize the LFSR state.
let mut state = bitarr![u8, Msb0; 1; STATE];
let mut set_bits = |offset: usize, len, value| {
// Poseidon reference impl sets initial state bits in MSB order.
for i in 0..len {
*state.get_mut(offset + len - 1 - i).unwrap() = (value >> i) & 1 != 0;
}
};
set_bits(0, 2, FieldType::PrimeOrder.tag() as u16);
set_bits(2, 4, sbox.tag() as u16);
set_bits(6, 12, F::NUM_BITS as u16);
set_bits(18, 12, t);
set_bits(30, 10, r_f);
set_bits(40, 10, r_p);
let mut grain = Grain {
state,
next_bit: STATE,
_field: PhantomData::default(),
};
// Discard the first 160 bits.
for _ in 0..20 {
grain.load_next_8_bits();
grain.next_bit = STATE;
}
grain
}
fn load_next_8_bits(&mut self) {
let mut new_bits = 0u8;
for i in 0..8 {
new_bits |= ((self.state[i + 62]
^ self.state[i + 51]
^ self.state[i + 38]
^ self.state[i + 23]
^ self.state[i + 13]
^ self.state[i]) as u8)
<< i;
}
self.state.rotate_left(8);
self.next_bit -= 8;
for i in 0..8 {
*self.state.get_mut(self.next_bit + i).unwrap() = (new_bits >> i) & 1 != 0;
}
}
fn get_next_bit(&mut self) -> bool {
if self.next_bit == STATE {
self.load_next_8_bits();
}
let ret = self.state[self.next_bit];
self.next_bit += 1;
ret
}
/// Returns the next field element from this Grain instantiation.
pub(super) fn next_field_element(&mut self) -> F {
// Loop until we get an element in the field.
loop {
let mut bytes = F::Repr::default();
// Poseidon reference impl interprets the bits as a repr in MSB order, because
// it's easy to do that in Python. Meanwhile, our field elements all use LSB
// order. There's little motivation to diverge from the reference impl; these
// are all constants, so we aren't introducing big-endianness into the rest of
// the circuit (assuming unkeyed Poseidon, but we probably wouldn't want to
// implement Grain inside a circuit, so we'd use a different round constant
// derivation function there).
let view = bytes.as_mut();
for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() {
// If we diverged from the reference impl and interpreted the bits in LSB
// order, we would remove this line.
let i = F::NUM_BITS as usize - 1 - i;
view[i / 8] |= if bit { 1 << (i % 8) } else { 0 };
}
if let Some(f) = F::from_repr_vartime(bytes) {
break f;
}
}
}
/// Returns the next field element from this Grain instantiation, without using
/// rejection sampling.
pub(super) fn next_field_element_without_rejection(&mut self) -> F
where
F: FromUniformBytes<64>,
{
let mut bytes = [0u8; 64];
// Poseidon reference impl interprets the bits as a repr in MSB order, because
// it's easy to do that in Python. Additionally, it does not use rejection
// sampling in cases where the constants don't specifically need to be uniformly
// random for security. We do not provide APIs that take a field-element-sized
// array and reduce it modulo the field order, because those are unsafe APIs to
// offer generally (accidentally using them can lead to divergence in consensus
// systems due to not rejecting canonical forms).
//
// Given that we don't want to diverge from the reference implementation, we hack
// around this restriction by serializing the bits into a 64-byte array and then
// calling F::from_bytes_wide. PLEASE DO NOT COPY THIS INTO YOUR OWN CODE!
let view = bytes.as_mut();
for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() {
// If we diverged from the reference impl and interpreted the bits in LSB
// order, we would remove this line.
let i = F::NUM_BITS as usize - 1 - i;
view[i / 8] |= if bit { 1 << (i % 8) } else { 0 };
}
F::from_uniform_bytes(&bytes)
}
}
impl<F: PrimeField> Iterator for Grain<F> {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
// Evaluate bits in pairs:
// - If the first bit is a 1, output the second bit.
// - If the first bit is a 0, discard the second bit.
while !self.get_next_bit() {
self.get_next_bit();
}
Some(self.get_next_bit())
}
}
#[cfg(test)]
mod tests {
use super::super::pasta::Fp;
use super::{Grain, SboxType};
#[test]
fn grain() {
let mut grain = Grain::<Fp>::new(SboxType::Pow, 3, 8, 56);
let _f = grain.next_field_element();
}
}