p3_keccak_air/
air.rs

1use core::array;
2use core::borrow::Borrow;
3
4use p3_air::{Air, AirBuilder, BaseAir};
5use p3_field::{PrimeCharacteristicRing, PrimeField64};
6use p3_matrix::Matrix;
7use p3_matrix::dense::RowMajorMatrix;
8use rand::rngs::SmallRng;
9use rand::{Rng, SeedableRng};
10
11use crate::columns::{KeccakCols, NUM_KECCAK_COLS};
12use crate::constants::rc_value_bit;
13use crate::round_flags::eval_round_flags;
14use crate::{BITS_PER_LIMB, NUM_ROUNDS, NUM_ROUNDS_MIN_1, U64_LIMBS, generate_trace_rows};
15
16/// Assumes the field size is at least 16 bits.
17#[derive(Debug)]
18pub struct KeccakAir {}
19
20impl KeccakAir {
21    pub fn generate_trace_rows<F: PrimeField64>(
22        &self,
23        num_hashes: usize,
24        extra_capacity_bits: usize,
25    ) -> RowMajorMatrix<F> {
26        let mut rng = SmallRng::seed_from_u64(1);
27        let inputs = (0..num_hashes).map(|_| rng.random()).collect();
28        generate_trace_rows(inputs, extra_capacity_bits)
29    }
30}
31
32impl<F> BaseAir<F> for KeccakAir {
33    fn width(&self) -> usize {
34        NUM_KECCAK_COLS
35    }
36}
37
38impl<AB: AirBuilder> Air<AB> for KeccakAir {
39    #[inline]
40    fn eval(&self, builder: &mut AB) {
41        eval_round_flags(builder);
42
43        let main = builder.main();
44        let (local, next) = (
45            main.row_slice(0).expect("The matrix is empty?"),
46            main.row_slice(1).expect("The matrix only has 1 row?"),
47        );
48        let local: &KeccakCols<AB::Var> = (*local).borrow();
49        let next: &KeccakCols<AB::Var> = (*next).borrow();
50
51        let first_step = local.step_flags[0].clone();
52        let final_step = local.step_flags[NUM_ROUNDS_MIN_1].clone();
53        let not_final_step = AB::Expr::ONE - final_step;
54
55        // If this is the first step, the input A must match the preimage.
56        for y in 0..5 {
57            for x in 0..5 {
58                builder
59                    .when(first_step.clone())
60                    .assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
61                        local.preimage[y][x][limb].clone() - local.a[y][x][limb].clone()
62                    }));
63            }
64        }
65
66        // If this is not the final step, the local and next preimages must match.
67        for y in 0..5 {
68            for x in 0..5 {
69                builder
70                    .when(not_final_step.clone())
71                    .when_transition()
72                    .assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
73                        local.preimage[y][x][limb].clone() - next.preimage[y][x][limb].clone()
74                    }));
75            }
76        }
77
78        // The export flag must be 0 or 1.
79        builder.assert_bool(local.export.clone());
80
81        // If this is not the final step, the export flag must be off.
82        builder
83            .when(not_final_step.clone())
84            .assert_zero(local.export.clone());
85
86        // C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
87        // Note that if all entries of C are boolean, the arithmetic generalization
88        // xor3 function only outputs 0, 1 and so this check also ensures that all
89        // entries of C'[x, z] are boolean.
90        for x in 0..5 {
91            builder.assert_bools(local.c[x].clone());
92            builder.assert_zeros::<64, _>(array::from_fn(|z| {
93                let xor = local.c[x][z].clone().into().xor3(
94                    &local.c[(x + 4) % 5][z].clone().into(),
95                    &local.c[(x + 1) % 5][(z + 63) % 64].clone().into(),
96                );
97                local.c_prime[x][z].clone() - xor
98            }));
99        }
100
101        // Check that the input limbs are consistent with A' and D.
102        // A[x, y, z] = xor(A'[x, y, z], D[x, y, z])
103        //            = xor(A'[x, y, z], C[x - 1, z], C[x + 1, z - 1])
104        //            = xor(A'[x, y, z], C[x, z], C'[x, z]).
105        // The last step is valid based on the identity we checked above.
106        // It isn't required, but makes this check a bit cleaner.
107        // We also check that all entries of A' are bools.
108        // This has the side effect of also range checking the limbs of A.
109        for y in 0..5 {
110            for x in 0..5 {
111                let get_bit = |z: usize| {
112                    local.a_prime[y][x][z].clone().into().xor3(
113                        &local.c[x][z].clone().into(),
114                        &local.c_prime[x][z].clone().into(),
115                    )
116                };
117
118                // Check that all entries of A'[y][x] are boolean.
119                builder.assert_bools(local.a_prime[y][x].clone());
120
121                builder.assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
122                    let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
123                        .rev()
124                        .fold(AB::Expr::ZERO, |acc, z| {
125                            // Check to ensure all entries of A' are bools.
126                            acc.double() + get_bit(z)
127                        });
128                    computed_limb - local.a[y][x][limb].clone()
129                }));
130            }
131        }
132
133        // xor_{i=0}^4 A'[x, i, z] = C'[x, z], so for each x, z,
134        // diff * (diff - 2) * (diff - 4) = 0, where
135        // diff = sum_{i=0}^4 A'[x, i, z] - C'[x, z]
136        for x in 0..5 {
137            let four = AB::Expr::TWO.double();
138            builder.assert_zeros::<64, _>(array::from_fn(|z| {
139                let sum: AB::Expr = (0..5).map(|y| local.a_prime[y][x][z].clone().into()).sum();
140                let diff = sum - local.c_prime[x][z].clone();
141                diff.clone() * (diff.clone() - AB::Expr::TWO) * (diff - four.clone())
142            }));
143        }
144
145        // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
146        // As B is a rotation of A', all entries must be bools and so
147        // this check also range checks A''.
148        for y in 0..5 {
149            for x in 0..5 {
150                let get_bit = |z| {
151                    let andn = local
152                        .b((x + 1) % 5, y, z)
153                        .into()
154                        .andn(&local.b((x + 2) % 5, y, z).into());
155                    andn.xor(&local.b(x, y, z).into())
156                };
157                builder.assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
158                    let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
159                        .rev()
160                        .fold(AB::Expr::ZERO, |acc, z| acc.double() + get_bit(z));
161                    computed_limb - local.a_prime_prime[y][x][limb].clone()
162                }));
163            }
164        }
165
166        // A'''[0, 0] = A''[0, 0] XOR RC
167        // Check to ensure the bits of A''[0, 0] are boolean.
168        builder.assert_bools(local.a_prime_prime_0_0_bits.clone());
169        builder.assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
170            let computed_a_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
171                ..(limb + 1) * BITS_PER_LIMB)
172                .rev()
173                .fold(AB::Expr::ZERO, |acc, z| {
174                    acc.double() + local.a_prime_prime_0_0_bits[z].clone()
175                });
176            computed_a_prime_prime_0_0_limb - local.a_prime_prime[0][0][limb].clone()
177        }));
178
179        let get_xored_bit = |i| {
180            let mut rc_bit_i = AB::Expr::ZERO;
181            for r in 0..NUM_ROUNDS {
182                let this_round = local.step_flags[r].clone();
183                let this_round_constant = AB::Expr::from_bool(rc_value_bit(r, i) != 0);
184                rc_bit_i += this_round * this_round_constant;
185            }
186
187            rc_bit_i.xor(&local.a_prime_prime_0_0_bits[i].clone().into())
188        };
189
190        builder.assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
191            let computed_a_prime_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
192                ..(limb + 1) * BITS_PER_LIMB)
193                .rev()
194                .fold(AB::Expr::ZERO, |acc, z| acc.double() + get_xored_bit(z));
195            computed_a_prime_prime_prime_0_0_limb
196                - local.a_prime_prime_prime_0_0_limbs[limb].clone()
197        }));
198
199        // Enforce that this round's output equals the next round's input.
200        for x in 0..5 {
201            for y in 0..5 {
202                builder
203                    .when_transition()
204                    .when(not_final_step.clone())
205                    .assert_zeros::<U64_LIMBS, _>(array::from_fn(|limb| {
206                        local.a_prime_prime_prime(y, x, limb) - next.a[y][x][limb].clone()
207                    }));
208            }
209        }
210    }
211}