openvm_algebra_circuit/fp2_chip/
muldiv.rs

1use std::{
2    cell::RefCell,
3    rc::Rc,
4    sync::{Arc, Mutex},
5};
6
7use openvm_algebra_transpiler::Fp2Opcode;
8use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory};
9use openvm_circuit_derive::InstructionExecutor;
10use openvm_circuit_primitives::var_range::{
11    SharedVariableRangeCheckerChip, VariableRangeCheckerBus,
12};
13use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
14use openvm_mod_circuit_builder::{
15    ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, SymbolicExpr,
16};
17use openvm_rv32_adapters::Rv32VecHeapAdapterChip;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::Fp2;
21
22// Input: Fp2 * 2
23// Output: Fp2
24#[derive(Chip, ChipUsageGetter, InstructionExecutor)]
25pub struct Fp2MulDivChip<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>(
26    pub  VmChipWrapper<
27        F,
28        Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
29        FieldExpressionCoreChip,
30    >,
31);
32
33impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>
34    Fp2MulDivChip<F, BLOCKS, BLOCK_SIZE>
35{
36    pub fn new(
37        adapter: Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
38        config: ExprBuilderConfig,
39        offset: usize,
40        range_checker: SharedVariableRangeCheckerChip,
41        offline_memory: Arc<Mutex<OfflineMemory<F>>>,
42    ) -> Self {
43        let (expr, is_mul_flag, is_div_flag) = fp2_muldiv_expr(config, range_checker.bus());
44        let core = FieldExpressionCoreChip::new(
45            expr,
46            offset,
47            vec![
48                Fp2Opcode::MUL as usize,
49                Fp2Opcode::DIV as usize,
50                Fp2Opcode::SETUP_MULDIV as usize,
51            ],
52            vec![is_mul_flag, is_div_flag],
53            range_checker,
54            "Fp2MulDiv",
55            false,
56        );
57        Self(VmChipWrapper::new(adapter, core, offline_memory))
58    }
59}
60
61pub fn fp2_muldiv_expr(
62    config: ExprBuilderConfig,
63    range_bus: VariableRangeCheckerBus,
64) -> (FieldExpr, usize, usize) {
65    config.check_valid();
66    let builder = ExprBuilder::new(config, range_bus.range_max_bits);
67    let builder = Rc::new(RefCell::new(builder));
68
69    let x = Fp2::new(builder.clone());
70    let mut y = Fp2::new(builder.clone());
71    let is_mul_flag = builder.borrow_mut().new_flag();
72    let is_div_flag = builder.borrow_mut().new_flag();
73    let (z_idx, mut z) = Fp2::new_var(builder.clone());
74
75    let mut lvar = Fp2::select(is_mul_flag, &x, &z);
76
77    let mut rvar = Fp2::select(is_mul_flag, &z, &x);
78    let fp2_constraint = lvar.mul(&mut y).sub(&mut rvar);
79    // When it's SETUP op, the constraints is z * y - x = 0, it still works as:
80    // x.c0 = x.c1 = p == 0, y.c0 = y.c1 = 0, so whatever z is, z * 0 - 0 = 0
81
82    z.save_output();
83    builder
84        .borrow_mut()
85        .set_constraint(z_idx.0, fp2_constraint.c0.expr);
86    builder
87        .borrow_mut()
88        .set_constraint(z_idx.1, fp2_constraint.c1.expr);
89
90    // Compute expression has to be done manually at the SymbolicExpr level.
91    // Otherwise it saves the quotient and introduces new variables.
92    let compute_z0_div = (&x.c0.expr * &y.c0.expr + &x.c1.expr * &y.c1.expr)
93        / (&y.c0.expr * &y.c0.expr + &y.c1.expr * &y.c1.expr);
94    let compute_z0_mul = &x.c0.expr * &y.c0.expr - &x.c1.expr * &y.c1.expr;
95    let compute_z0 = SymbolicExpr::Select(
96        is_mul_flag,
97        Box::new(compute_z0_mul),
98        Box::new(SymbolicExpr::Select(
99            is_div_flag,
100            Box::new(compute_z0_div),
101            Box::new(x.c0.expr.clone()),
102        )),
103    );
104    let compute_z1_div = (&x.c1.expr * &y.c0.expr - &x.c0.expr * &y.c1.expr)
105        / (&y.c0.expr * &y.c0.expr + &y.c1.expr * &y.c1.expr);
106    let compute_z1_mul = &x.c1.expr * &y.c0.expr + &x.c0.expr * &y.c1.expr;
107    let compute_z1 = SymbolicExpr::Select(
108        is_mul_flag,
109        Box::new(compute_z1_mul),
110        Box::new(SymbolicExpr::Select(
111            is_div_flag,
112            Box::new(compute_z1_div),
113            Box::new(x.c1.expr),
114        )),
115    );
116    builder.borrow_mut().set_compute(z_idx.0, compute_z0);
117    builder.borrow_mut().set_compute(z_idx.1, compute_z1);
118
119    let builder = builder.borrow().clone();
120    (
121        FieldExpr::new(builder, range_bus, true),
122        is_mul_flag,
123        is_div_flag,
124    )
125}
126
127#[cfg(test)]
128mod tests {
129
130    use halo2curves_axiom::{bn256::Fq2, ff::Field};
131    use itertools::Itertools;
132    use openvm_algebra_transpiler::Fp2Opcode;
133    use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS};
134    use openvm_circuit_primitives::bitwise_op_lookup::{
135        BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
136    };
137    use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode};
138    use openvm_mod_circuit_builder::{
139        test_utils::{biguint_to_limbs, bn254_fq2_to_biguint_vec, bn254_fq_to_biguint},
140        ExprBuilderConfig,
141    };
142    use openvm_pairing_guest::bn254::BN254_MODULUS;
143    use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip};
144    use openvm_stark_backend::p3_field::FieldAlgebra;
145    use openvm_stark_sdk::p3_baby_bear::BabyBear;
146    use rand::{rngs::StdRng, SeedableRng};
147
148    use super::Fp2MulDivChip;
149
150    const NUM_LIMBS: usize = 32;
151    const LIMB_BITS: usize = 8;
152    type F = BabyBear;
153
154    #[test]
155    fn test_fp2_muldiv() {
156        let mut tester: VmChipTestBuilder<F> = VmChipTestBuilder::default();
157        let modulus = BN254_MODULUS.clone();
158        let config = ExprBuilderConfig {
159            modulus: modulus.clone(),
160            num_limbs: NUM_LIMBS,
161            limb_bits: LIMB_BITS,
162        };
163        let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS);
164        let bitwise_chip = SharedBitwiseOperationLookupChip::<RV32_CELL_BITS>::new(bitwise_bus);
165        let adapter = Rv32VecHeapAdapterChip::<F, 2, 2, 2, NUM_LIMBS, NUM_LIMBS>::new(
166            tester.execution_bus(),
167            tester.program_bus(),
168            tester.memory_bridge(),
169            tester.address_bits(),
170            bitwise_chip.clone(),
171        );
172        let mut chip = Fp2MulDivChip::new(
173            adapter,
174            config,
175            Fp2Opcode::CLASS_OFFSET,
176            tester.range_checker(),
177            tester.offline_memory_mutex_arc(),
178        );
179        assert_eq!(
180            chip.0.core.expr().builder.num_variables,
181            2,
182            "Fp2MulDiv should only introduce new z Fp2 variable (2 Fp var)"
183        );
184
185        let mut rng = StdRng::seed_from_u64(42);
186        let x = Fq2::random(&mut rng);
187        let y = Fq2::random(&mut rng);
188        let inputs = [x.c0, x.c1, y.c0, y.c1].map(bn254_fq_to_biguint);
189
190        let expected_mul = bn254_fq2_to_biguint_vec(x * y);
191        let r_mul = chip
192            .0
193            .core
194            .expr()
195            .execute_with_output(inputs.to_vec(), vec![true, false]);
196        assert_eq!(r_mul.len(), 2);
197        assert_eq!(r_mul[0], expected_mul[0]);
198        assert_eq!(r_mul[1], expected_mul[1]);
199
200        let expected_div = bn254_fq2_to_biguint_vec(x * y.invert().unwrap());
201        let r_div = chip
202            .0
203            .core
204            .expr()
205            .execute_with_output(inputs.to_vec(), vec![false, true]);
206        assert_eq!(r_div.len(), 2);
207        assert_eq!(r_div[0], expected_div[0]);
208        assert_eq!(r_div[1], expected_div[1]);
209
210        let x_limbs = inputs[0..2]
211            .iter()
212            .map(|x| {
213                biguint_to_limbs::<NUM_LIMBS>(x.clone(), LIMB_BITS)
214                    .map(BabyBear::from_canonical_u32)
215            })
216            .collect_vec();
217        let y_limbs = inputs[2..4]
218            .iter()
219            .map(|x| {
220                biguint_to_limbs::<NUM_LIMBS>(x.clone(), LIMB_BITS)
221                    .map(BabyBear::from_canonical_u32)
222            })
223            .collect_vec();
224        let modulus =
225            biguint_to_limbs::<NUM_LIMBS>(modulus, LIMB_BITS).map(BabyBear::from_canonical_u32);
226        let zero = [BabyBear::ZERO; NUM_LIMBS];
227        let setup_instruction = rv32_write_heap_default(
228            &mut tester,
229            vec![modulus, zero],
230            vec![zero; 2],
231            chip.0.core.air.offset + Fp2Opcode::SETUP_MULDIV as usize,
232        );
233        let instruction1 = rv32_write_heap_default(
234            &mut tester,
235            x_limbs.clone(),
236            y_limbs.clone(),
237            chip.0.core.air.offset + Fp2Opcode::MUL as usize,
238        );
239        let instruction2 = rv32_write_heap_default(
240            &mut tester,
241            x_limbs,
242            y_limbs,
243            chip.0.core.air.offset + Fp2Opcode::DIV as usize,
244        );
245        tester.execute(&mut chip, &setup_instruction);
246        tester.execute(&mut chip, &instruction1);
247        tester.execute(&mut chip, &instruction2);
248        let tester = tester.build().load(chip).load(bitwise_chip).finalize();
249        tester.simple_test().expect("Verification failed");
250    }
251}