1use derive_more::derive::From;
2use num_bigint::BigUint;
3use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
4use openvm_circuit::{
5 self,
6 arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError},
7 system::phantom::PhantomChip,
8};
9use openvm_circuit_derive::{AnyEnum, InstructionExecutor};
10use openvm_circuit_primitives::bitwise_op_lookup::{
11 BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
12};
13use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
14use openvm_instructions::{LocalOpcode, VmOpcode};
15use openvm_mod_circuit_builder::ExprBuilderConfig;
16use openvm_rv32_adapters::{Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterChip};
17use openvm_stark_backend::p3_field::PrimeField32;
18use serde::{Deserialize, Serialize};
19use serde_with::{serde_as, DisplayFromStr};
20use strum::EnumCount;
21
22use crate::modular_chip::{
23 ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivChip,
24};
25
26#[serde_as]
27#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
28pub struct ModularExtension {
29 #[serde_as(as = "Vec<DisplayFromStr>")]
30 pub supported_modulus: Vec<BigUint>,
31}
32
33#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From)]
34pub enum ModularExtensionExecutor<F: PrimeField32> {
35 ModularAddSubRv32_32(ModularAddSubChip<F, 1, 32>),
37 ModularMulDivRv32_32(ModularMulDivChip<F, 1, 32>),
38 ModularIsEqualRv32_32(ModularIsEqualChip<F, 1, 32, 32>),
39 ModularAddSubRv32_48(ModularAddSubChip<F, 3, 16>),
41 ModularMulDivRv32_48(ModularMulDivChip<F, 3, 16>),
42 ModularIsEqualRv32_48(ModularIsEqualChip<F, 3, 16, 48>),
43}
44
45#[derive(ChipUsageGetter, Chip, AnyEnum, From)]
46pub enum ModularExtensionPeriphery<F: PrimeField32> {
47 BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>),
48 Phantom(PhantomChip<F>),
50}
51
52impl<F: PrimeField32> VmExtension<F> for ModularExtension {
53 type Executor = ModularExtensionExecutor<F>;
54 type Periphery = ModularExtensionPeriphery<F>;
55
56 fn build(
57 &self,
58 builder: &mut VmInventoryBuilder<F>,
59 ) -> Result<VmInventory<Self::Executor, Self::Periphery>, VmInventoryError> {
60 let mut inventory = VmInventory::new();
61 let SystemPort {
62 execution_bus,
63 program_bus,
64 memory_bridge,
65 } = builder.system_port();
66 let range_checker = builder.system_base().range_checker_chip.clone();
67 let bitwise_lu_chip = if let Some(&chip) = builder
68 .find_chip::<SharedBitwiseOperationLookupChip<8>>()
69 .first()
70 {
71 chip.clone()
72 } else {
73 let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx());
74 let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus);
75 inventory.add_periphery_chip(chip.clone());
76 chip
77 };
78 let offline_memory = builder.system_base().offline_memory();
79 let address_bits = builder.system_config().memory_config.pointer_max_bits;
80
81 let addsub_opcodes = (Rv32ModularArithmeticOpcode::ADD as usize)
82 ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize);
83 let muldiv_opcodes = (Rv32ModularArithmeticOpcode::MUL as usize)
84 ..=(Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize);
85 let iseq_opcodes = (Rv32ModularArithmeticOpcode::IS_EQ as usize)
86 ..=(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize);
87
88 for (i, modulus) in self.supported_modulus.iter().enumerate() {
89 let bytes = modulus.bits().div_ceil(8);
91 let start_offset =
92 Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
93
94 let config32 = ExprBuilderConfig {
95 modulus: modulus.clone(),
96 num_limbs: 32,
97 limb_bits: 8,
98 };
99 let config48 = ExprBuilderConfig {
100 modulus: modulus.clone(),
101 num_limbs: 48,
102 limb_bits: 8,
103 };
104 let adapter_chip_32 = Rv32VecHeapAdapterChip::new(
105 execution_bus,
106 program_bus,
107 memory_bridge,
108 address_bits,
109 bitwise_lu_chip.clone(),
110 );
111 let adapter_chip_48 = Rv32VecHeapAdapterChip::new(
112 execution_bus,
113 program_bus,
114 memory_bridge,
115 address_bits,
116 bitwise_lu_chip.clone(),
117 );
118
119 if bytes <= 32 {
120 let addsub_chip = ModularAddSubChip::new(
121 adapter_chip_32.clone(),
122 config32.clone(),
123 start_offset,
124 range_checker.clone(),
125 offline_memory.clone(),
126 );
127 inventory.add_executor(
128 ModularExtensionExecutor::ModularAddSubRv32_32(addsub_chip),
129 addsub_opcodes
130 .clone()
131 .map(|x| VmOpcode::from_usize(x + start_offset)),
132 )?;
133 let muldiv_chip = ModularMulDivChip::new(
134 adapter_chip_32.clone(),
135 config32.clone(),
136 start_offset,
137 range_checker.clone(),
138 offline_memory.clone(),
139 );
140 inventory.add_executor(
141 ModularExtensionExecutor::ModularMulDivRv32_32(muldiv_chip),
142 muldiv_opcodes
143 .clone()
144 .map(|x| VmOpcode::from_usize(x + start_offset)),
145 )?;
146 let isequal_chip = ModularIsEqualChip::new(
147 Rv32IsEqualModAdapterChip::new(
148 execution_bus,
149 program_bus,
150 memory_bridge,
151 address_bits,
152 bitwise_lu_chip.clone(),
153 ),
154 ModularIsEqualCoreChip::new(
155 modulus.clone(),
156 bitwise_lu_chip.clone(),
157 start_offset,
158 ),
159 offline_memory.clone(),
160 );
161 inventory.add_executor(
162 ModularExtensionExecutor::ModularIsEqualRv32_32(isequal_chip),
163 iseq_opcodes
164 .clone()
165 .map(|x| VmOpcode::from_usize(x + start_offset)),
166 )?;
167 } else if bytes <= 48 {
168 let addsub_chip = ModularAddSubChip::new(
169 adapter_chip_48.clone(),
170 config48.clone(),
171 start_offset,
172 range_checker.clone(),
173 offline_memory.clone(),
174 );
175 inventory.add_executor(
176 ModularExtensionExecutor::ModularAddSubRv32_48(addsub_chip),
177 addsub_opcodes
178 .clone()
179 .map(|x| VmOpcode::from_usize(x + start_offset)),
180 )?;
181 let muldiv_chip = ModularMulDivChip::new(
182 adapter_chip_48.clone(),
183 config48.clone(),
184 start_offset,
185 range_checker.clone(),
186 offline_memory.clone(),
187 );
188 inventory.add_executor(
189 ModularExtensionExecutor::ModularMulDivRv32_48(muldiv_chip),
190 muldiv_opcodes
191 .clone()
192 .map(|x| VmOpcode::from_usize(x + start_offset)),
193 )?;
194 let isequal_chip = ModularIsEqualChip::new(
195 Rv32IsEqualModAdapterChip::new(
196 execution_bus,
197 program_bus,
198 memory_bridge,
199 address_bits,
200 bitwise_lu_chip.clone(),
201 ),
202 ModularIsEqualCoreChip::new(
203 modulus.clone(),
204 bitwise_lu_chip.clone(),
205 start_offset,
206 ),
207 offline_memory.clone(),
208 );
209 inventory.add_executor(
210 ModularExtensionExecutor::ModularIsEqualRv32_48(isequal_chip),
211 iseq_opcodes
212 .clone()
213 .map(|x| VmOpcode::from_usize(x + start_offset)),
214 )?;
215 } else {
216 panic!("Modulus too large");
217 }
218 }
219
220 Ok(inventory)
221 }
222}