openvm_native_compiler/constraints/halo2/
compiler.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    marker::PhantomData,
5    panic::{catch_unwind, AssertUnwindSafe},
6    sync::{Arc, LazyLock},
7};
8
9use itertools::Itertools;
10#[cfg(feature = "bench-metrics")]
11use openvm_circuit::metrics::cycle_tracker::CycleTracker;
12use openvm_stark_backend::p3_field::{ExtensionField, Field, FieldAlgebra, PrimeField};
13use openvm_stark_sdk::{p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr};
14use snark_verifier_sdk::snark_verifier::{
15    halo2_base::{
16        gates::{
17            circuit::builder::BaseCircuitBuilder, GateChip, GateInstructions, RangeChip,
18            RangeInstructions,
19        },
20        halo2_proofs::halo2curves::bn256::Fr,
21        utils::{biguint_to_fe, decompose_fe_to_u64_limbs, ScalarField},
22        AssignedValue, Context, QuantumCell,
23    },
24    util::arithmetic::{Field as _, PrimeField as _},
25};
26
27use super::stats::Halo2Stats;
28use crate::{
29    constraints::halo2::{
30        baby_bear::{
31            AssignedBabyBear, AssignedBabyBearExt4, BabyBearChip, BabyBearExt4, BabyBearExt4Chip,
32        },
33        poseidon2_perm::{Poseidon2Params, Poseidon2State},
34    },
35    ir::{Config, DslIr, TracedVec, Witness},
36};
37
38const POSEIDON2_T: usize = 3;
39static POSEIDON2_PARAMS: LazyLock<Poseidon2Params<Fr, POSEIDON2_T>> = LazyLock::new(|| {
40    use zkhash::{
41        ark_ff::{BigInteger, PrimeField as _},
42        fields::bn256::FpBN256 as ark_FpBN256,
43        poseidon2::poseidon2_instance_bn256::{MAT_DIAG3_M_1, RC3},
44    };
45
46    fn convert_fr(input: ark_FpBN256) -> Fr {
47        Fr::from_bytes_le(&input.into_bigint().to_bytes_le())
48    }
49    const T: usize = 3;
50    let rounds_f = 8;
51    let rounds_p = 56;
52    let mut round_constants: Vec<[Fr; T]> = RC3
53        .iter()
54        .map(|vec| {
55            vec.iter()
56                .cloned()
57                .map(convert_fr)
58                .collect::<Vec<_>>()
59                .try_into()
60                .unwrap()
61        })
62        .collect();
63
64    let rounds_f_beginning = rounds_f / 2;
65    let p_end = rounds_f_beginning + rounds_p;
66    let internal_round_constants = round_constants
67        .drain(rounds_f_beginning..p_end)
68        .map(|vec| vec[0])
69        .collect::<Vec<_>>();
70    let external_round_constants = round_constants;
71    Poseidon2Params {
72        rounds_f,
73        rounds_p,
74        mat_internal_diag_m_1: MAT_DIAG3_M_1
75            .iter()
76            .copied()
77            .map(convert_fr)
78            .collect_vec()
79            .try_into()
80            .unwrap(),
81        external_rc: external_round_constants,
82        internal_rc: internal_round_constants,
83    }
84});
85
86/// The backend for the Halo2 constraint compiler.
87#[derive(Debug, Clone)]
88pub struct Halo2ConstraintCompiler<C: Config> {
89    pub num_public_values: usize,
90    #[allow(unused_variables)]
91    pub profiling: bool,
92    pub phantom: PhantomData<C>,
93}
94
95#[derive(Debug, Clone, Default)]
96pub struct Halo2State<C: Config> {
97    // halo2 stuff
98    pub builder: BaseCircuitBuilder<Fr>,
99    // Unassigned values: provided by the prover outside of constraint compiler
100    // map from name as string to halo2 assigned value
101    pub vars: HashMap<u32, Fr>,
102    pub felts: HashMap<u32, C::F>,
103    pub exts: HashMap<u32, C::EF>,
104}
105
106impl<C: Config> Halo2State<C> {
107    pub fn load_witness(&mut self, witness: Witness<C>) {
108        for (i, x) in witness.vars.iter().enumerate() {
109            self.vars.insert(i as u32, convert_fr(x));
110        }
111        for (i, x) in witness.felts.into_iter().enumerate() {
112            self.felts.insert(i as u32, x);
113        }
114        for (i, x) in witness.exts.into_iter().enumerate() {
115            self.exts.insert(i as u32, x);
116        }
117    }
118}
119
120impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
121    pub fn new(num_public_values: usize) -> Self {
122        Self {
123            num_public_values,
124            profiling: false,
125            phantom: PhantomData,
126        }
127    }
128    pub fn with_profiling(mut self) -> Self {
129        self.profiling = true;
130        self
131    }
132    // Create halo2-lib constraints from a list of operations in the DSL.
133    // Assume: C::N = C::F = C::EF is type Fr
134    pub fn constrain_halo2(&self, halo2_state: &mut Halo2State<C>, operations: TracedVec<DslIr<C>>)
135    where
136        C: Config<N = Bn254Fr, F = BabyBear, EF = BabyBearExt4>,
137    {
138        #[cfg(feature = "bench-metrics")]
139        let mut cell_tracker = CycleTracker::new();
140        let range = Arc::new(halo2_state.builder.range_chip());
141        let f_chip = Arc::new(BabyBearChip::new(range.clone()));
142        let ext_chip = BabyBearExt4Chip::new(Arc::clone(&f_chip));
143        let gate = f_chip.gate();
144        let ctx = halo2_state.builder.main(0);
145        let mut public_values = vec![ctx.load_zero(); self.num_public_values];
146
147        // Local variables for referencing during the course of constraint building
148        let mut vars = HashMap::new();
149        let mut felts = HashMap::<u32, AssignedBabyBear>::new();
150        let mut exts = HashMap::<u32, AssignedBabyBearExt4>::new();
151
152        #[cfg(feature = "bench-metrics")]
153        let mut old_stats = stats_snapshot(ctx, range.clone());
154        for (instruction, backtrace) in operations {
155            #[cfg(feature = "bench-metrics")]
156            if self.profiling {
157                old_stats = stats_snapshot(ctx, range.clone());
158            }
159            let res = catch_unwind(AssertUnwindSafe(|| {
160                match instruction {
161                    DslIr::ImmV(a, b) => {
162                        let x = ctx.load_constant(convert_fr(&b));
163                        vars.insert(a.0, x);
164                    }
165                    DslIr::ImmF(a, b) => {
166                        let x = f_chip.load_constant(ctx, b);
167                        felts.insert(a.0, x);
168                    }
169                    DslIr::ImmE(a, b) => {
170                        let x = ext_chip.load_constant(ctx, b);
171                        exts.insert(a.0, x);
172                    }
173                    DslIr::AddV(a, b, c) => {
174                        let x = gate.add(ctx, vars[&b.0], vars[&c.0]);
175                        vars.insert(a.0, x);
176                    }
177                    DslIr::AddVI(a, b, c) => {
178                        let x = if c.is_zero() {
179                            vars[&b.0]
180                        } else {
181                            let tmp = ctx.load_constant(convert_fr(&c));
182                            gate.add(ctx, vars[&b.0], tmp)
183                        };
184                        vars.insert(a.0, x);
185                    }
186                    DslIr::AddF(a, b, c) => {
187                        let x = f_chip.add(ctx, felts[&b.0], felts[&c.0]);
188                        felts.insert(a.0, x);
189                    }
190                    DslIr::AddFI(a, b, c) => {
191                        let x = if c.is_zero() {
192                            felts[&b.0]
193                        } else {
194                            let tmp = f_chip.load_constant(ctx, c);
195                            f_chip.add(ctx, felts[&b.0], tmp)
196                        };
197                        felts.insert(a.0, x);
198                    }
199                    DslIr::AddE(a, b, c) => {
200                        let x = ext_chip.add(ctx, exts[&b.0], exts[&c.0]);
201                        exts.insert(a.0, x);
202                    }
203                    DslIr::AddEF(a, b, c) => {
204                        let mut x = exts[&b.0];
205                        x.0[0] = f_chip.add(ctx, x.0[0], felts[&c.0]);
206                        exts.insert(a.0, x);
207                    }
208                    DslIr::AddEFI(a, b, c) => {
209                        let x = if c.is_zero() {
210                            exts[&b.0]
211                        } else {
212                            let tmp = f_chip.load_constant(ctx, c);
213                            let mut x = exts[&b.0];
214                            x.0[0] = f_chip.add(ctx, x.0[0], tmp);
215                            x
216                        };
217                        exts.insert(a.0, x);
218                    }
219                    DslIr::AddEI(a, b, c) => {
220                        let x = if c.is_zero() {
221                            exts[&b.0]
222                        } else {
223                            let tmp = ext_chip.load_constant(ctx, c);
224                            ext_chip.add(ctx, exts[&b.0], tmp)
225                        };
226                        exts.insert(a.0, x);
227                    }
228                    DslIr::AddEFFI(a, b, c) => {
229                        let mut x = ext_chip.load_constant(ctx, c);
230                        x.0[0] = f_chip.add(ctx, x.0[0], felts[&b.0]);
231                        exts.insert(a.0, x);
232                    }
233                    DslIr::SubV(a, b, c) => {
234                        let x = gate.sub(ctx, vars[&b.0], vars[&c.0]);
235                        vars.insert(a.0, x);
236                    }
237                    DslIr::SubF(a, b, c) => {
238                        let x = f_chip.sub(ctx, felts[&b.0], felts[&c.0]);
239                        felts.insert(a.0, x);
240                    }
241                    DslIr::SubE(a, b, c) => {
242                        let x = ext_chip.sub(ctx, exts[&b.0], exts[&c.0]);
243                        exts.insert(a.0, x);
244                    }
245                    DslIr::SubEF(a, b, c) => {
246                        let mut x = exts[&b.0];
247                        x.0[0] = f_chip.sub(ctx, x.0[0], felts[&c.0]);
248                        exts.insert(a.0, x);
249                    }
250                    DslIr::SubEI(a, b, c) => {
251                        let x = if c.is_zero() {
252                            exts[&b.0]
253                        } else {
254                            let tmp = ext_chip.load_constant(ctx, c);
255                            ext_chip.sub(ctx, exts[&b.0], tmp)
256                        };
257                        exts.insert(a.0, x);
258                    }
259                    DslIr::SubVIN(a, b, c) => {
260                        let tmp = ctx.load_constant(convert_fr(&b));
261                        let x = gate.sub(ctx, tmp, vars[&c.0]);
262                        vars.insert(a.0, x);
263                    }
264                    DslIr::SubEIN(a, b, c) => {
265                        let tmp = ext_chip.load_constant(ctx, b);
266                        let x = ext_chip.sub(ctx, tmp, exts[&c.0]);
267                        exts.insert(a.0, x);
268                    }
269                    DslIr::SubEFI(a, b, c) => {
270                        let x = if c.is_zero() {
271                            exts[&b.0]
272                        } else {
273                            let tmp = f_chip.load_constant(ctx, c);
274                            let mut x = exts[&b.0];
275                            x.0[0] = f_chip.sub(ctx, x.0[0], tmp);
276                            x
277                        };
278                        exts.insert(a.0, x);
279                    }
280                    DslIr::MulV(a, b, c) => {
281                        let x = gate.mul(ctx, vars[&b.0], vars[&c.0]);
282                        vars.insert(a.0, x);
283                    }
284                    DslIr::MulVI(a, b, c) => {
285                        let x = if c.is_one() {
286                            vars[&b.0]
287                        } else if c.is_zero() {
288                            ctx.load_zero()
289                        } else {
290                            let tmp = ctx.load_constant(convert_fr(&c));
291                            gate.mul(ctx, vars[&b.0], tmp)
292                        };
293                        vars.insert(a.0, x);
294                    }
295                    DslIr::MulF(a, b, c) => {
296                        let x = f_chip.mul(ctx, felts[&b.0], felts[&c.0]);
297                        felts.insert(a.0, x);
298                    }
299                    DslIr::MulFI(a, b, c) => {
300                        let x = if c.is_one() {
301                            felts[&b.0]
302                        } else if c.is_zero() {
303                            f_chip.load_constant(ctx, BabyBear::ZERO)
304                        } else {
305                            let tmp = f_chip.load_constant(ctx, c);
306                            f_chip.mul(ctx, felts[&b.0], tmp)
307                        };
308                        felts.insert(a.0, x);
309                    }
310                    DslIr::MulE(a, b, c) => {
311                        let x = ext_chip.mul(ctx, exts[&b.0], exts[&c.0]);
312                        exts.insert(a.0, x);
313                    }
314                    DslIr::MulEI(a, b, c) => {
315                        let x = if c.is_one() {
316                            exts[&b.0]
317                        } else if c.is_zero() {
318                            ext_chip.load_constant(ctx, BabyBearExt4::ZERO)
319                        } else {
320                            let tmp = ext_chip.load_constant(ctx, c);
321                            ext_chip.mul(ctx, exts[&b.0], tmp)
322                        };
323                        exts.insert(a.0, x);
324                    }
325                    DslIr::MulEF(a, b, c) => {
326                        let x = ext_chip.scalar_mul(ctx, exts[&b.0], felts[&c.0]);
327                        exts.insert(a.0, x);
328                    }
329                    DslIr::MulEFI(a, b, c) => {
330                        let x = if c.is_one() {
331                            exts[&b.0]
332                        } else if c.is_zero() {
333                            ext_chip.load_constant(ctx, BabyBearExt4::ZERO)
334                        } else {
335                            let tmp = f_chip.load_constant(ctx, c);
336                            ext_chip.scalar_mul(ctx, exts[&b.0], tmp)
337                        };
338                        exts.insert(a.0, x);
339                    }
340                    DslIr::DivF(a, b, c) => {
341                        let x = f_chip.div(ctx, felts[&b.0], felts[&c.0]);
342                        felts.insert(a.0, x);
343                    }
344                    DslIr::DivFIN(a, b, c) => {
345                        // a = b / c
346                        let tmp = f_chip.load_constant(ctx, b);
347                        let x = if b.is_zero() {
348                            tmp
349                        } else {
350                            f_chip.div(ctx, tmp, felts[&c.0])
351                        };
352                        felts.insert(a.0, x);
353                    }
354                    DslIr::DivE(a, b, c) => {
355                        let x = ext_chip.div(ctx, exts[&b.0], exts[&c.0]);
356                        exts.insert(a.0, x);
357                    }
358                    DslIr::DivEIN(a, b, c) => {
359                        let tmp = ext_chip.load_constant(ctx, b);
360                        let x = if b.is_zero() {
361                            tmp
362                        } else {
363                            ext_chip.div(ctx, tmp, exts[&c.0])
364                        };
365                        exts.insert(a.0, x);
366                    }
367                    DslIr::NegE(a, b) => {
368                        let x = ext_chip.neg(ctx, exts[&b.0]);
369                        exts.insert(a.0, x);
370                    }
371                    DslIr::CastFV(a, b) => {
372                        let felt = felts[&b.0];
373                        let reduced_felt = f_chip.reduce(ctx, felt);
374                        vars.insert(a.0, reduced_felt.value);
375                    }
376                    DslIr::CircuitNum2BitsF(value, output) => {
377                        let val = f_chip.reduce(ctx, felts[&value.0]);
378                        let x = gate.num_to_bits(ctx, val.value, 32); // C::F::bits());
379                        assert!(output.len() <= x.len());
380                        for (o, x) in output.into_iter().zip(x) {
381                            vars.insert(o.0, x);
382                        }
383                    }
384                    DslIr::CircuitVarTo64BitsF(value, output) => {
385                        let x = vars[&value.0];
386                        let limbs = var_to_u64_limbs(ctx, &range, gate, x);
387                        for (o, l) in output.into_iter().zip(limbs) {
388                            felts.insert(o.0, l);
389                        }
390                    }
391                    DslIr::CircuitPoseidon2Permute(state_vars) => {
392                        let mut state =
393                            Poseidon2State::<Fr, POSEIDON2_T>::new(state_vars.map(|x| vars[&x.0]));
394                        state.permutation(ctx, gate, &*POSEIDON2_PARAMS);
395                        for i in 0..POSEIDON2_T {
396                            *vars.get_mut(&state_vars[i].0).unwrap() = state.s[i];
397                        }
398                    }
399                    DslIr::CircuitSelectV(cond, a, b, out) => {
400                        let x = gate.select(ctx, vars[&a.0], vars[&b.0], vars[&cond.0]);
401                        vars.insert(out.0, x);
402                    }
403                    DslIr::CircuitSelectF(cond, a, b, out) => {
404                        let x = f_chip.select(ctx, vars[&cond.0], felts[&a.0], felts[&b.0]);
405                        felts.insert(out.0, x);
406                    }
407                    DslIr::CircuitSelectE(cond, a, b, out) => {
408                        let x = ext_chip.select(ctx, vars[&cond.0], exts[&a.0], exts[&b.0]);
409                        exts.insert(out.0, x);
410                    }
411                    DslIr::CircuitExt2Felt(a, b) => {
412                        for (i, x) in a.iter().enumerate() {
413                            felts.insert(x.0, exts[&b.0].0[i]);
414                        }
415                    }
416                    DslIr::AssertEqV(a, b) => {
417                        ctx.constrain_equal(&vars[&a.0], &vars[&b.0]);
418                    }
419                    DslIr::AssertEqVI(a, b) => {
420                        gate.assert_is_const(ctx, &vars[&a.0], &convert_fr(&b));
421                    }
422                    DslIr::AssertEqF(a, b) => {
423                        f_chip.assert_equal(ctx, felts[&a.0], felts[&b.0]);
424                    }
425                    DslIr::AssertEqFI(a, b) => {
426                        if b.is_zero() {
427                            f_chip.assert_zero(ctx, felts[&a.0]);
428                        } else {
429                            let tmp = f_chip.load_constant(ctx, b);
430                            f_chip.assert_equal(ctx, felts[&a.0], tmp);
431                        }
432                    }
433                    DslIr::AssertEqE(a, b) => {
434                        ext_chip.assert_equal(ctx, exts[&a.0], exts[&b.0]);
435                    }
436                    DslIr::AssertEqEI(a, b) => {
437                        // Note: we could check if each coordinate of `b` is zero separately for a little more efficiency,
438                        // but omitting to simplify the code
439                        if b.is_zero() {
440                            ext_chip.assert_zero(ctx, exts[&a.0]);
441                        } else {
442                            let tmp = ext_chip.load_constant(ctx, b);
443                            ext_chip.assert_equal(ctx, exts[&a.0], tmp);
444                        }
445                    }
446                    DslIr::PrintV(a) => {
447                        println!("PrintV: {:?}", vars[&a.0].value());
448                    }
449                    DslIr::PrintF(a) => {
450                        println!("PrintF: {:?}", felts[&a.0].to_baby_bear());
451                    }
452                    DslIr::PrintE(a) => {
453                        println!("PrintE:");
454                        for x in exts[&a.0].0.iter() {
455                            println!("{:?}", x.to_baby_bear());
456                        }
457                    }
458                    DslIr::WitnessVar(a, b) => {
459                        let x = ctx.load_witness(halo2_state.vars[&b]);
460                        vars.insert(a.0, x);
461                    }
462                    DslIr::WitnessFelt(a, b) => {
463                        let x = f_chip.load_witness(ctx, halo2_state.felts[&b]);
464                        felts.insert(a.0, x);
465                    }
466                    DslIr::WitnessExt(a, b) => {
467                        let x = ext_chip.load_witness(ctx, halo2_state.exts[&b]);
468                        exts.insert(a.0, x);
469                    }
470                    DslIr::CircuitFelts2Ext(a, b) => {
471                        let x = AssignedBabyBearExt4(
472                            a.iter()
473                                .map(|a| felts[&a.0])
474                                .collect_vec()
475                                .try_into()
476                                .unwrap(),
477                        );
478                        exts.insert(b.0, x);
479                    }
480                    DslIr::CircuitFeltReduce(a) => {
481                        let x = f_chip.reduce_max_bits(ctx, felts[&a.0]);
482                        felts.insert(a.0, x);
483                    }
484                    DslIr::CircuitExtReduce(a) => {
485                        let x = ext_chip.reduce_max_bits(ctx, exts[&a.0]);
486                        exts.insert(a.0, x);
487                    }
488                    DslIr::CircuitLessThan(a, b) => {
489                        range.range_check(ctx, vars[&a.0], C::F::bits());
490                        range.range_check(ctx, vars[&b.0], C::F::bits());
491                        range.check_less_than(ctx, vars[&a.0], vars[&b.0], C::F::bits());
492                    }
493                    DslIr::CycleTrackerStart(_name) => {
494                        #[cfg(feature = "bench-metrics")]
495                        cell_tracker.start(_name);
496                    }
497                    DslIr::CycleTrackerEnd(_name) => {
498                        #[cfg(feature = "bench-metrics")]
499                        cell_tracker.end(_name);
500                    }
501                    DslIr::CircuitPublish(val, index) => {
502                        public_values[index] = vars[&val.0];
503                    }
504                    _ => panic!("unsupported {:?}", instruction),
505                }
506            }));
507            if res.is_err() {
508                if let Some(mut backtrace) = backtrace {
509                    backtrace.resolve();
510                    eprintln!("openvm circuit failure; backtrace:\n{:?}", backtrace);
511                }
512                res.unwrap();
513            }
514            #[cfg(feature = "bench-metrics")]
515            if self.profiling {
516                let mut new_stats = stats_snapshot(ctx, range.clone());
517                new_stats.diff(&old_stats);
518                new_stats.increment(cell_tracker.get_full_name());
519            }
520        }
521
522        halo2_state.builder.assigned_instances = vec![public_values];
523    }
524}
525
526/// Assumes F is Bn254 Fr and converts to halo2 Fr type
527pub fn convert_fr<F: PrimeField>(a: &F) -> Fr {
528    biguint_to_fe(&a.as_canonical_biguint())
529}
530
531#[allow(dead_code)]
532pub fn convert_efr<F: PrimeField, EF: ExtensionField<F>>(a: &EF) -> Vec<Fr> {
533    let slc = a.as_base_slice();
534    slc.iter()
535        .map(|x| biguint_to_fe(&x.as_canonical_biguint()))
536        .collect()
537}
538
539// Unfortunately `builder.statistics()` cannot be called when `ctx` exists.
540#[allow(dead_code)] // used only in bench-metrics
541fn stats_snapshot(ctx: &Context<Fr>, range_chip: Arc<RangeChip<Fr>>) -> Halo2Stats {
542    Halo2Stats {
543        total_gate_cell: ctx.advice.len(),
544        // Note[Xinding]: this is inaccurate because of duplicated constants. But it's too slow if we always
545        // check for duplicates.
546        total_fixed: ctx.copy_manager.lock().unwrap().constant_equalities.len(),
547        total_lookup_cell: range_chip.lookup_manager()[0].total_rows(),
548    }
549}
550
551#[allow(dead_code)]
552fn is_babybear_ir<C: Config>(ir: &DslIr<C>) -> bool {
553    matches!(
554        ir,
555        DslIr::ImmF(_, _)
556            | DslIr::AddF(_, _, _)
557            | DslIr::AddFI(_, _, _)
558            | DslIr::SubF(_, _, _)
559            | DslIr::MulF(_, _, _)
560            | DslIr::MulFI(_, _, _)
561            | DslIr::DivFIN(_, _, _)
562            | DslIr::CircuitSelectF(_, _, _, _)
563            | DslIr::AssertEqF(_, _)
564            | DslIr::AssertEqFI(_, _)
565            | DslIr::WitnessFelt(_, _)
566            | DslIr::CircuitFelts2Ext(_, _)
567            | DslIr::CircuitFeltReduce(_)
568            | DslIr::CircuitExtReduce(_)
569            | DslIr::CircuitLessThan(..)
570            | DslIr::ImmE(_, _)
571            | DslIr::AddE(_, _, _)
572            | DslIr::AddEF(_, _, _)
573            | DslIr::AddEFI(_, _, _)
574            | DslIr::AddEI(_, _, _)
575            | DslIr::AddEFFI(_, _, _)
576            | DslIr::SubE(_, _, _)
577            | DslIr::SubEF(_, _, _)
578            | DslIr::SubEI(_, _, _)
579            | DslIr::SubEIN(_, _, _)
580            | DslIr::SubEFI(_, _, _)
581            | DslIr::MulE(_, _, _)
582            | DslIr::MulEI(_, _, _)
583            | DslIr::MulEF(_, _, _)
584            | DslIr::MulEFI(_, _, _)
585            | DslIr::DivE(_, _, _)
586            | DslIr::DivEIN(_, _, _)
587            | DslIr::NegE(_, _)
588            | DslIr::CircuitSelectE(_, _, _, _)
589            | DslIr::AssertEqE(_, _)
590            | DslIr::AssertEqEI(_, _)
591            | DslIr::WitnessExt(_, _)
592            | DslIr::CastFV(_, _)
593    )
594}
595
596fn fr_to_u64_limbs(fr: &Fr) -> [u64; 4] {
597    // We need 64-bit limbs but `decompose_fe_to_u64_limbs` only support `bit_len < 64`.
598    let limbs = decompose_fe_to_u64_limbs(fr, 8, 32);
599    std::array::from_fn(|i| limbs[2 * i] + limbs[2 * i + 1] * (1 << 32))
600}
601
602fn var_to_u64_limbs(
603    ctx: &mut Context<Fr>,
604    range: &RangeChip<Fr>,
605    gate: &GateChip<Fr>,
606    x: AssignedValue<Fr>,
607) -> [AssignedBabyBear; 4] {
608    let limbs = fr_to_u64_limbs(x.value()).map(|limb| ctx.load_witness(Fr::from(limb)));
609    let factors = [
610        Fr::from([1, 0, 0, 0]),
611        Fr::from([0, 1, 0, 0]),
612        Fr::from([0, 0, 1, 0]),
613        Fr::from([0, 0, 0, 1]),
614    ];
615    let sum = gate.inner_product(ctx, limbs, factors.map(QuantumCell::Constant));
616    ctx.constrain_equal(&sum, &x);
617    let fr_bound_limbs = fr_to_u64_limbs(&(Fr::ZERO - Fr::ONE));
618    let ret = std::array::from_fn(|i| {
619        let limb = limbs[i];
620        let bits = if i < 3 {
621            range.range_check(ctx, limb, 64);
622            64
623        } else {
624            range.check_less_than_safe(ctx, limbs[3], fr_bound_limbs[3] + 1);
625            (Fr::NUM_BITS - 3 * 64) as usize
626        };
627        AssignedBabyBear {
628            value: limb,
629            max_bits: bits,
630        }
631    });
632    // Constraint decomposition doesn't overflow.
633    // Whether limbs[i] == fr_bound_limbs[i] so far
634    let mut on_bound = gate.is_equal(
635        ctx,
636        limbs[3],
637        QuantumCell::Constant(Fr::from(fr_bound_limbs[3])),
638    );
639    for i in (0..3).rev() {
640        // limbs[i] > fr_bound_limbs[i]
641        let li_gt_bd = range.is_less_than(
642            ctx,
643            QuantumCell::Constant(Fr::from(fr_bound_limbs[i])),
644            limbs[i],
645            64,
646        );
647        let li_out_bd = gate.add(ctx, on_bound, li_gt_bd);
648        // on_bound  li_gt_bd  result
649        //    1         1       fail
650        //    1         0       pass
651        //    0         1       pass
652        //    0         0       pass
653        gate.assert_bit(ctx, li_out_bd);
654        // Update on_bound except the last limb
655        if i > 0 {
656            debug_assert_ne!(fr_bound_limbs[i], 0, "This should never happen for Bn254Fr");
657            // on_bound && limbs[i] - fr_bound_limbs[i] == 0
658            let diff = gate.sub_mul(
659                ctx,
660                QuantumCell::Constant(Fr::from(fr_bound_limbs[i])),
661                on_bound,
662                limbs[i],
663            );
664            on_bound = gate.is_zero(ctx, diff);
665        }
666    }
667    ret
668}
669
670#[test]
671fn test_var_to_u64_limbs() {}