openvm_sdk/keygen/
perm.rs

1use std::cmp::Reverse;
2
3use openvm_circuit::arch::{CONNECTOR_AIR_ID, PROGRAM_AIR_ID, PUBLIC_VALUES_AIR_ID};
4use openvm_continuations::verifier::common::types::SpecialAirIds;
5
6pub struct AirIdPermutation {
7    pub perm: Vec<usize>,
8}
9
10impl AirIdPermutation {
11    pub fn compute(heights: &[usize]) -> AirIdPermutation {
12        let mut height_with_air_id: Vec<_> = heights.iter().copied().enumerate().collect();
13        height_with_air_id.sort_by_key(|(_, h)| Reverse(*h));
14        AirIdPermutation {
15            perm: height_with_air_id
16                .into_iter()
17                .map(|(a_id, _)| a_id)
18                .collect(),
19        }
20    }
21    pub fn get_special_air_ids(&self) -> SpecialAirIds {
22        let perm_len = self.perm.len();
23        let mut ret = SpecialAirIds {
24            program_air_id: perm_len,
25            connector_air_id: perm_len,
26            public_values_air_id: perm_len,
27        };
28        for (i, &air_id) in self.perm.iter().enumerate() {
29            if air_id == PROGRAM_AIR_ID {
30                ret.program_air_id = i;
31            } else if air_id == CONNECTOR_AIR_ID {
32                ret.connector_air_id = i;
33            } else if air_id == PUBLIC_VALUES_AIR_ID {
34                ret.public_values_air_id = i;
35            }
36        }
37        debug_assert_ne!(ret.program_air_id, perm_len, "Program AIR not found");
38        debug_assert_ne!(ret.connector_air_id, perm_len, "Connector AIR not found");
39        debug_assert_ne!(
40            ret.public_values_air_id, perm_len,
41            "Public Values AIR not found"
42        );
43        ret
44    }
45    /// arr[i] <- arr[perm[i]]
46    pub(crate) fn permute<T>(&self, arr: &mut [T]) {
47        debug_assert_eq!(arr.len(), self.perm.len());
48        let mut perm = self.perm.clone();
49        for i in 0..perm.len() {
50            if perm[i] != i {
51                let mut curr = i;
52                loop {
53                    let target = perm[curr];
54                    perm[curr] = curr;
55                    if perm[target] == target {
56                        break;
57                    }
58                    arr.swap(curr, target);
59                    curr = target;
60                }
61            }
62        }
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use crate::keygen::perm::AirIdPermutation;
69
70    #[test]
71    fn test_air_id_permutation() {
72        {
73            let perm = AirIdPermutation {
74                perm: vec![2, 0, 1, 3],
75            };
76            let mut arr = vec![0, 100, 200, 300];
77            perm.permute(&mut arr);
78            assert_eq!(arr, vec![200, 0, 100, 300]);
79        }
80        {
81            let perm = AirIdPermutation {
82                perm: vec![0, 1, 2, 3],
83            };
84            let mut arr = vec![0, 100, 200, 300];
85            perm.permute(&mut arr);
86            assert_eq!(arr, vec![0, 100, 200, 300]);
87        }
88    }
89}