1use core::array;
2use core::borrow::Borrow;
3
4use p3_air::{Air, AirBuilder, BaseAir};
5use p3_field::{PrimeCharacteristicRing, PrimeField64};
6use p3_matrix::Matrix;
7use p3_matrix::dense::RowMajorMatrix;
8use rand::rngs::SmallRng;
9use rand::{Rng, SeedableRng};
10
11use crate::columns::{KeccakCols, NUM_KECCAK_COLS};
12use crate::constants::rc_value_bit;
13use crate::round_flags::eval_round_flags;
14use crate::{BITS_PER_LIMB, NUM_ROUNDS, NUM_ROUNDS_MIN_1, U64_LIMBS, generate_trace_rows};
15
16#[derive(Debug)]
18pub struct KeccakAir {}
19
20impl KeccakAir {
21 pub fn generate_trace_rows<F: PrimeField64>(
22 &self,
23 num_hashes: usize,
24 extra_capacity_bits: usize,
25 ) -> RowMajorMatrix<F> {
26 let mut rng = SmallRng::seed_from_u64(1);
27 let inputs = (0..num_hashes).map(|_| rng.random()).collect();
28 generate_trace_rows(inputs, extra_capacity_bits)
29 }
30}
31
32impl<F> BaseAir<F> for KeccakAir {
33 fn width(&self) -> usize {
34 NUM_KECCAK_COLS
35 }
36}
37
38impl<AB: AirBuilder> Air<AB> for KeccakAir {
39 #[inline]
40 fn eval(&self, builder: &mut AB) {
41 eval_round_flags(builder);
42
43 let main = builder.main();
44 let (local, next) = (
45 main.row_slice(0).expect("The matrix is empty?"),
46 main.row_slice(1).expect("The matrix only has 1 row?"),
47 );
48 let local: &KeccakCols<AB::Var> = (*local).borrow();
49 let next: &KeccakCols<AB::Var> = (*next).borrow();
50
51 let first_step = local.step_flags[0].clone();
52 let final_step = local.step_flags[NUM_ROUNDS_MIN_1].clone();
53 let not_final_step = AB::Expr::ONE - final_step;
54
55 for y in 0..5 {
57 for x in 0..5 {
58 builder
59 .when(first_step.clone())
60 .assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
61 local.preimage[y][x][limb].clone() - local.a[y][x][limb].clone()
62 }));
63 }
64 }
65
66 for y in 0..5 {
68 for x in 0..5 {
69 builder
70 .when(not_final_step.clone())
71 .when_transition()
72 .assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
73 local.preimage[y][x][limb].clone() - next.preimage[y][x][limb].clone()
74 }));
75 }
76 }
77
78 builder.assert_bool(local.export.clone());
80
81 builder
83 .when(not_final_step.clone())
84 .assert_zero(local.export.clone());
85
86 for x in 0..5 {
91 builder.assert_bools(local.c[x].clone());
92 builder.assert_zeros::<64, _>(array::from_fn(|z| {
93 let xor = local.c[x][z].clone().into().xor3(
94 &local.c[(x + 4) % 5][z].clone().into(),
95 &local.c[(x + 1) % 5][(z + 63) % 64].clone().into(),
96 );
97 local.c_prime[x][z].clone() - xor
98 }));
99 }
100
101 for y in 0..5 {
110 for x in 0..5 {
111 let get_bit = |z: usize| {
112 local.a_prime[y][x][z].clone().into().xor3(
113 &local.c[x][z].clone().into(),
114 &local.c_prime[x][z].clone().into(),
115 )
116 };
117
118 builder.assert_bools(local.a_prime[y][x].clone());
120
121 builder.assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
122 let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
123 .rev()
124 .fold(AB::Expr::ZERO, |acc, z| {
125 acc.double() + get_bit(z)
127 });
128 computed_limb - local.a[y][x][limb].clone()
129 }));
130 }
131 }
132
133 for x in 0..5 {
137 let four = AB::Expr::TWO.double();
138 builder.assert_zeros::<64, _>(array::from_fn(|z| {
139 let sum: AB::Expr = (0..5).map(|y| local.a_prime[y][x][z].clone().into()).sum();
140 let diff = sum - local.c_prime[x][z].clone();
141 diff.clone() * (diff.clone() - AB::Expr::TWO) * (diff - four.clone())
142 }));
143 }
144
145 for y in 0..5 {
149 for x in 0..5 {
150 let get_bit = |z| {
151 let andn = local
152 .b((x + 1) % 5, y, z)
153 .into()
154 .andn(&local.b((x + 2) % 5, y, z).into());
155 andn.xor(&local.b(x, y, z).into())
156 };
157 builder.assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
158 let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
159 .rev()
160 .fold(AB::Expr::ZERO, |acc, z| acc.double() + get_bit(z));
161 computed_limb - local.a_prime_prime[y][x][limb].clone()
162 }));
163 }
164 }
165
166 builder.assert_bools(local.a_prime_prime_0_0_bits.clone());
169 builder.assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
170 let computed_a_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
171 ..(limb + 1) * BITS_PER_LIMB)
172 .rev()
173 .fold(AB::Expr::ZERO, |acc, z| {
174 acc.double() + local.a_prime_prime_0_0_bits[z].clone()
175 });
176 computed_a_prime_prime_0_0_limb - local.a_prime_prime[0][0][limb].clone()
177 }));
178
179 let get_xored_bit = |i| {
180 let mut rc_bit_i = AB::Expr::ZERO;
181 for r in 0..NUM_ROUNDS {
182 let this_round = local.step_flags[r].clone();
183 let this_round_constant = AB::Expr::from_bool(rc_value_bit(r, i) != 0);
184 rc_bit_i += this_round * this_round_constant;
185 }
186
187 rc_bit_i.xor(&local.a_prime_prime_0_0_bits[i].clone().into())
188 };
189
190 builder.assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
191 let computed_a_prime_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
192 ..(limb + 1) * BITS_PER_LIMB)
193 .rev()
194 .fold(AB::Expr::ZERO, |acc, z| acc.double() + get_xored_bit(z));
195 computed_a_prime_prime_prime_0_0_limb
196 - local.a_prime_prime_prime_0_0_limbs[limb].clone()
197 }));
198
199 for x in 0..5 {
201 for y in 0..5 {
202 builder
203 .when_transition()
204 .when(not_final_step.clone())
205 .assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
206 local.a_prime_prime_prime(y, x, limb) - next.a[y][x][limb].clone()
207 }));
208 }
209 }
210 }
211}