openvm_native_compiler/constraints/halo2/
poseidon2_perm.rs
1use snark_verifier_sdk::snark_verifier::halo2_base::{
5 gates::GateInstructions,
6 safe_types::SafeBool,
7 utils::ScalarField,
8 AssignedValue, Context,
9 QuantumCell::{self, Constant},
10};
11
12#[derive(Clone, Debug)]
13pub struct Poseidon2State<F: ScalarField, const T: usize> {
14 pub s: [AssignedValue<F>; T],
15}
16
17#[derive(Debug, Clone)]
18pub struct Poseidon2Params<F: ScalarField, const T: usize> {
19 pub rounds_f: usize,
21 pub rounds_p: usize,
22 pub mat_internal_diag_m_1: [F; T],
23 pub external_rc: Vec<[F; T]>,
24 pub internal_rc: Vec<F>,
25}
26
27impl<F: ScalarField, const T: usize> Poseidon2Params<F, T> {
28 pub fn new(
29 rounds_f: usize,
30 rounds_p: usize,
31 mat_internal_diag_m_1: [F; T],
32 external_rc: Vec<[F; T]>,
33 internal_rc: Vec<F>,
34 ) -> Self {
35 Self {
36 rounds_f,
37 rounds_p,
38 mat_internal_diag_m_1,
39 external_rc,
40 internal_rc,
41 }
42 }
43}
44
45impl<F: ScalarField, const T: usize> Poseidon2State<F, T> {
46 pub fn new(state: [AssignedValue<F>; T]) -> Self {
47 Self { s: state }
48 }
49 pub fn permutation(
53 &mut self,
54 ctx: &mut Context<F>,
55 gate: &impl GateInstructions<F>,
56 params: &Poseidon2Params<F, T>,
57 ) {
58 let rounds_f_beginning = params.rounds_f / 2;
59
60 self.matmul_external(ctx, gate);
62 for r in 0..rounds_f_beginning {
63 self.add_rc(ctx, gate, params.external_rc[r]);
64 self.sbox(ctx, gate);
65 self.matmul_external(ctx, gate);
66 }
67
68 for r in 0..params.rounds_p {
69 self.s[0] = gate.add(ctx, self.s[0], Constant(params.internal_rc[r]));
70 self.s[0] = Self::x_power5(ctx, gate, self.s[0]);
71 self.matmul_internal(ctx, gate, params.mat_internal_diag_m_1);
72 }
73
74 for r in rounds_f_beginning..params.rounds_f {
75 self.add_rc(ctx, gate, params.external_rc[r]);
76 self.sbox(ctx, gate);
77 self.matmul_external(ctx, gate);
78 }
79 }
80
81 pub fn select(
83 &mut self,
84 ctx: &mut Context<F>,
85 gate: &impl GateInstructions<F>,
86 selector: SafeBool<F>,
87 set_to: &Self,
88 ) {
89 for i in 0..T {
90 self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref());
91 }
92 }
93
94 fn x_power5(
95 ctx: &mut Context<F>,
96 gate: &impl GateInstructions<F>,
97 x: AssignedValue<F>,
98 ) -> AssignedValue<F> {
99 let x2 = gate.mul(ctx, x, x);
100 let x4 = gate.mul(ctx, x2, x2);
101 gate.mul(ctx, x, x4)
102 }
103
104 fn sbox(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
105 for x in self.s.iter_mut() {
106 *x = Self::x_power5(ctx, gate, *x);
107 }
108 }
109
110 fn matmul_external(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
111 assert_eq!(T, 3);
113
114 let sum = gate.sum(ctx, self.s.iter().copied());
116 for (i, x) in self.s.iter_mut().enumerate() {
117 if i % 2 == 0 {
119 ctx.assign_region(
120 [
121 QuantumCell::Witness(*x.value() + sum.value()),
122 QuantumCell::Existing(*x),
123 QuantumCell::Constant(-F::ONE),
124 QuantumCell::Existing(sum),
125 ],
126 [0],
127 );
128 *x = ctx.get(-4);
129 } else {
130 ctx.assign_region(
131 [
132 QuantumCell::Existing(*x),
133 QuantumCell::Constant(F::ONE),
134 QuantumCell::Witness(*x.value() + sum.value()),
135 ],
136 [-1],
137 );
138 *x = ctx.get(-1);
139 }
140 }
141 }
142
143 fn add_rc(
144 &mut self,
145 ctx: &mut Context<F>,
146 gate: &impl GateInstructions<F>,
147 round_constants: [F; T],
148 ) {
149 for (x, rc) in self.s.iter_mut().zip(round_constants.iter()) {
150 *x = gate.add(ctx, *x, Constant(*rc));
151 }
152 }
153
154 fn matmul_internal(
155 &mut self,
156 ctx: &mut Context<F>,
157 gate: &impl GateInstructions<F>,
158 mat_internal_diag_m_1: [F; T],
159 ) {
160 assert_eq!(T, 3);
161 let sum = gate.sum(ctx, self.s.iter().copied());
162 for i in 0..T {
163 if i % 2 == 0 {
165 ctx.assign_region(
166 [
167 QuantumCell::Witness(
168 *self.s[i].value() * mat_internal_diag_m_1[i] + sum.value(),
169 ),
170 QuantumCell::Existing(self.s[i]),
171 QuantumCell::Constant(-mat_internal_diag_m_1[i]),
172 QuantumCell::Existing(sum),
173 ],
174 [0],
175 );
176 self.s[i] = ctx.get(-4);
177 } else {
178 ctx.assign_region(
179 [
180 QuantumCell::Existing(self.s[i]),
181 QuantumCell::Constant(mat_internal_diag_m_1[i]),
182 QuantumCell::Witness(
183 *self.s[i].value() * mat_internal_diag_m_1[i] + sum.value(),
184 ),
185 ],
186 [-1],
187 );
188 self.s[i] = ctx.get(-1);
189 }
190 }
191 }
192}