openvm_algebra_circuit/modular_chip/
muldiv.rs1use std::{
2 cell::RefCell,
3 rc::Rc,
4 sync::{Arc, Mutex},
5};
6
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory};
9use openvm_circuit_derive::InstructionExecutor;
10use openvm_circuit_primitives::var_range::{
11 SharedVariableRangeCheckerChip, VariableRangeCheckerBus,
12};
13use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
14use openvm_mod_circuit_builder::{
15 ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, SymbolicExpr,
16};
17use openvm_rv32_adapters::Rv32VecHeapAdapterChip;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20pub fn muldiv_expr(
21 config: ExprBuilderConfig,
22 range_bus: VariableRangeCheckerBus,
23) -> (FieldExpr, usize, usize) {
24 config.check_valid();
25 let builder = ExprBuilder::new(config, range_bus.range_max_bits);
26 let builder = Rc::new(RefCell::new(builder));
27 let x = ExprBuilder::new_input(builder.clone());
28 let y = ExprBuilder::new_input(builder.clone());
29 let (z_idx, z) = builder.borrow_mut().new_var();
30 let mut z = FieldVariable::from_var(builder.clone(), z);
31 let is_mul_flag = builder.borrow_mut().new_flag();
32 let is_div_flag = builder.borrow_mut().new_flag();
33 let lvar = FieldVariable::select(is_mul_flag, &x, &z);
35 let rvar = FieldVariable::select(is_mul_flag, &z, &x);
36 let constraint = lvar * y.clone() - rvar;
39 builder.borrow_mut().set_constraint(z_idx, constraint.expr);
40 let compute = SymbolicExpr::Select(
41 is_mul_flag,
42 Box::new(x.expr.clone() * y.expr.clone()),
43 Box::new(SymbolicExpr::Select(
44 is_div_flag,
45 Box::new(x.expr.clone() / y.expr.clone()),
46 Box::new(x.expr.clone()),
47 )),
48 );
49 builder.borrow_mut().set_compute(z_idx, compute);
50 z.save_output();
51
52 let builder = builder.borrow().clone();
53
54 (
55 FieldExpr::new(builder, range_bus, true),
56 is_mul_flag,
57 is_div_flag,
58 )
59}
60
61#[derive(Chip, ChipUsageGetter, InstructionExecutor)]
62pub struct ModularMulDivChip<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>(
63 pub VmChipWrapper<
64 F,
65 Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
66 FieldExpressionCoreChip,
67 >,
68);
69
70impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>
71 ModularMulDivChip<F, BLOCKS, BLOCK_SIZE>
72{
73 pub fn new(
74 adapter: Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
75 config: ExprBuilderConfig,
76 offset: usize,
77 range_checker: SharedVariableRangeCheckerChip,
78 offline_memory: Arc<Mutex<OfflineMemory<F>>>,
79 ) -> Self {
80 let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker.bus());
81 let core = FieldExpressionCoreChip::new(
82 expr,
83 offset,
84 vec![
85 Rv32ModularArithmeticOpcode::MUL as usize,
86 Rv32ModularArithmeticOpcode::DIV as usize,
87 Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize,
88 ],
89 vec![is_mul_flag, is_div_flag],
90 range_checker,
91 "ModularMulDiv",
92 false,
93 );
94 Self(VmChipWrapper::new(adapter, core, offline_memory))
95 }
96}