1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7 arch::*,
8 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12 utils::not,
13 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
14 AlignedBytesBorrow,
15};
16use openvm_circuit_primitives_derive::AlignedBorrow;
17use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
18use openvm_rv32im_transpiler::ShiftOpcode;
19use openvm_stark_backend::{
20 interaction::InteractionBuilder,
21 p3_air::{AirBuilder, BaseAir},
22 p3_field::{Field, FieldAlgebra, PrimeField32},
23 rap::BaseAirWithPublicValues,
24};
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow, Clone, Copy, Debug)]
29pub struct ShiftCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30 pub a: [T; NUM_LIMBS],
31 pub b: [T; NUM_LIMBS],
32 pub c: [T; NUM_LIMBS],
33
34 pub opcode_sll_flag: T,
35 pub opcode_srl_flag: T,
36 pub opcode_sra_flag: T,
37
38 pub bit_multiplier_left: T,
40 pub bit_multiplier_right: T,
41
42 pub b_sign: T,
44
45 pub bit_shift_marker: [T; LIMB_BITS],
47 pub limb_shift_marker: [T; NUM_LIMBS],
48
49 pub bit_shift_carry: [T; NUM_LIMBS],
51}
52
53#[derive(Copy, Clone, Debug, derive_new::new)]
57pub struct ShiftCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
58 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
59 pub range_bus: VariableRangeCheckerBus,
60 pub offset: usize,
61}
62
63impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
64 for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
65{
66 fn width(&self) -> usize {
67 ShiftCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
68 }
69}
70impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
71 for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
72{
73}
74
75impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
76 for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
77where
78 AB: InteractionBuilder,
79 I: VmAdapterInterface<AB::Expr>,
80 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
81 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
82 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
83{
84 fn eval(
85 &self,
86 builder: &mut AB,
87 local_core: &[AB::Var],
88 _from_pc: AB::Var,
89 ) -> AdapterAirContext<AB::Expr, I> {
90 let cols: &ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
91 let flags = [
92 cols.opcode_sll_flag,
93 cols.opcode_srl_flag,
94 cols.opcode_sra_flag,
95 ];
96
97 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
98 builder.assert_bool(flag);
99 acc + flag.into()
100 });
101 builder.assert_bool(is_valid.clone());
102
103 let a = &cols.a;
104 let b = &cols.b;
105 let c = &cols.c;
106 let right_shift = cols.opcode_srl_flag + cols.opcode_sra_flag;
107
108 let mut bit_marker_sum = AB::Expr::ZERO;
112 let mut bit_shift = AB::Expr::ZERO;
113
114 for i in 0..LIMB_BITS {
115 builder.assert_bool(cols.bit_shift_marker[i]);
116 bit_marker_sum += cols.bit_shift_marker[i].into();
117 bit_shift += AB::Expr::from_canonical_usize(i) * cols.bit_shift_marker[i];
118
119 let mut when_bit_shift = builder.when(cols.bit_shift_marker[i]);
120 when_bit_shift.assert_eq(
121 cols.bit_multiplier_left,
122 AB::Expr::from_canonical_usize(1 << i) * cols.opcode_sll_flag,
123 );
124 when_bit_shift.assert_eq(
125 cols.bit_multiplier_right,
126 AB::Expr::from_canonical_usize(1 << i) * right_shift.clone(),
127 );
128 }
129 builder.when(is_valid.clone()).assert_one(bit_marker_sum);
130
131 let mut limb_marker_sum = AB::Expr::ZERO;
134 let mut limb_shift = AB::Expr::ZERO;
135 for i in 0..NUM_LIMBS {
136 builder.assert_bool(cols.limb_shift_marker[i]);
137 limb_marker_sum += cols.limb_shift_marker[i].into();
138 limb_shift += AB::Expr::from_canonical_usize(i) * cols.limb_shift_marker[i];
139
140 let mut when_limb_shift = builder.when(cols.limb_shift_marker[i]);
141
142 for j in 0..NUM_LIMBS {
143 if j < i {
145 when_limb_shift.assert_zero(a[j] * cols.opcode_sll_flag);
146 } else {
147 let expected_a_left = if j - i == 0 {
148 AB::Expr::ZERO
149 } else {
150 cols.bit_shift_carry[j - i - 1].into() * cols.opcode_sll_flag
151 } + b[j - i] * cols.bit_multiplier_left
152 - AB::Expr::from_canonical_usize(1 << LIMB_BITS)
153 * cols.bit_shift_carry[j - i]
154 * cols.opcode_sll_flag;
155 when_limb_shift.assert_eq(a[j] * cols.opcode_sll_flag, expected_a_left);
156 }
157
158 if j + i > NUM_LIMBS - 1 {
160 when_limb_shift.assert_eq(
161 a[j] * right_shift.clone(),
162 cols.b_sign * AB::F::from_canonical_usize((1 << LIMB_BITS) - 1),
163 );
164 } else {
165 let expected_a_right = if j + i == NUM_LIMBS - 1 {
166 cols.b_sign * (cols.bit_multiplier_right - AB::F::ONE)
167 } else {
168 cols.bit_shift_carry[j + i + 1].into() * right_shift.clone()
169 } * AB::F::from_canonical_usize(1 << LIMB_BITS)
170 + right_shift.clone() * (b[j + i] - cols.bit_shift_carry[j + i]);
171 when_limb_shift.assert_eq(a[j] * cols.bit_multiplier_right, expected_a_right);
172 }
173 }
174 }
175 builder.when(is_valid.clone()).assert_one(limb_marker_sum);
176
177 let num_bits = AB::F::from_canonical_usize(NUM_LIMBS * LIMB_BITS);
179 self.range_bus
180 .range_check(
181 (c[0] - limb_shift * AB::F::from_canonical_usize(LIMB_BITS) - bit_shift.clone())
182 * num_bits.inverse(),
183 LIMB_BITS - ((NUM_LIMBS * LIMB_BITS) as u32).ilog2() as usize,
184 )
185 .eval(builder, is_valid.clone());
186
187 builder.assert_bool(cols.b_sign);
189 builder
190 .when(not(cols.opcode_sra_flag))
191 .assert_zero(cols.b_sign);
192
193 let mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
194 let b_sign_shifted = cols.b_sign * mask;
195 self.bitwise_lookup_bus
196 .send_xor(
197 b[NUM_LIMBS - 1],
198 mask,
199 b[NUM_LIMBS - 1] + mask - (AB::Expr::from_canonical_u32(2) * b_sign_shifted),
200 )
201 .eval(builder, cols.opcode_sra_flag);
202
203 for i in 0..(NUM_LIMBS / 2) {
204 self.bitwise_lookup_bus
205 .send_range(a[i * 2], a[i * 2 + 1])
206 .eval(builder, is_valid.clone());
207 }
208
209 for carry in cols.bit_shift_carry {
210 self.range_bus
211 .send(carry, bit_shift.clone())
212 .eval(builder, is_valid.clone());
213 }
214
215 let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
216 self,
217 flags
218 .iter()
219 .zip(ShiftOpcode::iter())
220 .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
221 acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
222 }),
223 );
224
225 AdapterAirContext {
226 to_pc: None,
227 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
228 writes: [cols.a.map(Into::into)].into(),
229 instruction: MinimalInstruction {
230 is_valid,
231 opcode: expected_opcode,
232 }
233 .into(),
234 }
235 }
236
237 fn start_offset(&self) -> usize {
238 self.offset
239 }
240}
241
242#[repr(C)]
243#[derive(AlignedBytesBorrow, Debug)]
244pub struct ShiftCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
245 pub b: [u8; NUM_LIMBS],
246 pub c: [u8; NUM_LIMBS],
247 pub local_opcode: u8,
248}
249
250#[derive(Clone, Copy)]
251pub struct ShiftExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
252 adapter: A,
253 pub offset: usize,
254}
255
256#[derive(Clone)]
257pub struct ShiftFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
258 adapter: A,
259 pub offset: usize,
260 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
261 pub range_checker_chip: SharedVariableRangeCheckerChip,
262}
263
264impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftExecutor<A, NUM_LIMBS, LIMB_BITS> {
265 pub fn new(adapter: A, offset: usize) -> Self {
266 assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2");
267 Self { adapter, offset }
268 }
269}
270
271impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftFiller<A, NUM_LIMBS, LIMB_BITS> {
272 pub fn new(
273 adapter: A,
274 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
275 range_checker_chip: SharedVariableRangeCheckerChip,
276 offset: usize,
277 ) -> Self {
278 assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2");
279 Self {
280 adapter,
281 offset,
282 bitwise_lookup_chip,
283 range_checker_chip,
284 }
285 }
286}
287
288impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
289 for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
290where
291 F: PrimeField32,
292 A: 'static
293 + AdapterTraceExecutor<
294 F,
295 ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
296 WriteData: From<[[u8; NUM_LIMBS]; 1]>,
297 >,
298 for<'buf> RA: RecordArena<
299 'buf,
300 EmptyAdapterCoreLayout<F, A>,
301 (
302 A::RecordMut<'buf>,
303 &'buf mut ShiftCoreRecord<NUM_LIMBS, LIMB_BITS>,
304 ),
305 >,
306{
307 fn get_opcode_name(&self, opcode: usize) -> String {
308 format!("{:?}", ShiftOpcode::from_usize(opcode - self.offset))
309 }
310
311 fn execute(
312 &self,
313 state: VmStateMut<F, TracingMemory, RA>,
314 instruction: &Instruction<F>,
315 ) -> Result<(), ExecutionError> {
316 let Instruction { opcode, .. } = instruction;
317
318 let local_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset));
319
320 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
321
322 A::start(*state.pc, state.memory, &mut adapter_record);
323
324 let [rs1, rs2] = self
325 .adapter
326 .read(state.memory, instruction, &mut adapter_record)
327 .into();
328
329 let (output, _, _) = run_shift::<NUM_LIMBS, LIMB_BITS>(local_opcode, &rs1, &rs2);
330
331 core_record.b = rs1;
332 core_record.c = rs2;
333 core_record.local_opcode = local_opcode as u8;
334
335 self.adapter.write(
336 state.memory,
337 instruction,
338 [output].into(),
339 &mut adapter_record,
340 );
341 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
342
343 Ok(())
344 }
345}
346
347impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
348 for ShiftFiller<A, NUM_LIMBS, LIMB_BITS>
349where
350 F: PrimeField32,
351 A: 'static + AdapterTraceFiller<F>,
352{
353 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
354 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
357 self.adapter.fill_trace_row(mem_helper, adapter_row);
358 let record: &ShiftCoreRecord<NUM_LIMBS, LIMB_BITS> =
361 unsafe { get_record_from_slice(&mut core_row, ()) };
362
363 let core_row: &mut ShiftCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
364
365 let opcode = ShiftOpcode::from_usize(record.local_opcode as usize);
366 let (a, limb_shift, bit_shift) =
367 run_shift::<NUM_LIMBS, LIMB_BITS>(opcode, &record.b, &record.c);
368
369 for pair in a.chunks_exact(2) {
370 self.bitwise_lookup_chip
371 .request_range(pair[0] as u32, pair[1] as u32);
372 }
373
374 let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2();
375 self.range_checker_chip.add_count(
376 ((record.c[0] as usize - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u32,
377 LIMB_BITS - num_bits_log as usize,
378 );
379
380 core_row.bit_shift_carry = if bit_shift == 0 {
381 for _ in 0..NUM_LIMBS {
382 self.range_checker_chip.add_count(0, 0);
383 }
384 [F::ZERO; NUM_LIMBS]
385 } else {
386 array::from_fn(|i| {
387 let carry = match opcode {
388 ShiftOpcode::SLL => record.b[i] >> (LIMB_BITS - bit_shift),
389 _ => record.b[i] % (1 << bit_shift),
390 };
391 self.range_checker_chip.add_count(carry as u32, bit_shift);
392 F::from_canonical_u8(carry)
393 })
394 };
395
396 core_row.limb_shift_marker = [F::ZERO; NUM_LIMBS];
397 core_row.limb_shift_marker[limb_shift] = F::ONE;
398 core_row.bit_shift_marker = [F::ZERO; LIMB_BITS];
399 core_row.bit_shift_marker[bit_shift] = F::ONE;
400
401 core_row.b_sign = F::ZERO;
402 if opcode == ShiftOpcode::SRA {
403 core_row.b_sign = F::from_canonical_u8(record.b[NUM_LIMBS - 1] >> (LIMB_BITS - 1));
404 self.bitwise_lookup_chip
405 .request_xor(record.b[NUM_LIMBS - 1] as u32, 1 << (LIMB_BITS - 1));
406 }
407
408 core_row.bit_multiplier_right = match opcode {
409 ShiftOpcode::SLL => F::ZERO,
410 _ => F::from_canonical_usize(1 << bit_shift),
411 };
412 core_row.bit_multiplier_left = match opcode {
413 ShiftOpcode::SLL => F::from_canonical_usize(1 << bit_shift),
414 _ => F::ZERO,
415 };
416
417 core_row.opcode_sra_flag = F::from_bool(opcode == ShiftOpcode::SRA);
418 core_row.opcode_srl_flag = F::from_bool(opcode == ShiftOpcode::SRL);
419 core_row.opcode_sll_flag = F::from_bool(opcode == ShiftOpcode::SLL);
420
421 core_row.c = record.c.map(F::from_canonical_u8);
422 core_row.b = record.b.map(F::from_canonical_u8);
423 core_row.a = a.map(F::from_canonical_u8);
424 }
425}
426
427#[inline(always)]
429pub(super) fn run_shift<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
430 opcode: ShiftOpcode,
431 x: &[u8; NUM_LIMBS],
432 y: &[u8; NUM_LIMBS],
433) -> ([u8; NUM_LIMBS], usize, usize) {
434 match opcode {
435 ShiftOpcode::SLL => run_shift_left::<NUM_LIMBS, LIMB_BITS>(x, y),
436 ShiftOpcode::SRL => run_shift_right::<NUM_LIMBS, LIMB_BITS>(x, y, true),
437 ShiftOpcode::SRA => run_shift_right::<NUM_LIMBS, LIMB_BITS>(x, y, false),
438 }
439}
440
441#[inline(always)]
442fn run_shift_left<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
443 x: &[u8; NUM_LIMBS],
444 y: &[u8; NUM_LIMBS],
445) -> ([u8; NUM_LIMBS], usize, usize) {
446 let mut result = [0u8; NUM_LIMBS];
447
448 let (limb_shift, bit_shift) = get_shift::<NUM_LIMBS, LIMB_BITS>(y);
449
450 for i in limb_shift..NUM_LIMBS {
451 result[i] = if i > limb_shift {
452 (((x[i - limb_shift] as u16) << bit_shift)
453 | ((x[i - limb_shift - 1] as u16) >> (LIMB_BITS - bit_shift)))
454 % (1u16 << LIMB_BITS)
455 } else {
456 ((x[i - limb_shift] as u16) << bit_shift) % (1u16 << LIMB_BITS)
457 } as u8;
458 }
459 (result, limb_shift, bit_shift)
460}
461
462#[inline(always)]
463fn run_shift_right<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
464 x: &[u8; NUM_LIMBS],
465 y: &[u8; NUM_LIMBS],
466 logical: bool,
467) -> ([u8; NUM_LIMBS], usize, usize) {
468 let fill = if logical {
469 0
470 } else {
471 (((1u16 << LIMB_BITS) - 1) as u8) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1))
472 };
473 let mut result = [fill; NUM_LIMBS];
474
475 let (limb_shift, bit_shift) = get_shift::<NUM_LIMBS, LIMB_BITS>(y);
476
477 for i in 0..(NUM_LIMBS - limb_shift) {
478 let res = if i + limb_shift + 1 < NUM_LIMBS {
479 (((x[i + limb_shift] >> bit_shift) as u16)
480 | ((x[i + limb_shift + 1] as u16) << (LIMB_BITS - bit_shift)))
481 % (1u16 << LIMB_BITS)
482 } else {
483 (((x[i + limb_shift] >> bit_shift) as u16) | ((fill as u16) << (LIMB_BITS - bit_shift)))
484 % (1u16 << LIMB_BITS)
485 };
486 result[i] = res as u8;
487 }
488 (result, limb_shift, bit_shift)
489}
490
491#[inline(always)]
492fn get_shift<const NUM_LIMBS: usize, const LIMB_BITS: usize>(y: &[u8]) -> (usize, usize) {
493 debug_assert!(NUM_LIMBS * LIMB_BITS <= (1 << LIMB_BITS));
494 let shift = (y[0] as usize) % (NUM_LIMBS * LIMB_BITS);
497 (shift / LIMB_BITS, shift % LIMB_BITS)
498}