openvm_algebra_circuit/modular_chip/
muldiv.rs

1use std::{
2    cell::RefCell,
3    rc::Rc,
4    sync::{Arc, Mutex},
5};
6
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory};
9use openvm_circuit_derive::InstructionExecutor;
10use openvm_circuit_primitives::var_range::{
11    SharedVariableRangeCheckerChip, VariableRangeCheckerBus,
12};
13use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
14use openvm_mod_circuit_builder::{
15    ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, SymbolicExpr,
16};
17use openvm_rv32_adapters::Rv32VecHeapAdapterChip;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20pub fn muldiv_expr(
21    config: ExprBuilderConfig,
22    range_bus: VariableRangeCheckerBus,
23) -> (FieldExpr, usize, usize) {
24    config.check_valid();
25    let builder = ExprBuilder::new(config, range_bus.range_max_bits);
26    let builder = Rc::new(RefCell::new(builder));
27    let x = ExprBuilder::new_input(builder.clone());
28    let y = ExprBuilder::new_input(builder.clone());
29    let (z_idx, z) = builder.borrow_mut().new_var();
30    let mut z = FieldVariable::from_var(builder.clone(), z);
31    let is_mul_flag = builder.borrow_mut().new_flag();
32    let is_div_flag = builder.borrow_mut().new_flag();
33    // constraint is x * y = z, or z * y = x
34    let lvar = FieldVariable::select(is_mul_flag, &x, &z);
35    let rvar = FieldVariable::select(is_mul_flag, &z, &x);
36    // When it's SETUP op, x = p == 0, y = 0, both flags are false, and it still works: z * 0 - x = 0, whatever z is.
37    let constraint = lvar * y.clone() - rvar;
38    builder.borrow_mut().set_constraint(z_idx, constraint.expr);
39    let compute = SymbolicExpr::Select(
40        is_mul_flag,
41        Box::new(x.expr.clone() * y.expr.clone()),
42        Box::new(SymbolicExpr::Select(
43            is_div_flag,
44            Box::new(x.expr.clone() / y.expr.clone()),
45            Box::new(x.expr.clone()),
46        )),
47    );
48    builder.borrow_mut().set_compute(z_idx, compute);
49    z.save_output();
50
51    let builder = builder.borrow().clone();
52
53    (
54        FieldExpr::new(builder, range_bus, true),
55        is_mul_flag,
56        is_div_flag,
57    )
58}
59
60#[derive(Chip, ChipUsageGetter, InstructionExecutor)]
61pub struct ModularMulDivChip<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>(
62    pub  VmChipWrapper<
63        F,
64        Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
65        FieldExpressionCoreChip,
66    >,
67);
68
69impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>
70    ModularMulDivChip<F, BLOCKS, BLOCK_SIZE>
71{
72    pub fn new(
73        adapter: Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
74        config: ExprBuilderConfig,
75        offset: usize,
76        range_checker: SharedVariableRangeCheckerChip,
77        offline_memory: Arc<Mutex<OfflineMemory<F>>>,
78    ) -> Self {
79        let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker.bus());
80        let core = FieldExpressionCoreChip::new(
81            expr,
82            offset,
83            vec![
84                Rv32ModularArithmeticOpcode::MUL as usize,
85                Rv32ModularArithmeticOpcode::DIV as usize,
86                Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize,
87            ],
88            vec![is_mul_flag, is_div_flag],
89            range_checker,
90            "ModularMulDiv",
91            false,
92        );
93        Self(VmChipWrapper::new(adapter, core, offline_memory))
94    }
95}