1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
8 VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12 utils::not,
13 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{instruction::Instruction, LocalOpcode};
17use openvm_rv32im_transpiler::ShiftOpcode;
18use openvm_stark_backend::{
19 interaction::InteractionBuilder,
20 p3_air::{AirBuilder, BaseAir},
21 p3_field::{Field, FieldAlgebra, PrimeField32},
22 rap::BaseAirWithPublicValues,
23};
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use serde_big_array::BigArray;
26use strum::IntoEnumIterator;
27
28#[repr(C)]
29#[derive(AlignedBorrow, Clone, Copy, Debug)]
30pub struct ShiftCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
31 pub a: [T; NUM_LIMBS],
32 pub b: [T; NUM_LIMBS],
33 pub c: [T; NUM_LIMBS],
34
35 pub opcode_sll_flag: T,
36 pub opcode_srl_flag: T,
37 pub opcode_sra_flag: T,
38
39 pub bit_multiplier_left: T,
41 pub bit_multiplier_right: T,
42
43 pub b_sign: T,
45
46 pub bit_shift_marker: [T; LIMB_BITS],
48 pub limb_shift_marker: [T; NUM_LIMBS],
49
50 pub bit_shift_carry: [T; NUM_LIMBS],
52}
53
54#[derive(Copy, Clone, Debug)]
55pub struct ShiftCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
56 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
57 pub range_bus: VariableRangeCheckerBus,
58 pub offset: usize,
59}
60
61impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
62 for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
63{
64 fn width(&self) -> usize {
65 ShiftCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
66 }
67}
68impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
69 for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
70{
71}
72
73impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
74 for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
75where
76 AB: InteractionBuilder,
77 I: VmAdapterInterface<AB::Expr>,
78 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
79 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
80 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
81{
82 fn eval(
83 &self,
84 builder: &mut AB,
85 local_core: &[AB::Var],
86 _from_pc: AB::Var,
87 ) -> AdapterAirContext<AB::Expr, I> {
88 let cols: &ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
89 let flags = [
90 cols.opcode_sll_flag,
91 cols.opcode_srl_flag,
92 cols.opcode_sra_flag,
93 ];
94
95 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
96 builder.assert_bool(flag);
97 acc + flag.into()
98 });
99 builder.assert_bool(is_valid.clone());
100
101 let a = &cols.a;
102 let b = &cols.b;
103 let c = &cols.c;
104 let right_shift = cols.opcode_srl_flag + cols.opcode_sra_flag;
105
106 let mut bit_marker_sum = AB::Expr::ZERO;
110 let mut bit_shift = AB::Expr::ZERO;
111
112 for i in 0..LIMB_BITS {
113 builder.assert_bool(cols.bit_shift_marker[i]);
114 bit_marker_sum += cols.bit_shift_marker[i].into();
115 bit_shift += AB::Expr::from_canonical_usize(i) * cols.bit_shift_marker[i];
116
117 let mut when_bit_shift = builder.when(cols.bit_shift_marker[i]);
118 when_bit_shift.assert_eq(
119 cols.bit_multiplier_left,
120 AB::Expr::from_canonical_usize(1 << i) * cols.opcode_sll_flag,
121 );
122 when_bit_shift.assert_eq(
123 cols.bit_multiplier_right,
124 AB::Expr::from_canonical_usize(1 << i) * right_shift.clone(),
125 );
126 }
127 builder.when(is_valid.clone()).assert_one(bit_marker_sum);
128
129 let mut limb_marker_sum = AB::Expr::ZERO;
132 let mut limb_shift = AB::Expr::ZERO;
133 for i in 0..NUM_LIMBS {
134 builder.assert_bool(cols.limb_shift_marker[i]);
135 limb_marker_sum += cols.limb_shift_marker[i].into();
136 limb_shift += AB::Expr::from_canonical_usize(i) * cols.limb_shift_marker[i];
137
138 let mut when_limb_shift = builder.when(cols.limb_shift_marker[i]);
139
140 for j in 0..NUM_LIMBS {
141 if j < i {
143 when_limb_shift.assert_zero(a[j] * cols.opcode_sll_flag);
144 } else {
145 let expected_a_left = if j - i == 0 {
146 AB::Expr::ZERO
147 } else {
148 cols.bit_shift_carry[j - i - 1].into() * cols.opcode_sll_flag
149 } + b[j - i] * cols.bit_multiplier_left
150 - AB::Expr::from_canonical_usize(1 << LIMB_BITS)
151 * cols.bit_shift_carry[j - i]
152 * cols.opcode_sll_flag;
153 when_limb_shift.assert_eq(a[j] * cols.opcode_sll_flag, expected_a_left);
154 }
155
156 if j + i > NUM_LIMBS - 1 {
158 when_limb_shift.assert_eq(
159 a[j] * right_shift.clone(),
160 cols.b_sign * AB::F::from_canonical_usize((1 << LIMB_BITS) - 1),
161 );
162 } else {
163 let expected_a_right = if j + i == NUM_LIMBS - 1 {
164 cols.b_sign * (cols.bit_multiplier_right - AB::F::ONE)
165 } else {
166 cols.bit_shift_carry[j + i + 1].into() * right_shift.clone()
167 } * AB::F::from_canonical_usize(1 << LIMB_BITS)
168 + right_shift.clone() * (b[j + i] - cols.bit_shift_carry[j + i]);
169 when_limb_shift.assert_eq(a[j] * cols.bit_multiplier_right, expected_a_right);
170 }
171 }
172 }
173 builder.when(is_valid.clone()).assert_one(limb_marker_sum);
174
175 let num_bits = AB::F::from_canonical_usize(NUM_LIMBS * LIMB_BITS);
177 self.range_bus
178 .range_check(
179 (c[0] - limb_shift * AB::F::from_canonical_usize(LIMB_BITS) - bit_shift.clone())
180 * num_bits.inverse(),
181 LIMB_BITS - ((NUM_LIMBS * LIMB_BITS) as u32).ilog2() as usize,
182 )
183 .eval(builder, is_valid.clone());
184
185 builder.assert_bool(cols.b_sign);
187 builder
188 .when(not(cols.opcode_sra_flag))
189 .assert_zero(cols.b_sign);
190
191 let mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
192 let b_sign_shifted = cols.b_sign * mask;
193 self.bitwise_lookup_bus
194 .send_xor(
195 b[NUM_LIMBS - 1],
196 mask,
197 b[NUM_LIMBS - 1] + mask - (AB::Expr::from_canonical_u32(2) * b_sign_shifted),
198 )
199 .eval(builder, cols.opcode_sra_flag);
200
201 for i in 0..(NUM_LIMBS / 2) {
202 self.bitwise_lookup_bus
203 .send_range(a[i * 2], a[i * 2 + 1])
204 .eval(builder, is_valid.clone());
205 }
206
207 for carry in cols.bit_shift_carry {
208 self.range_bus
209 .send(carry, bit_shift.clone())
210 .eval(builder, is_valid.clone());
211 }
212
213 let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
214 self,
215 flags
216 .iter()
217 .zip(ShiftOpcode::iter())
218 .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
219 acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
220 }),
221 );
222
223 AdapterAirContext {
224 to_pc: None,
225 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
226 writes: [cols.a.map(Into::into)].into(),
227 instruction: MinimalInstruction {
228 is_valid,
229 opcode: expected_opcode,
230 }
231 .into(),
232 }
233 }
234
235 fn start_offset(&self) -> usize {
236 self.offset
237 }
238}
239
240#[repr(C)]
241#[derive(Clone, Debug, Serialize, Deserialize)]
242#[serde(bound = "T: Serialize + DeserializeOwned")]
243pub struct ShiftCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
244 #[serde(with = "BigArray")]
245 pub a: [T; NUM_LIMBS],
246 #[serde(with = "BigArray")]
247 pub b: [T; NUM_LIMBS],
248 #[serde(with = "BigArray")]
249 pub c: [T; NUM_LIMBS],
250 pub b_sign: T,
251 #[serde(with = "BigArray")]
252 pub bit_shift_carry: [u32; NUM_LIMBS],
253 pub bit_shift: usize,
254 pub limb_shift: usize,
255 pub opcode: ShiftOpcode,
256}
257
258pub struct ShiftCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
259 pub air: ShiftCoreAir<NUM_LIMBS, LIMB_BITS>,
260 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
261 pub range_checker_chip: SharedVariableRangeCheckerChip,
262}
263
264impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftCoreChip<NUM_LIMBS, LIMB_BITS> {
265 pub fn new(
266 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
267 range_checker_chip: SharedVariableRangeCheckerChip,
268 offset: usize,
269 ) -> Self {
270 assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2");
271 Self {
272 air: ShiftCoreAir {
273 bitwise_lookup_bus: bitwise_lookup_chip.bus(),
274 range_bus: range_checker_chip.bus(),
275 offset,
276 },
277 bitwise_lookup_chip,
278 range_checker_chip,
279 }
280 }
281}
282
283impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
284 VmCoreChip<F, I> for ShiftCoreChip<NUM_LIMBS, LIMB_BITS>
285where
286 I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
287 I::Writes: From<[[F; NUM_LIMBS]; 1]>,
288{
289 type Record = ShiftCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
290 type Air = ShiftCoreAir<NUM_LIMBS, LIMB_BITS>;
291
292 #[allow(clippy::type_complexity)]
293 fn execute_instruction(
294 &self,
295 instruction: &Instruction<F>,
296 _from_pc: u32,
297 reads: I::Reads,
298 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
299 let Instruction { opcode, .. } = instruction;
300 let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
301
302 let data: [[F; NUM_LIMBS]; 2] = reads.into();
303 let b = data[0].map(|x| x.as_canonical_u32());
304 let c = data[1].map(|y| y.as_canonical_u32());
305 let (a, limb_shift, bit_shift) = run_shift::<NUM_LIMBS, LIMB_BITS>(shift_opcode, &b, &c);
306
307 let bit_shift_carry = array::from_fn(|i| match shift_opcode {
308 ShiftOpcode::SLL => b[i] >> (LIMB_BITS - bit_shift),
309 _ => b[i] % (1 << bit_shift),
310 });
311
312 let mut b_sign = 0;
313 if shift_opcode == ShiftOpcode::SRA {
314 b_sign = b[NUM_LIMBS - 1] >> (LIMB_BITS - 1);
315 self.bitwise_lookup_chip
316 .request_xor(b[NUM_LIMBS - 1], 1 << (LIMB_BITS - 1));
317 }
318
319 for i in 0..(NUM_LIMBS / 2) {
320 self.bitwise_lookup_chip
321 .request_range(a[i * 2], a[i * 2 + 1]);
322 }
323
324 let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]);
325 let record = ShiftCoreRecord {
326 opcode: shift_opcode,
327 a: a.map(F::from_canonical_u32),
328 b: data[0],
329 c: data[1],
330 bit_shift_carry,
331 bit_shift,
332 limb_shift,
333 b_sign: F::from_canonical_u32(b_sign),
334 };
335
336 Ok((output, record))
337 }
338
339 fn get_opcode_name(&self, opcode: usize) -> String {
340 format!("{:?}", ShiftOpcode::from_usize(opcode - self.air.offset))
341 }
342
343 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
344 for carry_val in record.bit_shift_carry {
345 self.range_checker_chip
346 .add_count(carry_val, record.bit_shift);
347 }
348
349 let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2();
350 self.range_checker_chip.add_count(
351 (((record.c[0].as_canonical_u32() as usize)
352 - record.bit_shift
353 - record.limb_shift * LIMB_BITS)
354 >> num_bits_log) as u32,
355 LIMB_BITS - num_bits_log as usize,
356 );
357
358 let row_slice: &mut ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
359 row_slice.a = record.a;
360 row_slice.b = record.b;
361 row_slice.c = record.c;
362 row_slice.bit_multiplier_left = match record.opcode {
363 ShiftOpcode::SLL => F::from_canonical_usize(1 << record.bit_shift),
364 _ => F::ZERO,
365 };
366 row_slice.bit_multiplier_right = match record.opcode {
367 ShiftOpcode::SLL => F::ZERO,
368 _ => F::from_canonical_usize(1 << record.bit_shift),
369 };
370 row_slice.b_sign = record.b_sign;
371 row_slice.bit_shift_marker = array::from_fn(|i| F::from_bool(i == record.bit_shift));
372 row_slice.limb_shift_marker = array::from_fn(|i| F::from_bool(i == record.limb_shift));
373 row_slice.bit_shift_carry = record.bit_shift_carry.map(F::from_canonical_u32);
374 row_slice.opcode_sll_flag = F::from_bool(record.opcode == ShiftOpcode::SLL);
375 row_slice.opcode_srl_flag = F::from_bool(record.opcode == ShiftOpcode::SRL);
376 row_slice.opcode_sra_flag = F::from_bool(record.opcode == ShiftOpcode::SRA);
377 }
378
379 fn air(&self) -> &Self::Air {
380 &self.air
381 }
382}
383
384pub(super) fn run_shift<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
385 opcode: ShiftOpcode,
386 x: &[u32; NUM_LIMBS],
387 y: &[u32; NUM_LIMBS],
388) -> ([u32; NUM_LIMBS], usize, usize) {
389 match opcode {
390 ShiftOpcode::SLL => run_shift_left::<NUM_LIMBS, LIMB_BITS>(x, y),
391 ShiftOpcode::SRL => run_shift_right::<NUM_LIMBS, LIMB_BITS>(x, y, true),
392 ShiftOpcode::SRA => run_shift_right::<NUM_LIMBS, LIMB_BITS>(x, y, false),
393 }
394}
395
396fn run_shift_left<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
397 x: &[u32; NUM_LIMBS],
398 y: &[u32; NUM_LIMBS],
399) -> ([u32; NUM_LIMBS], usize, usize) {
400 let mut result = [0u32; NUM_LIMBS];
401
402 let (limb_shift, bit_shift) = get_shift::<NUM_LIMBS, LIMB_BITS>(y);
403
404 for i in limb_shift..NUM_LIMBS {
405 result[i] = if i > limb_shift {
406 ((x[i - limb_shift] << bit_shift) + (x[i - limb_shift - 1] >> (LIMB_BITS - bit_shift)))
407 % (1 << LIMB_BITS)
408 } else {
409 (x[i - limb_shift] << bit_shift) % (1 << LIMB_BITS)
410 };
411 }
412 (result, limb_shift, bit_shift)
413}
414
415fn run_shift_right<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
416 x: &[u32; NUM_LIMBS],
417 y: &[u32; NUM_LIMBS],
418 logical: bool,
419) -> ([u32; NUM_LIMBS], usize, usize) {
420 let fill = if logical {
421 0
422 } else {
423 ((1 << LIMB_BITS) - 1) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1))
424 };
425 let mut result = [fill; NUM_LIMBS];
426
427 let (limb_shift, bit_shift) = get_shift::<NUM_LIMBS, LIMB_BITS>(y);
428
429 for i in 0..(NUM_LIMBS - limb_shift) {
430 result[i] = if i + limb_shift + 1 < NUM_LIMBS {
431 ((x[i + limb_shift] >> bit_shift) + (x[i + limb_shift + 1] << (LIMB_BITS - bit_shift)))
432 % (1 << LIMB_BITS)
433 } else {
434 ((x[i + limb_shift] >> bit_shift) + (fill << (LIMB_BITS - bit_shift)))
435 % (1 << LIMB_BITS)
436 }
437 }
438 (result, limb_shift, bit_shift)
439}
440
441fn get_shift<const NUM_LIMBS: usize, const LIMB_BITS: usize>(y: &[u32]) -> (usize, usize) {
442 let shift = (y[0] as usize) % (NUM_LIMBS * LIMB_BITS);
445 (shift / LIMB_BITS, shift % LIMB_BITS)
446}