openvm_native_compiler/constraints/halo2/
compiler.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    marker::PhantomData,
5    panic::{catch_unwind, AssertUnwindSafe},
6    sync::{Arc, LazyLock},
7};
8
9use itertools::Itertools;
10#[cfg(feature = "metrics")]
11use openvm_circuit::metrics::cycle_tracker::CycleTracker;
12use openvm_stark_backend::p3_field::{ExtensionField, Field, FieldAlgebra, PrimeField};
13use openvm_stark_sdk::{p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr};
14use snark_verifier_sdk::snark_verifier::{
15    halo2_base::{
16        gates::{
17            circuit::builder::BaseCircuitBuilder, GateChip, GateInstructions, RangeChip,
18            RangeInstructions,
19        },
20        halo2_proofs::halo2curves::bn256::Fr,
21        utils::{biguint_to_fe, decompose_fe_to_u64_limbs, ScalarField},
22        AssignedValue, Context, QuantumCell,
23    },
24    util::arithmetic::{Field as _, PrimeField as _},
25};
26
27use super::stats::Halo2Stats;
28use crate::{
29    constraints::halo2::{
30        baby_bear::{
31            AssignedBabyBear, AssignedBabyBearExt4, BabyBearChip, BabyBearExt4, BabyBearExt4Chip,
32        },
33        poseidon2_perm::{Poseidon2Params, Poseidon2State},
34    },
35    ir::{Config, DslIr, TracedVec, Witness},
36};
37
38const POSEIDON2_T: usize = 3;
39static POSEIDON2_PARAMS: LazyLock<Poseidon2Params<Fr, POSEIDON2_T>> = LazyLock::new(|| {
40    use zkhash::{
41        ark_ff::{BigInteger, PrimeField as _},
42        fields::bn256::FpBN256 as ark_FpBN256,
43        poseidon2::poseidon2_instance_bn256::{MAT_DIAG3_M_1, RC3},
44    };
45
46    fn convert_fr(input: ark_FpBN256) -> Fr {
47        Fr::from_bytes_le(&input.into_bigint().to_bytes_le())
48    }
49    const T: usize = 3;
50    let rounds_f = 8;
51    let rounds_p = 56;
52    let mut round_constants: Vec<[Fr; T]> = RC3
53        .iter()
54        .map(|vec| {
55            vec.iter()
56                .cloned()
57                .map(convert_fr)
58                .collect::<Vec<_>>()
59                .try_into()
60                .unwrap()
61        })
62        .collect();
63
64    let rounds_f_beginning = rounds_f / 2;
65    let p_end = rounds_f_beginning + rounds_p;
66    let internal_round_constants = round_constants
67        .drain(rounds_f_beginning..p_end)
68        .map(|vec| vec[0])
69        .collect::<Vec<_>>();
70    let external_round_constants = round_constants;
71    Poseidon2Params {
72        rounds_f,
73        rounds_p,
74        mat_internal_diag_m_1: MAT_DIAG3_M_1
75            .iter()
76            .copied()
77            .map(convert_fr)
78            .collect_vec()
79            .try_into()
80            .unwrap(),
81        external_rc: external_round_constants,
82        internal_rc: internal_round_constants,
83    }
84});
85
86/// The backend for the Halo2 constraint compiler.
87#[derive(Debug, Clone)]
88pub struct Halo2ConstraintCompiler<C: Config> {
89    pub num_public_values: usize,
90    #[allow(unused_variables)]
91    pub profiling: bool,
92    pub phantom: PhantomData<C>,
93}
94
95#[derive(Debug, Clone, Default)]
96pub struct Halo2State<C: Config> {
97    // halo2 stuff
98    pub builder: BaseCircuitBuilder<Fr>,
99    // Unassigned values: provided by the prover outside of constraint compiler
100    // map from name as string to halo2 assigned value
101    pub vars: HashMap<u32, Fr>,
102    pub felts: HashMap<u32, C::F>,
103    pub exts: HashMap<u32, C::EF>,
104}
105
106impl<C: Config> Halo2State<C> {
107    pub fn load_witness(&mut self, witness: Witness<C>) {
108        for (i, x) in witness.vars.iter().enumerate() {
109            self.vars.insert(i as u32, convert_fr(x));
110        }
111        for (i, x) in witness.felts.into_iter().enumerate() {
112            self.felts.insert(i as u32, x);
113        }
114        for (i, x) in witness.exts.into_iter().enumerate() {
115            self.exts.insert(i as u32, x);
116        }
117    }
118}
119
120impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
121    pub fn new(num_public_values: usize) -> Self {
122        Self {
123            num_public_values,
124            profiling: false,
125            phantom: PhantomData,
126        }
127    }
128    pub fn with_profiling(mut self) -> Self {
129        self.profiling = true;
130        self
131    }
132    // Create halo2-lib constraints from a list of operations in the DSL.
133    // Assume: C::N = C::F = C::EF is type Fr
134    pub fn constrain_halo2(&self, halo2_state: &mut Halo2State<C>, operations: TracedVec<DslIr<C>>)
135    where
136        C: Config<N = Bn254Fr, F = BabyBear, EF = BabyBearExt4>,
137    {
138        #[cfg(feature = "metrics")]
139        let mut cell_tracker = CycleTracker::new();
140        let range = Arc::new(halo2_state.builder.range_chip());
141        let f_chip = Arc::new(BabyBearChip::new(range.clone()));
142        let ext_chip = BabyBearExt4Chip::new(Arc::clone(&f_chip));
143        let gate = f_chip.gate();
144        let ctx = halo2_state.builder.main(0);
145        let mut public_values = vec![ctx.load_zero(); self.num_public_values];
146
147        // Local variables for referencing during the course of constraint building
148        let mut vars = HashMap::new();
149        let mut felts = HashMap::<u32, AssignedBabyBear>::new();
150        let mut exts = HashMap::<u32, AssignedBabyBearExt4>::new();
151
152        #[cfg(feature = "metrics")]
153        let mut old_stats = stats_snapshot(ctx, range.clone());
154        for (instruction, backtrace) in operations {
155            #[cfg(feature = "metrics")]
156            if self.profiling {
157                old_stats = stats_snapshot(ctx, range.clone());
158            }
159            let res = catch_unwind(AssertUnwindSafe(|| {
160                match instruction {
161                    DslIr::ImmV(a, b) => {
162                        let x = ctx.load_constant(convert_fr(&b));
163                        vars.insert(a.0, x);
164                    }
165                    DslIr::ImmF(a, b) => {
166                        let x = f_chip.load_constant(ctx, b);
167                        felts.insert(a.0, x);
168                    }
169                    DslIr::ImmE(a, b) => {
170                        let x = ext_chip.load_constant(ctx, b);
171                        exts.insert(a.0, x);
172                    }
173                    DslIr::AddV(a, b, c) => {
174                        let x = gate.add(ctx, vars[&b.0], vars[&c.0]);
175                        vars.insert(a.0, x);
176                    }
177                    DslIr::AddVI(a, b, c) => {
178                        let x = if c.is_zero() {
179                            vars[&b.0]
180                        } else {
181                            let tmp = ctx.load_constant(convert_fr(&c));
182                            gate.add(ctx, vars[&b.0], tmp)
183                        };
184                        vars.insert(a.0, x);
185                    }
186                    DslIr::AddF(a, b, c) => {
187                        let x = f_chip.add(ctx, felts[&b.0], felts[&c.0]);
188                        felts.insert(a.0, x);
189                    }
190                    DslIr::AddFI(a, b, c) => {
191                        let x = if c.is_zero() {
192                            felts[&b.0]
193                        } else {
194                            let tmp = f_chip.load_constant(ctx, c);
195                            f_chip.add(ctx, felts[&b.0], tmp)
196                        };
197                        felts.insert(a.0, x);
198                    }
199                    DslIr::AddE(a, b, c) => {
200                        let x = ext_chip.add(ctx, exts[&b.0], exts[&c.0]);
201                        exts.insert(a.0, x);
202                    }
203                    DslIr::AddEF(a, b, c) => {
204                        let mut x = exts[&b.0];
205                        x.0[0] = f_chip.add(ctx, x.0[0], felts[&c.0]);
206                        exts.insert(a.0, x);
207                    }
208                    DslIr::AddEFI(a, b, c) => {
209                        let x = if c.is_zero() {
210                            exts[&b.0]
211                        } else {
212                            let tmp = f_chip.load_constant(ctx, c);
213                            let mut x = exts[&b.0];
214                            x.0[0] = f_chip.add(ctx, x.0[0], tmp);
215                            x
216                        };
217                        exts.insert(a.0, x);
218                    }
219                    DslIr::AddEI(a, b, c) => {
220                        let x = if c.is_zero() {
221                            exts[&b.0]
222                        } else {
223                            let tmp = ext_chip.load_constant(ctx, c);
224                            ext_chip.add(ctx, exts[&b.0], tmp)
225                        };
226                        exts.insert(a.0, x);
227                    }
228                    DslIr::AddEFFI(a, b, c) => {
229                        let mut x = ext_chip.load_constant(ctx, c);
230                        x.0[0] = f_chip.add(ctx, x.0[0], felts[&b.0]);
231                        exts.insert(a.0, x);
232                    }
233                    DslIr::SubV(a, b, c) => {
234                        let x = gate.sub(ctx, vars[&b.0], vars[&c.0]);
235                        vars.insert(a.0, x);
236                    }
237                    DslIr::SubF(a, b, c) => {
238                        let x = f_chip.sub(ctx, felts[&b.0], felts[&c.0]);
239                        felts.insert(a.0, x);
240                    }
241                    DslIr::SubE(a, b, c) => {
242                        let x = ext_chip.sub(ctx, exts[&b.0], exts[&c.0]);
243                        exts.insert(a.0, x);
244                    }
245                    DslIr::SubEF(a, b, c) => {
246                        let mut x = exts[&b.0];
247                        x.0[0] = f_chip.sub(ctx, x.0[0], felts[&c.0]);
248                        exts.insert(a.0, x);
249                    }
250                    DslIr::SubEI(a, b, c) => {
251                        let x = if c.is_zero() {
252                            exts[&b.0]
253                        } else {
254                            let tmp = ext_chip.load_constant(ctx, c);
255                            ext_chip.sub(ctx, exts[&b.0], tmp)
256                        };
257                        exts.insert(a.0, x);
258                    }
259                    DslIr::SubVIN(a, b, c) => {
260                        let tmp = ctx.load_constant(convert_fr(&b));
261                        let x = gate.sub(ctx, tmp, vars[&c.0]);
262                        vars.insert(a.0, x);
263                    }
264                    DslIr::SubEIN(a, b, c) => {
265                        let tmp = ext_chip.load_constant(ctx, b);
266                        let x = ext_chip.sub(ctx, tmp, exts[&c.0]);
267                        exts.insert(a.0, x);
268                    }
269                    DslIr::SubEFI(a, b, c) => {
270                        let x = if c.is_zero() {
271                            exts[&b.0]
272                        } else {
273                            let tmp = f_chip.load_constant(ctx, c);
274                            let mut x = exts[&b.0];
275                            x.0[0] = f_chip.sub(ctx, x.0[0], tmp);
276                            x
277                        };
278                        exts.insert(a.0, x);
279                    }
280                    DslIr::MulV(a, b, c) => {
281                        let x = gate.mul(ctx, vars[&b.0], vars[&c.0]);
282                        vars.insert(a.0, x);
283                    }
284                    DslIr::MulVI(a, b, c) => {
285                        let x = if c.is_one() {
286                            vars[&b.0]
287                        } else if c.is_zero() {
288                            ctx.load_zero()
289                        } else {
290                            let tmp = ctx.load_constant(convert_fr(&c));
291                            gate.mul(ctx, vars[&b.0], tmp)
292                        };
293                        vars.insert(a.0, x);
294                    }
295                    DslIr::MulF(a, b, c) => {
296                        let x = f_chip.mul(ctx, felts[&b.0], felts[&c.0]);
297                        felts.insert(a.0, x);
298                    }
299                    DslIr::MulFI(a, b, c) => {
300                        let x = if c.is_one() {
301                            felts[&b.0]
302                        } else if c.is_zero() {
303                            f_chip.load_constant(ctx, BabyBear::ZERO)
304                        } else {
305                            let tmp = f_chip.load_constant(ctx, c);
306                            f_chip.mul(ctx, felts[&b.0], tmp)
307                        };
308                        felts.insert(a.0, x);
309                    }
310                    DslIr::MulE(a, b, c) => {
311                        let x = ext_chip.mul(ctx, exts[&b.0], exts[&c.0]);
312                        exts.insert(a.0, x);
313                    }
314                    DslIr::MulEI(a, b, c) => {
315                        let x = if c.is_one() {
316                            exts[&b.0]
317                        } else if c.is_zero() {
318                            ext_chip.load_constant(ctx, BabyBearExt4::ZERO)
319                        } else {
320                            let tmp = ext_chip.load_constant(ctx, c);
321                            ext_chip.mul(ctx, exts[&b.0], tmp)
322                        };
323                        exts.insert(a.0, x);
324                    }
325                    DslIr::MulEF(a, b, c) => {
326                        let x = ext_chip.scalar_mul(ctx, exts[&b.0], felts[&c.0]);
327                        exts.insert(a.0, x);
328                    }
329                    DslIr::MulEFI(a, b, c) => {
330                        let x = if c.is_one() {
331                            exts[&b.0]
332                        } else if c.is_zero() {
333                            ext_chip.load_constant(ctx, BabyBearExt4::ZERO)
334                        } else {
335                            let tmp = f_chip.load_constant(ctx, c);
336                            ext_chip.scalar_mul(ctx, exts[&b.0], tmp)
337                        };
338                        exts.insert(a.0, x);
339                    }
340                    DslIr::DivF(a, b, c) => {
341                        let x = f_chip.div(ctx, felts[&b.0], felts[&c.0]);
342                        felts.insert(a.0, x);
343                    }
344                    DslIr::DivFIN(a, b, c) => {
345                        // a = b / c
346                        let tmp = f_chip.load_constant(ctx, b);
347                        let x = if b.is_zero() {
348                            tmp
349                        } else {
350                            f_chip.div(ctx, tmp, felts[&c.0])
351                        };
352                        felts.insert(a.0, x);
353                    }
354                    DslIr::DivE(a, b, c) => {
355                        let x = ext_chip.div(ctx, exts[&b.0], exts[&c.0]);
356                        exts.insert(a.0, x);
357                    }
358                    DslIr::DivEIN(a, b, c) => {
359                        let tmp = ext_chip.load_constant(ctx, b);
360                        let x = if b.is_zero() {
361                            tmp
362                        } else {
363                            ext_chip.div(ctx, tmp, exts[&c.0])
364                        };
365                        exts.insert(a.0, x);
366                    }
367                    DslIr::NegE(a, b) => {
368                        let x = ext_chip.neg(ctx, exts[&b.0]);
369                        exts.insert(a.0, x);
370                    }
371                    DslIr::CastFV(a, b) => {
372                        let felt = felts[&b.0];
373                        let reduced_felt = f_chip.reduce(ctx, felt);
374                        vars.insert(a.0, reduced_felt.value);
375                    }
376                    DslIr::CircuitNum2BitsF(value, output) => {
377                        let val = f_chip.reduce(ctx, felts[&value.0]);
378                        let x = gate.num_to_bits(ctx, val.value, 32); // C::F::bits());
379                        assert!(output.len() <= x.len());
380                        for (o, x) in output.into_iter().zip(x) {
381                            vars.insert(o.0, x);
382                        }
383                    }
384                    DslIr::CircuitVarTo64BitsF(value, output) => {
385                        let x = vars[&value.0];
386                        let limbs = var_to_u64_limbs(ctx, &range, gate, x);
387                        for (o, l) in output.into_iter().zip(limbs) {
388                            felts.insert(o.0, l);
389                        }
390                    }
391                    DslIr::CircuitPoseidon2Permute(state_vars) => {
392                        let mut state =
393                            Poseidon2State::<Fr, POSEIDON2_T>::new(state_vars.map(|x| vars[&x.0]));
394                        state.permutation(ctx, gate, &*POSEIDON2_PARAMS);
395                        for i in 0..POSEIDON2_T {
396                            *vars.get_mut(&state_vars[i].0).unwrap() = state.s[i];
397                        }
398                    }
399                    DslIr::CircuitSelectV(cond, a, b, out) => {
400                        let x = gate.select(ctx, vars[&a.0], vars[&b.0], vars[&cond.0]);
401                        vars.insert(out.0, x);
402                    }
403                    DslIr::CircuitSelectF(cond, a, b, out) => {
404                        let x = f_chip.select(ctx, vars[&cond.0], felts[&a.0], felts[&b.0]);
405                        felts.insert(out.0, x);
406                    }
407                    DslIr::CircuitSelectE(cond, a, b, out) => {
408                        let x = ext_chip.select(ctx, vars[&cond.0], exts[&a.0], exts[&b.0]);
409                        exts.insert(out.0, x);
410                    }
411                    DslIr::CircuitExt2Felt(a, b) => {
412                        for (i, x) in a.iter().enumerate() {
413                            felts.insert(x.0, exts[&b.0].0[i]);
414                        }
415                    }
416                    DslIr::AssertEqV(a, b) => {
417                        ctx.constrain_equal(&vars[&a.0], &vars[&b.0]);
418                    }
419                    DslIr::AssertEqVI(a, b) => {
420                        gate.assert_is_const(ctx, &vars[&a.0], &convert_fr(&b));
421                    }
422                    DslIr::AssertEqF(a, b) => {
423                        f_chip.assert_equal(ctx, felts[&a.0], felts[&b.0]);
424                    }
425                    DslIr::AssertEqFI(a, b) => {
426                        if b.is_zero() {
427                            f_chip.assert_zero(ctx, felts[&a.0]);
428                        } else {
429                            let tmp = f_chip.load_constant(ctx, b);
430                            f_chip.assert_equal(ctx, felts[&a.0], tmp);
431                        }
432                    }
433                    DslIr::AssertEqE(a, b) => {
434                        ext_chip.assert_equal(ctx, exts[&a.0], exts[&b.0]);
435                    }
436                    DslIr::AssertEqEI(a, b) => {
437                        // Note: we could check if each coordinate of `b` is zero separately for a
438                        // little more efficiency, but omitting to simplify
439                        // the code
440                        if b.is_zero() {
441                            ext_chip.assert_zero(ctx, exts[&a.0]);
442                        } else {
443                            let tmp = ext_chip.load_constant(ctx, b);
444                            ext_chip.assert_equal(ctx, exts[&a.0], tmp);
445                        }
446                    }
447                    DslIr::PrintV(a) => {
448                        println!("PrintV: {:?}", vars[&a.0].value());
449                    }
450                    DslIr::PrintF(a) => {
451                        println!("PrintF: {:?}", felts[&a.0].to_baby_bear());
452                    }
453                    DslIr::PrintE(a) => {
454                        println!("PrintE:");
455                        for x in exts[&a.0].0.iter() {
456                            println!("{:?}", x.to_baby_bear());
457                        }
458                    }
459                    DslIr::WitnessVar(a, b) => {
460                        let x = ctx.load_witness(halo2_state.vars[&b]);
461                        vars.insert(a.0, x);
462                    }
463                    DslIr::WitnessFelt(a, b) => {
464                        let x = f_chip.load_witness(ctx, halo2_state.felts[&b]);
465                        felts.insert(a.0, x);
466                    }
467                    DslIr::WitnessExt(a, b) => {
468                        let x = ext_chip.load_witness(ctx, halo2_state.exts[&b]);
469                        exts.insert(a.0, x);
470                    }
471                    DslIr::CircuitFelts2Ext(a, b) => {
472                        let x = AssignedBabyBearExt4(
473                            a.iter()
474                                .map(|a| felts[&a.0])
475                                .collect_vec()
476                                .try_into()
477                                .unwrap(),
478                        );
479                        exts.insert(b.0, x);
480                    }
481                    DslIr::CircuitFeltReduce(a) => {
482                        let x = f_chip.reduce_max_bits(ctx, felts[&a.0]);
483                        felts.insert(a.0, x);
484                    }
485                    DslIr::CircuitExtReduce(a) => {
486                        let x = ext_chip.reduce_max_bits(ctx, exts[&a.0]);
487                        exts.insert(a.0, x);
488                    }
489                    DslIr::CircuitLessThan(a, b) => {
490                        range.range_check(ctx, vars[&a.0], C::F::bits());
491                        range.range_check(ctx, vars[&b.0], C::F::bits());
492                        range.check_less_than(ctx, vars[&a.0], vars[&b.0], C::F::bits());
493                    }
494                    DslIr::CycleTrackerStart(_name) => {
495                        #[cfg(feature = "metrics")]
496                        cell_tracker.start(_name);
497                    }
498                    DslIr::CycleTrackerEnd(_name) => {
499                        #[cfg(feature = "metrics")]
500                        cell_tracker.end(_name);
501                    }
502                    DslIr::CircuitPublish(val, index) => {
503                        public_values[index] = vars[&val.0];
504                    }
505                    _ => panic!("unsupported {:?}", instruction),
506                }
507            }));
508            if res.is_err() {
509                if let Some(mut backtrace) = backtrace {
510                    backtrace.resolve();
511                    eprintln!("openvm circuit failure; backtrace:\n{:?}", backtrace);
512                }
513                res.unwrap();
514            }
515            #[cfg(feature = "metrics")]
516            if self.profiling {
517                let mut new_stats = stats_snapshot(ctx, range.clone());
518                new_stats.diff(&old_stats);
519                new_stats.increment(cell_tracker.get_full_name());
520            }
521        }
522
523        halo2_state.builder.assigned_instances = vec![public_values];
524    }
525}
526
527/// Assumes F is Bn254 Fr and converts to halo2 Fr type
528pub fn convert_fr<F: PrimeField>(a: &F) -> Fr {
529    biguint_to_fe(&a.as_canonical_biguint())
530}
531
532#[allow(dead_code)]
533pub fn convert_efr<F: PrimeField, EF: ExtensionField<F>>(a: &EF) -> Vec<Fr> {
534    let slc = a.as_base_slice();
535    slc.iter()
536        .map(|x| biguint_to_fe(&x.as_canonical_biguint()))
537        .collect()
538}
539
540// Unfortunately `builder.statistics()` cannot be called when `ctx` exists.
541#[allow(dead_code)] // used only in metrics
542fn stats_snapshot(ctx: &Context<Fr>, range_chip: Arc<RangeChip<Fr>>) -> Halo2Stats {
543    Halo2Stats {
544        total_gate_cell: ctx.advice.len(),
545        // Note[Xinding]: this is inaccurate because of duplicated constants. But it's too slow if
546        // we always check for duplicates.
547        total_fixed: ctx.copy_manager.lock().unwrap().constant_equalities.len(),
548        total_lookup_cell: range_chip.lookup_manager()[0].total_rows(),
549    }
550}
551
552#[allow(dead_code)]
553fn is_babybear_ir<C: Config>(ir: &DslIr<C>) -> bool {
554    matches!(
555        ir,
556        DslIr::ImmF(_, _)
557            | DslIr::AddF(_, _, _)
558            | DslIr::AddFI(_, _, _)
559            | DslIr::SubF(_, _, _)
560            | DslIr::MulF(_, _, _)
561            | DslIr::MulFI(_, _, _)
562            | DslIr::DivFIN(_, _, _)
563            | DslIr::CircuitSelectF(_, _, _, _)
564            | DslIr::AssertEqF(_, _)
565            | DslIr::AssertEqFI(_, _)
566            | DslIr::WitnessFelt(_, _)
567            | DslIr::CircuitFelts2Ext(_, _)
568            | DslIr::CircuitFeltReduce(_)
569            | DslIr::CircuitExtReduce(_)
570            | DslIr::CircuitLessThan(..)
571            | DslIr::ImmE(_, _)
572            | DslIr::AddE(_, _, _)
573            | DslIr::AddEF(_, _, _)
574            | DslIr::AddEFI(_, _, _)
575            | DslIr::AddEI(_, _, _)
576            | DslIr::AddEFFI(_, _, _)
577            | DslIr::SubE(_, _, _)
578            | DslIr::SubEF(_, _, _)
579            | DslIr::SubEI(_, _, _)
580            | DslIr::SubEIN(_, _, _)
581            | DslIr::SubEFI(_, _, _)
582            | DslIr::MulE(_, _, _)
583            | DslIr::MulEI(_, _, _)
584            | DslIr::MulEF(_, _, _)
585            | DslIr::MulEFI(_, _, _)
586            | DslIr::DivE(_, _, _)
587            | DslIr::DivEIN(_, _, _)
588            | DslIr::NegE(_, _)
589            | DslIr::CircuitSelectE(_, _, _, _)
590            | DslIr::AssertEqE(_, _)
591            | DslIr::AssertEqEI(_, _)
592            | DslIr::WitnessExt(_, _)
593            | DslIr::CastFV(_, _)
594    )
595}
596
597fn fr_to_u64_limbs(fr: &Fr) -> [u64; 4] {
598    // We need 64-bit limbs but `decompose_fe_to_u64_limbs` only support `bit_len < 64`.
599    let limbs = decompose_fe_to_u64_limbs(fr, 8, 32);
600    std::array::from_fn(|i| limbs[2 * i] + limbs[2 * i + 1] * (1 << 32))
601}
602
603fn var_to_u64_limbs(
604    ctx: &mut Context<Fr>,
605    range: &RangeChip<Fr>,
606    gate: &GateChip<Fr>,
607    x: AssignedValue<Fr>,
608) -> [AssignedBabyBear; 4] {
609    let limbs = fr_to_u64_limbs(x.value()).map(|limb| ctx.load_witness(Fr::from(limb)));
610    let factors = [
611        Fr::from([1, 0, 0, 0]),
612        Fr::from([0, 1, 0, 0]),
613        Fr::from([0, 0, 1, 0]),
614        Fr::from([0, 0, 0, 1]),
615    ];
616    let sum = gate.inner_product(ctx, limbs, factors.map(QuantumCell::Constant));
617    ctx.constrain_equal(&sum, &x);
618    let fr_bound_limbs = fr_to_u64_limbs(&(Fr::ZERO - Fr::ONE));
619    let ret = std::array::from_fn(|i| {
620        let limb = limbs[i];
621        let bits = if i < 3 {
622            range.range_check(ctx, limb, 64);
623            64
624        } else {
625            range.check_less_than_safe(ctx, limbs[3], fr_bound_limbs[3] + 1);
626            (Fr::NUM_BITS - 3 * 64) as usize
627        };
628        AssignedBabyBear {
629            value: limb,
630            max_bits: bits,
631        }
632    });
633    // Constraint decomposition doesn't overflow.
634    // Whether limbs[i] == fr_bound_limbs[i] so far
635    let mut on_bound = gate.is_equal(
636        ctx,
637        limbs[3],
638        QuantumCell::Constant(Fr::from(fr_bound_limbs[3])),
639    );
640    for i in (0..3).rev() {
641        // limbs[i] > fr_bound_limbs[i]
642        let li_gt_bd = range.is_less_than(
643            ctx,
644            QuantumCell::Constant(Fr::from(fr_bound_limbs[i])),
645            limbs[i],
646            64,
647        );
648        let li_out_bd = gate.add(ctx, on_bound, li_gt_bd);
649        // on_bound  li_gt_bd  result
650        //    1         1       fail
651        //    1         0       pass
652        //    0         1       pass
653        //    0         0       pass
654        gate.assert_bit(ctx, li_out_bd);
655        // Update on_bound except the last limb
656        if i > 0 {
657            debug_assert_ne!(fr_bound_limbs[i], 0, "This should never happen for Bn254Fr");
658            // on_bound && limbs[i] - fr_bound_limbs[i] == 0
659            let diff = gate.sub_mul(
660                ctx,
661                QuantumCell::Constant(Fr::from(fr_bound_limbs[i])),
662                on_bound,
663                limbs[i],
664            );
665            on_bound = gate.is_zero(ctx, diff);
666        }
667    }
668    ret
669}
670
671#[test]
672fn test_var_to_u64_limbs() {}