openvm_algebra_circuit/modular_chip/
muldiv.rs

1use std::{cell::RefCell, rc::Rc};
2
3use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
4use openvm_circuit::{
5    arch::ExecutionBridge,
6    system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper},
7};
8use openvm_circuit_primitives::{
9    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
10    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
11};
12use openvm_instructions::riscv::RV32_CELL_BITS;
13use openvm_mod_circuit_builder::{
14    ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldExpressionExecutor,
15    FieldExpressionFiller, FieldVariable, SymbolicExpr,
16};
17use openvm_rv32_adapters::{
18    Rv32VecHeapAdapterAir, Rv32VecHeapAdapterExecutor, Rv32VecHeapAdapterFiller,
19};
20
21use super::{ModularAir, ModularChip, ModularExecutor};
22use crate::FieldExprVecHeapExecutor;
23
24pub fn muldiv_expr(
25    config: ExprBuilderConfig,
26    range_bus: VariableRangeCheckerBus,
27) -> (FieldExpr, usize, usize) {
28    config.check_valid();
29    let builder = ExprBuilder::new(config, range_bus.range_max_bits);
30    let builder = Rc::new(RefCell::new(builder));
31    let x = ExprBuilder::new_input(builder.clone());
32    let y = ExprBuilder::new_input(builder.clone());
33    let (z_idx, z) = (*builder).borrow_mut().new_var();
34    let mut z = FieldVariable::from_var(builder.clone(), z);
35    let is_mul_flag = (*builder).borrow_mut().new_flag();
36    let is_div_flag = (*builder).borrow_mut().new_flag();
37    // constraint is x * y = z, or z * y = x
38    let lvar = FieldVariable::select(is_mul_flag, &x, &z);
39    let rvar = FieldVariable::select(is_mul_flag, &z, &x);
40    // When it's SETUP op, x = p == 0, y = 0, both flags are false, and it still works: z * 0 - x =
41    // 0, whatever z is.
42    let constraint = lvar * y.clone() - rvar;
43    (*builder)
44        .borrow_mut()
45        .set_constraint(z_idx, constraint.expr);
46    let compute = SymbolicExpr::Select(
47        is_mul_flag,
48        Box::new(x.expr.clone() * y.expr.clone()),
49        Box::new(SymbolicExpr::Select(
50            is_div_flag,
51            Box::new(x.expr.clone() / y.expr.clone()),
52            Box::new(x.expr.clone()),
53        )),
54    );
55    (*builder).borrow_mut().set_compute(z_idx, compute);
56    z.save_output();
57
58    let builder = (*builder).borrow().clone();
59
60    (
61        FieldExpr::new(builder, range_bus, true),
62        is_mul_flag,
63        is_div_flag,
64    )
65}
66
67fn gen_base_expr(
68    config: ExprBuilderConfig,
69    range_checker_bus: VariableRangeCheckerBus,
70) -> (FieldExpr, Vec<usize>, Vec<usize>) {
71    let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker_bus);
72
73    let local_opcode_idx = vec![
74        Rv32ModularArithmeticOpcode::MUL as usize,
75        Rv32ModularArithmeticOpcode::DIV as usize,
76        Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize,
77    ];
78    let opcode_flag_idx = vec![is_mul_flag, is_div_flag];
79
80    (expr, local_opcode_idx, opcode_flag_idx)
81}
82
83pub fn get_modular_muldiv_air<const BLOCKS: usize, const BLOCK_SIZE: usize>(
84    exec_bridge: ExecutionBridge,
85    mem_bridge: MemoryBridge,
86    config: ExprBuilderConfig,
87    range_checker_bus: VariableRangeCheckerBus,
88    bitwise_lookup_bus: BitwiseOperationLookupBus,
89    pointer_max_bits: usize,
90    offset: usize,
91) -> ModularAir<BLOCKS, BLOCK_SIZE> {
92    let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus);
93    ModularAir::new(
94        Rv32VecHeapAdapterAir::new(
95            exec_bridge,
96            mem_bridge,
97            bitwise_lookup_bus,
98            pointer_max_bits,
99        ),
100        FieldExpressionCoreAir::new(expr, offset, local_opcode_idx, opcode_flag_idx),
101    )
102}
103
104pub fn get_modular_muldiv_step<const BLOCKS: usize, const BLOCK_SIZE: usize>(
105    config: ExprBuilderConfig,
106    range_checker_bus: VariableRangeCheckerBus,
107    pointer_max_bits: usize,
108    offset: usize,
109) -> ModularExecutor<BLOCKS, BLOCK_SIZE> {
110    let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus);
111
112    FieldExprVecHeapExecutor(FieldExpressionExecutor::new(
113        Rv32VecHeapAdapterExecutor::new(pointer_max_bits),
114        expr,
115        offset,
116        local_opcode_idx,
117        opcode_flag_idx,
118        "ModularMulDiv",
119    ))
120}
121
122pub fn get_modular_muldiv_chip<F, const BLOCKS: usize, const BLOCK_SIZE: usize>(
123    config: ExprBuilderConfig,
124    mem_helper: SharedMemoryHelper<F>,
125    range_checker: SharedVariableRangeCheckerChip,
126    bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
127    pointer_max_bits: usize,
128) -> ModularChip<F, BLOCKS, BLOCK_SIZE> {
129    let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker.bus());
130    ModularChip::new(
131        FieldExpressionFiller::new(
132            Rv32VecHeapAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip),
133            expr,
134            local_opcode_idx,
135            opcode_flag_idx,
136            range_checker,
137            false,
138        ),
139        mem_helper,
140    )
141}