p3_air/
utils.rs

1//! A collection of utility functions helpful in defining AIRs.
2
3use core::array;
4
5use p3_field::integers::QuotientMap;
6use p3_field::{Field, PrimeCharacteristicRing};
7
8use crate::AirBuilder;
9
10/// Pack a collection of bits into a number.
11///
12/// Given `vec = [v_0, v_1, ..., v_n]` returns `v_0 + 2v_1 + ... + 2^n v_n`
13#[inline]
14pub fn pack_bits_le<R, Var, I>(iter: I) -> R
15where
16    R: PrimeCharacteristicRing,
17    Var: Into<R> + Clone,
18    I: DoubleEndedIterator<Item = Var>,
19{
20    iter.rev()
21        .map(Into::into)
22        .reduce(|acc, elem| acc.double() + elem)
23        .unwrap_or(R::ZERO)
24}
25
26/// Compute `xor` on a list of boolean field elements.
27///
28/// Verifies at debug time that all inputs are boolean.
29#[inline(always)]
30pub fn checked_xor<F: Field, const N: usize>(xs: &[F]) -> F {
31    xs.iter().fold(F::ZERO, |acc, x| {
32        debug_assert!(x.is_zero() || x.is_one());
33        acc.xor(x)
34    })
35}
36
37/// Compute `andnot` on a pair of boolean field elements.
38///
39/// Verifies at debug time that both inputs are boolean.
40#[inline(always)]
41pub fn checked_andn<F: Field>(x: F, y: F) -> F {
42    debug_assert!(x.is_zero() || x.is_one());
43    debug_assert!(y.is_zero() || y.is_one());
44    x.andn(&y)
45}
46
47/// Convert a 32-bit integer into an array of 32 0 or 1 field elements.
48///
49/// The output array is in little-endian order.
50#[inline]
51pub fn u32_to_bits_le<R: PrimeCharacteristicRing>(val: u32) -> [R; 32] {
52    array::from_fn(|i| R::from_bool(val & (1 << i) != 0))
53}
54
55/// Convert a 64-bit integer into an array of 64 0 or 1 field elements.
56///
57/// The output array is in little-endian order.
58#[inline]
59pub fn u64_to_bits_le<R: PrimeCharacteristicRing>(val: u64) -> [R; 64] {
60    array::from_fn(|i| R::from_bool(val & (1 << i) != 0))
61}
62
63/// Convert a 64-bit integer into an array of four field elements representing the 16 bit limb decomposition.
64///
65/// The output array is in little-endian order.
66#[inline]
67pub fn u64_to_16_bit_limbs<R: PrimeCharacteristicRing>(val: u64) -> [R; 4] {
68    array::from_fn(|i| R::from_u16((val >> (16 * i)) as u16))
69}
70
71/// Verify that `a = b + c + d mod 2^32`
72///
73/// We assume that a, b, c, d are all given as `2, 16` bit limbs (e.g. `a = a[0] + 2^16 a[1]`) and
74/// each `16` bit limb has been range checked to ensure it contains a value in `[0, 2^16)`.
75///
76/// This function assumes we are working over a field with characteristic `P > 3*2^16`.
77///
78/// # Panics
79///
80/// The function will panic if the characteristic of the field is less than or equal to 2^16.
81#[inline]
82pub fn add3<AB: AirBuilder>(
83    builder: &mut AB,
84    a: &[AB::Var; 2],
85    b: &[AB::Var; 2],
86    c: &[AB::Expr; 2],
87    d: &[AB::Expr; 2],
88) {
89    // Define:
90    //  acc    = a - b - c - d (mod P)
91    //  acc_16 = a[0] - b[0] - c[0] - d[0] (mod P)
92    //
93    // We perform 2 checks:
94    //
95    // (1) acc*(acc + 2^32)*(acc + 2*2^32) = 0.
96    // (2) acc_16*(acc_16 + 2^16)*(acc_16 + 2*2^16) = 0.
97    //
98    // We give a short proof for why this lets us conclude that a = b + c + d mod 2^32:
99    //
100    // As all 16 bit limbs have been range checked, we know that a, b, c, d lie in [0, 2^32) and hence
101    // a = b + c + d mod 2^32 if and only if, over the integers, a - b - c - d = 0, -2^32 or -2*2^32.
102    //
103    // Equation (1) verifies that a - b - c - d mod P = 0, -2^32 or -2*2^32.
104    //
105    // Field overflow cannot occur when computing acc_16 as our characteristic is larger than 3*2^16.
106    // Hence, equation (2) verifies that, over the integers, a[0] - b[0] - c[0] - d[0] = 0, -2^16 or -2*2^16.
107    // Either way we can immediately conclude that a - b - c - d = 0 mod 2^16.
108    //
109    // Now we can use the chinese remainder theorem to combine these results to conclude that
110    // a - b - c - d mod 2^16P = 0, -2^32 or -2*2^32.
111    //
112    // No overflow can occur mod 2^16 P as 2^16 P > 3*2^32 and a, b, c, d < 2^32. Hence we conclude that
113    // over the integers a - b - c - d = 0, -2^32 or -2*2^32 which implies a = b + c + d mod 2^32.
114
115    // By assumption P > 3*2^16 so 1 << 16 will be less than P. We use the checked version just to be safe.
116    // The compiler should optimize it away.
117    let two_16 =
118        <AB::Expr as PrimeCharacteristicRing>::PrimeSubfield::from_canonical_checked(1 << 16)
119            .unwrap();
120    let two_32 = two_16.square();
121
122    let acc_16 = a[0].clone() - b[0].clone() - c[0].clone() - d[0].clone();
123    let acc_32 = a[1].clone() - b[1].clone() - c[1].clone() - d[1].clone();
124    let acc = acc_16.clone() + acc_32.mul_2exp_u64(16);
125
126    builder.assert_zeros([
127        acc.clone()
128            * (acc.clone() + AB::Expr::from_prime_subfield(two_32))
129            * (acc + AB::Expr::from_prime_subfield(two_32.double())),
130        acc_16.clone()
131            * (acc_16.clone() + AB::Expr::from_prime_subfield(two_16))
132            * (acc_16 + AB::Expr::from_prime_subfield(two_16.double())),
133    ]);
134}
135
136/// Verify that `a = b + c mod 2^32`
137///
138/// We assume that a, b, c are all given as `2, 16` bit limbs (e.g. `a = a[0] + 2^16 a[1]`) and
139/// each `16` bit limb has been range checked to ensure it contains a value in `[0, 2^16)`.
140///
141/// This function assumes we are working over a field with characteristic `P > 2^17`.
142///
143/// # Panics
144///
145/// The function will panic if the characteristic of the field is less than or equal to 2^16.
146#[inline]
147pub fn add2<AB: AirBuilder>(
148    builder: &mut AB,
149    a: &[AB::Var; 2],
150    b: &[AB::Var; 2],
151    c: &[AB::Expr; 2],
152) {
153    // Define:
154    //  acc    = a - b - c (mod P)
155    //  acc_16 = a[0] - b[0] - c[0] (mod P)
156    //
157    // We perform 2 checks:
158    //
159    // (1) acc*(acc + 2^32) = 0.
160    // (2) acc_16*(acc_16 + 2^16) = 0.
161    //
162    // We give a short proof for why this lets us conclude that a = b + c mod 2^32:
163    //
164    // As all 16 bit limbs have been range checked, we know that a, b, c lie in [0, 2^32) and hence
165    // a = b + c mod 2^32 if and only if, over the integers, a - b - c = 0 or -2^32.
166    //
167    // Equation (1) verifies that either a - b - c = 0 mod P or a - b - c = -2^32 mod P.
168    //
169    // Field overflow cannot occur when computing acc_16 as our characteristic is larger than 2^17.
170    // Hence, equation (2) verifies that, over the integers, a[0] - b[0] - c[0] = 0 or -2^16.
171    // Either way we can immediately conclude that a - b - c = 0 mod 2^16.
172    //
173    // Now we can use the chinese remainder theorem to combine these results to conclude that
174    // either a - b - c = 0 mod 2^16 P or a - b - c = -2^32 mod 2^16 P.
175    //
176    // No overflow can occur mod 2^16 P as 2^16 P > 2^33 and a, b, c < 2^32. Hence we conclude that
177    // over the integers a - b - c = 0 or a - b - c = -2^32 which is equivalent to a = b + c mod 2^32.
178
179    // By assumption P > 2^17 so 1 << 16 will be less than P. We use the checked version just to be safe.
180    // The compiler should optimize it away.
181    let two_16 =
182        <AB::Expr as PrimeCharacteristicRing>::PrimeSubfield::from_canonical_checked(1 << 16)
183            .unwrap();
184    let two_32 = two_16.square();
185
186    let acc_16 = a[0].clone() - b[0].clone() - c[0].clone();
187    let acc_32 = a[1].clone() - b[1].clone() - c[1].clone();
188    let acc = acc_16.clone() + acc_32.mul_2exp_u64(16);
189
190    builder.assert_zeros([
191        acc.clone() * (acc + AB::Expr::from_prime_subfield(two_32)),
192        acc_16.clone() * (acc_16 + AB::Expr::from_prime_subfield(two_16)),
193    ]);
194}
195
196/// Verify that `a = (b ^ (c << shift))`
197///
198/// We assume that a is given as `2 16` bit limbs and both b and c are unpacked into 32 individual bits.
199/// We assume that the bits of b have been range checked but not the inputs in c or a. Both of these are
200/// range checked as part of this function.
201#[inline]
202pub fn xor_32_shift<AB: AirBuilder>(
203    builder: &mut AB,
204    a: &[AB::Var; 2],
205    b: &[AB::Var; 32],
206    c: &[AB::Var; 32],
207    shift: usize,
208) {
209    // First we range check all elements of c.
210    builder.assert_bools(c.clone());
211
212    // Next we compute (b ^ (c << shift)) and pack the result into two 16-bit integers.
213    let xor_shift_c_0_16 = b[..16].iter().enumerate().map(|(i, elem)| {
214        (elem.clone())
215            .into()
216            .xor(&c[(32 + i - shift) % 32].clone().into())
217    });
218    let sum_0_16: AB::Expr = pack_bits_le(xor_shift_c_0_16);
219
220    let xor_shift_c_16_32 = b[16..].iter().enumerate().map(|(i, elem)| {
221        (elem.clone())
222            .into()
223            .xor(&c[(32 + (i + 16) - shift) % 32].clone().into())
224    });
225    let sum_16_32: AB::Expr = pack_bits_le(xor_shift_c_16_32);
226
227    // As both b and c have been range checked to be boolean, all the (b ^ (c << shift))
228    // are also boolean and so this final check additionally has the effect of range checking a[0], a[1].
229    builder.assert_zeros([a[0].clone() - sum_0_16, a[1].clone() - sum_16_32]);
230}
231
232#[cfg(test)]
233mod tests {
234    use alloc::vec;
235
236    use p3_baby_bear::BabyBear;
237
238    use super::*;
239
240    type F = BabyBear;
241
242    #[test]
243    fn test_pack_bits_le_various_patterns() {
244        // Pattern: [1, 0, 1] as little-endian => 1 + 2*0 + 4*1 = 5
245        let bits = [F::ONE, F::ZERO, F::ONE];
246        let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
247        assert_eq!(packed, F::from_u8(5));
248
249        // Pattern: [1, 1, 0, 1] => 1 + 2*1 + 4*0 + 8*1 = 1 + 2 + 8 = 11
250        let bits = [F::ONE, F::ONE, F::ZERO, F::ONE];
251        let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
252        assert_eq!(packed, F::from_u8(11));
253
254        // Pattern: all zeros
255        let bits = [F::ZERO; 5];
256        let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
257        assert_eq!(packed, F::ZERO);
258
259        // Pattern: single one at the highest place
260        let bits = [F::ZERO, F::ZERO, F::ZERO, F::ZERO, F::ONE];
261        let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
262        assert_eq!(packed, F::from_u8(16));
263    }
264
265    #[test]
266    fn test_checked_xor_multiple_cases() {
267        // Input: [1, 0, 1] => XOR(1 ^ 0 ^ 1) = 0
268        let bits = vec![F::ONE, F::ZERO, F::ONE];
269        let result = checked_xor::<F, 3>(&bits);
270        assert_eq!(result, F::ZERO);
271
272        // [1, 1, 1] => XOR = 1 ^ 1 ^ 1 = 1
273        let bits = vec![F::ONE, F::ONE, F::ONE];
274        let result = checked_xor::<F, 3>(&bits);
275        assert_eq!(result, F::ONE);
276
277        // [0, 0, 0] => XOR = 0
278        let bits = vec![F::ZERO, F::ZERO, F::ZERO];
279        let result = checked_xor::<F, 3>(&bits);
280        assert_eq!(result, F::ZERO);
281
282        // [1, 0, 1, 0] => XOR = 1 ^ 0 ^ 1 ^ 0 = 0
283        let bits = vec![F::ONE, F::ZERO, F::ONE, F::ZERO];
284        let result = checked_xor::<F, 4>(&bits);
285        assert_eq!(result, F::ZERO);
286    }
287
288    #[test]
289    fn test_checked_andn() {
290        // x = 1, y = 0 => 1 & !0 = 0
291        let result = checked_andn(F::ONE, F::ZERO);
292        assert_eq!(result, F::ZERO);
293
294        // x = 0, y = 1 => 0 & !1 = 1
295        let result = checked_andn(F::ZERO, F::ONE);
296        assert_eq!(result, F::ONE);
297
298        // x = 0, y = 0 => 0 & !0 = 0
299        let result = checked_andn(F::ZERO, F::ZERO);
300        assert_eq!(result, F::ZERO);
301
302        // x = 1, y = 1 => 1 & !1 = 0
303        let result = checked_andn(F::ONE, F::ONE);
304        assert_eq!(result, F::ZERO);
305    }
306
307    #[test]
308    fn test_u32_to_bits_le() {
309        // Convert 0b1010 (decimal 10) => [0, 1, 0, 1, ...]
310        let bits = u32_to_bits_le::<F>(10);
311        assert_eq!(bits[0], F::ZERO); // LSB first
312        assert_eq!(bits[1], F::ONE);
313        assert_eq!(bits[2], F::ZERO);
314        assert_eq!(bits[3], F::ONE);
315
316        for &bit in &bits[4..] {
317            assert_eq!(bit, F::ZERO);
318        }
319
320        // Check 0 => all zeros
321        let bits = u32_to_bits_le::<F>(0);
322        assert!(bits.iter().all(|b| *b == F::ZERO));
323
324        // Check max => all ones
325        let bits = u32_to_bits_le::<F>(u32::MAX);
326        assert!(bits.iter().all(|b| *b == F::ONE));
327    }
328
329    #[test]
330    fn test_u64_to_bits_le() {
331        // Convert 0b11 (decimal 3) => [1, 1, 0, ...]
332        let bits = u64_to_bits_le::<F>(3);
333        assert_eq!(bits[0], F::ONE);
334        assert_eq!(bits[1], F::ONE);
335        assert_eq!(bits[2], F::ZERO);
336
337        for &bit in &bits[3..] {
338            assert_eq!(bit, F::ZERO);
339        }
340
341        // Check 0 => all zeros
342        let bits = u64_to_bits_le::<F>(0);
343        assert!(bits.iter().all(|b| *b == F::ZERO));
344
345        // Check max => all ones
346        let bits = u64_to_bits_le::<F>(u64::MAX);
347        assert!(bits.iter().all(|b| *b == F::ONE));
348    }
349
350    #[test]
351    fn test_u64_to_16_bit_limbs() {
352        // Convert 0x123456789ABCDEF0
353        let val: u64 = 0x123456789ABCDEF0;
354        let limbs = u64_to_16_bit_limbs::<F>(val);
355
356        // Expected limbs (little endian): [0xDEF0, 0x9ABC, 0x5678, 0x1234]
357        assert_eq!(limbs[0], F::from_u16(0xDEF0));
358        assert_eq!(limbs[1], F::from_u16(0x9ABC));
359        assert_eq!(limbs[2], F::from_u16(0x5678));
360        assert_eq!(limbs[3], F::from_u16(0x1234));
361
362        assert_eq!(
363            limbs[0]
364                + limbs[1].mul_2exp_u64(16)
365                + limbs[2].mul_2exp_u64(32)
366                + limbs[3].mul_2exp_u64(48),
367            F::from_u64(val)
368        );
369
370        // Check zero
371        let limbs = u64_to_16_bit_limbs::<F>(0);
372        assert!(limbs.iter().all(|l| *l == F::ZERO));
373
374        // Check max
375        let limbs = u64_to_16_bit_limbs::<F>(u64::MAX);
376        for l in limbs {
377            assert_eq!(l, F::from_u64(0xFFFF));
378        }
379
380        // Check small value
381        let val: u64 = 0x1234;
382        let limbs = u64_to_16_bit_limbs::<F>(val);
383        assert_eq!(limbs[0], F::from_u64(0x1234));
384        assert!(limbs[1..].iter().all(|l| *l == F::ZERO));
385    }
386}