openvm_rv32im_circuit/adapters/
mul.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::{
5        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
6        BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir,
7    },
8    system::memory::{
9        offline_checker::{
10            MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
11            MemoryWriteBytesAuxRecord,
12        },
13        online::TracingMemory,
14        MemoryAddress, MemoryAuxColsFactory,
15    },
16};
17use openvm_circuit_primitives::AlignedBytesBorrow;
18use openvm_circuit_primitives_derive::AlignedBorrow;
19use openvm_instructions::{
20    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
21};
22use openvm_stark_backend::{
23    interaction::InteractionBuilder,
24    p3_air::BaseAir,
25    p3_field::{Field, FieldAlgebra, PrimeField32},
26};
27
28use super::{tracing_write, RV32_REGISTER_NUM_LIMBS};
29use crate::adapters::tracing_read;
30
31#[repr(C)]
32#[derive(AlignedBorrow)]
33pub struct Rv32MultAdapterCols<T> {
34    pub from_state: ExecutionState<T>,
35    pub rd_ptr: T,
36    pub rs1_ptr: T,
37    pub rs2_ptr: T,
38    pub reads_aux: [MemoryReadAuxCols<T>; 2],
39    pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
40}
41
42/// Reads instructions of the form OP a, b, c, d where \[a:4\]_d = \[b:4\]_d op \[c:4\]_d.
43/// Operand d can only be 1, and there is no immediate support.
44#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct Rv32MultAdapterAir {
46    pub(super) execution_bridge: ExecutionBridge,
47    pub(super) memory_bridge: MemoryBridge,
48}
49
50impl<F: Field> BaseAir<F> for Rv32MultAdapterAir {
51    fn width(&self) -> usize {
52        Rv32MultAdapterCols::<F>::width()
53    }
54}
55
56impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32MultAdapterAir {
57    type Interface = BasicAdapterInterface<
58        AB::Expr,
59        MinimalInstruction<AB::Expr>,
60        2,
61        1,
62        RV32_REGISTER_NUM_LIMBS,
63        RV32_REGISTER_NUM_LIMBS,
64    >;
65
66    fn eval(
67        &self,
68        builder: &mut AB,
69        local: &[AB::Var],
70        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
71    ) {
72        let local: &Rv32MultAdapterCols<_> = local.borrow();
73        let timestamp = local.from_state.timestamp;
74        let mut timestamp_delta: usize = 0;
75        let mut timestamp_pp = || {
76            timestamp_delta += 1;
77            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
78        };
79
80        self.memory_bridge
81            .read(
82                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs1_ptr),
83                ctx.reads[0].clone(),
84                timestamp_pp(),
85                &local.reads_aux[0],
86            )
87            .eval(builder, ctx.instruction.is_valid.clone());
88
89        self.memory_bridge
90            .read(
91                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs2_ptr),
92                ctx.reads[1].clone(),
93                timestamp_pp(),
94                &local.reads_aux[1],
95            )
96            .eval(builder, ctx.instruction.is_valid.clone());
97
98        self.memory_bridge
99            .write(
100                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rd_ptr),
101                ctx.writes[0].clone(),
102                timestamp_pp(),
103                &local.writes_aux,
104            )
105            .eval(builder, ctx.instruction.is_valid.clone());
106
107        self.execution_bridge
108            .execute_and_increment_or_set_pc(
109                ctx.instruction.opcode,
110                [
111                    local.rd_ptr.into(),
112                    local.rs1_ptr.into(),
113                    local.rs2_ptr.into(),
114                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
115                    AB::Expr::ZERO,
116                ],
117                local.from_state,
118                AB::F::from_canonical_usize(timestamp_delta),
119                (DEFAULT_PC_STEP, ctx.to_pc),
120            )
121            .eval(builder, ctx.instruction.is_valid);
122    }
123
124    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
125        let cols: &Rv32MultAdapterCols<_> = local.borrow();
126        cols.from_state.pc
127    }
128}
129
130#[repr(C)]
131#[derive(AlignedBytesBorrow, Debug)]
132pub struct Rv32MultAdapterRecord {
133    pub from_pc: u32,
134    pub from_timestamp: u32,
135
136    pub rd_ptr: u32,
137    pub rs1_ptr: u32,
138    pub rs2_ptr: u32,
139
140    pub reads_aux: [MemoryReadAuxRecord; 2],
141    pub writes_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
142}
143
144#[derive(Clone, Copy, derive_new::new)]
145pub struct Rv32MultAdapterExecutor;
146
147#[derive(Clone, Copy, derive_new::new)]
148pub struct Rv32MultAdapterFiller;
149
150impl<F> AdapterTraceExecutor<F> for Rv32MultAdapterExecutor
151where
152    F: PrimeField32,
153{
154    const WIDTH: usize = size_of::<Rv32MultAdapterCols<u8>>();
155    type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2];
156    type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1];
157    type RecordMut<'a> = &'a mut Rv32MultAdapterRecord;
158
159    #[inline(always)]
160    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
161        record.from_pc = pc;
162        record.from_timestamp = memory.timestamp;
163    }
164
165    #[inline(always)]
166    fn read(
167        &self,
168        memory: &mut TracingMemory,
169        instruction: &Instruction<F>,
170        record: &mut Self::RecordMut<'_>,
171    ) -> Self::ReadData {
172        let &Instruction { b, c, d, .. } = instruction;
173
174        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
175
176        record.rs1_ptr = b.as_canonical_u32();
177        let rs1 = tracing_read(
178            memory,
179            RV32_REGISTER_AS,
180            b.as_canonical_u32(),
181            &mut record.reads_aux[0].prev_timestamp,
182        );
183        record.rs2_ptr = c.as_canonical_u32();
184        let rs2 = tracing_read(
185            memory,
186            RV32_REGISTER_AS,
187            c.as_canonical_u32(),
188            &mut record.reads_aux[1].prev_timestamp,
189        );
190
191        [rs1, rs2]
192    }
193
194    #[inline(always)]
195    fn write(
196        &self,
197        memory: &mut TracingMemory,
198        instruction: &Instruction<F>,
199        data: Self::WriteData,
200        record: &mut Self::RecordMut<'_>,
201    ) {
202        let &Instruction { a, d, .. } = instruction;
203
204        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
205
206        record.rd_ptr = a.as_canonical_u32();
207        tracing_write(
208            memory,
209            RV32_REGISTER_AS,
210            a.as_canonical_u32(),
211            data[0],
212            &mut record.writes_aux.prev_timestamp,
213            &mut record.writes_aux.prev_data,
214        )
215    }
216}
217
218impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32MultAdapterFiller {
219    const WIDTH: usize = size_of::<Rv32MultAdapterCols<u8>>();
220
221    #[inline(always)]
222    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
223        // SAFETY:
224        // - caller ensures `adapter_row` contains a valid record representation that was previously
225        //   written by the executor
226        // - get_record_from_slice correctly interprets the bytes as Rv32MultAdapterRecord
227        let record: &Rv32MultAdapterRecord = unsafe { get_record_from_slice(&mut adapter_row, ()) };
228        let adapter_row: &mut Rv32MultAdapterCols<F> = adapter_row.borrow_mut();
229
230        let timestamp = record.from_timestamp;
231
232        adapter_row
233            .writes_aux
234            .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8));
235        mem_helper.fill(
236            record.writes_aux.prev_timestamp,
237            timestamp + 2,
238            adapter_row.writes_aux.as_mut(),
239        );
240
241        mem_helper.fill(
242            record.reads_aux[1].prev_timestamp,
243            timestamp + 1,
244            adapter_row.reads_aux[1].as_mut(),
245        );
246
247        mem_helper.fill(
248            record.reads_aux[0].prev_timestamp,
249            timestamp,
250            adapter_row.reads_aux[0].as_mut(),
251        );
252
253        adapter_row.rs2_ptr = F::from_canonical_u32(record.rs2_ptr);
254        adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr);
255        adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr);
256
257        adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
258        adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
259    }
260}