p3_keccak_air/
air.rs
1use alloc::vec::Vec;
2use core::borrow::Borrow;
3
4use p3_air::utils::{andn, xor, xor3};
5use p3_air::{Air, AirBuilder, BaseAir};
6use p3_field::{FieldAlgebra, PrimeField64};
7use p3_matrix::dense::RowMajorMatrix;
8use p3_matrix::Matrix;
9use rand::random;
10
11use crate::columns::{KeccakCols, NUM_KECCAK_COLS};
12use crate::constants::rc_value_bit;
13use crate::round_flags::eval_round_flags;
14use crate::{generate_trace_rows, BITS_PER_LIMB, NUM_ROUNDS, U64_LIMBS};
15
16#[derive(Debug)]
18pub struct KeccakAir {}
19
20impl KeccakAir {
21 pub fn generate_trace_rows<F: PrimeField64>(
22 &self,
23 num_hashes: usize,
24 extra_capacity_bits: usize,
25 ) -> RowMajorMatrix<F> {
26 let inputs = (0..num_hashes).map(|_| random()).collect::<Vec<_>>();
27 generate_trace_rows(inputs, extra_capacity_bits)
28 }
29}
30
31impl<F> BaseAir<F> for KeccakAir {
32 fn width(&self) -> usize {
33 NUM_KECCAK_COLS
34 }
35}
36
37impl<AB: AirBuilder> Air<AB> for KeccakAir {
38 #[inline]
39 fn eval(&self, builder: &mut AB) {
40 eval_round_flags(builder);
41
42 let main = builder.main();
43 let (local, next) = (main.row_slice(0), main.row_slice(1));
44 let local: &KeccakCols<AB::Var> = (*local).borrow();
45 let next: &KeccakCols<AB::Var> = (*next).borrow();
46
47 let first_step = local.step_flags[0];
48 let final_step = local.step_flags[NUM_ROUNDS - 1];
49 let not_final_step = AB::Expr::ONE - final_step;
50
51 for y in 0..5 {
53 for x in 0..5 {
54 for limb in 0..U64_LIMBS {
55 builder
56 .when(first_step)
57 .assert_eq(local.preimage[y][x][limb], local.a[y][x][limb]);
58 }
59 }
60 }
61
62 for y in 0..5 {
64 for x in 0..5 {
65 for limb in 0..U64_LIMBS {
66 builder
67 .when(not_final_step.clone())
68 .when_transition()
69 .assert_eq(local.preimage[y][x][limb], next.preimage[y][x][limb]);
70 }
71 }
72 }
73
74 builder.assert_bool(local.export);
76
77 builder
79 .when(not_final_step.clone())
80 .assert_zero(local.export);
81
82 for x in 0..5 {
87 for z in 0..64 {
88 builder.assert_bool(local.c[x][z]);
90 let xor = xor3::<AB::Expr>(
91 local.c[x][z].into(),
92 local.c[(x + 4) % 5][z].into(),
93 local.c[(x + 1) % 5][(z + 63) % 64].into(),
94 );
95 let c_prime = local.c_prime[x][z];
96 builder.assert_eq(c_prime, xor);
97 }
98 }
99
100 for y in 0..5 {
109 for x in 0..5 {
110 let get_bit = |z| {
111 let a_prime: AB::Var = local.a_prime[y][x][z];
112 let c: AB::Var = local.c[x][z];
113 let c_prime: AB::Var = local.c_prime[x][z];
114 xor3::<AB::Expr>(a_prime.into(), c.into(), c_prime.into())
115 };
116
117 for limb in 0..U64_LIMBS {
118 let a_limb = local.a[y][x][limb];
119 let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
120 .rev()
121 .fold(AB::Expr::ZERO, |acc, z| {
122 builder.assert_bool(local.a_prime[y][x][z]);
124 acc.double() + get_bit(z)
125 });
126 builder.assert_eq(computed_limb, a_limb);
127 }
128 }
129 }
130
131 for x in 0..5 {
135 for z in 0..64 {
136 let sum: AB::Expr = (0..5).map(|y| local.a_prime[y][x][z].into()).sum();
137 let diff = sum - local.c_prime[x][z];
138 let four = AB::Expr::TWO.double();
140 builder.assert_zero(diff.clone() * (diff.clone() - AB::Expr::TWO) * (diff - four));
141 }
142 }
143
144 for y in 0..5 {
148 for x in 0..5 {
149 let get_bit = |z| {
150 let andn = andn::<AB::Expr>(
151 local.b((x + 1) % 5, y, z).into(),
152 local.b((x + 2) % 5, y, z).into(),
153 );
154 xor::<AB::Expr>(local.b(x, y, z).into(), andn)
155 };
156
157 for limb in 0..U64_LIMBS {
158 let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
159 .rev()
160 .fold(AB::Expr::ZERO, |acc, z| acc.double() + get_bit(z));
161 builder.assert_eq(computed_limb, local.a_prime_prime[y][x][limb]);
162 }
163 }
164 }
165
166 for limb in 0..U64_LIMBS {
168 let computed_a_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
169 ..(limb + 1) * BITS_PER_LIMB)
170 .rev()
171 .fold(AB::Expr::ZERO, |acc, z| {
172 builder.assert_bool(local.a_prime_prime_0_0_bits[z]);
174 acc.double() + local.a_prime_prime_0_0_bits[z]
175 });
176 let a_prime_prime_0_0_limb = local.a_prime_prime[0][0][limb];
177 builder.assert_eq(computed_a_prime_prime_0_0_limb, a_prime_prime_0_0_limb);
178 }
179
180 let get_xored_bit = |i| {
181 let mut rc_bit_i = AB::Expr::ZERO;
182 for r in 0..NUM_ROUNDS {
183 let this_round = local.step_flags[r];
184 let this_round_constant = AB::Expr::from_bool(rc_value_bit(r, i) != 0);
185 rc_bit_i += this_round * this_round_constant;
186 }
187
188 xor::<AB::Expr>(local.a_prime_prime_0_0_bits[i].into(), rc_bit_i)
189 };
190
191 for limb in 0..U64_LIMBS {
192 let a_prime_prime_prime_0_0_limb = local.a_prime_prime_prime_0_0_limbs[limb];
193 let computed_a_prime_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
194 ..(limb + 1) * BITS_PER_LIMB)
195 .rev()
196 .fold(AB::Expr::ZERO, |acc, z| acc.double() + get_xored_bit(z));
197 builder.assert_eq(
198 computed_a_prime_prime_prime_0_0_limb,
199 a_prime_prime_prime_0_0_limb,
200 );
201 }
202
203 for x in 0..5 {
205 for y in 0..5 {
206 for limb in 0..U64_LIMBS {
207 let output = local.a_prime_prime_prime(y, x, limb);
208 let input = next.a[y][x][limb];
209 builder
210 .when_transition()
211 .when(not_final_step.clone())
212 .assert_eq(output, input);
213 }
214 }
215 }
216 }
217}