openvm_sdk/keygen/
perm.rs

1use std::cmp::Reverse;
2
3#[cfg(feature = "evm-prove")]
4use openvm_continuations::verifier::common::types::SpecialAirIds;
5
6/// Permutation of the AIR IDs to order them by forced trace heights.
7pub(crate) struct AirIdPermutation {
8    pub perm: Vec<usize>,
9}
10
11impl AirIdPermutation {
12    pub fn compute(heights: &[u32]) -> AirIdPermutation {
13        let mut height_with_air_id: Vec<_> = heights.iter().copied().enumerate().collect();
14        height_with_air_id.sort_by_key(|(_, h)| Reverse(*h));
15        AirIdPermutation {
16            perm: height_with_air_id
17                .into_iter()
18                .map(|(a_id, _)| a_id)
19                .collect(),
20        }
21    }
22    #[cfg(feature = "evm-prove")]
23    pub fn get_special_air_ids(&self) -> SpecialAirIds {
24        use openvm_circuit::arch::{CONNECTOR_AIR_ID, PROGRAM_AIR_ID, PUBLIC_VALUES_AIR_ID};
25
26        let perm_len = self.perm.len();
27        let mut ret = SpecialAirIds {
28            program_air_id: perm_len,
29            connector_air_id: perm_len,
30            public_values_air_id: perm_len,
31        };
32        for (i, &air_id) in self.perm.iter().enumerate() {
33            if air_id == PROGRAM_AIR_ID {
34                ret.program_air_id = i;
35            } else if air_id == CONNECTOR_AIR_ID {
36                ret.connector_air_id = i;
37            } else if air_id == PUBLIC_VALUES_AIR_ID {
38                ret.public_values_air_id = i;
39            }
40        }
41        debug_assert_ne!(ret.program_air_id, perm_len, "Program AIR not found");
42        debug_assert_ne!(ret.connector_air_id, perm_len, "Connector AIR not found");
43        debug_assert_ne!(
44            ret.public_values_air_id, perm_len,
45            "Public Values AIR not found"
46        );
47        ret
48    }
49    /// arr[i] <- arr[perm[i]]
50    pub(crate) fn permute<T>(&self, arr: &mut [T]) {
51        debug_assert_eq!(arr.len(), self.perm.len());
52        let mut perm = self.perm.clone();
53        for i in 0..perm.len() {
54            if perm[i] != i {
55                let mut curr = i;
56                loop {
57                    let target = perm[curr];
58                    perm[curr] = curr;
59                    if perm[target] == target {
60                        break;
61                    }
62                    arr.swap(curr, target);
63                    curr = target;
64                }
65            }
66        }
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use crate::keygen::perm::AirIdPermutation;
73
74    #[test]
75    fn test_air_id_permutation() {
76        {
77            let perm = AirIdPermutation {
78                perm: vec![2, 0, 1, 3],
79            };
80            let mut arr = vec![0, 100, 200, 300];
81            perm.permute(&mut arr);
82            assert_eq!(arr, vec![200, 0, 100, 300]);
83        }
84        {
85            let perm = AirIdPermutation {
86                perm: vec![0, 1, 2, 3],
87            };
88            let mut arr = vec![0, 100, 200, 300];
89            perm.permute(&mut arr);
90            assert_eq!(arr, vec![0, 100, 200, 300]);
91        }
92    }
93}