openvm_algebra_circuit/modular_chip/
muldiv.rs
1use std::{
2 cell::RefCell,
3 rc::Rc,
4 sync::{Arc, Mutex},
5};
6
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory};
9use openvm_circuit_derive::InstructionExecutor;
10use openvm_circuit_primitives::var_range::{
11 SharedVariableRangeCheckerChip, VariableRangeCheckerBus,
12};
13use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
14use openvm_mod_circuit_builder::{
15 ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, SymbolicExpr,
16};
17use openvm_rv32_adapters::Rv32VecHeapAdapterChip;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20pub fn muldiv_expr(
21 config: ExprBuilderConfig,
22 range_bus: VariableRangeCheckerBus,
23) -> (FieldExpr, usize, usize) {
24 config.check_valid();
25 let builder = ExprBuilder::new(config, range_bus.range_max_bits);
26 let builder = Rc::new(RefCell::new(builder));
27 let x = ExprBuilder::new_input(builder.clone());
28 let y = ExprBuilder::new_input(builder.clone());
29 let (z_idx, z) = builder.borrow_mut().new_var();
30 let mut z = FieldVariable::from_var(builder.clone(), z);
31 let is_mul_flag = builder.borrow_mut().new_flag();
32 let is_div_flag = builder.borrow_mut().new_flag();
33 let lvar = FieldVariable::select(is_mul_flag, &x, &z);
35 let rvar = FieldVariable::select(is_mul_flag, &z, &x);
36 let constraint = lvar * y.clone() - rvar;
38 builder.borrow_mut().set_constraint(z_idx, constraint.expr);
39 let compute = SymbolicExpr::Select(
40 is_mul_flag,
41 Box::new(x.expr.clone() * y.expr.clone()),
42 Box::new(SymbolicExpr::Select(
43 is_div_flag,
44 Box::new(x.expr.clone() / y.expr.clone()),
45 Box::new(x.expr.clone()),
46 )),
47 );
48 builder.borrow_mut().set_compute(z_idx, compute);
49 z.save_output();
50
51 let builder = builder.borrow().clone();
52
53 (
54 FieldExpr::new(builder, range_bus, true),
55 is_mul_flag,
56 is_div_flag,
57 )
58}
59
60#[derive(Chip, ChipUsageGetter, InstructionExecutor)]
61pub struct ModularMulDivChip<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>(
62 pub VmChipWrapper<
63 F,
64 Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
65 FieldExpressionCoreChip,
66 >,
67);
68
69impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>
70 ModularMulDivChip<F, BLOCKS, BLOCK_SIZE>
71{
72 pub fn new(
73 adapter: Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
74 config: ExprBuilderConfig,
75 offset: usize,
76 range_checker: SharedVariableRangeCheckerChip,
77 offline_memory: Arc<Mutex<OfflineMemory<F>>>,
78 ) -> Self {
79 let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker.bus());
80 let core = FieldExpressionCoreChip::new(
81 expr,
82 offset,
83 vec![
84 Rv32ModularArithmeticOpcode::MUL as usize,
85 Rv32ModularArithmeticOpcode::DIV as usize,
86 Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize,
87 ],
88 vec![is_mul_flag, is_div_flag],
89 range_checker,
90 "ModularMulDiv",
91 false,
92 );
93 Self(VmChipWrapper::new(adapter, core, offline_memory))
94 }
95}