openvm_native_compiler/constraints/halo2/
poseidon2_perm.rs

1//! Halo2 implementation of poseidon2 perm for Bn254Fr
2//! sbox degree 5
3
4use snark_verifier_sdk::snark_verifier::halo2_base::{
5    gates::GateInstructions,
6    safe_types::SafeBool,
7    utils::ScalarField,
8    AssignedValue, Context,
9    QuantumCell::{self, Constant},
10};
11
12#[derive(Clone, Debug)]
13pub struct Poseidon2State<F: ScalarField, const T: usize> {
14    pub s: [AssignedValue<F>; T],
15}
16
17#[derive(Debug, Clone)]
18pub struct Poseidon2Params<F: ScalarField, const T: usize> {
19    /// Number of full rounds
20    pub rounds_f: usize,
21    pub rounds_p: usize,
22    pub mat_internal_diag_m_1: [F; T],
23    pub external_rc: Vec<[F; T]>,
24    pub internal_rc: Vec<F>,
25}
26
27impl<F: ScalarField, const T: usize> Poseidon2Params<F, T> {
28    pub fn new(
29        rounds_f: usize,
30        rounds_p: usize,
31        mat_internal_diag_m_1: [F; T],
32        external_rc: Vec<[F; T]>,
33        internal_rc: Vec<F>,
34    ) -> Self {
35        Self {
36            rounds_f,
37            rounds_p,
38            mat_internal_diag_m_1,
39            external_rc,
40            internal_rc,
41        }
42    }
43}
44
45impl<F: ScalarField, const T: usize> Poseidon2State<F, T> {
46    pub fn new(state: [AssignedValue<F>; T]) -> Self {
47        Self { s: state }
48    }
49    /// Perform permutation on this state.
50    ///
51    /// ATTENTION: inputs.len() needs to be fixed at compile time.
52    pub fn permutation(
53        &mut self,
54        ctx: &mut Context<F>,
55        gate: &impl GateInstructions<F>,
56        params: &Poseidon2Params<F, T>,
57    ) {
58        let rounds_f_beginning = params.rounds_f / 2;
59
60        // First half of the full round
61        self.matmul_external(ctx, gate);
62        for r in 0..rounds_f_beginning {
63            self.add_rc(ctx, gate, params.external_rc[r]);
64            self.sbox(ctx, gate);
65            self.matmul_external(ctx, gate);
66        }
67
68        for r in 0..params.rounds_p {
69            self.s[0] = gate.add(ctx, self.s[0], Constant(params.internal_rc[r]));
70            self.s[0] = Self::x_power5(ctx, gate, self.s[0]);
71            self.matmul_internal(ctx, gate, params.mat_internal_diag_m_1);
72        }
73
74        for r in rounds_f_beginning..params.rounds_f {
75            self.add_rc(ctx, gate, params.external_rc[r]);
76            self.sbox(ctx, gate);
77            self.matmul_external(ctx, gate);
78        }
79    }
80
81    /// Constrains and set self to a specific state if `selector` is true.
82    pub fn select(
83        &mut self,
84        ctx: &mut Context<F>,
85        gate: &impl GateInstructions<F>,
86        selector: SafeBool<F>,
87        set_to: &Self,
88    ) {
89        for i in 0..T {
90            self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref());
91        }
92    }
93
94    fn x_power5(
95        ctx: &mut Context<F>,
96        gate: &impl GateInstructions<F>,
97        x: AssignedValue<F>,
98    ) -> AssignedValue<F> {
99        let x2 = gate.mul(ctx, x, x);
100        let x4 = gate.mul(ctx, x2, x2);
101        gate.mul(ctx, x, x4)
102    }
103
104    fn sbox(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
105        for x in self.s.iter_mut() {
106            *x = Self::x_power5(ctx, gate, *x);
107        }
108    }
109
110    fn matmul_external(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
111        // Only doing T = 3 case
112        assert_eq!(T, 3);
113
114        // Matrix is circ(2, 1, 1)
115        let sum = gate.sum(ctx, self.s.iter().copied());
116        for (i, x) in self.s.iter_mut().enumerate() {
117            // This is the same as `*x = gate.add(ctx, *x, sum)` but we save a cell by reusing `sum`:
118            if i % 2 == 0 {
119                ctx.assign_region(
120                    [
121                        QuantumCell::Witness(*x.value() + sum.value()),
122                        QuantumCell::Existing(*x),
123                        QuantumCell::Constant(-F::ONE),
124                        QuantumCell::Existing(sum),
125                    ],
126                    [0],
127                );
128                *x = ctx.get(-4);
129            } else {
130                ctx.assign_region(
131                    [
132                        QuantumCell::Existing(*x),
133                        QuantumCell::Constant(F::ONE),
134                        QuantumCell::Witness(*x.value() + sum.value()),
135                    ],
136                    [-1],
137                );
138                *x = ctx.get(-1);
139            }
140        }
141    }
142
143    fn add_rc(
144        &mut self,
145        ctx: &mut Context<F>,
146        gate: &impl GateInstructions<F>,
147        round_constants: [F; T],
148    ) {
149        for (x, rc) in self.s.iter_mut().zip(round_constants.iter()) {
150            *x = gate.add(ctx, *x, Constant(*rc));
151        }
152    }
153
154    fn matmul_internal(
155        &mut self,
156        ctx: &mut Context<F>,
157        gate: &impl GateInstructions<F>,
158        mat_internal_diag_m_1: [F; T],
159    ) {
160        assert_eq!(T, 3);
161        let sum = gate.sum(ctx, self.s.iter().copied());
162        for i in 0..T {
163            // This is the same as `self.s[i] = gate.mul_add(ctx, self.s[i], Constant(mat_internal_diag_m_1[i]), sum)` but we save a cell by reusing `sum`.
164            if i % 2 == 0 {
165                ctx.assign_region(
166                    [
167                        QuantumCell::Witness(
168                            *self.s[i].value() * mat_internal_diag_m_1[i] + sum.value(),
169                        ),
170                        QuantumCell::Existing(self.s[i]),
171                        QuantumCell::Constant(-mat_internal_diag_m_1[i]),
172                        QuantumCell::Existing(sum),
173                    ],
174                    [0],
175                );
176                self.s[i] = ctx.get(-4);
177            } else {
178                ctx.assign_region(
179                    [
180                        QuantumCell::Existing(self.s[i]),
181                        QuantumCell::Constant(mat_internal_diag_m_1[i]),
182                        QuantumCell::Witness(
183                            *self.s[i].value() * mat_internal_diag_m_1[i] + sum.value(),
184                        ),
185                    ],
186                    [-1],
187                );
188                self.s[i] = ctx.get(-1);
189            }
190        }
191    }
192}