ark_ff_macros/
unroll.rs
1use std::borrow::Borrow;
55
56use syn::{
57 parse_quote, token::Brace, Block, Expr, ExprBlock, ExprForLoop, ExprIf, ExprLet, ExprRange,
58 Pat, PatIdent, RangeLimits, Stmt,
59};
60
61pub(crate) fn unroll_in_block(block: &Block, unroll_by: usize) -> Block {
63 let &Block {
64 ref brace_token,
65 ref stmts,
66 } = block;
67 let mut new_stmts = Vec::new();
68 for stmt in stmts.iter() {
69 if let Stmt::Expr(expr) = stmt {
70 new_stmts.push(Stmt::Expr(unroll(expr, unroll_by)));
71 } else if let Stmt::Semi(expr, semi) = stmt {
72 new_stmts.push(Stmt::Semi(unroll(expr, unroll_by), *semi));
73 } else {
74 new_stmts.push((*stmt).clone());
75 }
76 }
77 Block {
78 brace_token: *brace_token,
79 stmts: new_stmts,
80 }
81}
82
83fn unroll(expr: &Expr, unroll_by: usize) -> Expr {
86 if let Expr::ForLoop(for_loop) = expr {
88 let ExprForLoop {
89 ref attrs,
90 ref label,
91 ref pat,
92 expr: ref range,
93 ref body,
94 ..
95 } = *for_loop;
96
97 let new_body = unroll_in_block(body, unroll_by);
98
99 let forloop_with_body = |body| {
100 Expr::ForLoop(ExprForLoop {
101 body,
102 ..(*for_loop).clone()
103 })
104 };
105
106 if let Pat::Ident(PatIdent {
107 by_ref,
108 mutability,
109 ident,
110 subpat,
111 ..
112 }) = pat
113 {
114 if !by_ref.is_none() || !mutability.is_none() || !subpat.is_none() {
116 return forloop_with_body(new_body);
117 }
118 let idx = ident; if let Expr::Range(ExprRange {
121 from, limits, to, ..
122 }) = range.borrow()
123 {
124 let begin = match from {
126 Some(e) => e.clone(),
127 _ => Box::new(parse_quote!(0usize)),
128 };
129 let end = match to {
130 Some(e) => e.clone(),
131 _ => return forloop_with_body(new_body),
132 };
133 let end_is_closed = if let RangeLimits::Closed(_) = limits {
134 1usize
135 } else {
136 0
137 };
138 let end: Expr = parse_quote!(#end + #end_is_closed);
139
140 let preamble: Vec<Stmt> = parse_quote! {
141 let total_iters: usize = (#end).checked_sub(#begin).unwrap_or(0);
142 let num_loops = total_iters / #unroll_by;
143 let remainder = total_iters % #unroll_by;
144 };
145 let mut block = Block {
146 brace_token: Brace::default(),
147 stmts: preamble,
148 };
149 let mut loop_expr: ExprForLoop = parse_quote! {
150 for #idx in (0..num_loops) {
151 let mut #idx = #begin + #idx * #unroll_by;
152 }
153 };
154 let loop_block: Vec<Stmt> = parse_quote! {
155 if #idx < #end {
156 #new_body
157 }
158 #idx += 1;
159 };
160 let loop_body = (0..unroll_by).flat_map(|_| loop_block.clone());
161 loop_expr.body.stmts.extend(loop_body);
162 block.stmts.push(Stmt::Expr(Expr::ForLoop(loop_expr)));
163
164 block
166 .stmts
167 .push(parse_quote! { let mut #idx = #begin + num_loops * #unroll_by; });
168 let post_loop_block: Vec<Stmt> = parse_quote! {
170 if #idx < #end {
171 #new_body
172 }
173 #idx += 1;
174 };
175 let post_loop = (0..unroll_by).flat_map(|_| post_loop_block.clone());
176 block.stmts.extend(post_loop);
177
178 let mut attrs = attrs.clone();
179 attrs.extend(vec![parse_quote!(#[allow(unused)])]);
180 Expr::Block(ExprBlock {
181 attrs,
182 label: label.clone(),
183 block,
184 })
185 } else {
186 forloop_with_body(new_body)
187 }
188 } else {
189 forloop_with_body(new_body)
190 }
191 } else if let Expr::If(if_expr) = expr {
192 let ExprIf {
193 ref cond,
194 ref then_branch,
195 ref else_branch,
196 ..
197 } = *if_expr;
198 Expr::If(ExprIf {
199 cond: Box::new(unroll(cond, unroll_by)),
200 then_branch: unroll_in_block(then_branch, unroll_by),
201 else_branch: else_branch
202 .as_ref()
203 .map(|x| (x.0, Box::new(unroll(&x.1, unroll_by)))),
204 ..(*if_expr).clone()
205 })
206 } else if let Expr::Let(let_expr) = expr {
207 let ExprLet { ref expr, .. } = *let_expr;
208 Expr::Let(ExprLet {
209 expr: Box::new(unroll(expr, unroll_by)),
210 ..(*let_expr).clone()
211 })
212 } else if let Expr::Block(expr_block) = expr {
213 let ExprBlock { ref block, .. } = *expr_block;
214 Expr::Block(ExprBlock {
215 block: unroll_in_block(block, unroll_by),
216 ..(*expr_block).clone()
217 })
218 } else {
219 (*expr).clone()
220 }
221}
222
223#[test]
224fn test_expand() {
225 use quote::ToTokens;
226 let for_loop: Block = parse_quote! {{
227 let mut sum = 0;
228 for i in 0..8 {
229 sum += i;
230 }
231 }};
232 println!("{}", unroll_in_block(&for_loop, 12).to_token_stream());
233}