openvm_rv32im_circuit/load_sign_extend/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    utils::select,
12    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
13    AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17    instruction::Instruction,
18    program::DEFAULT_PC_STEP,
19    riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS},
20    LocalOpcode,
21};
22use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
23use openvm_stark_backend::{
24    interaction::InteractionBuilder,
25    p3_air::BaseAir,
26    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
27    rap::BaseAirWithPublicValues,
28};
29
30use crate::adapters::{LoadStoreInstruction, Rv32LoadStoreAdapterFiller};
31
32/// LoadSignExtend Core Chip handles byte/halfword into word conversions through sign extend
33/// This chip uses read_data to construct write_data
34/// prev_data columns are not used in constraints defined in the CoreAir, but are used in
35/// constraints by the Adapter shifted_read_data is the read_data shifted by (shift_amount & 2),
36/// this reduces the number of opcode flags needed using this shifted data we can generate the
37/// write_data as if the shift_amount was 0 for loadh and 0 or 1 for loadb
38#[repr(C)]
39#[derive(Debug, Clone, AlignedBorrow)]
40pub struct LoadSignExtendCoreCols<T, const NUM_CELLS: usize> {
41    /// This chip treats loadb with 0 shift and loadb with 1 shift as different instructions
42    pub opcode_loadb_flag0: T,
43    pub opcode_loadb_flag1: T,
44    pub opcode_loadh_flag: T,
45
46    pub shift_most_sig_bit: T,
47    // The bit that is extended to the remaining bits
48    pub data_most_sig_bit: T,
49
50    pub shifted_read_data: [T; NUM_CELLS],
51    pub prev_data: [T; NUM_CELLS],
52}
53
54#[derive(Debug, Clone, derive_new::new)]
55pub struct LoadSignExtendCoreAir<const NUM_CELLS: usize, const LIMB_BITS: usize> {
56    pub range_bus: VariableRangeCheckerBus,
57}
58
59impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAir<F>
60    for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
61{
62    fn width(&self) -> usize {
63        LoadSignExtendCoreCols::<F, NUM_CELLS>::width()
64    }
65}
66
67impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
68    for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
69{
70}
71
72impl<AB, I, const NUM_CELLS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
73    for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
74where
75    AB: InteractionBuilder,
76    I: VmAdapterInterface<AB::Expr>,
77    I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
78    I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
79    I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
80{
81    fn eval(
82        &self,
83        builder: &mut AB,
84        local_core: &[AB::Var],
85        _from_pc: AB::Var,
86    ) -> AdapterAirContext<AB::Expr, I> {
87        let cols: &LoadSignExtendCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
88        let LoadSignExtendCoreCols::<AB::Var, NUM_CELLS> {
89            shifted_read_data,
90            prev_data,
91            opcode_loadb_flag0: is_loadb0,
92            opcode_loadb_flag1: is_loadb1,
93            opcode_loadh_flag: is_loadh,
94            data_most_sig_bit,
95            shift_most_sig_bit,
96        } = *cols;
97
98        let flags = [is_loadb0, is_loadb1, is_loadh];
99
100        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
101            builder.assert_bool(flag);
102            acc + flag
103        });
104
105        builder.assert_bool(is_valid.clone());
106        builder.assert_bool(data_most_sig_bit);
107        builder.assert_bool(shift_most_sig_bit);
108
109        let expected_opcode = (is_loadb0 + is_loadb1) * AB::F::from_u8(LOADB as u8)
110            + is_loadh * AB::F::from_u8(LOADH as u8)
111            + AB::Expr::from_usize(Rv32LoadStoreOpcode::CLASS_OFFSET);
112
113        let limb_mask = data_most_sig_bit * AB::Expr::from_u32((1 << LIMB_BITS) - 1);
114
115        // there are three parts to write_data:
116        // - 1st limb is always shifted_read_data
117        // - 2nd to (NUM_CELLS/2)th limbs are read_data if loadh and sign extended if loadb
118        // - (NUM_CELLS/2 + 1)th to last limbs are always sign extended limbs
119        let write_data: [AB::Expr; NUM_CELLS] = array::from_fn(|i| {
120            if i == 0 {
121                (is_loadh + is_loadb0) * shifted_read_data[i].into()
122                    + is_loadb1 * shifted_read_data[i + 1].into()
123            } else if i < NUM_CELLS / 2 {
124                shifted_read_data[i] * is_loadh + (is_loadb0 + is_loadb1) * limb_mask.clone()
125            } else {
126                limb_mask.clone()
127            }
128        });
129
130        // Constrain that most_sig_bit is correct
131        let most_sig_limb = shifted_read_data[0] * is_loadb0
132            + shifted_read_data[1] * is_loadb1
133            + shifted_read_data[NUM_CELLS / 2 - 1] * is_loadh;
134
135        self.range_bus
136            .range_check(
137                most_sig_limb - data_most_sig_bit * AB::Expr::from_u32(1 << (LIMB_BITS - 1)),
138                LIMB_BITS - 1,
139            )
140            .eval(builder, is_valid.clone());
141
142        // Unshift the shifted_read_data to get the original read_data
143        let read_data = array::from_fn(|i| {
144            select(
145                shift_most_sig_bit,
146                shifted_read_data[(i + NUM_CELLS - 2) % NUM_CELLS],
147                shifted_read_data[i],
148            )
149        });
150        let load_shift_amount = shift_most_sig_bit * AB::Expr::TWO + is_loadb1;
151
152        AdapterAirContext {
153            to_pc: None,
154            reads: (prev_data, read_data).into(),
155            writes: [write_data].into(),
156            instruction: LoadStoreInstruction {
157                is_valid: is_valid.clone(),
158                opcode: expected_opcode,
159                is_load: is_valid,
160                load_shift_amount,
161                store_shift_amount: AB::Expr::ZERO,
162            }
163            .into(),
164        }
165    }
166
167    fn start_offset(&self) -> usize {
168        Rv32LoadStoreOpcode::CLASS_OFFSET
169    }
170}
171
172#[repr(C)]
173#[derive(AlignedBytesBorrow, Debug)]
174pub struct LoadSignExtendCoreRecord<const NUM_CELLS: usize> {
175    pub is_byte: bool,
176    pub shift_amount: u8,
177    pub read_data: [u8; NUM_CELLS],
178    pub prev_data: [u8; NUM_CELLS],
179}
180
181#[derive(Clone, Copy, derive_new::new)]
182pub struct LoadSignExtendExecutor<A, const NUM_CELLS: usize, const LIMB_BITS: usize> {
183    adapter: A,
184}
185
186#[derive(Clone, derive_new::new)]
187pub struct LoadSignExtendFiller<
188    A = Rv32LoadStoreAdapterFiller,
189    const NUM_CELLS: usize = RV32_REGISTER_NUM_LIMBS,
190    const LIMB_BITS: usize = RV32_CELL_BITS,
191> {
192    adapter: A,
193    pub range_checker_chip: SharedVariableRangeCheckerChip,
194}
195
196impl<F, A, RA, const NUM_CELLS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
197    for LoadSignExtendExecutor<A, NUM_CELLS, LIMB_BITS>
198where
199    F: PrimeField32,
200    A: 'static
201        + AdapterTraceExecutor<
202            F,
203            ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8),
204            WriteData = [u32; NUM_CELLS],
205        >,
206    for<'buf> RA: RecordArena<
207        'buf,
208        EmptyAdapterCoreLayout<F, A>,
209        (
210            A::RecordMut<'buf>,
211            &'buf mut LoadSignExtendCoreRecord<NUM_CELLS>,
212        ),
213    >,
214{
215    fn get_opcode_name(&self, opcode: usize) -> String {
216        format!(
217            "{:?}",
218            Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET)
219        )
220    }
221
222    fn execute(
223        &self,
224        state: VmStateMut<F, TracingMemory, RA>,
225        instruction: &Instruction<F>,
226    ) -> Result<(), ExecutionError> {
227        let Instruction { opcode, .. } = instruction;
228
229        let local_opcode = Rv32LoadStoreOpcode::from_usize(
230            opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
231        );
232
233        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
234
235        A::start(*state.pc, state.memory, &mut adapter_record);
236
237        let tmp = self
238            .adapter
239            .read(state.memory, instruction, &mut adapter_record);
240
241        core_record.is_byte = local_opcode == LOADB;
242        core_record.prev_data = tmp.0 .0.map(|x| x as u8);
243        core_record.read_data = tmp.0 .1;
244        core_record.shift_amount = tmp.1;
245
246        let write_data = run_write_data_sign_extend(
247            local_opcode,
248            core_record.read_data,
249            core_record.shift_amount as usize,
250        );
251
252        self.adapter.write(
253            state.memory,
254            instruction,
255            write_data.map(u32::from),
256            &mut adapter_record,
257        );
258
259        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
260
261        Ok(())
262    }
263}
264
265impl<F, A, const NUM_CELLS: usize, const LIMB_BITS: usize> TraceFiller<F>
266    for LoadSignExtendFiller<A, NUM_CELLS, LIMB_BITS>
267where
268    F: PrimeField32,
269    A: 'static + AdapterTraceFiller<F>,
270{
271    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
272        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
273        // LoadSignExtendCoreCols::width() elements
274        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
275        self.adapter.fill_trace_row(mem_helper, adapter_row);
276        // SAFETY: core_row contains a valid LoadSignExtendCoreRecord written by the executor
277        // during trace generation
278        let record: &LoadSignExtendCoreRecord<NUM_CELLS> =
279            unsafe { get_record_from_slice(&mut core_row, ()) };
280
281        let core_row: &mut LoadSignExtendCoreCols<F, NUM_CELLS> = core_row.borrow_mut();
282
283        let shift = record.shift_amount;
284        let most_sig_limb = if record.is_byte {
285            record.read_data[shift as usize]
286        } else {
287            record.read_data[NUM_CELLS / 2 - 1 + shift as usize]
288        };
289
290        let most_sig_bit = most_sig_limb & (1 << 7);
291        self.range_checker_chip
292            .add_count((most_sig_limb - most_sig_bit) as u32, 7);
293
294        core_row.prev_data = record.prev_data.map(F::from_u8);
295        core_row.shifted_read_data = record.read_data.map(F::from_u8);
296        core_row.shifted_read_data.rotate_left((shift & 2) as usize);
297
298        core_row.data_most_sig_bit = F::from_bool(most_sig_bit != 0);
299        core_row.shift_most_sig_bit = F::from_bool(shift & 2 == 2);
300        core_row.opcode_loadh_flag = F::from_bool(!record.is_byte);
301        core_row.opcode_loadb_flag1 = F::from_bool(record.is_byte && ((shift & 1) == 1));
302        core_row.opcode_loadb_flag0 = F::from_bool(record.is_byte && ((shift & 1) == 0));
303    }
304}
305
306// Returns write_data
307#[inline(always)]
308pub(super) fn run_write_data_sign_extend<const NUM_CELLS: usize>(
309    opcode: Rv32LoadStoreOpcode,
310    read_data: [u8; NUM_CELLS],
311    shift: usize,
312) -> [u8; NUM_CELLS] {
313    match (opcode, shift) {
314        (LOADH, 0) | (LOADH, 2) => {
315            let ext = (read_data[NUM_CELLS / 2 - 1 + shift] >> 7) * u8::MAX;
316            array::from_fn(|i| {
317                if i < NUM_CELLS / 2 {
318                    read_data[i + shift]
319                } else {
320                    ext
321                }
322            })
323        }
324        (LOADB, 0) | (LOADB, 1) | (LOADB, 2) | (LOADB, 3) => {
325            let ext = (read_data[shift] >> 7) * u8::MAX;
326            array::from_fn(|i| {
327                if i == 0 {
328                    read_data[i + shift]
329                } else {
330                    ext
331                }
332            })
333        }
334        // Currently the adapter AIR requires `ptr_val` to be aligned to the data size in bytes.
335        // The circuit requires that `shift = ptr_val % 4` so that `ptr_val - shift` is a multiple of 4.
336        // This requirement is non-trivial to remove, because we use it to ensure that `ptr_val - shift + 4 <= 2^pointer_max_bits`.
337        _ => unreachable!(
338            "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
339        ),
340    }
341}