1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip,
8};
9use openvm_circuit_primitives::{
10 utils::select,
11 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
12};
13use openvm_circuit_primitives_derive::AlignedBorrow;
14use openvm_instructions::{instruction::Instruction, LocalOpcode};
15use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
16use openvm_stark_backend::{
17 interaction::InteractionBuilder,
18 p3_air::BaseAir,
19 p3_field::{Field, FieldAlgebra, PrimeField32},
20 rap::BaseAirWithPublicValues,
21};
22use serde::{de::DeserializeOwned, Deserialize, Serialize};
23use serde_big_array::BigArray;
24
25use crate::adapters::LoadStoreInstruction;
26
27#[repr(C)]
33#[derive(Debug, Clone, AlignedBorrow)]
34pub struct LoadSignExtendCoreCols<T, const NUM_CELLS: usize> {
35 pub opcode_loadb_flag0: T,
37 pub opcode_loadb_flag1: T,
38 pub opcode_loadh_flag: T,
39
40 pub shift_most_sig_bit: T,
41 pub data_most_sig_bit: T,
43
44 pub shifted_read_data: [T; NUM_CELLS],
45 pub prev_data: [T; NUM_CELLS],
46}
47
48#[repr(C)]
49#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(bound = "F: Serialize + DeserializeOwned")]
51pub struct LoadSignExtendCoreRecord<F, const NUM_CELLS: usize> {
52 #[serde(with = "BigArray")]
53 pub shifted_read_data: [F; NUM_CELLS],
54 #[serde(with = "BigArray")]
55 pub prev_data: [F; NUM_CELLS],
56 pub opcode: Rv32LoadStoreOpcode,
57 pub shift_amount: u32,
58 pub most_sig_bit: bool,
59}
60
61#[derive(Debug, Clone)]
62pub struct LoadSignExtendCoreAir<const NUM_CELLS: usize, const LIMB_BITS: usize> {
63 pub range_bus: VariableRangeCheckerBus,
64}
65
66impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAir<F>
67 for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
68{
69 fn width(&self) -> usize {
70 LoadSignExtendCoreCols::<F, NUM_CELLS>::width()
71 }
72}
73
74impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
75 for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
76{
77}
78
79impl<AB, I, const NUM_CELLS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
80 for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
81where
82 AB: InteractionBuilder,
83 I: VmAdapterInterface<AB::Expr>,
84 I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
85 I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
86 I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
87{
88 fn eval(
89 &self,
90 builder: &mut AB,
91 local_core: &[AB::Var],
92 _from_pc: AB::Var,
93 ) -> AdapterAirContext<AB::Expr, I> {
94 let cols: &LoadSignExtendCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
95 let LoadSignExtendCoreCols::<AB::Var, NUM_CELLS> {
96 shifted_read_data,
97 prev_data,
98 opcode_loadb_flag0: is_loadb0,
99 opcode_loadb_flag1: is_loadb1,
100 opcode_loadh_flag: is_loadh,
101 data_most_sig_bit,
102 shift_most_sig_bit,
103 } = *cols;
104
105 let flags = [is_loadb0, is_loadb1, is_loadh];
106
107 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
108 builder.assert_bool(flag);
109 acc + flag
110 });
111
112 builder.assert_bool(is_valid.clone());
113 builder.assert_bool(data_most_sig_bit);
114 builder.assert_bool(shift_most_sig_bit);
115
116 let expected_opcode = (is_loadb0 + is_loadb1) * AB::F::from_canonical_u8(LOADB as u8)
117 + is_loadh * AB::F::from_canonical_u8(LOADH as u8)
118 + AB::Expr::from_canonical_usize(Rv32LoadStoreOpcode::CLASS_OFFSET);
119
120 let limb_mask = data_most_sig_bit * AB::Expr::from_canonical_u32((1 << LIMB_BITS) - 1);
121
122 let write_data: [AB::Expr; NUM_CELLS] = array::from_fn(|i| {
127 if i == 0 {
128 (is_loadh + is_loadb0) * shifted_read_data[i].into()
129 + is_loadb1 * shifted_read_data[i + 1].into()
130 } else if i < NUM_CELLS / 2 {
131 shifted_read_data[i] * is_loadh + (is_loadb0 + is_loadb1) * limb_mask.clone()
132 } else {
133 limb_mask.clone()
134 }
135 });
136
137 let most_sig_limb = shifted_read_data[0] * is_loadb0
139 + shifted_read_data[1] * is_loadb1
140 + shifted_read_data[NUM_CELLS / 2 - 1] * is_loadh;
141
142 self.range_bus
143 .range_check(
144 most_sig_limb
145 - data_most_sig_bit * AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)),
146 LIMB_BITS - 1,
147 )
148 .eval(builder, is_valid.clone());
149
150 let read_data = array::from_fn(|i| {
152 select(
153 shift_most_sig_bit,
154 shifted_read_data[(i + NUM_CELLS - 2) % NUM_CELLS],
155 shifted_read_data[i],
156 )
157 });
158 let load_shift_amount = shift_most_sig_bit * AB::Expr::TWO + is_loadb1;
159
160 AdapterAirContext {
161 to_pc: None,
162 reads: (prev_data, read_data).into(),
163 writes: [write_data].into(),
164 instruction: LoadStoreInstruction {
165 is_valid: is_valid.clone(),
166 opcode: expected_opcode,
167 is_load: is_valid,
168 load_shift_amount,
169 store_shift_amount: AB::Expr::ZERO,
170 }
171 .into(),
172 }
173 }
174
175 fn start_offset(&self) -> usize {
176 Rv32LoadStoreOpcode::CLASS_OFFSET
177 }
178}
179
180pub struct LoadSignExtendCoreChip<const NUM_CELLS: usize, const LIMB_BITS: usize> {
181 pub air: LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>,
182 pub range_checker_chip: SharedVariableRangeCheckerChip,
183}
184
185impl<const NUM_CELLS: usize, const LIMB_BITS: usize> LoadSignExtendCoreChip<NUM_CELLS, LIMB_BITS> {
186 pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self {
187 Self {
188 air: LoadSignExtendCoreAir::<NUM_CELLS, LIMB_BITS> {
189 range_bus: range_checker_chip.bus(),
190 },
191 range_checker_chip,
192 }
193 }
194}
195
196impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_CELLS: usize, const LIMB_BITS: usize>
197 VmCoreChip<F, I> for LoadSignExtendCoreChip<NUM_CELLS, LIMB_BITS>
198where
199 I::Reads: Into<([[F; NUM_CELLS]; 2], F)>,
200 I::Writes: From<[[F; NUM_CELLS]; 1]>,
201{
202 type Record = LoadSignExtendCoreRecord<F, NUM_CELLS>;
203 type Air = LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>;
204
205 #[allow(clippy::type_complexity)]
206 fn execute_instruction(
207 &self,
208 instruction: &Instruction<F>,
209 _from_pc: u32,
210 reads: I::Reads,
211 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
212 let local_opcode = Rv32LoadStoreOpcode::from_usize(
213 instruction
214 .opcode
215 .local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
216 );
217
218 let (data, shift_amount) = reads.into();
219 let shift_amount = shift_amount.as_canonical_u32();
220 let write_data: [F; NUM_CELLS] = run_write_data_sign_extend::<_, NUM_CELLS, LIMB_BITS>(
221 local_opcode,
222 data[1],
223 data[0],
224 shift_amount,
225 );
226 let output = AdapterRuntimeContext::without_pc([write_data]);
227
228 let most_sig_limb = match local_opcode {
229 LOADB => write_data[0],
230 LOADH => write_data[NUM_CELLS / 2 - 1],
231 _ => unreachable!(),
232 }
233 .as_canonical_u32();
234
235 let most_sig_bit = most_sig_limb & (1 << (LIMB_BITS - 1));
236 self.range_checker_chip
237 .add_count(most_sig_limb - most_sig_bit, LIMB_BITS - 1);
238
239 let read_shift = shift_amount & 2;
240
241 Ok((
242 output,
243 LoadSignExtendCoreRecord {
244 opcode: local_opcode,
245 most_sig_bit: most_sig_bit != 0,
246 prev_data: data[0],
247 shifted_read_data: array::from_fn(|i| {
248 data[1][(i + read_shift as usize) % NUM_CELLS]
249 }),
250 shift_amount,
251 },
252 ))
253 }
254
255 fn get_opcode_name(&self, opcode: usize) -> String {
256 format!(
257 "{:?}",
258 Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET)
259 )
260 }
261
262 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
263 let core_cols: &mut LoadSignExtendCoreCols<F, NUM_CELLS> = row_slice.borrow_mut();
264 let opcode = record.opcode;
265 let shift = record.shift_amount;
266 core_cols.opcode_loadb_flag0 = F::from_bool(opcode == LOADB && (shift & 1) == 0);
267 core_cols.opcode_loadb_flag1 = F::from_bool(opcode == LOADB && (shift & 1) == 1);
268 core_cols.opcode_loadh_flag = F::from_bool(opcode == LOADH);
269 core_cols.shift_most_sig_bit = F::from_canonical_u32((shift & 2) >> 1);
270 core_cols.data_most_sig_bit = F::from_bool(record.most_sig_bit);
271 core_cols.prev_data = record.prev_data;
272 core_cols.shifted_read_data = record.shifted_read_data;
273 }
274
275 fn air(&self) -> &Self::Air {
276 &self.air
277 }
278}
279
280pub(super) fn run_write_data_sign_extend<
281 F: PrimeField32,
282 const NUM_CELLS: usize,
283 const LIMB_BITS: usize,
284>(
285 opcode: Rv32LoadStoreOpcode,
286 read_data: [F; NUM_CELLS],
287 _prev_data: [F; NUM_CELLS],
288 shift: u32,
289) -> [F; NUM_CELLS] {
290 let shift = shift as usize;
291 let mut write_data = read_data;
292 match (opcode, shift) {
293 (LOADH, 0) | (LOADH, 2) => {
294 let ext = read_data[NUM_CELLS / 2 - 1 + shift].as_canonical_u32();
295 let ext = (ext >> (LIMB_BITS - 1)) * ((1 << LIMB_BITS) - 1);
296 for cell in write_data.iter_mut().take(NUM_CELLS).skip(NUM_CELLS / 2) {
297 *cell = F::from_canonical_u32(ext);
298 }
299 write_data[0..NUM_CELLS / 2]
300 .copy_from_slice(&read_data[shift..(NUM_CELLS / 2 + shift)]);
301 }
302 (LOADB, 0) | (LOADB, 1) | (LOADB, 2) | (LOADB, 3) => {
303 let ext = read_data[shift].as_canonical_u32();
304 let ext = (ext >> (LIMB_BITS - 1)) * ((1 << LIMB_BITS) - 1);
305 for cell in write_data.iter_mut().take(NUM_CELLS).skip(1) {
306 *cell = F::from_canonical_u32(ext);
307 }
308 write_data[0] = read_data[shift];
309 }
310 _ => unreachable!(
314 "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
315 ),
316 };
317 write_data
318}