openvm_native_compiler/constraints/halo2/
poseidon2_perm.rs

1//! Halo2 implementation of poseidon2 perm for Bn254Fr
2//! sbox degree 5
3
4use snark_verifier_sdk::snark_verifier::halo2_base::{
5    gates::GateInstructions,
6    safe_types::SafeBool,
7    utils::ScalarField,
8    AssignedValue, Context,
9    QuantumCell::{self, Constant},
10};
11
12#[derive(Clone, Debug)]
13pub struct Poseidon2State<F: ScalarField, const T: usize> {
14    pub s: [AssignedValue<F>; T],
15}
16
17#[derive(Debug, Clone)]
18pub struct Poseidon2Params<F: ScalarField, const T: usize> {
19    /// Number of full rounds
20    pub rounds_f: usize,
21    pub rounds_p: usize,
22    pub mat_internal_diag_m_1: [F; T],
23    pub external_rc: Vec<[F; T]>,
24    pub internal_rc: Vec<F>,
25}
26
27impl<F: ScalarField, const T: usize> Poseidon2Params<F, T> {
28    pub fn new(
29        rounds_f: usize,
30        rounds_p: usize,
31        mat_internal_diag_m_1: [F; T],
32        external_rc: Vec<[F; T]>,
33        internal_rc: Vec<F>,
34    ) -> Self {
35        Self {
36            rounds_f,
37            rounds_p,
38            mat_internal_diag_m_1,
39            external_rc,
40            internal_rc,
41        }
42    }
43}
44
45impl<F: ScalarField, const T: usize> Poseidon2State<F, T> {
46    pub fn new(state: [AssignedValue<F>; T]) -> Self {
47        Self { s: state }
48    }
49    /// Perform permutation on this state.
50    ///
51    /// ATTENTION: inputs.len() needs to be fixed at compile time.
52    pub fn permutation(
53        &mut self,
54        ctx: &mut Context<F>,
55        gate: &impl GateInstructions<F>,
56        params: &Poseidon2Params<F, T>,
57    ) {
58        let rounds_f_beginning = params.rounds_f / 2;
59
60        // First half of the full round
61        self.matmul_external(ctx, gate);
62        for r in 0..rounds_f_beginning {
63            self.add_rc(ctx, gate, params.external_rc[r]);
64            self.sbox(ctx, gate);
65            self.matmul_external(ctx, gate);
66        }
67
68        for r in 0..params.rounds_p {
69            self.s[0] = gate.add(ctx, self.s[0], Constant(params.internal_rc[r]));
70            self.s[0] = Self::x_power5(ctx, gate, self.s[0]);
71            self.matmul_internal(ctx, gate, params.mat_internal_diag_m_1);
72        }
73
74        for r in rounds_f_beginning..params.rounds_f {
75            self.add_rc(ctx, gate, params.external_rc[r]);
76            self.sbox(ctx, gate);
77            self.matmul_external(ctx, gate);
78        }
79    }
80
81    /// Constrains and set self to a specific state if `selector` is true.
82    pub fn select(
83        &mut self,
84        ctx: &mut Context<F>,
85        gate: &impl GateInstructions<F>,
86        selector: SafeBool<F>,
87        set_to: &Self,
88    ) {
89        for i in 0..T {
90            self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref());
91        }
92    }
93
94    fn x_power5(
95        ctx: &mut Context<F>,
96        gate: &impl GateInstructions<F>,
97        x: AssignedValue<F>,
98    ) -> AssignedValue<F> {
99        let x2 = gate.mul(ctx, x, x);
100        let x4 = gate.mul(ctx, x2, x2);
101        gate.mul(ctx, x, x4)
102    }
103
104    fn sbox(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
105        for x in self.s.iter_mut() {
106            *x = Self::x_power5(ctx, gate, *x);
107        }
108    }
109
110    fn matmul_external(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
111        // Only doing T = 3 case
112        assert_eq!(T, 3);
113
114        // Matrix is circ(2, 1, 1)
115        let sum = gate.sum(ctx, self.s.iter().copied());
116        for (i, x) in self.s.iter_mut().enumerate() {
117            // This is the same as `*x = gate.add(ctx, *x, sum)` but we save a cell by reusing
118            // `sum`:
119            if i % 2 == 0 {
120                ctx.assign_region(
121                    [
122                        QuantumCell::Witness(*x.value() + sum.value()),
123                        QuantumCell::Existing(*x),
124                        QuantumCell::Constant(-F::ONE),
125                        QuantumCell::Existing(sum),
126                    ],
127                    [0],
128                );
129                *x = ctx.get(-4);
130            } else {
131                ctx.assign_region(
132                    [
133                        QuantumCell::Existing(*x),
134                        QuantumCell::Constant(F::ONE),
135                        QuantumCell::Witness(*x.value() + sum.value()),
136                    ],
137                    [-1],
138                );
139                *x = ctx.get(-1);
140            }
141        }
142    }
143
144    fn add_rc(
145        &mut self,
146        ctx: &mut Context<F>,
147        gate: &impl GateInstructions<F>,
148        round_constants: [F; T],
149    ) {
150        for (x, rc) in self.s.iter_mut().zip(round_constants.iter()) {
151            *x = gate.add(ctx, *x, Constant(*rc));
152        }
153    }
154
155    fn matmul_internal(
156        &mut self,
157        ctx: &mut Context<F>,
158        gate: &impl GateInstructions<F>,
159        mat_internal_diag_m_1: [F; T],
160    ) {
161        assert_eq!(T, 3);
162        let sum = gate.sum(ctx, self.s.iter().copied());
163        for i in 0..T {
164            // This is the same as `self.s[i] = gate.mul_add(ctx, self.s[i],
165            // Constant(mat_internal_diag_m_1[i]), sum)` but we save a cell by reusing `sum`.
166            if i % 2 == 0 {
167                ctx.assign_region(
168                    [
169                        QuantumCell::Witness(
170                            *self.s[i].value() * mat_internal_diag_m_1[i] + sum.value(),
171                        ),
172                        QuantumCell::Existing(self.s[i]),
173                        QuantumCell::Constant(-mat_internal_diag_m_1[i]),
174                        QuantumCell::Existing(sum),
175                    ],
176                    [0],
177                );
178                self.s[i] = ctx.get(-4);
179            } else {
180                ctx.assign_region(
181                    [
182                        QuantumCell::Existing(self.s[i]),
183                        QuantumCell::Constant(mat_internal_diag_m_1[i]),
184                        QuantumCell::Witness(
185                            *self.s[i].value() * mat_internal_diag_m_1[i] + sum.value(),
186                        ),
187                    ],
188                    [-1],
189                );
190                self.s[i] = ctx.get(-1);
191            }
192        }
193    }
194}