openvm_native_compiler/constraints/halo2/
compiler.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    marker::PhantomData,
5    panic::{catch_unwind, AssertUnwindSafe},
6    sync::{Arc, LazyLock},
7};
8
9use itertools::Itertools;
10use num_bigint::BigUint;
11#[cfg(feature = "metrics")]
12use openvm_circuit::metrics::cycle_tracker::CycleTracker;
13use openvm_stark_backend::p3_field::{
14    ExtensionField, Field, PrimeCharacteristicRing, PrimeField, PrimeField32,
15};
16use openvm_stark_sdk::{p3_baby_bear::BabyBear, p3_bn254::Bn254};
17use snark_verifier_sdk::snark_verifier::{
18    halo2_base::{
19        gates::{
20            circuit::builder::BaseCircuitBuilder, GateChip, GateInstructions, RangeChip,
21            RangeInstructions,
22        },
23        halo2_proofs::halo2curves::bn256::Fr,
24        utils::{biguint_to_fe, decompose_fe_to_u64_limbs, fe_to_biguint, ScalarField},
25        AssignedValue, Context, QuantumCell,
26    },
27    util::arithmetic::{Field as _, PrimeField as _},
28};
29
30use super::stats::Halo2Stats;
31use crate::{
32    constraints::halo2::{
33        baby_bear::{
34            AssignedBabyBear, AssignedBabyBearExt4, BabyBearChip, BabyBearExt4, BabyBearExt4Chip,
35        },
36        poseidon2_perm::{Poseidon2Params, Poseidon2State},
37    },
38    ir::{Config, DslIr, TracedVec, Witness},
39};
40
41const POSEIDON2_T: usize = 3;
42static POSEIDON2_PARAMS: LazyLock<Poseidon2Params<Fr, POSEIDON2_T>> = LazyLock::new(|| {
43    use zkhash::{
44        ark_ff::{BigInteger, PrimeField as _},
45        fields::bn256::FpBN256 as ark_FpBN256,
46        poseidon2::poseidon2_instance_bn256::{MAT_DIAG3_M_1, RC3},
47    };
48
49    fn convert_fr(input: ark_FpBN256) -> Fr {
50        Fr::from_bytes_le(&input.into_bigint().to_bytes_le())
51    }
52    const T: usize = 3;
53    let rounds_f = 8;
54    let rounds_p = 56;
55    let mut round_constants: Vec<[Fr; T]> = RC3
56        .iter()
57        .map(|vec| {
58            vec.iter()
59                .cloned()
60                .map(convert_fr)
61                .collect::<Vec<_>>()
62                .try_into()
63                .unwrap()
64        })
65        .collect();
66
67    let rounds_f_beginning = rounds_f / 2;
68    let p_end = rounds_f_beginning + rounds_p;
69    let internal_round_constants = round_constants
70        .drain(rounds_f_beginning..p_end)
71        .map(|vec| vec[0])
72        .collect::<Vec<_>>();
73    let external_round_constants = round_constants;
74    Poseidon2Params {
75        rounds_f,
76        rounds_p,
77        mat_internal_diag_m_1: MAT_DIAG3_M_1
78            .iter()
79            .copied()
80            .map(convert_fr)
81            .collect_vec()
82            .try_into()
83            .unwrap(),
84        external_rc: external_round_constants,
85        internal_rc: internal_round_constants,
86    }
87});
88
89/// The backend for the Halo2 constraint compiler.
90#[derive(Debug, Clone)]
91pub struct Halo2ConstraintCompiler<C: Config> {
92    pub num_public_values: usize,
93    #[allow(unused_variables)]
94    pub profiling: bool,
95    pub phantom: PhantomData<C>,
96}
97
98#[derive(Debug, Clone, Default)]
99pub struct Halo2State<C: Config> {
100    // halo2 stuff
101    pub builder: BaseCircuitBuilder<Fr>,
102    // Unassigned values: provided by the prover outside of constraint compiler
103    // map from name as string to halo2 assigned value
104    pub vars: HashMap<u32, Fr>,
105    pub felts: HashMap<u32, C::F>,
106    pub exts: HashMap<u32, C::EF>,
107}
108
109impl<C: Config> Halo2State<C> {
110    pub fn load_witness(&mut self, witness: Witness<C>) {
111        for (i, x) in witness.vars.iter().enumerate() {
112            self.vars.insert(i as u32, convert_fr(x));
113        }
114        for (i, x) in witness.felts.into_iter().enumerate() {
115            self.felts.insert(i as u32, x);
116        }
117        for (i, x) in witness.exts.into_iter().enumerate() {
118            self.exts.insert(i as u32, x);
119        }
120    }
121}
122
123impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
124    pub fn new(num_public_values: usize) -> Self {
125        Self {
126            num_public_values,
127            profiling: false,
128            phantom: PhantomData,
129        }
130    }
131    pub fn with_profiling(mut self) -> Self {
132        self.profiling = true;
133        self
134    }
135    // Create halo2-lib constraints from a list of operations in the DSL.
136    // Assume: C::N = C::F = C::EF is type Fr
137    pub fn constrain_halo2(&self, halo2_state: &mut Halo2State<C>, operations: TracedVec<DslIr<C>>)
138    where
139        C: Config<N = Bn254, F = BabyBear, EF = BabyBearExt4>,
140    {
141        #[cfg(feature = "metrics")]
142        let mut cell_tracker = CycleTracker::new();
143        let range = Arc::new(halo2_state.builder.range_chip());
144        let f_chip = Arc::new(BabyBearChip::new(range.clone()));
145        let ext_chip = BabyBearExt4Chip::new(Arc::clone(&f_chip));
146        let gate = f_chip.gate();
147        let ctx = halo2_state.builder.main(0);
148        let mut public_values = vec![ctx.load_zero(); self.num_public_values];
149
150        // Local variables for referencing during the course of constraint building
151        let mut vars = HashMap::new();
152        let mut felts = HashMap::<u32, AssignedBabyBear>::new();
153        let mut exts = HashMap::<u32, AssignedBabyBearExt4>::new();
154
155        #[cfg(feature = "metrics")]
156        let mut old_stats = stats_snapshot(ctx, range.clone());
157        for (instruction, backtrace) in operations {
158            #[cfg(feature = "metrics")]
159            if self.profiling {
160                old_stats = stats_snapshot(ctx, range.clone());
161            }
162            let res = catch_unwind(AssertUnwindSafe(|| {
163                match instruction {
164                    DslIr::ImmV(a, b) => {
165                        let x = ctx.load_constant(convert_fr(&b));
166                        vars.insert(a.0, x);
167                    }
168                    DslIr::ImmF(a, b) => {
169                        let x = f_chip.load_constant(ctx, b);
170                        felts.insert(a.0, x);
171                    }
172                    DslIr::ImmE(a, b) => {
173                        let x = ext_chip.load_constant(ctx, b);
174                        exts.insert(a.0, x);
175                    }
176                    DslIr::AddV(a, b, c) => {
177                        let x = gate.add(ctx, vars[&b.0], vars[&c.0]);
178                        vars.insert(a.0, x);
179                    }
180                    DslIr::AddVI(a, b, c) => {
181                        let x = if c.is_zero() {
182                            vars[&b.0]
183                        } else {
184                            let tmp = ctx.load_constant(convert_fr(&c));
185                            gate.add(ctx, vars[&b.0], tmp)
186                        };
187                        vars.insert(a.0, x);
188                    }
189                    DslIr::AddF(a, b, c) => {
190                        let x = f_chip.add(ctx, felts[&b.0], felts[&c.0]);
191                        felts.insert(a.0, x);
192                    }
193                    DslIr::AddFI(a, b, c) => {
194                        let x = if c.is_zero() {
195                            felts[&b.0]
196                        } else {
197                            let tmp = f_chip.load_constant(ctx, c);
198                            f_chip.add(ctx, felts[&b.0], tmp)
199                        };
200                        felts.insert(a.0, x);
201                    }
202                    DslIr::AddE(a, b, c) => {
203                        let x = ext_chip.add(ctx, exts[&b.0], exts[&c.0]);
204                        exts.insert(a.0, x);
205                    }
206                    DslIr::AddEF(a, b, c) => {
207                        let mut x = exts[&b.0];
208                        x.0[0] = f_chip.add(ctx, x.0[0], felts[&c.0]);
209                        exts.insert(a.0, x);
210                    }
211                    DslIr::AddEFI(a, b, c) => {
212                        let x = if c.is_zero() {
213                            exts[&b.0]
214                        } else {
215                            let tmp = f_chip.load_constant(ctx, c);
216                            let mut x = exts[&b.0];
217                            x.0[0] = f_chip.add(ctx, x.0[0], tmp);
218                            x
219                        };
220                        exts.insert(a.0, x);
221                    }
222                    DslIr::AddEI(a, b, c) => {
223                        let x = if c.is_zero() {
224                            exts[&b.0]
225                        } else {
226                            let tmp = ext_chip.load_constant(ctx, c);
227                            ext_chip.add(ctx, exts[&b.0], tmp)
228                        };
229                        exts.insert(a.0, x);
230                    }
231                    DslIr::AddEFFI(a, b, c) => {
232                        let mut x = ext_chip.load_constant(ctx, c);
233                        x.0[0] = f_chip.add(ctx, x.0[0], felts[&b.0]);
234                        exts.insert(a.0, x);
235                    }
236                    DslIr::SubV(a, b, c) => {
237                        let x = gate.sub(ctx, vars[&b.0], vars[&c.0]);
238                        vars.insert(a.0, x);
239                    }
240                    DslIr::SubF(a, b, c) => {
241                        let x = f_chip.sub(ctx, felts[&b.0], felts[&c.0]);
242                        felts.insert(a.0, x);
243                    }
244                    DslIr::SubE(a, b, c) => {
245                        let x = ext_chip.sub(ctx, exts[&b.0], exts[&c.0]);
246                        exts.insert(a.0, x);
247                    }
248                    DslIr::SubEF(a, b, c) => {
249                        let mut x = exts[&b.0];
250                        x.0[0] = f_chip.sub(ctx, x.0[0], felts[&c.0]);
251                        exts.insert(a.0, x);
252                    }
253                    DslIr::SubEI(a, b, c) => {
254                        let x = if c.is_zero() {
255                            exts[&b.0]
256                        } else {
257                            let tmp = ext_chip.load_constant(ctx, c);
258                            ext_chip.sub(ctx, exts[&b.0], tmp)
259                        };
260                        exts.insert(a.0, x);
261                    }
262                    DslIr::SubVIN(a, b, c) => {
263                        let tmp = ctx.load_constant(convert_fr(&b));
264                        let x = gate.sub(ctx, tmp, vars[&c.0]);
265                        vars.insert(a.0, x);
266                    }
267                    DslIr::SubEIN(a, b, c) => {
268                        let tmp = ext_chip.load_constant(ctx, b);
269                        let x = ext_chip.sub(ctx, tmp, exts[&c.0]);
270                        exts.insert(a.0, x);
271                    }
272                    DslIr::SubEFI(a, b, c) => {
273                        let x = if c.is_zero() {
274                            exts[&b.0]
275                        } else {
276                            let tmp = f_chip.load_constant(ctx, c);
277                            let mut x = exts[&b.0];
278                            x.0[0] = f_chip.sub(ctx, x.0[0], tmp);
279                            x
280                        };
281                        exts.insert(a.0, x);
282                    }
283                    DslIr::MulV(a, b, c) => {
284                        let x = gate.mul(ctx, vars[&b.0], vars[&c.0]);
285                        vars.insert(a.0, x);
286                    }
287                    DslIr::MulVI(a, b, c) => {
288                        let x = if c.is_one() {
289                            vars[&b.0]
290                        } else if c.is_zero() {
291                            ctx.load_zero()
292                        } else {
293                            let tmp = ctx.load_constant(convert_fr(&c));
294                            gate.mul(ctx, vars[&b.0], tmp)
295                        };
296                        vars.insert(a.0, x);
297                    }
298                    DslIr::MulF(a, b, c) => {
299                        let x = f_chip.mul(ctx, felts[&b.0], felts[&c.0]);
300                        felts.insert(a.0, x);
301                    }
302                    DslIr::MulFI(a, b, c) => {
303                        let x = if c.is_one() {
304                            felts[&b.0]
305                        } else if c.is_zero() {
306                            f_chip.load_constant(ctx, BabyBear::ZERO)
307                        } else {
308                            let tmp = f_chip.load_constant(ctx, c);
309                            f_chip.mul(ctx, felts[&b.0], tmp)
310                        };
311                        felts.insert(a.0, x);
312                    }
313                    DslIr::MulE(a, b, c) => {
314                        let x = ext_chip.mul(ctx, exts[&b.0], exts[&c.0]);
315                        exts.insert(a.0, x);
316                    }
317                    DslIr::MulEI(a, b, c) => {
318                        let x = if c.is_one() {
319                            exts[&b.0]
320                        } else if c.is_zero() {
321                            ext_chip.load_constant(ctx, BabyBearExt4::ZERO)
322                        } else {
323                            let tmp = ext_chip.load_constant(ctx, c);
324                            ext_chip.mul(ctx, exts[&b.0], tmp)
325                        };
326                        exts.insert(a.0, x);
327                    }
328                    DslIr::MulEF(a, b, c) => {
329                        let x = ext_chip.scalar_mul(ctx, exts[&b.0], felts[&c.0]);
330                        exts.insert(a.0, x);
331                    }
332                    DslIr::MulEFI(a, b, c) => {
333                        let x = if c.is_one() {
334                            exts[&b.0]
335                        } else if c.is_zero() {
336                            ext_chip.load_constant(ctx, BabyBearExt4::ZERO)
337                        } else {
338                            let tmp = f_chip.load_constant(ctx, c);
339                            ext_chip.scalar_mul(ctx, exts[&b.0], tmp)
340                        };
341                        exts.insert(a.0, x);
342                    }
343                    DslIr::DivF(a, b, c) => {
344                        let x = f_chip.div(ctx, felts[&b.0], felts[&c.0]);
345                        felts.insert(a.0, x);
346                    }
347                    DslIr::DivFIN(a, b, c) => {
348                        // a = b / c
349                        let tmp = f_chip.load_constant(ctx, b);
350                        let x = if b.is_zero() {
351                            tmp
352                        } else {
353                            f_chip.div(ctx, tmp, felts[&c.0])
354                        };
355                        felts.insert(a.0, x);
356                    }
357                    DslIr::DivE(a, b, c) => {
358                        let x = ext_chip.div(ctx, exts[&b.0], exts[&c.0]);
359                        exts.insert(a.0, x);
360                    }
361                    DslIr::DivEIN(a, b, c) => {
362                        let tmp = ext_chip.load_constant(ctx, b);
363                        let x = if b.is_zero() {
364                            tmp
365                        } else {
366                            ext_chip.div(ctx, tmp, exts[&c.0])
367                        };
368                        exts.insert(a.0, x);
369                    }
370                    DslIr::NegE(a, b) => {
371                        let x = ext_chip.neg(ctx, exts[&b.0]);
372                        exts.insert(a.0, x);
373                    }
374                    DslIr::CastFV(a, b) => {
375                        let felt = felts[&b.0];
376                        let reduced_felt = f_chip.reduce(ctx, felt);
377                        vars.insert(a.0, reduced_felt.value);
378                    }
379                    DslIr::CircuitNum2BitsF(value, output) => {
380                        let val = f_chip.reduce(ctx, felts[&value.0]);
381                        let x = gate.num_to_bits(ctx, val.value, 32); // C::F::bits());
382                        assert!(output.len() <= x.len());
383                        for (o, x) in output.into_iter().zip(x) {
384                            vars.insert(o.0, x);
385                        }
386                    }
387                    DslIr::CircuitVarTo64BitsF(value, output) => {
388                        let x = vars[&value.0];
389                        let limbs = var_to_u64_limbs(ctx, &range, gate, x);
390                        for (o, l) in output.into_iter().zip(limbs) {
391                            felts.insert(o.0, l);
392                        }
393                    }
394                    DslIr::CircuitVarToFieldOrderLimbsF(value, output) => {
395                        let x = vars[&value.0];
396                        let limbs =
397                            var_to_babybear_field_order_limbs(ctx, &range, gate, x, output.len());
398                        for (o, l) in output.into_iter().zip_eq(limbs) {
399                            felts.insert(o.0, l);
400                        }
401                    }
402                    DslIr::CircuitPoseidon2Permute(state_vars) => {
403                        let mut state =
404                            Poseidon2State::<Fr, POSEIDON2_T>::new(state_vars.map(|x| vars[&x.0]));
405                        state.permutation(ctx, gate, &*POSEIDON2_PARAMS);
406                        for i in 0..POSEIDON2_T {
407                            *vars.get_mut(&state_vars[i].0).unwrap() = state.s[i];
408                        }
409                    }
410                    DslIr::CircuitSelectV(cond, a, b, out) => {
411                        let x = gate.select(ctx, vars[&a.0], vars[&b.0], vars[&cond.0]);
412                        vars.insert(out.0, x);
413                    }
414                    DslIr::CircuitSelectF(cond, a, b, out) => {
415                        let x = f_chip.select(ctx, vars[&cond.0], felts[&a.0], felts[&b.0]);
416                        felts.insert(out.0, x);
417                    }
418                    DslIr::CircuitSelectE(cond, a, b, out) => {
419                        let x = ext_chip.select(ctx, vars[&cond.0], exts[&a.0], exts[&b.0]);
420                        exts.insert(out.0, x);
421                    }
422                    DslIr::CircuitExt2Felt(a, b) => {
423                        for (i, x) in a.iter().enumerate() {
424                            felts.insert(x.0, exts[&b.0].0[i]);
425                        }
426                    }
427                    DslIr::AssertEqV(a, b) => {
428                        ctx.constrain_equal(&vars[&a.0], &vars[&b.0]);
429                    }
430                    DslIr::AssertEqVI(a, b) => {
431                        gate.assert_is_const(ctx, &vars[&a.0], &convert_fr(&b));
432                    }
433                    DslIr::AssertEqF(a, b) => {
434                        f_chip.assert_equal(ctx, felts[&a.0], felts[&b.0]);
435                    }
436                    DslIr::AssertEqFI(a, b) => {
437                        if b.is_zero() {
438                            f_chip.assert_zero(ctx, felts[&a.0]);
439                        } else {
440                            let tmp = f_chip.load_constant(ctx, b);
441                            f_chip.assert_equal(ctx, felts[&a.0], tmp);
442                        }
443                    }
444                    DslIr::AssertEqE(a, b) => {
445                        ext_chip.assert_equal(ctx, exts[&a.0], exts[&b.0]);
446                    }
447                    DslIr::AssertEqEI(a, b) => {
448                        // Note: we could check if each coordinate of `b` is zero separately for a
449                        // little more efficiency, but omitting to simplify
450                        // the code
451                        if b.is_zero() {
452                            ext_chip.assert_zero(ctx, exts[&a.0]);
453                        } else {
454                            let tmp = ext_chip.load_constant(ctx, b);
455                            ext_chip.assert_equal(ctx, exts[&a.0], tmp);
456                        }
457                    }
458                    DslIr::PrintV(a) => {
459                        println!("PrintV: {:?}", vars[&a.0].value());
460                    }
461                    DslIr::PrintF(a) => {
462                        println!("PrintF: {:?}", felts[&a.0].to_baby_bear());
463                    }
464                    DslIr::PrintE(a) => {
465                        println!("PrintE:");
466                        for x in exts[&a.0].0.iter() {
467                            println!("{:?}", x.to_baby_bear());
468                        }
469                    }
470                    DslIr::WitnessVar(a, b) => {
471                        let x = ctx.load_witness(halo2_state.vars[&b]);
472                        vars.insert(a.0, x);
473                    }
474                    DslIr::WitnessFelt(a, b) => {
475                        let x = f_chip.load_witness(ctx, halo2_state.felts[&b]);
476                        felts.insert(a.0, x);
477                    }
478                    DslIr::WitnessExt(a, b) => {
479                        let x = ext_chip.load_witness(ctx, halo2_state.exts[&b]);
480                        exts.insert(a.0, x);
481                    }
482                    DslIr::CircuitFelts2Ext(a, b) => {
483                        let x = AssignedBabyBearExt4(
484                            a.iter()
485                                .map(|a| felts[&a.0])
486                                .collect_vec()
487                                .try_into()
488                                .unwrap(),
489                        );
490                        exts.insert(b.0, x);
491                    }
492                    DslIr::CircuitFeltReduce(a) => {
493                        let x = f_chip.reduce_max_bits(ctx, felts[&a.0]);
494                        felts.insert(a.0, x);
495                    }
496                    DslIr::CircuitExtReduce(a) => {
497                        let x = ext_chip.reduce_max_bits(ctx, exts[&a.0]);
498                        exts.insert(a.0, x);
499                    }
500                    DslIr::CircuitLessThan(a, b) => {
501                        range.range_check(ctx, vars[&a.0], C::F::bits());
502                        range.range_check(ctx, vars[&b.0], C::F::bits());
503                        range.check_less_than(ctx, vars[&a.0], vars[&b.0], C::F::bits());
504                    }
505                    DslIr::CycleTrackerStart(_name) => {
506                        #[cfg(feature = "metrics")]
507                        cell_tracker.start(_name);
508                    }
509                    DslIr::CycleTrackerEnd(_name) => {
510                        #[cfg(feature = "metrics")]
511                        cell_tracker.end(_name);
512                    }
513                    DslIr::CircuitPublish(val, index) => {
514                        public_values[index] = vars[&val.0];
515                    }
516                    _ => panic!("unsupported {instruction:?}"),
517                }
518            }));
519            if res.is_err() {
520                if let Some(mut backtrace) = backtrace {
521                    backtrace.resolve();
522                    eprintln!("openvm circuit failure; backtrace:\n{backtrace:?}");
523                }
524                #[allow(clippy::panicking_unwrap)]
525                res.unwrap();
526            }
527            #[cfg(feature = "metrics")]
528            if self.profiling {
529                let mut new_stats = stats_snapshot(ctx, range.clone());
530                new_stats.diff(&old_stats);
531                new_stats.increment(cell_tracker.get_full_name());
532            }
533        }
534
535        halo2_state.builder.assigned_instances = vec![public_values];
536    }
537}
538
539/// Assumes F is Bn254 Fr and converts to halo2 Fr type
540pub fn convert_fr<F: PrimeField>(a: &F) -> Fr {
541    biguint_to_fe(&a.as_canonical_biguint())
542}
543
544#[allow(dead_code)]
545pub fn convert_efr<F: PrimeField, EF: ExtensionField<F>>(a: &EF) -> Vec<Fr> {
546    let slc = a.as_basis_coefficients_slice();
547    slc.iter()
548        .map(|x| biguint_to_fe(&x.as_canonical_biguint()))
549        .collect()
550}
551
552// Unfortunately `builder.statistics()` cannot be called when `ctx` exists.
553#[allow(dead_code)] // used only in metrics
554fn stats_snapshot(ctx: &Context<Fr>, range_chip: Arc<RangeChip<Fr>>) -> Halo2Stats {
555    Halo2Stats {
556        total_gate_cell: ctx.advice.len(),
557        // Note[Xinding]: this is inaccurate because of duplicated constants. But it's too slow if
558        // we always check for duplicates.
559        total_fixed: ctx.copy_manager.lock().unwrap().constant_equalities.len(),
560        total_lookup_cell: range_chip.lookup_manager()[0].total_rows(),
561    }
562}
563
564#[allow(dead_code)]
565fn is_babybear_ir<C: Config>(ir: &DslIr<C>) -> bool {
566    matches!(
567        ir,
568        DslIr::ImmF(_, _)
569            | DslIr::AddF(_, _, _)
570            | DslIr::AddFI(_, _, _)
571            | DslIr::SubF(_, _, _)
572            | DslIr::MulF(_, _, _)
573            | DslIr::MulFI(_, _, _)
574            | DslIr::DivFIN(_, _, _)
575            | DslIr::CircuitSelectF(_, _, _, _)
576            | DslIr::AssertEqF(_, _)
577            | DslIr::AssertEqFI(_, _)
578            | DslIr::WitnessFelt(_, _)
579            | DslIr::CircuitFelts2Ext(_, _)
580            | DslIr::CircuitFeltReduce(_)
581            | DslIr::CircuitExtReduce(_)
582            | DslIr::CircuitLessThan(..)
583            | DslIr::ImmE(_, _)
584            | DslIr::AddE(_, _, _)
585            | DslIr::AddEF(_, _, _)
586            | DslIr::AddEFI(_, _, _)
587            | DslIr::AddEI(_, _, _)
588            | DslIr::AddEFFI(_, _, _)
589            | DslIr::SubE(_, _, _)
590            | DslIr::SubEF(_, _, _)
591            | DslIr::SubEI(_, _, _)
592            | DslIr::SubEIN(_, _, _)
593            | DslIr::SubEFI(_, _, _)
594            | DslIr::MulE(_, _, _)
595            | DslIr::MulEI(_, _, _)
596            | DslIr::MulEF(_, _, _)
597            | DslIr::MulEFI(_, _, _)
598            | DslIr::DivE(_, _, _)
599            | DslIr::DivEIN(_, _, _)
600            | DslIr::NegE(_, _)
601            | DslIr::CircuitSelectE(_, _, _, _)
602            | DslIr::AssertEqE(_, _)
603            | DslIr::AssertEqEI(_, _)
604            | DslIr::WitnessExt(_, _)
605            | DslIr::CastFV(_, _)
606    )
607}
608
609fn fr_to_u64_limbs(fr: &Fr) -> [u64; 4] {
610    // We need 64-bit limbs but `decompose_fe_to_u64_limbs` only support `bit_len < 64`.
611    let limbs = decompose_fe_to_u64_limbs(fr, 8, 32);
612    std::array::from_fn(|i| limbs[2 * i] + limbs[2 * i + 1] * (1 << 32))
613}
614
615/// Bounds for the base-BabyBear decomposition (depend on `num_limbs = k`).
616struct BaseBabyBearDecompBounds {
617    /// floor((q-1) / p^k) as a field element.
618    top_quotient_max_fe: Fr,
619    /// floor((q-1) / p^k) + 1.
620    top_quotient_max_plus_one: BigUint,
621    /// One more than the maximum lower part when top quotient is at its max.
622    lower_max_plus_one: BigUint,
623    /// p^k as a BigUint.
624    pow_k: BigUint,
625}
626
627fn base_baby_bear_decomp_bounds(num_limbs: usize) -> BaseBabyBearDecompBounds {
628    let p = BigUint::from(BabyBear::ORDER_U32);
629    let one = BigUint::from(1u64);
630    let modulus = fe_to_biguint(&(Fr::ZERO - Fr::ONE)) + &one;
631    let modulus_minus_one = &modulus - &one;
632    let pow_k = p.pow(num_limbs as u32);
633    let q_k_max = &modulus_minus_one / &pow_k;
634    let lower_max = modulus_minus_one - &q_k_max * &pow_k;
635    BaseBabyBearDecompBounds {
636        top_quotient_max_fe: biguint_to_fe(&q_k_max),
637        top_quotient_max_plus_one: &q_k_max + &one,
638        lower_max_plus_one: lower_max + one,
639        pow_k,
640    }
641}
642
643/// Decompose a Bn254 element `packed` into `num_limbs` little-endian base-|F| (BabyBear) digits
644/// using the witness-and-verify pattern.
645///
646/// Constrains `packed = top_quotient * p^k + sum(limbs[i] * p^i)` where `k = num_limbs` and `p`
647/// is the BabyBear modulus, including the boundary check that pins the witness to the canonical
648/// decomposition: when `top_quotient == floor((q-1) / p^k)`, the lower part must not exceed
649/// `(q-1) - top_quotient * p^k`. Without this check, the same `packed` would admit multiple
650/// valid limb decompositions.
651fn var_to_babybear_field_order_limbs(
652    ctx: &mut Context<Fr>,
653    range: &RangeChip<Fr>,
654    gate: &GateChip<Fr>,
655    packed: AssignedValue<Fr>,
656    num_limbs: usize,
657) -> Vec<AssignedBabyBear> {
658    let bounds = base_baby_bear_decomp_bounds(num_limbs);
659    let p = BigUint::from(BabyBear::ORDER_U32);
660    let one = BigUint::from(1u64);
661
662    // Witness: compute digits and top quotient out-of-circuit.
663    let mut value = fe_to_biguint(packed.value());
664    let output_digits_big: Vec<BigUint> = (0..num_limbs)
665        .map(|_| {
666            let digit = &value % &p;
667            value /= &p;
668            digit
669        })
670        .collect();
671    let top_quotient_big = value;
672
673    // Load each digit witness.
674    let output_digits: Vec<AssignedBabyBear> = output_digits_big
675        .iter()
676        .map(|digit| {
677            let v = ctx.load_witness(biguint_to_fe(digit));
678            AssignedBabyBear {
679                value: v,
680                max_bits: 31,
681            }
682        })
683        .collect();
684
685    // Load top quotient witness.
686    let top_quotient = ctx.load_witness(biguint_to_fe(&top_quotient_big));
687
688    // Constrain.
689    for digit in &output_digits {
690        range.check_less_than_safe(ctx, digit.value, BabyBear::ORDER_U32 as u64);
691    }
692    let top_quotient_valid =
693        range.is_big_less_than_safe(ctx, top_quotient, bounds.top_quotient_max_plus_one.clone());
694    gate.assert_is_const(ctx, &top_quotient_valid, &Fr::ONE);
695
696    // Verify recomposition: lower = sum(digit_i * p^i)
697    let lower = gate.inner_product(
698        ctx,
699        output_digits.iter().map(|d| d.value),
700        std::iter::successors(Some(one), |power| Some(power * &p))
701            .take(num_limbs)
702            .map(|power| QuantumCell::Constant(biguint_to_fe(&power))),
703    );
704
705    // packed == top_quotient * p^k + lower
706    let recomposed = gate.mul_add(
707        ctx,
708        top_quotient,
709        QuantumCell::Constant(biguint_to_fe(&bounds.pow_k)),
710        lower,
711    );
712    ctx.constrain_equal(&packed, &recomposed);
713
714    // Boundary check: when top_quotient is at max, lower must not exceed the remainder.
715    let at_top_boundary = gate.is_equal(
716        ctx,
717        top_quotient,
718        QuantumCell::Constant(bounds.top_quotient_max_fe),
719    );
720    // Range-check lower against p^k (not lower_max_plus_one) because lower can be
721    // any value in [0, p^k - 1] and p^k may need more bits than lower_max_plus_one.
722    // Using is_big_less_than_safe with lower_max_plus_one would derive an insufficient
723    // range_bits from lower_max_plus_one.bits() (e.g. 152 vs the 155 bits needed).
724    let lower_range_bits =
725        (bounds.pow_k.bits() as usize).div_ceil(range.lookup_bits()) * range.lookup_bits();
726    range.range_check(ctx, lower, lower_range_bits);
727    let lower_is_valid = range.is_less_than(
728        ctx,
729        lower,
730        QuantumCell::Constant(biguint_to_fe(&bounds.lower_max_plus_one)),
731        lower_range_bits,
732    );
733    let lower_is_invalid = gate.not(ctx, lower_is_valid);
734    let lower_violation = gate.mul(ctx, at_top_boundary, lower_is_invalid);
735    gate.assert_is_const(ctx, &lower_violation, &Fr::ZERO);
736
737    output_digits
738}
739
740fn var_to_u64_limbs(
741    ctx: &mut Context<Fr>,
742    range: &RangeChip<Fr>,
743    gate: &GateChip<Fr>,
744    x: AssignedValue<Fr>,
745) -> [AssignedBabyBear; 4] {
746    let limbs = fr_to_u64_limbs(x.value()).map(|limb| ctx.load_witness(Fr::from(limb)));
747    let factors = [
748        Fr::from([1, 0, 0, 0]),
749        Fr::from([0, 1, 0, 0]),
750        Fr::from([0, 0, 1, 0]),
751        Fr::from([0, 0, 0, 1]),
752    ];
753    let sum = gate.inner_product(ctx, limbs, factors.map(QuantumCell::Constant));
754    ctx.constrain_equal(&sum, &x);
755    let fr_bound_limbs = fr_to_u64_limbs(&(Fr::ZERO - Fr::ONE));
756    let ret = std::array::from_fn(|i| {
757        let limb = limbs[i];
758        let bits = if i < 3 {
759            range.range_check(ctx, limb, 64);
760            64
761        } else {
762            range.check_less_than_safe(ctx, limbs[3], fr_bound_limbs[3] + 1);
763            (Fr::NUM_BITS - 3 * 64) as usize
764        };
765        AssignedBabyBear {
766            value: limb,
767            max_bits: bits,
768        }
769    });
770    // Constraint decomposition doesn't overflow.
771    // Whether limbs[i] == fr_bound_limbs[i] so far
772    let mut on_bound = gate.is_equal(
773        ctx,
774        limbs[3],
775        QuantumCell::Constant(Fr::from(fr_bound_limbs[3])),
776    );
777    for i in (0..3).rev() {
778        // limbs[i] > fr_bound_limbs[i]
779        let li_gt_bd = range.is_less_than(
780            ctx,
781            QuantumCell::Constant(Fr::from(fr_bound_limbs[i])),
782            limbs[i],
783            64,
784        );
785        let li_out_bd = gate.add(ctx, on_bound, li_gt_bd);
786        // on_bound  li_gt_bd  result
787        //    1         1       fail
788        //    1         0       pass
789        //    0         1       pass
790        //    0         0       pass
791        gate.assert_bit(ctx, li_out_bd);
792        // Update on_bound except the last limb
793        if i > 0 {
794            debug_assert_ne!(fr_bound_limbs[i], 0, "This should never happen for Bn254");
795            // on_bound && limbs[i] - fr_bound_limbs[i] == 0
796            let diff = gate.sub_mul(
797                ctx,
798                QuantumCell::Constant(Fr::from(fr_bound_limbs[i])),
799                on_bound,
800                limbs[i],
801            );
802            on_bound = gate.is_zero(ctx, diff);
803        }
804    }
805    ret
806}
807
808#[test]
809fn test_var_to_u64_limbs() {}