halo2derive/field/
arith.rs

1use proc_macro2::TokenStream;
2use quote::format_ident as fmtid;
3use quote::quote;
4
5fn select(cond: bool, this: TokenStream, other: TokenStream) -> TokenStream {
6    if cond {
7        this
8    } else {
9        other
10    }
11}
12
13pub(crate) fn impl_arith(field: &syn::Ident, num_limbs: usize, inv: u64) -> TokenStream {
14    let impl_add = impl_add(field, num_limbs);
15    let impl_sub = impl_sub(field, num_limbs);
16    let impl_neg = impl_neg(field, num_limbs);
17    let impl_mont = impl_mont(field, num_limbs, inv);
18    let impl_from_mont = impl_from_mont(field, num_limbs, inv);
19    let impl_mul = impl_mul(field, num_limbs, false);
20    let impl_square = impl_square(field, num_limbs);
21    let wide_num_limbs = num_limbs * 2;
22    quote::quote! {
23        impl #field {
24            #[inline(always)]
25            pub const fn add(&self, rhs: &Self) -> Self {
26                #impl_add
27            }
28
29            #[inline]
30            pub const fn double(&self) -> Self {
31                self.add(self)
32            }
33
34            #[inline(always)]
35            pub const fn sub(&self, rhs: &Self) -> Self {
36                #impl_sub
37            }
38
39            #[inline(always)]
40            pub const fn neg(&self) -> Self {
41                #impl_neg
42            }
43
44            #[inline(always)]
45            pub const fn mul(&self, rhs: &Self) -> Self{
46                #impl_mul
47            }
48
49            #[inline(always)]
50            pub const fn square(&self) -> Self{
51                #impl_square
52            }
53
54            #[inline(always)]
55            pub(crate) const fn montgomery_reduce(r: &[u64; #wide_num_limbs]) -> Self {
56                #impl_mont
57            }
58
59            #[inline(always)]
60            pub(crate) const fn from_mont(&self) -> [u64; #num_limbs] {
61                #impl_from_mont
62            }
63        }
64    }
65}
66
67pub(crate) fn impl_arith_always_const(
68    field: &syn::Ident,
69    num_limbs: usize,
70    inv: u64,
71) -> TokenStream {
72    let impl_sub = impl_sub(field, num_limbs);
73    let impl_mont = impl_mont(field, num_limbs, inv);
74    let impl_mul = impl_mul(field, num_limbs, true);
75    let wide_num_limbs = num_limbs * 2;
76    quote::quote! {
77        impl #field {
78
79            #[inline(always)]
80            pub(crate) const fn sub_const(&self, rhs: &Self) -> Self {
81                #impl_sub
82            }
83
84
85            #[inline(always)]
86            pub(crate) const fn mul_const(&self, rhs: &Self) -> Self{
87                #impl_mul
88            }
89
90            #[inline(always)]
91            pub(crate) const fn montgomery_reduce_const(r: &[u64; #wide_num_limbs]) -> Self {
92                #impl_mont
93            }
94        }
95    }
96}
97
98fn impl_mul(field: &syn::Ident, num_limbs: usize, constant: bool) -> TokenStream {
99    let mut gen = quote! { use crate::arithmetic::{adc, sbb, mac}; };
100    for i in 0..num_limbs {
101        for j in 0..num_limbs {
102            let r_out = fmtid!("r_{}", i + j);
103            let r_next = fmtid!("r_{}", i + j + 1);
104            let r_in = select(i == 0, quote! {0}, quote! {#r_out});
105            let carry_in = select(j == 0, quote! {0}, quote! {carry});
106            let carry_out = select(j == num_limbs - 1, quote! {#r_next}, quote! {carry});
107            gen.extend(
108                quote! { let (#r_out, #carry_out) = mac(#r_in, self.0[#i], rhs.0[#j], #carry_in); },
109            );
110        }
111    }
112
113    let r: Vec<_> = (0..num_limbs * 2).map(|i| fmtid!("r_{}", i)).collect();
114    let mont_red = if constant {
115        quote! { #field::montgomery_reduce_const(&[#(#r),*]) }
116    } else {
117        quote! { #field::montgomery_reduce(&[#(#r),*]) }
118    };
119    quote! {
120        #gen
121        #mont_red
122    }
123}
124
125fn impl_square(field: &syn::Ident, num_limbs: usize) -> TokenStream {
126    let mut gen = quote! { use crate::arithmetic::{adc, sbb, mac}; };
127    for i in 0..num_limbs - 1 {
128        let start_index = i * 2 + 1;
129        for j in 0..num_limbs - i - 1 {
130            let r_out = fmtid!("r_{}", start_index + j);
131            let r_in = select(i == 0, quote! {0}, quote! {#r_out});
132            let r_next = fmtid!("r_{}", start_index + j + 1);
133            let carry_in = select(j == 0, quote! {0}, quote! {carry});
134            let carry_out = select(j == num_limbs - i - 2, quote! {#r_next}, quote! {carry});
135            let j = i + j + 1;
136            gen.extend(quote! { let (#r_out, #carry_out) = mac(#r_in, self.0[#i], self.0[#j], #carry_in); });
137        }
138    }
139
140    for i in (1..num_limbs * 2).rev() {
141        let (r_cur, r_next) = (fmtid!("r_{}", i), fmtid!("r_{}", i - 1));
142        if i == num_limbs * 2 - 1 {
143            gen.extend(quote! { let #r_cur = #r_next >> 63; });
144        } else if i == 1 {
145            gen.extend(quote! { let #r_cur = (#r_cur << 1); });
146        } else {
147            gen.extend(quote! { let #r_cur = (#r_cur << 1) | (#r_next >> 63); });
148        }
149    }
150
151    for i in 0..num_limbs {
152        let index = i * 2;
153        let r_cur = fmtid!("r_{}", index);
154        let r_next = fmtid!("r_{}", index + 1);
155        let r_cur_in = select(i == 0, quote! {0}, quote! {#r_cur});
156        let carry_in = select(i == 0, quote! {0}, quote! {carry});
157        let carry_out = select(i == num_limbs - 1, quote! {_}, quote! {carry});
158        gen.extend(quote! {
159            let (#r_cur, carry) = mac(#r_cur_in, self.0[#i], self.0[#i], #carry_in);
160            let (#r_next, #carry_out) = adc(0, #r_next, carry);
161        });
162    }
163
164    let r: Vec<_> = (0..num_limbs * 2).map(|i| fmtid!("r_{}", i)).collect();
165    quote! {
166        #gen
167        #field::montgomery_reduce(&[#(#r),*])
168    }
169}
170
171fn impl_add(field: &syn::Ident, num_limbs: usize) -> TokenStream {
172    let mut gen = quote! { use crate::arithmetic::{adc, sbb}; };
173
174    (0..num_limbs).for_each(|i| {
175        let carry = select(i == 0, quote! {0}, quote! {carry});
176        let d_i = fmtid!("d_{}", i);
177        gen.extend(quote! { let ( #d_i, carry) = adc(self.0[#i], rhs.0[#i], #carry); });
178    });
179
180    // Attempt to subtract the modulus, to ensure the value
181    // is smaller than the modulus.
182    (0..num_limbs).for_each(|i| {
183        let borrow = select(i == 0, quote! {0}, quote! {borrow});
184        let d_i = fmtid!("d_{}", i);
185        gen.extend(quote! { let (#d_i, borrow) = sbb(#d_i, Self::MODULUS_LIMBS[#i], #borrow); });
186    });
187    gen.extend(quote! {let (_, borrow) = sbb(carry, 0, borrow);});
188
189    (0..num_limbs).for_each(|i| {
190        let carry_in = select(i == 0, quote! {0}, quote! {carry});
191        let carry_out = select(i == num_limbs - 1, quote! {_}, quote! {carry});
192        let d_i = fmtid!("d_{}", i);
193        gen.extend(
194            quote! { let (#d_i, #carry_out) = adc(#d_i, Self::MODULUS_LIMBS[#i] & borrow, #carry_in); },
195        );
196    });
197
198    let ret: Vec<_> = (0..num_limbs).map(|i| fmtid!("d_{}", i)).collect();
199
200    quote! {
201        #gen
202        #field([#(#ret),*])
203    }
204}
205
206fn impl_sub(field: &syn::Ident, num_limbs: usize) -> TokenStream {
207    let mut gen = quote! { use crate::arithmetic::{adc, sbb}; };
208
209    (0..num_limbs).for_each(|i| {
210        let borrow = select(i == 0, quote! {0}, quote! {borrow});
211        let d_i = fmtid!("d_{}", i);
212        gen.extend(quote! { let (#d_i, borrow) = sbb(self.0[#i], rhs.0[#i], #borrow); });
213    });
214
215    (0..num_limbs).for_each(|i| {
216        let carry_in = select(i == 0, quote! {0}, quote! {carry});
217        let carry_out = select(i == num_limbs - 1, quote! {_}, quote! {carry});
218        let d_i = fmtid!("d_{}", i);
219        gen.extend(
220            quote! { let (#d_i, #carry_out) = adc(#d_i, Self::MODULUS_LIMBS[#i] & borrow, #carry_in); },
221        );
222    });
223
224    let ret: Vec<_> = (0..num_limbs).map(|i| fmtid!("d_{}", i)).collect();
225
226    quote! {
227        #gen
228        #field([#(#ret),*])
229    }
230}
231
232fn impl_neg(field: &syn::Ident, num_limbs: usize) -> TokenStream {
233    let mut gen = quote! { use crate::arithmetic::{adc, sbb}; };
234
235    (0..num_limbs).for_each(|i| {
236        let borrow_in = select(i == 0, quote! {0}, quote! {borrow});
237        let borrow_out = select(i == num_limbs - 1, quote! {_}, quote! {borrow});
238        let d_i = fmtid!("d_{}", i);
239        gen.extend(quote! { let (#d_i, #borrow_out) = sbb(Self::MODULUS_LIMBS[#i], self.0[#i], #borrow_in); })
240    });
241
242    let mask_limbs: Vec<_> = (0..num_limbs)
243        .map(|i| quote::quote! { self.0[#i] })
244        .collect();
245    gen.extend(quote! { let mask = (((#(#mask_limbs)|*) == 0) as u64).wrapping_sub(1); });
246
247    let ret: Vec<_> = (0..num_limbs)
248        .map(|i| {
249            let d_i = fmtid!("d_{}", i);
250            quote! { #d_i & mask }
251        })
252        .collect();
253
254    quote! {
255        #gen
256        #field([#(#ret),*])
257    }
258}
259
260fn impl_mont(field: &syn::Ident, num_limbs: usize, inv: u64) -> TokenStream {
261    let mut gen = quote! { use crate::arithmetic::{adc, sbb, mac}; };
262
263    for i in 0..num_limbs {
264        if i == 0 {
265            gen.extend(quote! { let k = r[0].wrapping_mul(#inv); });
266
267            for j in 0..num_limbs {
268                let r_out = fmtid!("r_{}", j);
269                let r_out = select(j == 0, quote! {_}, quote! {#r_out});
270                let carry_in = select(j == 0, quote! {0}, quote! {carry});
271                gen.extend(quote! { let (#r_out, carry) = mac(r[#j], k, Self::MODULUS_LIMBS[#j], #carry_in); });
272            }
273            let r_out = fmtid!("r_{}", num_limbs);
274            gen.extend(quote! { let (#r_out, carry2) = adc(r[#num_limbs], 0, carry); });
275        } else {
276            let r_i = fmtid!("r_{}", i);
277            gen.extend(quote! { let k = #r_i.wrapping_mul(#inv); });
278
279            for j in 0..num_limbs {
280                let r_in = fmtid!("r_{}", j + i);
281                let r_out = select(j == 0, quote! {_}, quote! {#r_in});
282                let carry_in = select(j == 0, quote! {0}, quote! {carry});
283                gen.extend(quote! { let (#r_out, carry) = mac(#r_in, k, Self::MODULUS_LIMBS[#j], #carry_in); });
284            }
285            let idx = num_limbs + i;
286            let r_out = fmtid!("r_{}", idx);
287            gen.extend(quote! { let (#r_out, carry2) = adc(r[#idx], carry2, carry); });
288        }
289    }
290
291    (0..num_limbs).for_each(|i| {
292        let borrow = select(i == 0, quote! {0}, quote! {borrow});
293        let d_i = fmtid!("d_{}", i);
294        let r_in = fmtid!("r_{}", num_limbs + i);
295        gen.extend(quote! { let (#d_i, borrow) = sbb(#r_in, Self::MODULUS_LIMBS[#i], #borrow); });
296    });
297
298    gen.extend(quote! {let (_, borrow) = sbb(carry2, 0, borrow);});
299
300    (0..num_limbs).for_each(|i| {
301        let carry_in = select(i == 0, quote! {0}, quote! {carry});
302        let carry_out = select(i == num_limbs - 1, quote! {_}, quote! {carry});
303        let d_i = fmtid!("d_{}", i);
304        gen.extend(
305            quote! { let (#d_i, #carry_out) = adc(#d_i, Self::MODULUS_LIMBS[#i] & borrow, #carry_in); },
306        );
307    });
308    let ret: Vec<_> = (0..num_limbs).map(|i| fmtid!("d_{}", i)).collect();
309
310    quote! {
311        #gen
312        #field([#(#ret),*])
313    }
314}
315
316fn impl_from_mont(field: &syn::Ident, num_limbs: usize, inv: u64) -> TokenStream {
317    let mut gen = quote! { use crate::arithmetic::{adc, sbb, mac}; };
318
319    for i in 0..num_limbs {
320        let r_i = fmtid!("r_{}", i);
321        if i == 0 {
322            gen.extend(quote! { let k = self.0[0].wrapping_mul(#inv); });
323        } else {
324            gen.extend(quote! { let k = #r_i.wrapping_mul(#inv); });
325        }
326
327        for j in 0..num_limbs {
328            let r_ij = fmtid!("r_{}", (j + i) % num_limbs);
329            let r_out = select(j == 0, quote! {_}, quote! {#r_ij});
330            let r_ij = select(i == 0, quote! {self.0[#j]}, quote! {#r_ij});
331            let carry_in = select(j == 0, quote! {0}, quote! {#r_i});
332            gen.extend(
333                quote! { let (#r_out, #r_i) = mac(#r_ij, k, Self::MODULUS_LIMBS[#j], #carry_in); },
334            );
335        }
336    }
337    let ret: Vec<_> = (0..num_limbs).map(|i| fmtid!("r_{}", i)).collect();
338    quote! {
339        #gen
340        #field([#(#ret),*]).sub(&#field(Self::MODULUS_LIMBS)).0
341    }
342}