halo2_ecc/fields/
vector.rs

1use halo2_base::{
2    gates::GateInstructions,
3    utils::{BigPrimeField, ScalarField},
4    AssignedValue, Context,
5};
6use itertools::Itertools;
7use std::{
8    marker::PhantomData,
9    ops::{Index, IndexMut},
10};
11
12use crate::bigint::{CRTInteger, ProperCrtUint};
13
14use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeFieldChip, Selectable};
15
16/// A fixed length vector of `FieldPoint`s
17#[repr(transparent)]
18#[derive(Clone, Debug)]
19pub struct FieldVector<T>(pub Vec<T>);
20
21impl<T> Index<usize> for FieldVector<T> {
22    type Output = T;
23
24    fn index(&self, index: usize) -> &Self::Output {
25        &self.0[index]
26    }
27}
28
29impl<T> IndexMut<usize> for FieldVector<T> {
30    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
31        &mut self.0[index]
32    }
33}
34
35impl<T> AsRef<[T]> for FieldVector<T> {
36    fn as_ref(&self) -> &[T] {
37        &self.0
38    }
39}
40
41impl<'a, T: Clone, U: From<T>> From<&'a FieldVector<T>> for FieldVector<U> {
42    fn from(other: &'a FieldVector<T>) -> Self {
43        FieldVector(other.clone().into_iter().map(Into::into).collect())
44    }
45}
46
47impl<F: ScalarField> From<FieldVector<ProperCrtUint<F>>> for FieldVector<CRTInteger<F>> {
48    fn from(other: FieldVector<ProperCrtUint<F>>) -> Self {
49        FieldVector(other.into_iter().map(|x| x.0).collect())
50    }
51}
52
53impl<T, Fp> From<FieldVector<Reduced<T, Fp>>> for FieldVector<T> {
54    fn from(value: FieldVector<Reduced<T, Fp>>) -> Self {
55        FieldVector(value.0.into_iter().map(|x| x.0).collect())
56    }
57}
58
59impl<T> IntoIterator for FieldVector<T> {
60    type Item = T;
61    type IntoIter = std::vec::IntoIter<T>;
62
63    fn into_iter(self) -> Self::IntoIter {
64        self.0.into_iter()
65    }
66}
67
68/// Contains common functionality for vector operations that can be derived from those of the underlying `FpChip`
69#[derive(Clone, Copy, Debug)]
70pub struct FieldVectorChip<'fp, F: BigPrimeField, FpChip: FieldChip<F>> {
71    pub fp_chip: &'fp FpChip,
72    _f: PhantomData<F>,
73}
74
75impl<'fp, F, FpChip> FieldVectorChip<'fp, F, FpChip>
76where
77    F: BigPrimeField,
78    FpChip: PrimeFieldChip<F>,
79    FpChip::FieldType: BigPrimeField,
80{
81    pub fn new(fp_chip: &'fp FpChip) -> Self {
82        Self { fp_chip, _f: PhantomData }
83    }
84
85    pub fn gate(&self) -> &impl GateInstructions<F> {
86        self.fp_chip.gate()
87    }
88
89    pub fn fp_mul_no_carry<FP>(
90        &self,
91        ctx: &mut Context<F>,
92        a: impl IntoIterator<Item = FP>,
93        fp_point: impl Into<FpChip::UnsafeFieldPoint>,
94    ) -> FieldVector<FpChip::UnsafeFieldPoint>
95    where
96        FP: Into<FpChip::UnsafeFieldPoint>,
97    {
98        let fp_point = fp_point.into();
99        FieldVector(
100            a.into_iter().map(|a| self.fp_chip.mul_no_carry(ctx, a, fp_point.clone())).collect(),
101        )
102    }
103
104    pub fn select<FP>(
105        &self,
106        ctx: &mut Context<F>,
107        a: impl IntoIterator<Item = FP>,
108        b: impl IntoIterator<Item = FP>,
109        sel: AssignedValue<F>,
110    ) -> FieldVector<FP>
111    where
112        FpChip: Selectable<F, FP>,
113    {
114        FieldVector(
115            a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.select(ctx, a, b, sel)).collect(),
116        )
117    }
118
119    pub fn load_private<FieldExt, const DEGREE: usize>(
120        &self,
121        ctx: &mut Context<F>,
122        fe: FieldExt,
123    ) -> FieldVector<FpChip::FieldPoint>
124    where
125        FieldExt: FieldExtConstructor<FpChip::FieldType, DEGREE>,
126    {
127        FieldVector(fe.coeffs().into_iter().map(|a| self.fp_chip.load_private(ctx, a)).collect())
128    }
129
130    pub fn load_constant<FieldExt, const DEGREE: usize>(
131        &self,
132        ctx: &mut Context<F>,
133        c: FieldExt,
134    ) -> FieldVector<FpChip::FieldPoint>
135    where
136        FieldExt: FieldExtConstructor<FpChip::FieldType, DEGREE>,
137    {
138        FieldVector(c.coeffs().into_iter().map(|a| self.fp_chip.load_constant(ctx, a)).collect())
139    }
140
141    // signed overflow BigInt functions
142    pub fn add_no_carry<A, B>(
143        &self,
144        ctx: &mut Context<F>,
145        a: impl IntoIterator<Item = A>,
146        b: impl IntoIterator<Item = B>,
147    ) -> FieldVector<FpChip::UnsafeFieldPoint>
148    where
149        A: Into<FpChip::UnsafeFieldPoint>,
150        B: Into<FpChip::UnsafeFieldPoint>,
151    {
152        FieldVector(
153            a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.add_no_carry(ctx, a, b)).collect(),
154        )
155    }
156
157    pub fn add_constant_no_carry<A, FieldExt, const DEGREE: usize>(
158        &self,
159        ctx: &mut Context<F>,
160        a: impl IntoIterator<Item = A>,
161        c: FieldExt,
162    ) -> FieldVector<FpChip::UnsafeFieldPoint>
163    where
164        A: Into<FpChip::UnsafeFieldPoint>,
165        FieldExt: FieldExtConstructor<FpChip::FieldType, DEGREE>,
166    {
167        let c_coeffs = c.coeffs();
168        FieldVector(
169            a.into_iter()
170                .zip_eq(c_coeffs)
171                .map(|(a, c)| self.fp_chip.add_constant_no_carry(ctx, a, c))
172                .collect(),
173        )
174    }
175
176    pub fn sub_no_carry<A, B>(
177        &self,
178        ctx: &mut Context<F>,
179        a: impl IntoIterator<Item = A>,
180        b: impl IntoIterator<Item = B>,
181    ) -> FieldVector<FpChip::UnsafeFieldPoint>
182    where
183        A: Into<FpChip::UnsafeFieldPoint>,
184        B: Into<FpChip::UnsafeFieldPoint>,
185    {
186        FieldVector(
187            a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.sub_no_carry(ctx, a, b)).collect(),
188        )
189    }
190
191    pub fn negate(
192        &self,
193        ctx: &mut Context<F>,
194        a: impl IntoIterator<Item = FpChip::FieldPoint>,
195    ) -> FieldVector<FpChip::FieldPoint> {
196        FieldVector(a.into_iter().map(|a| self.fp_chip.negate(ctx, a)).collect())
197    }
198
199    pub fn scalar_mul_no_carry<A>(
200        &self,
201        ctx: &mut Context<F>,
202        a: impl IntoIterator<Item = A>,
203        c: i64,
204    ) -> FieldVector<FpChip::UnsafeFieldPoint>
205    where
206        A: Into<FpChip::UnsafeFieldPoint>,
207    {
208        FieldVector(a.into_iter().map(|a| self.fp_chip.scalar_mul_no_carry(ctx, a, c)).collect())
209    }
210
211    pub fn scalar_mul_and_add_no_carry<A, B>(
212        &self,
213        ctx: &mut Context<F>,
214        a: impl IntoIterator<Item = A>,
215        b: impl IntoIterator<Item = B>,
216        c: i64,
217    ) -> FieldVector<FpChip::UnsafeFieldPoint>
218    where
219        A: Into<FpChip::UnsafeFieldPoint>,
220        B: Into<FpChip::UnsafeFieldPoint>,
221    {
222        FieldVector(
223            a.into_iter()
224                .zip_eq(b)
225                .map(|(a, b)| self.fp_chip.scalar_mul_and_add_no_carry(ctx, a, b, c))
226                .collect(),
227        )
228    }
229
230    pub fn check_carry_mod_to_zero(
231        &self,
232        ctx: &mut Context<F>,
233        a: impl IntoIterator<Item = FpChip::UnsafeFieldPoint>,
234    ) {
235        for coeff in a {
236            self.fp_chip.check_carry_mod_to_zero(ctx, coeff);
237        }
238    }
239
240    pub fn carry_mod(
241        &self,
242        ctx: &mut Context<F>,
243        a: impl IntoIterator<Item = FpChip::UnsafeFieldPoint>,
244    ) -> FieldVector<FpChip::FieldPoint> {
245        FieldVector(a.into_iter().map(|coeff| self.fp_chip.carry_mod(ctx, coeff)).collect())
246    }
247
248    /// # Assumptions
249    /// * `max_bits <= n * k` where `n = self.fp_chip.limb_bits` and `k = self.fp_chip.num_limbs`
250    /// * `a[i].truncation.limbs.len() = self.fp_chip.num_limbs` for all `i = 0..a.len()`
251    pub fn range_check<A>(
252        &self,
253        ctx: &mut Context<F>,
254        a: impl IntoIterator<Item = A>,
255        max_bits: usize,
256    ) where
257        A: Into<FpChip::FieldPoint>,
258    {
259        for coeff in a {
260            self.fp_chip.range_check(ctx, coeff, max_bits);
261        }
262    }
263
264    pub fn enforce_less_than(
265        &self,
266        ctx: &mut Context<F>,
267        a: impl IntoIterator<Item = FpChip::FieldPoint>,
268    ) -> FieldVector<FpChip::ReducedFieldPoint> {
269        FieldVector(a.into_iter().map(|coeff| self.fp_chip.enforce_less_than(ctx, coeff)).collect())
270    }
271
272    pub fn is_soft_zero(
273        &self,
274        ctx: &mut Context<F>,
275        a: impl IntoIterator<Item = FpChip::FieldPoint>,
276    ) -> AssignedValue<F> {
277        let mut prev = None;
278        for a_coeff in a {
279            let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff);
280            if let Some(p) = prev {
281                let new = self.gate().and(ctx, coeff, p);
282                prev = Some(new);
283            } else {
284                prev = Some(coeff);
285            }
286        }
287        prev.unwrap()
288    }
289
290    pub fn is_soft_nonzero(
291        &self,
292        ctx: &mut Context<F>,
293        a: impl IntoIterator<Item = FpChip::FieldPoint>,
294    ) -> AssignedValue<F> {
295        let mut prev = None;
296        for a_coeff in a {
297            let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff);
298            if let Some(p) = prev {
299                let new = self.gate().or(ctx, coeff, p);
300                prev = Some(new);
301            } else {
302                prev = Some(coeff);
303            }
304        }
305        prev.unwrap()
306    }
307
308    pub fn is_zero(
309        &self,
310        ctx: &mut Context<F>,
311        a: impl IntoIterator<Item = FpChip::FieldPoint>,
312    ) -> AssignedValue<F> {
313        let mut prev = None;
314        for a_coeff in a {
315            let coeff = self.fp_chip.is_zero(ctx, a_coeff);
316            if let Some(p) = prev {
317                let new = self.gate().and(ctx, coeff, p);
318                prev = Some(new);
319            } else {
320                prev = Some(coeff);
321            }
322        }
323        prev.unwrap()
324    }
325
326    pub fn is_equal_unenforced(
327        &self,
328        ctx: &mut Context<F>,
329        a: impl IntoIterator<Item = FpChip::ReducedFieldPoint>,
330        b: impl IntoIterator<Item = FpChip::ReducedFieldPoint>,
331    ) -> AssignedValue<F> {
332        let mut acc = None;
333        for (a_coeff, b_coeff) in a.into_iter().zip_eq(b) {
334            let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff);
335            if let Some(c) = acc {
336                acc = Some(self.gate().and(ctx, coeff, c));
337            } else {
338                acc = Some(coeff);
339            }
340        }
341        acc.unwrap()
342    }
343
344    pub fn assert_equal(
345        &self,
346        ctx: &mut Context<F>,
347        a: impl IntoIterator<Item = FpChip::FieldPoint>,
348        b: impl IntoIterator<Item = FpChip::FieldPoint>,
349    ) {
350        for (a_coeff, b_coeff) in a.into_iter().zip(b) {
351            self.fp_chip.assert_equal(ctx, a_coeff, b_coeff)
352        }
353    }
354}
355
356#[macro_export]
357macro_rules! impl_field_ext_chip_common {
358    // Implementation of the functions in `FieldChip` trait for field extensions that can be derived from `FieldVectorChip`
359    () => {
360        fn native_modulus(&self) -> &BigUint {
361            self.0.fp_chip.native_modulus()
362        }
363
364        fn range(&self) -> &Self::RangeChip {
365            self.0.fp_chip.range()
366        }
367
368        fn limb_bits(&self) -> usize {
369            self.0.fp_chip.limb_bits()
370        }
371
372        fn load_private(&self, ctx: &mut Context<F>, fe: Self::FieldType) -> Self::FieldPoint {
373            self.0.load_private(ctx, fe)
374        }
375
376        fn load_constant(&self, ctx: &mut Context<F>, fe: Self::FieldType) -> Self::FieldPoint {
377            self.0.load_constant(ctx, fe)
378        }
379
380        fn add_no_carry(
381            &self,
382            ctx: &mut Context<F>,
383            a: impl Into<Self::UnsafeFieldPoint>,
384            b: impl Into<Self::UnsafeFieldPoint>,
385        ) -> Self::UnsafeFieldPoint {
386            self.0.add_no_carry(ctx, a.into(), b.into())
387        }
388
389        fn add_constant_no_carry(
390            &self,
391            ctx: &mut Context<F>,
392            a: impl Into<Self::UnsafeFieldPoint>,
393            c: Self::FieldType,
394        ) -> Self::UnsafeFieldPoint {
395            self.0.add_constant_no_carry(ctx, a.into(), c)
396        }
397
398        fn sub_no_carry(
399            &self,
400            ctx: &mut Context<F>,
401            a: impl Into<Self::UnsafeFieldPoint>,
402            b: impl Into<Self::UnsafeFieldPoint>,
403        ) -> Self::UnsafeFieldPoint {
404            self.0.sub_no_carry(ctx, a.into(), b.into())
405        }
406
407        fn negate(&self, ctx: &mut Context<F>, a: Self::FieldPoint) -> Self::FieldPoint {
408            self.0.negate(ctx, a)
409        }
410
411        fn scalar_mul_no_carry(
412            &self,
413            ctx: &mut Context<F>,
414            a: impl Into<Self::UnsafeFieldPoint>,
415            c: i64,
416        ) -> Self::UnsafeFieldPoint {
417            self.0.scalar_mul_no_carry(ctx, a.into(), c)
418        }
419
420        fn scalar_mul_and_add_no_carry(
421            &self,
422            ctx: &mut Context<F>,
423            a: impl Into<Self::UnsafeFieldPoint>,
424            b: impl Into<Self::UnsafeFieldPoint>,
425            c: i64,
426        ) -> Self::UnsafeFieldPoint {
427            self.0.scalar_mul_and_add_no_carry(ctx, a.into(), b.into(), c)
428        }
429
430        fn check_carry_mod_to_zero(&self, ctx: &mut Context<F>, a: Self::UnsafeFieldPoint) {
431            self.0.check_carry_mod_to_zero(ctx, a);
432        }
433
434        fn carry_mod(&self, ctx: &mut Context<F>, a: Self::UnsafeFieldPoint) -> Self::FieldPoint {
435            self.0.carry_mod(ctx, a)
436        }
437
438        /// # Assumptions
439        /// * `max_bits <= n * k` where `n = self.fp_chip.limb_bits` and `k = self.fp_chip.num_limbs`
440        /// * `a[i].truncation.limbs.len() = self.fp_chip.num_limbs` for all `i = 0..a.len()`
441        fn range_check(
442            &self,
443            ctx: &mut Context<F>,
444            a: impl Into<Self::FieldPoint>,
445            max_bits: usize,
446        ) {
447            self.0.range_check(ctx, a.into(), max_bits)
448        }
449
450        fn enforce_less_than(
451            &self,
452            ctx: &mut Context<F>,
453            a: Self::FieldPoint,
454        ) -> Self::ReducedFieldPoint {
455            self.0.enforce_less_than(ctx, a)
456        }
457
458        fn is_soft_zero(
459            &self,
460            ctx: &mut Context<F>,
461            a: impl Into<Self::FieldPoint>,
462        ) -> AssignedValue<F> {
463            let a = a.into();
464            self.0.is_soft_zero(ctx, a)
465        }
466
467        fn is_soft_nonzero(
468            &self,
469            ctx: &mut Context<F>,
470            a: impl Into<Self::FieldPoint>,
471        ) -> AssignedValue<F> {
472            let a = a.into();
473            self.0.is_soft_nonzero(ctx, a)
474        }
475
476        fn is_zero(
477            &self,
478            ctx: &mut Context<F>,
479            a: impl Into<Self::FieldPoint>,
480        ) -> AssignedValue<F> {
481            let a = a.into();
482            self.0.is_zero(ctx, a)
483        }
484
485        fn is_equal_unenforced(
486            &self,
487            ctx: &mut Context<F>,
488            a: Self::ReducedFieldPoint,
489            b: Self::ReducedFieldPoint,
490        ) -> AssignedValue<F> {
491            self.0.is_equal_unenforced(ctx, a, b)
492        }
493
494        fn assert_equal(
495            &self,
496            ctx: &mut Context<F>,
497            a: impl Into<Self::FieldPoint>,
498            b: impl Into<Self::FieldPoint>,
499        ) {
500            let a = a.into();
501            let b = b.into();
502            self.0.assert_equal(ctx, a, b)
503        }
504    };
505}