p3_keccak_air/
air.rs

1use alloc::vec::Vec;
2use core::borrow::Borrow;
3
4use p3_air::utils::{andn, xor, xor3};
5use p3_air::{Air, AirBuilder, BaseAir};
6use p3_field::{FieldAlgebra, PrimeField64};
7use p3_matrix::dense::RowMajorMatrix;
8use p3_matrix::Matrix;
9use rand::random;
10
11use crate::columns::{KeccakCols, NUM_KECCAK_COLS};
12use crate::constants::rc_value_bit;
13use crate::round_flags::eval_round_flags;
14use crate::{generate_trace_rows, BITS_PER_LIMB, NUM_ROUNDS, U64_LIMBS};
15
16/// Assumes the field size is at least 16 bits.
17#[derive(Debug)]
18pub struct KeccakAir {}
19
20impl KeccakAir {
21    pub fn generate_trace_rows<F: PrimeField64>(
22        &self,
23        num_hashes: usize,
24        extra_capacity_bits: usize,
25    ) -> RowMajorMatrix<F> {
26        let inputs = (0..num_hashes).map(|_| random()).collect::<Vec<_>>();
27        generate_trace_rows(inputs, extra_capacity_bits)
28    }
29}
30
31impl<F> BaseAir<F> for KeccakAir {
32    fn width(&self) -> usize {
33        NUM_KECCAK_COLS
34    }
35}
36
37impl<AB: AirBuilder> Air<AB> for KeccakAir {
38    #[inline]
39    fn eval(&self, builder: &mut AB) {
40        eval_round_flags(builder);
41
42        let main = builder.main();
43        let (local, next) = (main.row_slice(0), main.row_slice(1));
44        let local: &KeccakCols<AB::Var> = (*local).borrow();
45        let next: &KeccakCols<AB::Var> = (*next).borrow();
46
47        let first_step = local.step_flags[0];
48        let final_step = local.step_flags[NUM_ROUNDS - 1];
49        let not_final_step = AB::Expr::ONE - final_step;
50
51        // If this is the first step, the input A must match the preimage.
52        for y in 0..5 {
53            for x in 0..5 {
54                for limb in 0..U64_LIMBS {
55                    builder
56                        .when(first_step)
57                        .assert_eq(local.preimage[y][x][limb], local.a[y][x][limb]);
58                }
59            }
60        }
61
62        // If this is not the final step, the local and next preimages must match.
63        for y in 0..5 {
64            for x in 0..5 {
65                for limb in 0..U64_LIMBS {
66                    builder
67                        .when(not_final_step.clone())
68                        .when_transition()
69                        .assert_eq(local.preimage[y][x][limb], next.preimage[y][x][limb]);
70                }
71            }
72        }
73
74        // The export flag must be 0 or 1.
75        builder.assert_bool(local.export);
76
77        // If this is not the final step, the export flag must be off.
78        builder
79            .when(not_final_step.clone())
80            .assert_zero(local.export);
81
82        // C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
83        // Note that if all entries of C are boolean, the arithmetic generalization
84        // xor3 function only outputs 0, 1 and so this check also ensures that all
85        // entries of C'[x, z] are boolean.
86        for x in 0..5 {
87            for z in 0..64 {
88                // Check to ensure all entries of C are bools.
89                builder.assert_bool(local.c[x][z]);
90                let xor = xor3::<AB::Expr>(
91                    local.c[x][z].into(),
92                    local.c[(x + 4) % 5][z].into(),
93                    local.c[(x + 1) % 5][(z + 63) % 64].into(),
94                );
95                let c_prime = local.c_prime[x][z];
96                builder.assert_eq(c_prime, xor);
97            }
98        }
99
100        // Check that the input limbs are consistent with A' and D.
101        // A[x, y, z] = xor(A'[x, y, z], D[x, y, z])
102        //            = xor(A'[x, y, z], C[x - 1, z], C[x + 1, z - 1])
103        //            = xor(A'[x, y, z], C[x, z], C'[x, z]).
104        // The last step is valid based on the identity we checked above.
105        // It isn't required, but makes this check a bit cleaner.
106        // We also check that all entires of A' are bools.
107        // This has the side effect of also range checking the limbs of A.
108        for y in 0..5 {
109            for x in 0..5 {
110                let get_bit = |z| {
111                    let a_prime: AB::Var = local.a_prime[y][x][z];
112                    let c: AB::Var = local.c[x][z];
113                    let c_prime: AB::Var = local.c_prime[x][z];
114                    xor3::<AB::Expr>(a_prime.into(), c.into(), c_prime.into())
115                };
116
117                for limb in 0..U64_LIMBS {
118                    let a_limb = local.a[y][x][limb];
119                    let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
120                        .rev()
121                        .fold(AB::Expr::ZERO, |acc, z| {
122                            // Check to ensure all entries of A' are bools.
123                            builder.assert_bool(local.a_prime[y][x][z]);
124                            acc.double() + get_bit(z)
125                        });
126                    builder.assert_eq(computed_limb, a_limb);
127                }
128            }
129        }
130
131        // xor_{i=0}^4 A'[x, i, z] = C'[x, z], so for each x, z,
132        // diff * (diff - 2) * (diff - 4) = 0, where
133        // diff = sum_{i=0}^4 A'[x, i, z] - C'[x, z]
134        for x in 0..5 {
135            for z in 0..64 {
136                let sum: AB::Expr = (0..5).map(|y| local.a_prime[y][x][z].into()).sum();
137                let diff = sum - local.c_prime[x][z];
138                // This should be slightly faster than from_canonical_u8(4) for some fields.
139                let four = AB::Expr::TWO.double();
140                builder.assert_zero(diff.clone() * (diff.clone() - AB::Expr::TWO) * (diff - four));
141            }
142        }
143
144        // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
145        // As B is a rotation of A', all entries must be bools and so
146        // this check also range checks A''.
147        for y in 0..5 {
148            for x in 0..5 {
149                let get_bit = |z| {
150                    let andn = andn::<AB::Expr>(
151                        local.b((x + 1) % 5, y, z).into(),
152                        local.b((x + 2) % 5, y, z).into(),
153                    );
154                    xor::<AB::Expr>(local.b(x, y, z).into(), andn)
155                };
156
157                for limb in 0..U64_LIMBS {
158                    let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
159                        .rev()
160                        .fold(AB::Expr::ZERO, |acc, z| acc.double() + get_bit(z));
161                    builder.assert_eq(computed_limb, local.a_prime_prime[y][x][limb]);
162                }
163            }
164        }
165
166        // A'''[0, 0] = A''[0, 0] XOR RC
167        for limb in 0..U64_LIMBS {
168            let computed_a_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
169                ..(limb + 1) * BITS_PER_LIMB)
170                .rev()
171                .fold(AB::Expr::ZERO, |acc, z| {
172                    // Check to ensure the bits of A''[0, 0] are boolean.
173                    builder.assert_bool(local.a_prime_prime_0_0_bits[z]);
174                    acc.double() + local.a_prime_prime_0_0_bits[z]
175                });
176            let a_prime_prime_0_0_limb = local.a_prime_prime[0][0][limb];
177            builder.assert_eq(computed_a_prime_prime_0_0_limb, a_prime_prime_0_0_limb);
178        }
179
180        let get_xored_bit = |i| {
181            let mut rc_bit_i = AB::Expr::ZERO;
182            for r in 0..NUM_ROUNDS {
183                let this_round = local.step_flags[r];
184                let this_round_constant = AB::Expr::from_bool(rc_value_bit(r, i) != 0);
185                rc_bit_i += this_round * this_round_constant;
186            }
187
188            xor::<AB::Expr>(local.a_prime_prime_0_0_bits[i].into(), rc_bit_i)
189        };
190
191        for limb in 0..U64_LIMBS {
192            let a_prime_prime_prime_0_0_limb = local.a_prime_prime_prime_0_0_limbs[limb];
193            let computed_a_prime_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
194                ..(limb + 1) * BITS_PER_LIMB)
195                .rev()
196                .fold(AB::Expr::ZERO, |acc, z| acc.double() + get_xored_bit(z));
197            builder.assert_eq(
198                computed_a_prime_prime_prime_0_0_limb,
199                a_prime_prime_prime_0_0_limb,
200            );
201        }
202
203        // Enforce that this round's output equals the next round's input.
204        for x in 0..5 {
205            for y in 0..5 {
206                for limb in 0..U64_LIMBS {
207                    let output = local.a_prime_prime_prime(y, x, limb);
208                    let input = next.a[y][x][limb];
209                    builder
210                        .when_transition()
211                        .when(not_final_step.clone())
212                        .assert_eq(output, input);
213                }
214            }
215        }
216    }
217}