halo2_proofs/circuit/floor_planner/v1/
strategy.rs

1use std::{
2    cmp,
3    collections::{BTreeSet, HashMap},
4    ops::Range,
5};
6
7use super::{RegionColumn, RegionShape};
8use crate::{circuit::RegionStart, plonk::Any};
9
10/// A region allocated within a column.
11#[derive(Clone, Default, Debug, PartialEq, Eq)]
12struct AllocatedRegion {
13    // The starting position of the region.
14    start: usize,
15    // The length of the region.
16    length: usize,
17}
18
19impl Ord for AllocatedRegion {
20    fn cmp(&self, other: &Self) -> cmp::Ordering {
21        self.start.cmp(&other.start)
22    }
23}
24
25impl PartialOrd for AllocatedRegion {
26    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
27        Some(self.cmp(other))
28    }
29}
30
31/// An area of empty space within a column.
32pub(crate) struct EmptySpace {
33    // The starting position (inclusive) of the empty space.
34    start: usize,
35    // The ending position (exclusive) of the empty space, or `None` if unbounded.
36    end: Option<usize>,
37}
38
39impl EmptySpace {
40    pub(crate) fn range(&self) -> Option<Range<usize>> {
41        self.end.map(|end| self.start..end)
42    }
43}
44
45/// Allocated rows within a column.
46///
47/// This is a set of [a_start, a_end) pairs representing disjoint allocated intervals.
48#[derive(Clone, Default, Debug)]
49pub struct Allocations(BTreeSet<AllocatedRegion>);
50
51impl Allocations {
52    /// Returns the row that forms the unbounded unallocated interval [row, None).
53    pub(crate) fn unbounded_interval_start(&self) -> usize {
54        self.0
55            .iter()
56            .last()
57            .map(|r| r.start + r.length)
58            .unwrap_or(0)
59    }
60
61    /// Return all the *unallocated* nonempty intervals intersecting [start, end).
62    ///
63    /// `end = None` represents an unbounded end.
64    pub(crate) fn free_intervals(
65        &self,
66        start: usize,
67        end: Option<usize>,
68    ) -> impl Iterator<Item = EmptySpace> + '_ {
69        self.0
70            .iter()
71            .map(Some)
72            .chain(Some(None))
73            .scan(start, move |row, region| {
74                Some(if let Some(region) = region {
75                    if end.map(|end| region.start >= end).unwrap_or(false) {
76                        None
77                    } else {
78                        let ret = if *row < region.start {
79                            Some(EmptySpace {
80                                start: *row,
81                                end: Some(region.start),
82                            })
83                        } else {
84                            None
85                        };
86
87                        *row = cmp::max(*row, region.start + region.length);
88
89                        ret
90                    }
91                } else if end.map(|end| *row < end).unwrap_or(true) {
92                    Some(EmptySpace { start: *row, end })
93                } else {
94                    None
95                })
96            })
97            .flatten()
98    }
99}
100
101/// Allocated rows within a circuit.
102pub type CircuitAllocations = HashMap<RegionColumn, Allocations>;
103
104/// - `start` is the current start row of the region (not of this column).
105/// - `slack` is the maximum number of rows the start could be moved down, taking into
106///   account prior columns.
107fn first_fit_region(
108    column_allocations: &mut CircuitAllocations,
109    region_columns: &[RegionColumn],
110    region_length: usize,
111    start: usize,
112    slack: Option<usize>,
113) -> Option<usize> {
114    let (c, remaining_columns) = match region_columns.split_first() {
115        Some(cols) => cols,
116        None => return Some(start),
117    };
118    let end = slack.map(|slack| start + region_length + slack);
119
120    // Iterate over the unallocated non-empty intervals in c that intersect [start, end).
121    for space in column_allocations
122        .entry(*c)
123        .or_default()
124        .clone()
125        .free_intervals(start, end)
126    {
127        // Do we have enough room for this column of the region in this interval?
128        let s_slack = space
129            .end
130            .map(|end| (end as isize - space.start as isize) - region_length as isize);
131        if let Some((slack, s_slack)) = slack.zip(s_slack) {
132            assert!(s_slack <= slack as isize);
133        }
134        if s_slack.unwrap_or(0) >= 0 {
135            let row = first_fit_region(
136                column_allocations,
137                remaining_columns,
138                region_length,
139                space.start,
140                s_slack.map(|s| s as usize),
141            );
142            if let Some(row) = row {
143                if let Some(end) = end {
144                    assert!(row + region_length <= end);
145                }
146                column_allocations
147                    .get_mut(c)
148                    .unwrap()
149                    .0
150                    .insert(AllocatedRegion {
151                        start: row,
152                        length: region_length,
153                    });
154                return Some(row);
155            }
156        }
157    }
158
159    // No placement worked; the caller will need to try other possibilities.
160    None
161}
162
163/// Positions the regions starting at the earliest row for which none of the columns are
164/// in use, taking into account gaps between earlier regions.
165fn slot_in(
166    region_shapes: Vec<RegionShape>,
167) -> (Vec<(RegionStart, RegionShape)>, CircuitAllocations) {
168    // Tracks the empty regions for each column.
169    let mut column_allocations: CircuitAllocations = Default::default();
170
171    let regions = region_shapes
172        .into_iter()
173        .map(|region| {
174            // Sort the region's columns to ensure determinism.
175            // - An unstable sort is fine, because region.columns() returns a set.
176            // - The sort order relies on Column's Ord implementation!
177            let mut region_columns: Vec<_> = region.columns().iter().cloned().collect();
178            region_columns.sort_unstable();
179
180            let region_start = first_fit_region(
181                &mut column_allocations,
182                &region_columns,
183                region.row_count(),
184                0,
185                None,
186            )
187            .expect("We can always fit a region somewhere");
188
189            (region_start.into(), region)
190        })
191        .collect();
192
193    // Return the column allocations for potential further processing.
194    (regions, column_allocations)
195}
196
197/// Sorts the regions by advice area and then lays them out with the [`slot_in`] strategy.
198pub fn slot_in_biggest_advice_first(
199    region_shapes: Vec<RegionShape>,
200) -> (Vec<RegionStart>, CircuitAllocations) {
201    let mut sorted_regions: Vec<_> = region_shapes.into_iter().collect();
202    sorted_regions.sort_unstable_by_key(|shape| {
203        // Count the number of advice columns
204        let advice_cols = shape
205            .columns()
206            .iter()
207            .filter(|c| match c {
208                RegionColumn::Column(c) => matches!(c.column_type(), Any::Advice),
209                _ => false,
210            })
211            .count();
212        // Sort by advice area (since this has the most contention).
213        advice_cols * shape.row_count()
214    });
215    sorted_regions.reverse();
216
217    // Lay out the sorted regions.
218    let (mut regions, column_allocations) = slot_in(sorted_regions);
219
220    // Un-sort the regions so they match the original indexing.
221    regions.sort_unstable_by_key(|(_, region)| region.region_index().0);
222    let regions = regions.into_iter().map(|(start, _)| start).collect();
223
224    (regions, column_allocations)
225}
226
227#[test]
228fn test_slot_in() {
229    use crate::plonk::Column;
230
231    let regions = vec![
232        RegionShape {
233            region_index: 0.into(),
234            columns: vec![Column::new(0, Any::Advice), Column::new(1, Any::Advice)]
235                .into_iter()
236                .map(|a| a.into())
237                .collect(),
238            row_count: 15,
239        },
240        RegionShape {
241            region_index: 1.into(),
242            columns: vec![Column::new(2, Any::Advice)]
243                .into_iter()
244                .map(|a| a.into())
245                .collect(),
246            row_count: 10,
247        },
248        RegionShape {
249            region_index: 2.into(),
250            columns: vec![Column::new(2, Any::Advice), Column::new(0, Any::Advice)]
251                .into_iter()
252                .map(|a| a.into())
253                .collect(),
254            row_count: 10,
255        },
256    ];
257    assert_eq!(
258        slot_in(regions)
259            .0
260            .into_iter()
261            .map(|(i, _)| i)
262            .collect::<Vec<_>>(),
263        vec![0.into(), 0.into(), 15.into()]
264    );
265}