openvm_rv32im_circuit/loadstore/
core.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::arch::{
4    AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip,
5};
6use openvm_circuit_primitives_derive::AlignedBorrow;
7use openvm_instructions::{instruction::Instruction, LocalOpcode};
8use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
9use openvm_stark_backend::{
10    interaction::InteractionBuilder,
11    p3_air::{AirBuilder, BaseAir},
12    p3_field::{Field, FieldAlgebra, PrimeField32},
13    rap::BaseAirWithPublicValues,
14};
15use serde::{de::DeserializeOwned, Deserialize, Serialize};
16use serde_big_array::BigArray;
17
18use crate::adapters::LoadStoreInstruction;
19
20#[derive(Debug, Clone, Copy)]
21enum InstructionOpcode {
22    LoadW0,
23    LoadHu0,
24    LoadHu2,
25    LoadBu0,
26    LoadBu1,
27    LoadBu2,
28    LoadBu3,
29    StoreW0,
30    StoreH0,
31    StoreH2,
32    StoreB0,
33    StoreB1,
34    StoreB2,
35    StoreB3,
36}
37
38use InstructionOpcode::*;
39
40/// LoadStore Core Chip handles byte/halfword into word conversions and unsigned extends
41/// This chip uses read_data and prev_data to constrain the write_data
42/// It also handles the shifting in case of not 4 byte aligned instructions
43/// This chips treats each (opcode, shift) pair as a separate instruction
44#[repr(C)]
45#[derive(Debug, Clone, AlignedBorrow)]
46pub struct LoadStoreCoreCols<T, const NUM_CELLS: usize> {
47    pub flags: [T; 4],
48    /// we need to keep the degree of is_valid and is_load to 1
49    pub is_valid: T,
50    pub is_load: T,
51
52    pub read_data: [T; NUM_CELLS],
53    pub prev_data: [T; NUM_CELLS],
54    /// write_data will be constrained against read_data and prev_data
55    /// depending on the opcode and the shift amount
56    pub write_data: [T; NUM_CELLS],
57}
58
59#[repr(C)]
60#[derive(Debug, Clone, Serialize, Deserialize)]
61#[serde(bound = "F: Serialize + DeserializeOwned")]
62pub struct LoadStoreCoreRecord<F, const NUM_CELLS: usize> {
63    pub opcode: Rv32LoadStoreOpcode,
64    pub shift: u32,
65    #[serde(with = "BigArray")]
66    pub read_data: [F; NUM_CELLS],
67    #[serde(with = "BigArray")]
68    pub prev_data: [F; NUM_CELLS],
69    #[serde(with = "BigArray")]
70    pub write_data: [F; NUM_CELLS],
71}
72
73#[derive(Debug, Clone)]
74pub struct LoadStoreCoreAir<const NUM_CELLS: usize> {
75    pub offset: usize,
76}
77
78impl<F: Field, const NUM_CELLS: usize> BaseAir<F> for LoadStoreCoreAir<NUM_CELLS> {
79    fn width(&self) -> usize {
80        LoadStoreCoreCols::<F, NUM_CELLS>::width()
81    }
82}
83
84impl<F: Field, const NUM_CELLS: usize> BaseAirWithPublicValues<F> for LoadStoreCoreAir<NUM_CELLS> {}
85
86impl<AB, I, const NUM_CELLS: usize> VmCoreAir<AB, I> for LoadStoreCoreAir<NUM_CELLS>
87where
88    AB: InteractionBuilder,
89    I: VmAdapterInterface<AB::Expr>,
90    I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
91    I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
92    I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
93{
94    fn eval(
95        &self,
96        builder: &mut AB,
97        local_core: &[AB::Var],
98        _from_pc: AB::Var,
99    ) -> AdapterAirContext<AB::Expr, I> {
100        let cols: &LoadStoreCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
101        let LoadStoreCoreCols::<AB::Var, NUM_CELLS> {
102            read_data,
103            prev_data,
104            write_data,
105            flags,
106            is_valid,
107            is_load,
108        } = *cols;
109
110        let get_expr_12 = |x: &AB::Expr| (x.clone() - AB::Expr::ONE) * (x.clone() - AB::Expr::TWO);
111
112        builder.assert_bool(is_valid);
113        let sum = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
114            builder.assert_zero(flag * get_expr_12(&flag.into()));
115            acc + flag
116        });
117        builder.assert_zero(sum.clone() * get_expr_12(&sum));
118        // when sum is 0, is_valid must be 0
119        builder.when(get_expr_12(&sum)).assert_zero(is_valid);
120
121        // We will use the InstructionOpcode enum to encode the opcodes
122        // the appended digit to each opcode is the shift amount
123        let inv_2 = AB::F::from_canonical_u32(2).inverse();
124        let mut opcode_flags = vec![];
125        for flag in flags {
126            opcode_flags.push(flag * (flag - AB::F::ONE) * inv_2);
127        }
128        for flag in flags {
129            opcode_flags.push(flag * (sum.clone() - AB::F::TWO) * AB::F::NEG_ONE);
130        }
131        (0..4).for_each(|i| {
132            ((i + 1)..4).for_each(|j| opcode_flags.push(flags[i] * flags[j]));
133        });
134
135        let opcode_when = |idxs: &[InstructionOpcode]| -> AB::Expr {
136            idxs.iter().fold(AB::Expr::ZERO, |acc, &idx| {
137                acc + opcode_flags[idx as usize].clone()
138            })
139        };
140
141        // Constrain that is_load matches the opcode
142        builder.assert_eq(
143            is_load,
144            opcode_when(&[LoadW0, LoadHu0, LoadHu2, LoadBu0, LoadBu1, LoadBu2, LoadBu3]),
145        );
146        builder.when(is_load).assert_one(is_valid);
147
148        // There are three parts to write_data:
149        // - 1st limb is always read_data
150        // - 2nd to (NUM_CELLS/2)th limbs are:
151        //   - read_data if loadw/loadhu/storew/storeh
152        //   - prev_data if storeb
153        //   - zero if loadbu
154        // - (NUM_CELLS/2 + 1)th to last limbs are:
155        //   - read_data if loadw/storew
156        //   - prev_data if storeb/storeh
157        //   - zero if loadbu/loadhu
158        // Shifting needs to be carefully handled in case by case basis
159        // refer to [run_write_data] for the expected behavior in each case
160        for (i, cell) in write_data.iter().enumerate() {
161            // handling loads, expected_load_val = 0 if a store operation is happening
162            let expected_load_val = if i == 0 {
163                opcode_when(&[LoadW0, LoadHu0, LoadBu0]) * read_data[0]
164                    + opcode_when(&[LoadBu1]) * read_data[1]
165                    + opcode_when(&[LoadHu2, LoadBu2]) * read_data[2]
166                    + opcode_when(&[LoadBu3]) * read_data[3]
167            } else if i < NUM_CELLS / 2 {
168                opcode_when(&[LoadW0, LoadHu0]) * read_data[i]
169                    + opcode_when(&[LoadHu2]) * read_data[i + 2]
170            } else {
171                opcode_when(&[LoadW0]) * read_data[i]
172            };
173
174            // handling stores, expected_store_val = 0 if a load operation is happening
175            let expected_store_val = if i == 0 {
176                opcode_when(&[StoreW0, StoreH0, StoreB0]) * read_data[i]
177                    + opcode_when(&[StoreH2, StoreB1, StoreB2, StoreB3]) * prev_data[i]
178            } else if i == 1 {
179                opcode_when(&[StoreB1]) * read_data[i - 1]
180                    + opcode_when(&[StoreW0, StoreH0]) * read_data[i]
181                    + opcode_when(&[StoreH2, StoreB0, StoreB2, StoreB3]) * prev_data[i]
182            } else if i == 2 {
183                opcode_when(&[StoreH2, StoreB2]) * read_data[i - 2]
184                    + opcode_when(&[StoreW0]) * read_data[i]
185                    + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB3]) * prev_data[i]
186            } else if i == 3 {
187                opcode_when(&[StoreB3]) * read_data[i - 3]
188                    + opcode_when(&[StoreH2]) * read_data[i - 2]
189                    + opcode_when(&[StoreW0]) * read_data[i]
190                    + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB2]) * prev_data[i]
191            } else {
192                opcode_when(&[StoreW0]) * read_data[i]
193                    + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3]) * prev_data[i]
194                    + opcode_when(&[StoreH0])
195                        * if i < NUM_CELLS / 2 {
196                            read_data[i]
197                        } else {
198                            prev_data[i]
199                        }
200                    + opcode_when(&[StoreH2])
201                        * if i - 2 < NUM_CELLS / 2 {
202                            read_data[i - 2]
203                        } else {
204                            prev_data[i]
205                        }
206            };
207            let expected_val = expected_load_val + expected_store_val;
208            builder.assert_eq(*cell, expected_val);
209        }
210
211        let expected_opcode = opcode_when(&[LoadW0]) * AB::Expr::from_canonical_u8(LOADW as u8)
212            + opcode_when(&[LoadHu0, LoadHu2]) * AB::Expr::from_canonical_u8(LOADHU as u8)
213            + opcode_when(&[LoadBu0, LoadBu1, LoadBu2, LoadBu3])
214                * AB::Expr::from_canonical_u8(LOADBU as u8)
215            + opcode_when(&[StoreW0]) * AB::Expr::from_canonical_u8(STOREW as u8)
216            + opcode_when(&[StoreH0, StoreH2]) * AB::Expr::from_canonical_u8(STOREH as u8)
217            + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3])
218                * AB::Expr::from_canonical_u8(STOREB as u8);
219        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(self, expected_opcode);
220
221        let load_shift_amount = opcode_when(&[LoadBu1]) * AB::Expr::ONE
222            + opcode_when(&[LoadHu2, LoadBu2]) * AB::Expr::TWO
223            + opcode_when(&[LoadBu3]) * AB::Expr::from_canonical_u32(3);
224
225        let store_shift_amount = opcode_when(&[StoreB1]) * AB::Expr::ONE
226            + opcode_when(&[StoreH2, StoreB2]) * AB::Expr::TWO
227            + opcode_when(&[StoreB3]) * AB::Expr::from_canonical_u32(3);
228
229        AdapterAirContext {
230            to_pc: None,
231            reads: (prev_data, read_data.map(|x| x.into())).into(),
232            writes: [write_data.map(|x| x.into())].into(),
233            instruction: LoadStoreInstruction {
234                is_valid: is_valid.into(),
235                opcode: expected_opcode,
236                is_load: is_load.into(),
237                load_shift_amount,
238                store_shift_amount,
239            }
240            .into(),
241        }
242    }
243
244    fn start_offset(&self) -> usize {
245        self.offset
246    }
247}
248
249#[derive(Debug)]
250pub struct LoadStoreCoreChip<const NUM_CELLS: usize> {
251    pub air: LoadStoreCoreAir<NUM_CELLS>,
252}
253
254impl<const NUM_CELLS: usize> LoadStoreCoreChip<NUM_CELLS> {
255    pub fn new(offset: usize) -> Self {
256        Self {
257            air: LoadStoreCoreAir { offset },
258        }
259    }
260}
261
262impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_CELLS: usize> VmCoreChip<F, I>
263    for LoadStoreCoreChip<NUM_CELLS>
264where
265    I::Reads: Into<([[F; NUM_CELLS]; 2], F)>,
266    I::Writes: From<[[F; NUM_CELLS]; 1]>,
267{
268    type Record = LoadStoreCoreRecord<F, NUM_CELLS>;
269    type Air = LoadStoreCoreAir<NUM_CELLS>;
270
271    #[allow(clippy::type_complexity)]
272    fn execute_instruction(
273        &self,
274        instruction: &Instruction<F>,
275        _from_pc: u32,
276        reads: I::Reads,
277    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
278        let local_opcode =
279            Rv32LoadStoreOpcode::from_usize(instruction.opcode.local_opcode_idx(self.air.offset));
280
281        let (reads, shift_amount) = reads.into();
282        let shift = shift_amount.as_canonical_u32();
283        let prev_data = reads[0];
284        let read_data = reads[1];
285        let write_data = run_write_data(local_opcode, read_data, prev_data, shift);
286        let output = AdapterRuntimeContext::without_pc([write_data]);
287
288        Ok((
289            output,
290            LoadStoreCoreRecord {
291                opcode: local_opcode,
292                shift,
293                prev_data,
294                read_data,
295                write_data,
296            },
297        ))
298    }
299
300    fn get_opcode_name(&self, opcode: usize) -> String {
301        format!(
302            "{:?}",
303            Rv32LoadStoreOpcode::from_usize(opcode - self.air.offset)
304        )
305    }
306
307    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
308        let core_cols: &mut LoadStoreCoreCols<F, NUM_CELLS> = row_slice.borrow_mut();
309        let opcode = record.opcode;
310        let flags = &mut core_cols.flags;
311        *flags = [F::ZERO; 4];
312        match (opcode, record.shift) {
313            (LOADW, 0) => flags[0] = F::TWO,
314            (LOADHU, 0) => flags[1] = F::TWO,
315            (LOADHU, 2) => flags[2] = F::TWO,
316            (LOADBU, 0) => flags[3] = F::TWO,
317
318            (LOADBU, 1) => flags[0] = F::ONE,
319            (LOADBU, 2) => flags[1] = F::ONE,
320            (LOADBU, 3) => flags[2] = F::ONE,
321            (STOREW, 0) => flags[3] = F::ONE,
322
323            (STOREH, 0) => (flags[0], flags[1]) = (F::ONE, F::ONE),
324            (STOREH, 2) => (flags[0], flags[2]) = (F::ONE, F::ONE),
325            (STOREB, 0) => (flags[0], flags[3]) = (F::ONE, F::ONE),
326            (STOREB, 1) => (flags[1], flags[2]) = (F::ONE, F::ONE),
327            (STOREB, 2) => (flags[1], flags[3]) = (F::ONE, F::ONE),
328            (STOREB, 3) => (flags[2], flags[3]) = (F::ONE, F::ONE),
329            _ => unreachable!(),
330        };
331        core_cols.prev_data = record.prev_data;
332        core_cols.read_data = record.read_data;
333        core_cols.is_valid = F::ONE;
334        core_cols.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode));
335        core_cols.write_data = record.write_data;
336    }
337
338    fn air(&self) -> &Self::Air {
339        &self.air
340    }
341}
342
343pub(super) fn run_write_data<F: PrimeField32, const NUM_CELLS: usize>(
344    opcode: Rv32LoadStoreOpcode,
345    read_data: [F; NUM_CELLS],
346    prev_data: [F; NUM_CELLS],
347    shift: u32,
348) -> [F; NUM_CELLS] {
349    let shift = shift as usize;
350    let mut write_data = read_data;
351    match (opcode, shift) {
352        (LOADW, 0) => (),
353        (LOADBU, 0) | (LOADBU, 1) | (LOADBU, 2) | (LOADBU, 3) => {
354            for cell in write_data.iter_mut().take(NUM_CELLS).skip(1) {
355                *cell = F::ZERO;
356            }
357            write_data[0] = read_data[shift];
358        }
359        (LOADHU, 0) | (LOADHU, 2) => {
360            for cell in write_data.iter_mut().take(NUM_CELLS).skip(NUM_CELLS / 2) {
361                *cell = F::ZERO;
362            }
363            for (i, cell) in write_data.iter_mut().take(NUM_CELLS / 2).enumerate() {
364                *cell = read_data[i + shift];
365            }
366        }
367        (STOREW, 0) => (),
368        (STOREB, 0) | (STOREB, 1) | (STOREB, 2) | (STOREB, 3) => {
369            write_data = prev_data;
370            write_data[shift] = read_data[0];
371        }
372        (STOREH, 0) | (STOREH, 2) => {
373            write_data = prev_data;
374            write_data[shift..(NUM_CELLS / 2 + shift)]
375                .copy_from_slice(&read_data[..(NUM_CELLS / 2)]);
376        }
377        // Currently the adapter AIR requires `ptr_val` to be aligned to the data size in bytes.
378        // The circuit requires that `shift = ptr_val % 4` so that `ptr_val - shift` is a multiple of 4.
379        // This requirement is non-trivial to remove, because we use it to ensure that `ptr_val - shift + 4 <= 2^pointer_max_bits`.
380        _ => unreachable!(
381            "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
382        ),
383    };
384    write_data
385}