p3_air/utils.rs
1//! A collection of utility functions helpful in defining AIRs.
2
3use core::array;
4
5use p3_field::integers::QuotientMap;
6use p3_field::{Field, PrimeCharacteristicRing};
7
8use crate::AirBuilder;
9
10/// Pack a collection of bits into a number.
11///
12/// Given `vec = [v_0, v_1, ..., v_n]` returns `v_0 + 2v_1 + ... + 2^n v_n`
13#[inline]
14pub fn pack_bits_le<R, Var, I>(iter: I) -> R
15where
16 R: PrimeCharacteristicRing,
17 Var: Into<R> + Clone,
18 I: DoubleEndedIterator<Item = Var>,
19{
20 iter.rev()
21 .map(Into::into)
22 .reduce(|acc, elem| acc.double() + elem)
23 .unwrap_or(R::ZERO)
24}
25
26/// Compute `xor` on a list of boolean field elements.
27///
28/// Verifies at debug time that all inputs are boolean.
29#[inline(always)]
30pub fn checked_xor<F: Field, const N: usize>(xs: &[F]) -> F {
31 xs.iter().fold(F::ZERO, |acc, x| {
32 debug_assert!(x.is_zero() || x.is_one());
33 acc.xor(x)
34 })
35}
36
37/// Compute `andnot` on a pair of boolean field elements.
38///
39/// Verifies at debug time that both inputs are boolean.
40#[inline(always)]
41pub fn checked_andn<F: Field>(x: F, y: F) -> F {
42 debug_assert!(x.is_zero() || x.is_one());
43 debug_assert!(y.is_zero() || y.is_one());
44 x.andn(&y)
45}
46
47/// Convert a 32-bit integer into an array of 32 0 or 1 field elements.
48///
49/// The output array is in little-endian order.
50#[inline]
51pub fn u32_to_bits_le<R: PrimeCharacteristicRing>(val: u32) -> [R; 32] {
52 array::from_fn(|i| R::from_bool(val & (1 << i) != 0))
53}
54
55/// Convert a 64-bit integer into an array of 64 0 or 1 field elements.
56///
57/// The output array is in little-endian order.
58#[inline]
59pub fn u64_to_bits_le<R: PrimeCharacteristicRing>(val: u64) -> [R; 64] {
60 array::from_fn(|i| R::from_bool(val & (1 << i) != 0))
61}
62
63/// Convert a 64-bit integer into an array of four field elements representing the 16 bit limb decomposition.
64///
65/// The output array is in little-endian order.
66#[inline]
67pub fn u64_to_16_bit_limbs<R: PrimeCharacteristicRing>(val: u64) -> [R; 4] {
68 array::from_fn(|i| R::from_u16((val >> (16 * i)) as u16))
69}
70
71/// Verify that `a = b + c + d mod 2^32`
72///
73/// We assume that a, b, c, d are all given as `2, 16` bit limbs (e.g. `a = a[0] + 2^16 a[1]`) and
74/// each `16` bit limb has been range checked to ensure it contains a value in `[0, 2^16)`.
75///
76/// This function assumes we are working over a field with characteristic `P > 3*2^16`.
77///
78/// # Panics
79///
80/// The function will panic if the characteristic of the field is less than or equal to 2^16.
81#[inline]
82pub fn add3<AB: AirBuilder>(
83 builder: &mut AB,
84 a: &[AB::Var; 2],
85 b: &[AB::Var; 2],
86 c: &[AB::Expr; 2],
87 d: &[AB::Expr; 2],
88) {
89 // Define:
90 // acc = a - b - c - d (mod P)
91 // acc_16 = a[0] - b[0] - c[0] - d[0] (mod P)
92 //
93 // We perform 2 checks:
94 //
95 // (1) acc*(acc + 2^32)*(acc + 2*2^32) = 0.
96 // (2) acc_16*(acc_16 + 2^16)*(acc_16 + 2*2^16) = 0.
97 //
98 // We give a short proof for why this lets us conclude that a = b + c + d mod 2^32:
99 //
100 // As all 16 bit limbs have been range checked, we know that a, b, c, d lie in [0, 2^32) and hence
101 // a = b + c + d mod 2^32 if and only if, over the integers, a - b - c - d = 0, -2^32 or -2*2^32.
102 //
103 // Equation (1) verifies that a - b - c - d mod P = 0, -2^32 or -2*2^32.
104 //
105 // Field overflow cannot occur when computing acc_16 as our characteristic is larger than 3*2^16.
106 // Hence, equation (2) verifies that, over the integers, a[0] - b[0] - c[0] - d[0] = 0, -2^16 or -2*2^16.
107 // Either way we can immediately conclude that a - b - c - d = 0 mod 2^16.
108 //
109 // Now we can use the chinese remainder theorem to combine these results to conclude that
110 // a - b - c - d mod 2^16P = 0, -2^32 or -2*2^32.
111 //
112 // No overflow can occur mod 2^16 P as 2^16 P > 3*2^32 and a, b, c, d < 2^32. Hence we conclude that
113 // over the integers a - b - c - d = 0, -2^32 or -2*2^32 which implies a = b + c + d mod 2^32.
114
115 // By assumption P > 3*2^16 so 1 << 16 will be less than P. We use the checked version just to be safe.
116 // The compiler should optimize it away.
117 let two_16 =
118 <AB::Expr as PrimeCharacteristicRing>::PrimeSubfield::from_canonical_checked(1 << 16)
119 .unwrap();
120 let two_32 = two_16.square();
121
122 let acc_16 = a[0].clone() - b[0].clone() - c[0].clone() - d[0].clone();
123 let acc_32 = a[1].clone() - b[1].clone() - c[1].clone() - d[1].clone();
124 let acc = acc_16.clone() + acc_32.mul_2exp_u64(16);
125
126 builder.assert_zeros([
127 acc.clone()
128 * (acc.clone() + AB::Expr::from_prime_subfield(two_32))
129 * (acc + AB::Expr::from_prime_subfield(two_32.double())),
130 acc_16.clone()
131 * (acc_16.clone() + AB::Expr::from_prime_subfield(two_16))
132 * (acc_16 + AB::Expr::from_prime_subfield(two_16.double())),
133 ]);
134}
135
136/// Verify that `a = b + c mod 2^32`
137///
138/// We assume that a, b, c are all given as `2, 16` bit limbs (e.g. `a = a[0] + 2^16 a[1]`) and
139/// each `16` bit limb has been range checked to ensure it contains a value in `[0, 2^16)`.
140///
141/// This function assumes we are working over a field with characteristic `P > 2^17`.
142///
143/// # Panics
144///
145/// The function will panic if the characteristic of the field is less than or equal to 2^16.
146#[inline]
147pub fn add2<AB: AirBuilder>(
148 builder: &mut AB,
149 a: &[AB::Var; 2],
150 b: &[AB::Var; 2],
151 c: &[AB::Expr; 2],
152) {
153 // Define:
154 // acc = a - b - c (mod P)
155 // acc_16 = a[0] - b[0] - c[0] (mod P)
156 //
157 // We perform 2 checks:
158 //
159 // (1) acc*(acc + 2^32) = 0.
160 // (2) acc_16*(acc_16 + 2^16) = 0.
161 //
162 // We give a short proof for why this lets us conclude that a = b + c mod 2^32:
163 //
164 // As all 16 bit limbs have been range checked, we know that a, b, c lie in [0, 2^32) and hence
165 // a = b + c mod 2^32 if and only if, over the integers, a - b - c = 0 or -2^32.
166 //
167 // Equation (1) verifies that either a - b - c = 0 mod P or a - b - c = -2^32 mod P.
168 //
169 // Field overflow cannot occur when computing acc_16 as our characteristic is larger than 2^17.
170 // Hence, equation (2) verifies that, over the integers, a[0] - b[0] - c[0] = 0 or -2^16.
171 // Either way we can immediately conclude that a - b - c = 0 mod 2^16.
172 //
173 // Now we can use the chinese remainder theorem to combine these results to conclude that
174 // either a - b - c = 0 mod 2^16 P or a - b - c = -2^32 mod 2^16 P.
175 //
176 // No overflow can occur mod 2^16 P as 2^16 P > 2^33 and a, b, c < 2^32. Hence we conclude that
177 // over the integers a - b - c = 0 or a - b - c = -2^32 which is equivalent to a = b + c mod 2^32.
178
179 // By assumption P > 2^17 so 1 << 16 will be less than P. We use the checked version just to be safe.
180 // The compiler should optimize it away.
181 let two_16 =
182 <AB::Expr as PrimeCharacteristicRing>::PrimeSubfield::from_canonical_checked(1 << 16)
183 .unwrap();
184 let two_32 = two_16.square();
185
186 let acc_16 = a[0].clone() - b[0].clone() - c[0].clone();
187 let acc_32 = a[1].clone() - b[1].clone() - c[1].clone();
188 let acc = acc_16.clone() + acc_32.mul_2exp_u64(16);
189
190 builder.assert_zeros([
191 acc.clone() * (acc + AB::Expr::from_prime_subfield(two_32)),
192 acc_16.clone() * (acc_16 + AB::Expr::from_prime_subfield(two_16)),
193 ]);
194}
195
196/// Verify that `a = (b ^ (c << shift))`
197///
198/// We assume that a is given as `2 16` bit limbs and both b and c are unpacked into 32 individual bits.
199/// We assume that the bits of b have been range checked but not the inputs in c or a. Both of these are
200/// range checked as part of this function.
201#[inline]
202pub fn xor_32_shift<AB: AirBuilder>(
203 builder: &mut AB,
204 a: &[AB::Var; 2],
205 b: &[AB::Var; 32],
206 c: &[AB::Var; 32],
207 shift: usize,
208) {
209 // First we range check all elements of c.
210 builder.assert_bools(c.clone());
211
212 // Next we compute (b ^ (c << shift)) and pack the result into two 16-bit integers.
213 let xor_shift_c_0_16 = b[..16].iter().enumerate().map(|(i, elem)| {
214 (elem.clone())
215 .into()
216 .xor(&c[(32 + i - shift) % 32].clone().into())
217 });
218 let sum_0_16: AB::Expr = pack_bits_le(xor_shift_c_0_16);
219
220 let xor_shift_c_16_32 = b[16..].iter().enumerate().map(|(i, elem)| {
221 (elem.clone())
222 .into()
223 .xor(&c[(32 + (i + 16) - shift) % 32].clone().into())
224 });
225 let sum_16_32: AB::Expr = pack_bits_le(xor_shift_c_16_32);
226
227 // As both b and c have been range checked to be boolean, all the (b ^ (c << shift))
228 // are also boolean and so this final check additionally has the effect of range checking a[0], a[1].
229 builder.assert_zeros([a[0].clone() - sum_0_16, a[1].clone() - sum_16_32]);
230}
231
232#[cfg(test)]
233mod tests {
234 use alloc::vec;
235
236 use p3_baby_bear::BabyBear;
237
238 use super::*;
239
240 type F = BabyBear;
241
242 #[test]
243 fn test_pack_bits_le_various_patterns() {
244 // Pattern: [1, 0, 1] as little-endian => 1 + 2*0 + 4*1 = 5
245 let bits = [F::ONE, F::ZERO, F::ONE];
246 let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
247 assert_eq!(packed, F::from_u8(5));
248
249 // Pattern: [1, 1, 0, 1] => 1 + 2*1 + 4*0 + 8*1 = 1 + 2 + 8 = 11
250 let bits = [F::ONE, F::ONE, F::ZERO, F::ONE];
251 let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
252 assert_eq!(packed, F::from_u8(11));
253
254 // Pattern: all zeros
255 let bits = [F::ZERO; 5];
256 let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
257 assert_eq!(packed, F::ZERO);
258
259 // Pattern: single one at the highest place
260 let bits = [F::ZERO, F::ZERO, F::ZERO, F::ZERO, F::ONE];
261 let packed = pack_bits_le::<F, _, _>(bits.iter().cloned());
262 assert_eq!(packed, F::from_u8(16));
263 }
264
265 #[test]
266 fn test_checked_xor_multiple_cases() {
267 // Input: [1, 0, 1] => XOR(1 ^ 0 ^ 1) = 0
268 let bits = vec![F::ONE, F::ZERO, F::ONE];
269 let result = checked_xor::<F, 3>(&bits);
270 assert_eq!(result, F::ZERO);
271
272 // [1, 1, 1] => XOR = 1 ^ 1 ^ 1 = 1
273 let bits = vec![F::ONE, F::ONE, F::ONE];
274 let result = checked_xor::<F, 3>(&bits);
275 assert_eq!(result, F::ONE);
276
277 // [0, 0, 0] => XOR = 0
278 let bits = vec![F::ZERO, F::ZERO, F::ZERO];
279 let result = checked_xor::<F, 3>(&bits);
280 assert_eq!(result, F::ZERO);
281
282 // [1, 0, 1, 0] => XOR = 1 ^ 0 ^ 1 ^ 0 = 0
283 let bits = vec![F::ONE, F::ZERO, F::ONE, F::ZERO];
284 let result = checked_xor::<F, 4>(&bits);
285 assert_eq!(result, F::ZERO);
286 }
287
288 #[test]
289 fn test_checked_andn() {
290 // x = 1, y = 0 => 1 & !0 = 0
291 let result = checked_andn(F::ONE, F::ZERO);
292 assert_eq!(result, F::ZERO);
293
294 // x = 0, y = 1 => 0 & !1 = 1
295 let result = checked_andn(F::ZERO, F::ONE);
296 assert_eq!(result, F::ONE);
297
298 // x = 0, y = 0 => 0 & !0 = 0
299 let result = checked_andn(F::ZERO, F::ZERO);
300 assert_eq!(result, F::ZERO);
301
302 // x = 1, y = 1 => 1 & !1 = 0
303 let result = checked_andn(F::ONE, F::ONE);
304 assert_eq!(result, F::ZERO);
305 }
306
307 #[test]
308 fn test_u32_to_bits_le() {
309 // Convert 0b1010 (decimal 10) => [0, 1, 0, 1, ...]
310 let bits = u32_to_bits_le::<F>(10);
311 assert_eq!(bits[0], F::ZERO); // LSB first
312 assert_eq!(bits[1], F::ONE);
313 assert_eq!(bits[2], F::ZERO);
314 assert_eq!(bits[3], F::ONE);
315
316 for &bit in &bits[4..] {
317 assert_eq!(bit, F::ZERO);
318 }
319
320 // Check 0 => all zeros
321 let bits = u32_to_bits_le::<F>(0);
322 assert!(bits.iter().all(|b| *b == F::ZERO));
323
324 // Check max => all ones
325 let bits = u32_to_bits_le::<F>(u32::MAX);
326 assert!(bits.iter().all(|b| *b == F::ONE));
327 }
328
329 #[test]
330 fn test_u64_to_bits_le() {
331 // Convert 0b11 (decimal 3) => [1, 1, 0, ...]
332 let bits = u64_to_bits_le::<F>(3);
333 assert_eq!(bits[0], F::ONE);
334 assert_eq!(bits[1], F::ONE);
335 assert_eq!(bits[2], F::ZERO);
336
337 for &bit in &bits[3..] {
338 assert_eq!(bit, F::ZERO);
339 }
340
341 // Check 0 => all zeros
342 let bits = u64_to_bits_le::<F>(0);
343 assert!(bits.iter().all(|b| *b == F::ZERO));
344
345 // Check max => all ones
346 let bits = u64_to_bits_le::<F>(u64::MAX);
347 assert!(bits.iter().all(|b| *b == F::ONE));
348 }
349
350 #[test]
351 fn test_u64_to_16_bit_limbs() {
352 // Convert 0x123456789ABCDEF0
353 let val: u64 = 0x123456789ABCDEF0;
354 let limbs = u64_to_16_bit_limbs::<F>(val);
355
356 // Expected limbs (little endian): [0xDEF0, 0x9ABC, 0x5678, 0x1234]
357 assert_eq!(limbs[0], F::from_u16(0xDEF0));
358 assert_eq!(limbs[1], F::from_u16(0x9ABC));
359 assert_eq!(limbs[2], F::from_u16(0x5678));
360 assert_eq!(limbs[3], F::from_u16(0x1234));
361
362 assert_eq!(
363 limbs[0]
364 + limbs[1].mul_2exp_u64(16)
365 + limbs[2].mul_2exp_u64(32)
366 + limbs[3].mul_2exp_u64(48),
367 F::from_u64(val)
368 );
369
370 // Check zero
371 let limbs = u64_to_16_bit_limbs::<F>(0);
372 assert!(limbs.iter().all(|l| *l == F::ZERO));
373
374 // Check max
375 let limbs = u64_to_16_bit_limbs::<F>(u64::MAX);
376 for l in limbs {
377 assert_eq!(l, F::from_u64(0xFFFF));
378 }
379
380 // Check small value
381 let val: u64 = 0x1234;
382 let limbs = u64_to_16_bit_limbs::<F>(val);
383 assert_eq!(limbs[0], F::from_u64(0x1234));
384 assert!(limbs[1..].iter().all(|l| *l == F::ZERO));
385 }
386}