openvm_algebra_circuit/modular_chip/
muldiv.rs

1use std::{
2    cell::RefCell,
3    rc::Rc,
4    sync::{Arc, Mutex},
5};
6
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory};
9use openvm_circuit_derive::InstructionExecutor;
10use openvm_circuit_primitives::var_range::{
11    SharedVariableRangeCheckerChip, VariableRangeCheckerBus,
12};
13use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
14use openvm_mod_circuit_builder::{
15    ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, SymbolicExpr,
16};
17use openvm_rv32_adapters::Rv32VecHeapAdapterChip;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20pub fn muldiv_expr(
21    config: ExprBuilderConfig,
22    range_bus: VariableRangeCheckerBus,
23) -> (FieldExpr, usize, usize) {
24    config.check_valid();
25    let builder = ExprBuilder::new(config, range_bus.range_max_bits);
26    let builder = Rc::new(RefCell::new(builder));
27    let x = ExprBuilder::new_input(builder.clone());
28    let y = ExprBuilder::new_input(builder.clone());
29    let (z_idx, z) = builder.borrow_mut().new_var();
30    let mut z = FieldVariable::from_var(builder.clone(), z);
31    let is_mul_flag = builder.borrow_mut().new_flag();
32    let is_div_flag = builder.borrow_mut().new_flag();
33    // constraint is x * y = z, or z * y = x
34    let lvar = FieldVariable::select(is_mul_flag, &x, &z);
35    let rvar = FieldVariable::select(is_mul_flag, &z, &x);
36    // When it's SETUP op, x = p == 0, y = 0, both flags are false, and it still works: z * 0 - x =
37    // 0, whatever z is.
38    let constraint = lvar * y.clone() - rvar;
39    builder.borrow_mut().set_constraint(z_idx, constraint.expr);
40    let compute = SymbolicExpr::Select(
41        is_mul_flag,
42        Box::new(x.expr.clone() * y.expr.clone()),
43        Box::new(SymbolicExpr::Select(
44            is_div_flag,
45            Box::new(x.expr.clone() / y.expr.clone()),
46            Box::new(x.expr.clone()),
47        )),
48    );
49    builder.borrow_mut().set_compute(z_idx, compute);
50    z.save_output();
51
52    let builder = builder.borrow().clone();
53
54    (
55        FieldExpr::new(builder, range_bus, true),
56        is_mul_flag,
57        is_div_flag,
58    )
59}
60
61#[derive(Chip, ChipUsageGetter, InstructionExecutor)]
62pub struct ModularMulDivChip<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>(
63    pub  VmChipWrapper<
64        F,
65        Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
66        FieldExpressionCoreChip,
67    >,
68);
69
70impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>
71    ModularMulDivChip<F, BLOCKS, BLOCK_SIZE>
72{
73    pub fn new(
74        adapter: Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
75        config: ExprBuilderConfig,
76        offset: usize,
77        range_checker: SharedVariableRangeCheckerChip,
78        offline_memory: Arc<Mutex<OfflineMemory<F>>>,
79    ) -> Self {
80        let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker.bus());
81        let core = FieldExpressionCoreChip::new(
82            expr,
83            offset,
84            vec![
85                Rv32ModularArithmeticOpcode::MUL as usize,
86                Rv32ModularArithmeticOpcode::DIV as usize,
87                Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize,
88            ],
89            vec![is_mul_flag, is_div_flag],
90            range_checker,
91            "ModularMulDiv",
92            false,
93        );
94        Self(VmChipWrapper::new(adapter, core, offline_memory))
95    }
96}