ark_ff_macros/
unroll.rs

1//! An attribute-like procedural macro for unrolling for loops with integer
2//! literal bounds.
3//!
4//! This crate provides the [`unroll_for_loops`](../attr.unroll_for_loops.html)
5//! attribute-like macro that can be applied to functions containing for-loops
6//! with integer bounds. This macro looks for loops to unroll and unrolls them
7//! at compile time.
8//!
9//!
10//! ## Usage
11//!
12//! Just add `#[unroll_for_loops]` above the function whose for loops you would
13//! like to unroll. Currently all for loops with integer literal bounds will be
14//! unrolled, although this macro currently can't see inside complex code (e.g.
15//! for loops within closures).
16//!
17//!
18//! ## Example
19//!
20//! The following function computes a matrix-vector product and returns the
21//! result as an array. Both of the inner for-loops are unrolled when
22//! `#[unroll_for_loops]` is applied.
23//!
24//! ```rust
25//! use ark_ff_macros::unroll_for_loops;
26//!
27//! #[unroll_for_loops(12)]
28//! fn mtx_vec_mul(mtx: &[[f64; 5]; 5], vec: &[f64; 5]) -> [f64; 5] {
29//!     let mut out = [0.0; 5];
30//!     for col in 0..5 {
31//!         for row in 0..5 {
32//!             out[row] += mtx[col][row] * vec[col];
33//!         }
34//!     }
35//!     out
36//! }
37//!
38//! fn mtx_vec_mul_2(mtx: &[[f64; 5]; 5], vec: &[f64; 5]) -> [f64; 5] {
39//!     let mut out = [0.0; 5];
40//!     for col in 0..5 {
41//!         for row in 0..5 {
42//!             out[row] += mtx[col][row] * vec[col];
43//!         }
44//!     }
45//!     out
46//! }
47//! let a = [[1.0, 2.0, 3.0, 4.0, 5.0]; 5];
48//! let b = [7.9, 4.8, 3.8, 4.22, 5.2];
49//! assert_eq!(mtx_vec_mul(&a, &b), mtx_vec_mul_2(&a, &b));
50//! ```
51//!
52//! This code was adapted from the [`unroll`](https://crates.io/crates/unroll) crate.
53
54use std::borrow::Borrow;
55
56use syn::{
57    parse_quote, token::Brace, Block, Expr, ExprBlock, ExprForLoop, ExprIf, ExprLet, ExprRange,
58    Pat, PatIdent, RangeLimits, Stmt,
59};
60
61/// Routine to unroll for loops within a block
62pub(crate) fn unroll_in_block(block: &Block, unroll_by: usize) -> Block {
63    let &Block {
64        ref brace_token,
65        ref stmts,
66    } = block;
67    let mut new_stmts = Vec::new();
68    for stmt in stmts.iter() {
69        if let Stmt::Expr(expr) = stmt {
70            new_stmts.push(Stmt::Expr(unroll(expr, unroll_by)));
71        } else if let Stmt::Semi(expr, semi) = stmt {
72            new_stmts.push(Stmt::Semi(unroll(expr, unroll_by), *semi));
73        } else {
74            new_stmts.push((*stmt).clone());
75        }
76    }
77    Block {
78        brace_token: *brace_token,
79        stmts: new_stmts,
80    }
81}
82
83/// Routine to unroll a for loop statement, or return the statement unchanged if
84/// it's not a for loop.
85fn unroll(expr: &Expr, unroll_by: usize) -> Expr {
86    // impose a scope that we can break out of so we can return stmt without copying it.
87    if let Expr::ForLoop(for_loop) = expr {
88        let ExprForLoop {
89            ref attrs,
90            ref label,
91            ref pat,
92            expr: ref range,
93            ref body,
94            ..
95        } = *for_loop;
96
97        let new_body = unroll_in_block(body, unroll_by);
98
99        let forloop_with_body = |body| {
100            Expr::ForLoop(ExprForLoop {
101                body,
102                ..(*for_loop).clone()
103            })
104        };
105
106        if let Pat::Ident(PatIdent {
107            by_ref,
108            mutability,
109            ident,
110            subpat,
111            ..
112        }) = pat
113        {
114            // Don't know how to deal with these so skip and return the original.
115            if !by_ref.is_none() || !mutability.is_none() || !subpat.is_none() {
116                return forloop_with_body(new_body);
117            }
118            let idx = ident; // got the index variable name
119
120            if let Expr::Range(ExprRange {
121                from, limits, to, ..
122            }) = range.borrow()
123            {
124                // Parse `from` in `from..to`.
125                let begin = match from {
126                    Some(e) => e.clone(),
127                    _ => Box::new(parse_quote!(0usize)),
128                };
129                let end = match to {
130                    Some(e) => e.clone(),
131                    _ => return forloop_with_body(new_body),
132                };
133                let end_is_closed = if let RangeLimits::Closed(_) = limits {
134                    1usize
135                } else {
136                    0
137                };
138                let end: Expr = parse_quote!(#end + #end_is_closed);
139
140                let preamble: Vec<Stmt> = parse_quote! {
141                    let total_iters: usize = (#end).checked_sub(#begin).unwrap_or(0);
142                    let num_loops = total_iters / #unroll_by;
143                    let remainder = total_iters % #unroll_by;
144                };
145                let mut block = Block {
146                    brace_token: Brace::default(),
147                    stmts: preamble,
148                };
149                let mut loop_expr: ExprForLoop = parse_quote! {
150                    for #idx in (0..num_loops) {
151                        let mut #idx = #begin + #idx * #unroll_by;
152                    }
153                };
154                let loop_block: Vec<Stmt> = parse_quote! {
155                    if #idx < #end {
156                        #new_body
157                    }
158                    #idx += 1;
159                };
160                let loop_body = (0..unroll_by).flat_map(|_| loop_block.clone());
161                loop_expr.body.stmts.extend(loop_body);
162                block.stmts.push(Stmt::Expr(Expr::ForLoop(loop_expr)));
163
164                // idx = num_loops * unroll_by;
165                block
166                    .stmts
167                    .push(parse_quote! { let mut #idx = #begin + num_loops * #unroll_by; });
168                // if idx < remainder + num_loops * unroll_by { ... }
169                let post_loop_block: Vec<Stmt> = parse_quote! {
170                    if #idx < #end {
171                        #new_body
172                    }
173                    #idx += 1;
174                };
175                let post_loop = (0..unroll_by).flat_map(|_| post_loop_block.clone());
176                block.stmts.extend(post_loop);
177
178                let mut attrs = attrs.clone();
179                attrs.extend(vec![parse_quote!(#[allow(unused)])]);
180                Expr::Block(ExprBlock {
181                    attrs,
182                    label: label.clone(),
183                    block,
184                })
185            } else {
186                forloop_with_body(new_body)
187            }
188        } else {
189            forloop_with_body(new_body)
190        }
191    } else if let Expr::If(if_expr) = expr {
192        let ExprIf {
193            ref cond,
194            ref then_branch,
195            ref else_branch,
196            ..
197        } = *if_expr;
198        Expr::If(ExprIf {
199            cond: Box::new(unroll(cond, unroll_by)),
200            then_branch: unroll_in_block(then_branch, unroll_by),
201            else_branch: else_branch
202                .as_ref()
203                .map(|x| (x.0, Box::new(unroll(&x.1, unroll_by)))),
204            ..(*if_expr).clone()
205        })
206    } else if let Expr::Let(let_expr) = expr {
207        let ExprLet { ref expr, .. } = *let_expr;
208        Expr::Let(ExprLet {
209            expr: Box::new(unroll(expr, unroll_by)),
210            ..(*let_expr).clone()
211        })
212    } else if let Expr::Block(expr_block) = expr {
213        let ExprBlock { ref block, .. } = *expr_block;
214        Expr::Block(ExprBlock {
215            block: unroll_in_block(block, unroll_by),
216            ..(*expr_block).clone()
217        })
218    } else {
219        (*expr).clone()
220    }
221}
222
223#[test]
224fn test_expand() {
225    use quote::ToTokens;
226    let for_loop: Block = parse_quote! {{
227        let mut sum = 0;
228        for i in 0..8 {
229            sum += i;
230        }
231    }};
232    println!("{}", unroll_in_block(&for_loop, 12).to_token_stream());
233}