openvm_rv32im_circuit/load_sign_extend/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    utils::select,
12    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
13    AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17    instruction::Instruction,
18    program::DEFAULT_PC_STEP,
19    riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS},
20    LocalOpcode,
21};
22use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
23use openvm_stark_backend::{
24    interaction::InteractionBuilder,
25    p3_air::BaseAir,
26    p3_field::{Field, FieldAlgebra, PrimeField32},
27    rap::BaseAirWithPublicValues,
28};
29
30use crate::adapters::{LoadStoreInstruction, Rv32LoadStoreAdapterFiller};
31
32/// LoadSignExtend Core Chip handles byte/halfword into word conversions through sign extend
33/// This chip uses read_data to construct write_data
34/// prev_data columns are not used in constraints defined in the CoreAir, but are used in
35/// constraints by the Adapter shifted_read_data is the read_data shifted by (shift_amount & 2),
36/// this reduces the number of opcode flags needed using this shifted data we can generate the
37/// write_data as if the shift_amount was 0 for loadh and 0 or 1 for loadb
38#[repr(C)]
39#[derive(Debug, Clone, AlignedBorrow)]
40pub struct LoadSignExtendCoreCols<T, const NUM_CELLS: usize> {
41    /// This chip treats loadb with 0 shift and loadb with 1 shift as different instructions
42    pub opcode_loadb_flag0: T,
43    pub opcode_loadb_flag1: T,
44    pub opcode_loadh_flag: T,
45
46    pub shift_most_sig_bit: T,
47    // The bit that is extended to the remaining bits
48    pub data_most_sig_bit: T,
49
50    pub shifted_read_data: [T; NUM_CELLS],
51    pub prev_data: [T; NUM_CELLS],
52}
53
54#[derive(Debug, Clone, derive_new::new)]
55pub struct LoadSignExtendCoreAir<const NUM_CELLS: usize, const LIMB_BITS: usize> {
56    pub range_bus: VariableRangeCheckerBus,
57}
58
59impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAir<F>
60    for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
61{
62    fn width(&self) -> usize {
63        LoadSignExtendCoreCols::<F, NUM_CELLS>::width()
64    }
65}
66
67impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
68    for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
69{
70}
71
72impl<AB, I, const NUM_CELLS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
73    for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
74where
75    AB: InteractionBuilder,
76    I: VmAdapterInterface<AB::Expr>,
77    I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
78    I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
79    I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
80{
81    fn eval(
82        &self,
83        builder: &mut AB,
84        local_core: &[AB::Var],
85        _from_pc: AB::Var,
86    ) -> AdapterAirContext<AB::Expr, I> {
87        let cols: &LoadSignExtendCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
88        let LoadSignExtendCoreCols::<AB::Var, NUM_CELLS> {
89            shifted_read_data,
90            prev_data,
91            opcode_loadb_flag0: is_loadb0,
92            opcode_loadb_flag1: is_loadb1,
93            opcode_loadh_flag: is_loadh,
94            data_most_sig_bit,
95            shift_most_sig_bit,
96        } = *cols;
97
98        let flags = [is_loadb0, is_loadb1, is_loadh];
99
100        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
101            builder.assert_bool(flag);
102            acc + flag
103        });
104
105        builder.assert_bool(is_valid.clone());
106        builder.assert_bool(data_most_sig_bit);
107        builder.assert_bool(shift_most_sig_bit);
108
109        let expected_opcode = (is_loadb0 + is_loadb1) * AB::F::from_canonical_u8(LOADB as u8)
110            + is_loadh * AB::F::from_canonical_u8(LOADH as u8)
111            + AB::Expr::from_canonical_usize(Rv32LoadStoreOpcode::CLASS_OFFSET);
112
113        let limb_mask = data_most_sig_bit * AB::Expr::from_canonical_u32((1 << LIMB_BITS) - 1);
114
115        // there are three parts to write_data:
116        // - 1st limb is always shifted_read_data
117        // - 2nd to (NUM_CELLS/2)th limbs are read_data if loadh and sign extended if loadb
118        // - (NUM_CELLS/2 + 1)th to last limbs are always sign extended limbs
119        let write_data: [AB::Expr; NUM_CELLS] = array::from_fn(|i| {
120            if i == 0 {
121                (is_loadh + is_loadb0) * shifted_read_data[i].into()
122                    + is_loadb1 * shifted_read_data[i + 1].into()
123            } else if i < NUM_CELLS / 2 {
124                shifted_read_data[i] * is_loadh + (is_loadb0 + is_loadb1) * limb_mask.clone()
125            } else {
126                limb_mask.clone()
127            }
128        });
129
130        // Constrain that most_sig_bit is correct
131        let most_sig_limb = shifted_read_data[0] * is_loadb0
132            + shifted_read_data[1] * is_loadb1
133            + shifted_read_data[NUM_CELLS / 2 - 1] * is_loadh;
134
135        self.range_bus
136            .range_check(
137                most_sig_limb
138                    - data_most_sig_bit * AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)),
139                LIMB_BITS - 1,
140            )
141            .eval(builder, is_valid.clone());
142
143        // Unshift the shifted_read_data to get the original read_data
144        let read_data = array::from_fn(|i| {
145            select(
146                shift_most_sig_bit,
147                shifted_read_data[(i + NUM_CELLS - 2) % NUM_CELLS],
148                shifted_read_data[i],
149            )
150        });
151        let load_shift_amount = shift_most_sig_bit * AB::Expr::TWO + is_loadb1;
152
153        AdapterAirContext {
154            to_pc: None,
155            reads: (prev_data, read_data).into(),
156            writes: [write_data].into(),
157            instruction: LoadStoreInstruction {
158                is_valid: is_valid.clone(),
159                opcode: expected_opcode,
160                is_load: is_valid,
161                load_shift_amount,
162                store_shift_amount: AB::Expr::ZERO,
163            }
164            .into(),
165        }
166    }
167
168    fn start_offset(&self) -> usize {
169        Rv32LoadStoreOpcode::CLASS_OFFSET
170    }
171}
172
173#[repr(C)]
174#[derive(AlignedBytesBorrow, Debug)]
175pub struct LoadSignExtendCoreRecord<const NUM_CELLS: usize> {
176    pub is_byte: bool,
177    pub shift_amount: u8,
178    pub read_data: [u8; NUM_CELLS],
179    pub prev_data: [u8; NUM_CELLS],
180}
181
182#[derive(Clone, Copy, derive_new::new)]
183pub struct LoadSignExtendExecutor<A, const NUM_CELLS: usize, const LIMB_BITS: usize> {
184    adapter: A,
185}
186
187#[derive(Clone, derive_new::new)]
188pub struct LoadSignExtendFiller<
189    A = Rv32LoadStoreAdapterFiller,
190    const NUM_CELLS: usize = RV32_REGISTER_NUM_LIMBS,
191    const LIMB_BITS: usize = RV32_CELL_BITS,
192> {
193    adapter: A,
194    pub range_checker_chip: SharedVariableRangeCheckerChip,
195}
196
197impl<F, A, RA, const NUM_CELLS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
198    for LoadSignExtendExecutor<A, NUM_CELLS, LIMB_BITS>
199where
200    F: PrimeField32,
201    A: 'static
202        + AdapterTraceExecutor<
203            F,
204            ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8),
205            WriteData = [u32; NUM_CELLS],
206        >,
207    for<'buf> RA: RecordArena<
208        'buf,
209        EmptyAdapterCoreLayout<F, A>,
210        (
211            A::RecordMut<'buf>,
212            &'buf mut LoadSignExtendCoreRecord<NUM_CELLS>,
213        ),
214    >,
215{
216    fn get_opcode_name(&self, opcode: usize) -> String {
217        format!(
218            "{:?}",
219            Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET)
220        )
221    }
222
223    fn execute(
224        &self,
225        state: VmStateMut<F, TracingMemory, RA>,
226        instruction: &Instruction<F>,
227    ) -> Result<(), ExecutionError> {
228        let Instruction { opcode, .. } = instruction;
229
230        let local_opcode = Rv32LoadStoreOpcode::from_usize(
231            opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
232        );
233
234        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
235
236        A::start(*state.pc, state.memory, &mut adapter_record);
237
238        let tmp = self
239            .adapter
240            .read(state.memory, instruction, &mut adapter_record);
241
242        core_record.is_byte = local_opcode == LOADB;
243        core_record.prev_data = tmp.0 .0.map(|x| x as u8);
244        core_record.read_data = tmp.0 .1;
245        core_record.shift_amount = tmp.1;
246
247        let write_data = run_write_data_sign_extend(
248            local_opcode,
249            core_record.read_data,
250            core_record.shift_amount as usize,
251        );
252
253        self.adapter.write(
254            state.memory,
255            instruction,
256            write_data.map(u32::from),
257            &mut adapter_record,
258        );
259
260        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
261
262        Ok(())
263    }
264}
265
266impl<F, A, const NUM_CELLS: usize, const LIMB_BITS: usize> TraceFiller<F>
267    for LoadSignExtendFiller<A, NUM_CELLS, LIMB_BITS>
268where
269    F: PrimeField32,
270    A: 'static + AdapterTraceFiller<F>,
271{
272    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
273        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
274        // LoadSignExtendCoreCols::width() elements
275        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
276        self.adapter.fill_trace_row(mem_helper, adapter_row);
277        // SAFETY: core_row contains a valid LoadSignExtendCoreRecord written by the executor
278        // during trace generation
279        let record: &LoadSignExtendCoreRecord<NUM_CELLS> =
280            unsafe { get_record_from_slice(&mut core_row, ()) };
281
282        let core_row: &mut LoadSignExtendCoreCols<F, NUM_CELLS> = core_row.borrow_mut();
283
284        let shift = record.shift_amount;
285        let most_sig_limb = if record.is_byte {
286            record.read_data[shift as usize]
287        } else {
288            record.read_data[NUM_CELLS / 2 - 1 + shift as usize]
289        };
290
291        let most_sig_bit = most_sig_limb & (1 << 7);
292        self.range_checker_chip
293            .add_count((most_sig_limb - most_sig_bit) as u32, 7);
294
295        core_row.prev_data = record.prev_data.map(F::from_canonical_u8);
296        core_row.shifted_read_data = record.read_data.map(F::from_canonical_u8);
297        core_row.shifted_read_data.rotate_left((shift & 2) as usize);
298
299        core_row.data_most_sig_bit = F::from_bool(most_sig_bit != 0);
300        core_row.shift_most_sig_bit = F::from_bool(shift & 2 == 2);
301        core_row.opcode_loadh_flag = F::from_bool(!record.is_byte);
302        core_row.opcode_loadb_flag1 = F::from_bool(record.is_byte && ((shift & 1) == 1));
303        core_row.opcode_loadb_flag0 = F::from_bool(record.is_byte && ((shift & 1) == 0));
304    }
305}
306
307// Returns write_data
308#[inline(always)]
309pub(super) fn run_write_data_sign_extend<const NUM_CELLS: usize>(
310    opcode: Rv32LoadStoreOpcode,
311    read_data: [u8; NUM_CELLS],
312    shift: usize,
313) -> [u8; NUM_CELLS] {
314    match (opcode, shift) {
315        (LOADH, 0) | (LOADH, 2) => {
316            let ext = (read_data[NUM_CELLS / 2 - 1 + shift] >> 7) * u8::MAX;
317            array::from_fn(|i| {
318                if i < NUM_CELLS / 2 {
319                    read_data[i + shift]
320                } else {
321                    ext
322                }
323            })
324        }
325        (LOADB, 0) | (LOADB, 1) | (LOADB, 2) | (LOADB, 3) => {
326            let ext = (read_data[shift] >> 7) * u8::MAX;
327            array::from_fn(|i| {
328                if i == 0 {
329                    read_data[i + shift]
330                } else {
331                    ext
332                }
333            })
334        }
335        // Currently the adapter AIR requires `ptr_val` to be aligned to the data size in bytes.
336        // The circuit requires that `shift = ptr_val % 4` so that `ptr_val - shift` is a multiple of 4.
337        // This requirement is non-trivial to remove, because we use it to ensure that `ptr_val - shift + 4 <= 2^pointer_max_bits`.
338        _ => unreachable!(
339            "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
340        ),
341    }
342}