openvm_rv32im_circuit/adapters/
mul.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4};
5
6use openvm_circuit::{
7    arch::{
8        AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
9        ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip,
10        VmAdapterInterface,
11    },
12    system::{
13        memory::{
14            offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
15            MemoryAddress, MemoryController, OfflineMemory, RecordId,
16        },
17        program::ProgramBus,
18    },
19};
20use openvm_circuit_primitives_derive::AlignedBorrow;
21use openvm_instructions::{
22    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
23};
24use openvm_stark_backend::{
25    interaction::InteractionBuilder,
26    p3_air::BaseAir,
27    p3_field::{Field, FieldAlgebra, PrimeField32},
28};
29use serde::{Deserialize, Serialize};
30
31use super::RV32_REGISTER_NUM_LIMBS;
32
33/// Reads instructions of the form OP a, b, c, d where \[a:4\]_d = \[b:4\]_d op \[c:4\]_d.
34/// Operand d can only be 1, and there is no immediate support.
35#[derive(Debug)]
36pub struct Rv32MultAdapterChip<F: Field> {
37    pub air: Rv32MultAdapterAir,
38    _marker: PhantomData<F>,
39}
40
41impl<F: PrimeField32> Rv32MultAdapterChip<F> {
42    pub fn new(
43        execution_bus: ExecutionBus,
44        program_bus: ProgramBus,
45        memory_bridge: MemoryBridge,
46    ) -> Self {
47        Self {
48            air: Rv32MultAdapterAir {
49                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
50                memory_bridge,
51            },
52            _marker: PhantomData,
53        }
54    }
55}
56
57#[repr(C)]
58#[derive(Debug, Serialize, Deserialize)]
59pub struct Rv32MultReadRecord {
60    /// Reads from operand registers
61    pub rs1: RecordId,
62    pub rs2: RecordId,
63}
64
65#[repr(C)]
66#[derive(Debug, Serialize, Deserialize)]
67pub struct Rv32MultWriteRecord {
68    pub from_state: ExecutionState<u32>,
69    /// Write to destination register
70    pub rd_id: RecordId,
71}
72
73#[repr(C)]
74#[derive(AlignedBorrow)]
75pub struct Rv32MultAdapterCols<T> {
76    pub from_state: ExecutionState<T>,
77    pub rd_ptr: T,
78    pub rs1_ptr: T,
79    pub rs2_ptr: T,
80    pub reads_aux: [MemoryReadAuxCols<T>; 2],
81    pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
82}
83
84#[derive(Clone, Copy, Debug, derive_new::new)]
85pub struct Rv32MultAdapterAir {
86    pub(super) execution_bridge: ExecutionBridge,
87    pub(super) memory_bridge: MemoryBridge,
88}
89
90impl<F: Field> BaseAir<F> for Rv32MultAdapterAir {
91    fn width(&self) -> usize {
92        Rv32MultAdapterCols::<F>::width()
93    }
94}
95
96impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32MultAdapterAir {
97    type Interface = BasicAdapterInterface<
98        AB::Expr,
99        MinimalInstruction<AB::Expr>,
100        2,
101        1,
102        RV32_REGISTER_NUM_LIMBS,
103        RV32_REGISTER_NUM_LIMBS,
104    >;
105
106    fn eval(
107        &self,
108        builder: &mut AB,
109        local: &[AB::Var],
110        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
111    ) {
112        let local: &Rv32MultAdapterCols<_> = local.borrow();
113        let timestamp = local.from_state.timestamp;
114        let mut timestamp_delta: usize = 0;
115        let mut timestamp_pp = || {
116            timestamp_delta += 1;
117            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
118        };
119
120        self.memory_bridge
121            .read(
122                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs1_ptr),
123                ctx.reads[0].clone(),
124                timestamp_pp(),
125                &local.reads_aux[0],
126            )
127            .eval(builder, ctx.instruction.is_valid.clone());
128
129        self.memory_bridge
130            .read(
131                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs2_ptr),
132                ctx.reads[1].clone(),
133                timestamp_pp(),
134                &local.reads_aux[1],
135            )
136            .eval(builder, ctx.instruction.is_valid.clone());
137
138        self.memory_bridge
139            .write(
140                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rd_ptr),
141                ctx.writes[0].clone(),
142                timestamp_pp(),
143                &local.writes_aux,
144            )
145            .eval(builder, ctx.instruction.is_valid.clone());
146
147        self.execution_bridge
148            .execute_and_increment_or_set_pc(
149                ctx.instruction.opcode,
150                [
151                    local.rd_ptr.into(),
152                    local.rs1_ptr.into(),
153                    local.rs2_ptr.into(),
154                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
155                    AB::Expr::ZERO,
156                ],
157                local.from_state,
158                AB::F::from_canonical_usize(timestamp_delta),
159                (DEFAULT_PC_STEP, ctx.to_pc),
160            )
161            .eval(builder, ctx.instruction.is_valid);
162    }
163
164    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
165        let cols: &Rv32MultAdapterCols<_> = local.borrow();
166        cols.from_state.pc
167    }
168}
169
170impl<F: PrimeField32> VmAdapterChip<F> for Rv32MultAdapterChip<F> {
171    type ReadRecord = Rv32MultReadRecord;
172    type WriteRecord = Rv32MultWriteRecord;
173    type Air = Rv32MultAdapterAir;
174    type Interface = BasicAdapterInterface<
175        F,
176        MinimalInstruction<F>,
177        2,
178        1,
179        RV32_REGISTER_NUM_LIMBS,
180        RV32_REGISTER_NUM_LIMBS,
181    >;
182
183    fn preprocess(
184        &mut self,
185        memory: &mut MemoryController<F>,
186        instruction: &Instruction<F>,
187    ) -> Result<(
188        <Self::Interface as VmAdapterInterface<F>>::Reads,
189        Self::ReadRecord,
190    )> {
191        let Instruction { b, c, d, .. } = *instruction;
192
193        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
194
195        let rs1 = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, b);
196        let rs2 = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, c);
197
198        Ok((
199            [rs1.1, rs2.1],
200            Self::ReadRecord {
201                rs1: rs1.0,
202                rs2: rs2.0,
203            },
204        ))
205    }
206
207    fn postprocess(
208        &mut self,
209        memory: &mut MemoryController<F>,
210        instruction: &Instruction<F>,
211        from_state: ExecutionState<u32>,
212        output: AdapterRuntimeContext<F, Self::Interface>,
213        _read_record: &Self::ReadRecord,
214    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
215        let Instruction { a, d, .. } = *instruction;
216        let (rd_id, _) = memory.write(d, a, output.writes[0]);
217
218        let timestamp_delta = memory.timestamp() - from_state.timestamp;
219        debug_assert!(
220            timestamp_delta == 3,
221            "timestamp delta is {}, expected 3",
222            timestamp_delta
223        );
224
225        Ok((
226            ExecutionState {
227                pc: from_state.pc + DEFAULT_PC_STEP,
228                timestamp: memory.timestamp(),
229            },
230            Self::WriteRecord { from_state, rd_id },
231        ))
232    }
233
234    fn generate_trace_row(
235        &self,
236        row_slice: &mut [F],
237        read_record: Self::ReadRecord,
238        write_record: Self::WriteRecord,
239        memory: &OfflineMemory<F>,
240    ) {
241        let aux_cols_factory = memory.aux_cols_factory();
242        let row_slice: &mut Rv32MultAdapterCols<_> = row_slice.borrow_mut();
243        row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
244        let rd = memory.record_by_id(write_record.rd_id);
245        row_slice.rd_ptr = rd.pointer;
246        let rs1 = memory.record_by_id(read_record.rs1);
247        let rs2 = memory.record_by_id(read_record.rs2);
248        row_slice.rs1_ptr = rs1.pointer;
249        row_slice.rs2_ptr = rs2.pointer;
250        aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]);
251        aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]);
252        aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux);
253    }
254
255    fn air(&self) -> &Self::Air {
256        &self.air
257    }
258}