openvm_rv32im_circuit/load_sign_extend/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip,
8};
9use openvm_circuit_primitives::{
10    utils::select,
11    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
12};
13use openvm_circuit_primitives_derive::AlignedBorrow;
14use openvm_instructions::{instruction::Instruction, LocalOpcode};
15use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
16use openvm_stark_backend::{
17    interaction::InteractionBuilder,
18    p3_air::BaseAir,
19    p3_field::{Field, FieldAlgebra, PrimeField32},
20    rap::BaseAirWithPublicValues,
21};
22use serde::{de::DeserializeOwned, Deserialize, Serialize};
23use serde_big_array::BigArray;
24
25use crate::adapters::LoadStoreInstruction;
26
27/// LoadSignExtend Core Chip handles byte/halfword into word conversions through sign extend
28/// This chip uses read_data to construct write_data
29/// prev_data columns are not used in constraints defined in the CoreAir, but are used in constraints by the Adapter
30/// shifted_read_data is the read_data shifted by (shift_amount & 2), this reduces the number of opcode flags needed
31/// using this shifted data we can generate the write_data as if the shift_amount was 0 for loadh and 0 or 1 for loadb
32#[repr(C)]
33#[derive(Debug, Clone, AlignedBorrow)]
34pub struct LoadSignExtendCoreCols<T, const NUM_CELLS: usize> {
35    /// This chip treats loadb with 0 shift and loadb with 1 shift as different instructions
36    pub opcode_loadb_flag0: T,
37    pub opcode_loadb_flag1: T,
38    pub opcode_loadh_flag: T,
39
40    pub shift_most_sig_bit: T,
41    // The bit that is extended to the remaining bits
42    pub data_most_sig_bit: T,
43
44    pub shifted_read_data: [T; NUM_CELLS],
45    pub prev_data: [T; NUM_CELLS],
46}
47
48#[repr(C)]
49#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(bound = "F: Serialize + DeserializeOwned")]
51pub struct LoadSignExtendCoreRecord<F, const NUM_CELLS: usize> {
52    #[serde(with = "BigArray")]
53    pub shifted_read_data: [F; NUM_CELLS],
54    #[serde(with = "BigArray")]
55    pub prev_data: [F; NUM_CELLS],
56    pub opcode: Rv32LoadStoreOpcode,
57    pub shift_amount: u32,
58    pub most_sig_bit: bool,
59}
60
61#[derive(Debug, Clone)]
62pub struct LoadSignExtendCoreAir<const NUM_CELLS: usize, const LIMB_BITS: usize> {
63    pub range_bus: VariableRangeCheckerBus,
64}
65
66impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAir<F>
67    for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
68{
69    fn width(&self) -> usize {
70        LoadSignExtendCoreCols::<F, NUM_CELLS>::width()
71    }
72}
73
74impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
75    for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
76{
77}
78
79impl<AB, I, const NUM_CELLS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
80    for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
81where
82    AB: InteractionBuilder,
83    I: VmAdapterInterface<AB::Expr>,
84    I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
85    I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
86    I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
87{
88    fn eval(
89        &self,
90        builder: &mut AB,
91        local_core: &[AB::Var],
92        _from_pc: AB::Var,
93    ) -> AdapterAirContext<AB::Expr, I> {
94        let cols: &LoadSignExtendCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
95        let LoadSignExtendCoreCols::<AB::Var, NUM_CELLS> {
96            shifted_read_data,
97            prev_data,
98            opcode_loadb_flag0: is_loadb0,
99            opcode_loadb_flag1: is_loadb1,
100            opcode_loadh_flag: is_loadh,
101            data_most_sig_bit,
102            shift_most_sig_bit,
103        } = *cols;
104
105        let flags = [is_loadb0, is_loadb1, is_loadh];
106
107        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
108            builder.assert_bool(flag);
109            acc + flag
110        });
111
112        builder.assert_bool(is_valid.clone());
113        builder.assert_bool(data_most_sig_bit);
114        builder.assert_bool(shift_most_sig_bit);
115
116        let expected_opcode = (is_loadb0 + is_loadb1) * AB::F::from_canonical_u8(LOADB as u8)
117            + is_loadh * AB::F::from_canonical_u8(LOADH as u8)
118            + AB::Expr::from_canonical_usize(Rv32LoadStoreOpcode::CLASS_OFFSET);
119
120        let limb_mask = data_most_sig_bit * AB::Expr::from_canonical_u32((1 << LIMB_BITS) - 1);
121
122        // there are three parts to write_data:
123        // - 1st limb is always shifted_read_data
124        // - 2nd to (NUM_CELLS/2)th limbs are read_data if loadh and sign extended if loadb
125        // - (NUM_CELLS/2 + 1)th to last limbs are always sign extended limbs
126        let write_data: [AB::Expr; NUM_CELLS] = array::from_fn(|i| {
127            if i == 0 {
128                (is_loadh + is_loadb0) * shifted_read_data[i].into()
129                    + is_loadb1 * shifted_read_data[i + 1].into()
130            } else if i < NUM_CELLS / 2 {
131                shifted_read_data[i] * is_loadh + (is_loadb0 + is_loadb1) * limb_mask.clone()
132            } else {
133                limb_mask.clone()
134            }
135        });
136
137        // Constrain that most_sig_bit is correct
138        let most_sig_limb = shifted_read_data[0] * is_loadb0
139            + shifted_read_data[1] * is_loadb1
140            + shifted_read_data[NUM_CELLS / 2 - 1] * is_loadh;
141
142        self.range_bus
143            .range_check(
144                most_sig_limb
145                    - data_most_sig_bit * AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)),
146                LIMB_BITS - 1,
147            )
148            .eval(builder, is_valid.clone());
149
150        // Unshift the shifted_read_data to get the original read_data
151        let read_data = array::from_fn(|i| {
152            select(
153                shift_most_sig_bit,
154                shifted_read_data[(i + NUM_CELLS - 2) % NUM_CELLS],
155                shifted_read_data[i],
156            )
157        });
158        let load_shift_amount = shift_most_sig_bit * AB::Expr::TWO + is_loadb1;
159
160        AdapterAirContext {
161            to_pc: None,
162            reads: (prev_data, read_data).into(),
163            writes: [write_data].into(),
164            instruction: LoadStoreInstruction {
165                is_valid: is_valid.clone(),
166                opcode: expected_opcode,
167                is_load: is_valid,
168                load_shift_amount,
169                store_shift_amount: AB::Expr::ZERO,
170            }
171            .into(),
172        }
173    }
174
175    fn start_offset(&self) -> usize {
176        Rv32LoadStoreOpcode::CLASS_OFFSET
177    }
178}
179
180pub struct LoadSignExtendCoreChip<const NUM_CELLS: usize, const LIMB_BITS: usize> {
181    pub air: LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>,
182    pub range_checker_chip: SharedVariableRangeCheckerChip,
183}
184
185impl<const NUM_CELLS: usize, const LIMB_BITS: usize> LoadSignExtendCoreChip<NUM_CELLS, LIMB_BITS> {
186    pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self {
187        Self {
188            air: LoadSignExtendCoreAir::<NUM_CELLS, LIMB_BITS> {
189                range_bus: range_checker_chip.bus(),
190            },
191            range_checker_chip,
192        }
193    }
194}
195
196impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_CELLS: usize, const LIMB_BITS: usize>
197    VmCoreChip<F, I> for LoadSignExtendCoreChip<NUM_CELLS, LIMB_BITS>
198where
199    I::Reads: Into<([[F; NUM_CELLS]; 2], F)>,
200    I::Writes: From<[[F; NUM_CELLS]; 1]>,
201{
202    type Record = LoadSignExtendCoreRecord<F, NUM_CELLS>;
203    type Air = LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>;
204
205    #[allow(clippy::type_complexity)]
206    fn execute_instruction(
207        &self,
208        instruction: &Instruction<F>,
209        _from_pc: u32,
210        reads: I::Reads,
211    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
212        let local_opcode = Rv32LoadStoreOpcode::from_usize(
213            instruction
214                .opcode
215                .local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
216        );
217
218        let (data, shift_amount) = reads.into();
219        let shift_amount = shift_amount.as_canonical_u32();
220        let write_data: [F; NUM_CELLS] = run_write_data_sign_extend::<_, NUM_CELLS, LIMB_BITS>(
221            local_opcode,
222            data[1],
223            data[0],
224            shift_amount,
225        );
226        let output = AdapterRuntimeContext::without_pc([write_data]);
227
228        let most_sig_limb = match local_opcode {
229            LOADB => write_data[0],
230            LOADH => write_data[NUM_CELLS / 2 - 1],
231            _ => unreachable!(),
232        }
233        .as_canonical_u32();
234
235        let most_sig_bit = most_sig_limb & (1 << (LIMB_BITS - 1));
236        self.range_checker_chip
237            .add_count(most_sig_limb - most_sig_bit, LIMB_BITS - 1);
238
239        let read_shift = shift_amount & 2;
240
241        Ok((
242            output,
243            LoadSignExtendCoreRecord {
244                opcode: local_opcode,
245                most_sig_bit: most_sig_bit != 0,
246                prev_data: data[0],
247                shifted_read_data: array::from_fn(|i| {
248                    data[1][(i + read_shift as usize) % NUM_CELLS]
249                }),
250                shift_amount,
251            },
252        ))
253    }
254
255    fn get_opcode_name(&self, opcode: usize) -> String {
256        format!(
257            "{:?}",
258            Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET)
259        )
260    }
261
262    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
263        let core_cols: &mut LoadSignExtendCoreCols<F, NUM_CELLS> = row_slice.borrow_mut();
264        let opcode = record.opcode;
265        let shift = record.shift_amount;
266        core_cols.opcode_loadb_flag0 = F::from_bool(opcode == LOADB && (shift & 1) == 0);
267        core_cols.opcode_loadb_flag1 = F::from_bool(opcode == LOADB && (shift & 1) == 1);
268        core_cols.opcode_loadh_flag = F::from_bool(opcode == LOADH);
269        core_cols.shift_most_sig_bit = F::from_canonical_u32((shift & 2) >> 1);
270        core_cols.data_most_sig_bit = F::from_bool(record.most_sig_bit);
271        core_cols.prev_data = record.prev_data;
272        core_cols.shifted_read_data = record.shifted_read_data;
273    }
274
275    fn air(&self) -> &Self::Air {
276        &self.air
277    }
278}
279
280pub(super) fn run_write_data_sign_extend<
281    F: PrimeField32,
282    const NUM_CELLS: usize,
283    const LIMB_BITS: usize,
284>(
285    opcode: Rv32LoadStoreOpcode,
286    read_data: [F; NUM_CELLS],
287    _prev_data: [F; NUM_CELLS],
288    shift: u32,
289) -> [F; NUM_CELLS] {
290    let shift = shift as usize;
291    let mut write_data = read_data;
292    match (opcode, shift) {
293        (LOADH, 0) | (LOADH, 2) => {
294            let ext = read_data[NUM_CELLS / 2 - 1 + shift].as_canonical_u32();
295            let ext = (ext >> (LIMB_BITS - 1)) * ((1 << LIMB_BITS) - 1);
296            for cell in write_data.iter_mut().take(NUM_CELLS).skip(NUM_CELLS / 2) {
297                *cell = F::from_canonical_u32(ext);
298            }
299            write_data[0..NUM_CELLS / 2]
300                .copy_from_slice(&read_data[shift..(NUM_CELLS / 2 + shift)]);
301        }
302        (LOADB, 0) | (LOADB, 1) | (LOADB, 2) | (LOADB, 3) => {
303            let ext = read_data[shift].as_canonical_u32();
304            let ext = (ext >> (LIMB_BITS - 1)) * ((1 << LIMB_BITS) - 1);
305            for cell in write_data.iter_mut().take(NUM_CELLS).skip(1) {
306                *cell = F::from_canonical_u32(ext);
307            }
308            write_data[0] = read_data[shift];
309        }
310        // Currently the adapter AIR requires `ptr_val` to be aligned to the data size in bytes.
311        // The circuit requires that `shift = ptr_val % 4` so that `ptr_val - shift` is a multiple of 4.
312        // This requirement is non-trivial to remove, because we use it to ensure that `ptr_val - shift + 4 <= 2^pointer_max_bits`.
313        _ => unreachable!(
314            "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
315        ),
316    };
317    write_data
318}