openvm_native_compiler/constraints/halo2/
poseidon2_perm.rs1use snark_verifier_sdk::snark_verifier::halo2_base::{
5 gates::GateInstructions,
6 safe_types::SafeBool,
7 utils::ScalarField,
8 AssignedValue, Context,
9 QuantumCell::{self, Constant},
10};
11
12#[derive(Clone, Debug)]
13pub struct Poseidon2State<F: ScalarField, const T: usize> {
14 pub s: [AssignedValue<F>; T],
15}
16
17#[derive(Debug, Clone)]
18pub struct Poseidon2Params<F: ScalarField, const T: usize> {
19 pub rounds_f: usize,
21 pub rounds_p: usize,
22 pub mat_internal_diag_m_1: [F; T],
23 pub external_rc: Vec<[F; T]>,
24 pub internal_rc: Vec<F>,
25}
26
27impl<F: ScalarField, const T: usize> Poseidon2Params<F, T> {
28 pub fn new(
29 rounds_f: usize,
30 rounds_p: usize,
31 mat_internal_diag_m_1: [F; T],
32 external_rc: Vec<[F; T]>,
33 internal_rc: Vec<F>,
34 ) -> Self {
35 Self {
36 rounds_f,
37 rounds_p,
38 mat_internal_diag_m_1,
39 external_rc,
40 internal_rc,
41 }
42 }
43}
44
45impl<F: ScalarField, const T: usize> Poseidon2State<F, T> {
46 pub fn new(state: [AssignedValue<F>; T]) -> Self {
47 Self { s: state }
48 }
49 pub fn permutation(
53 &mut self,
54 ctx: &mut Context<F>,
55 gate: &impl GateInstructions<F>,
56 params: &Poseidon2Params<F, T>,
57 ) {
58 let rounds_f_beginning = params.rounds_f / 2;
59
60 self.matmul_external(ctx, gate);
62 for r in 0..rounds_f_beginning {
63 self.add_rc(ctx, gate, params.external_rc[r]);
64 self.sbox(ctx, gate);
65 self.matmul_external(ctx, gate);
66 }
67
68 for r in 0..params.rounds_p {
69 self.s[0] = gate.add(ctx, self.s[0], Constant(params.internal_rc[r]));
70 self.s[0] = Self::x_power5(ctx, gate, self.s[0]);
71 self.matmul_internal(ctx, gate, params.mat_internal_diag_m_1);
72 }
73
74 for r in rounds_f_beginning..params.rounds_f {
75 self.add_rc(ctx, gate, params.external_rc[r]);
76 self.sbox(ctx, gate);
77 self.matmul_external(ctx, gate);
78 }
79 }
80
81 pub fn select(
83 &mut self,
84 ctx: &mut Context<F>,
85 gate: &impl GateInstructions<F>,
86 selector: SafeBool<F>,
87 set_to: &Self,
88 ) {
89 for i in 0..T {
90 self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref());
91 }
92 }
93
94 fn x_power5(
95 ctx: &mut Context<F>,
96 gate: &impl GateInstructions<F>,
97 x: AssignedValue<F>,
98 ) -> AssignedValue<F> {
99 let x2 = gate.mul(ctx, x, x);
100 let x4 = gate.mul(ctx, x2, x2);
101 gate.mul(ctx, x, x4)
102 }
103
104 fn sbox(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
105 for x in self.s.iter_mut() {
106 *x = Self::x_power5(ctx, gate, *x);
107 }
108 }
109
110 fn matmul_external(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
111 assert_eq!(T, 3);
113
114 let sum = gate.sum(ctx, self.s.iter().copied());
116 for (i, x) in self.s.iter_mut().enumerate() {
117 if i % 2 == 0 {
120 ctx.assign_region(
121 [
122 QuantumCell::Witness(*x.value() + sum.value()),
123 QuantumCell::Existing(*x),
124 QuantumCell::Constant(-F::ONE),
125 QuantumCell::Existing(sum),
126 ],
127 [0],
128 );
129 *x = ctx.get(-4);
130 } else {
131 ctx.assign_region(
132 [
133 QuantumCell::Existing(*x),
134 QuantumCell::Constant(F::ONE),
135 QuantumCell::Witness(*x.value() + sum.value()),
136 ],
137 [-1],
138 );
139 *x = ctx.get(-1);
140 }
141 }
142 }
143
144 fn add_rc(
145 &mut self,
146 ctx: &mut Context<F>,
147 gate: &impl GateInstructions<F>,
148 round_constants: [F; T],
149 ) {
150 for (x, rc) in self.s.iter_mut().zip(round_constants.iter()) {
151 *x = gate.add(ctx, *x, Constant(*rc));
152 }
153 }
154
155 fn matmul_internal(
156 &mut self,
157 ctx: &mut Context<F>,
158 gate: &impl GateInstructions<F>,
159 mat_internal_diag_m_1: [F; T],
160 ) {
161 assert_eq!(T, 3);
162 let sum = gate.sum(ctx, self.s.iter().copied());
163 for i in 0..T {
164 if i % 2 == 0 {
167 ctx.assign_region(
168 [
169 QuantumCell::Witness(
170 *self.s[i].value() * mat_internal_diag_m_1[i] + sum.value(),
171 ),
172 QuantumCell::Existing(self.s[i]),
173 QuantumCell::Constant(-mat_internal_diag_m_1[i]),
174 QuantumCell::Existing(sum),
175 ],
176 [0],
177 );
178 self.s[i] = ctx.get(-4);
179 } else {
180 ctx.assign_region(
181 [
182 QuantumCell::Existing(self.s[i]),
183 QuantumCell::Constant(mat_internal_diag_m_1[i]),
184 QuantumCell::Witness(
185 *self.s[i].value() * mat_internal_diag_m_1[i] + sum.value(),
186 ),
187 ],
188 [-1],
189 );
190 self.s[i] = ctx.get(-1);
191 }
192 }
193 }
194}