openvm_rv32im_circuit/shift/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    utils::not,
13    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
14    AlignedBytesBorrow,
15};
16use openvm_circuit_primitives_derive::AlignedBorrow;
17use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
18use openvm_rv32im_transpiler::ShiftOpcode;
19use openvm_stark_backend::{
20    interaction::InteractionBuilder,
21    p3_air::{AirBuilder, BaseAir},
22    p3_field::{Field, FieldAlgebra, PrimeField32},
23    rap::BaseAirWithPublicValues,
24};
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow, Clone, Copy, Debug)]
29pub struct ShiftCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30    pub a: [T; NUM_LIMBS],
31    pub b: [T; NUM_LIMBS],
32    pub c: [T; NUM_LIMBS],
33
34    pub opcode_sll_flag: T,
35    pub opcode_srl_flag: T,
36    pub opcode_sra_flag: T,
37
38    // bit_multiplier = 2^bit_shift
39    pub bit_multiplier_left: T,
40    pub bit_multiplier_right: T,
41
42    // Sign of x for SRA
43    pub b_sign: T,
44
45    // Boolean columns that are 1 exactly at the index of the bit/limb shift amount
46    pub bit_shift_marker: [T; LIMB_BITS],
47    pub limb_shift_marker: [T; NUM_LIMBS],
48
49    // Part of each x[i] that gets bit shifted to the next limb
50    pub bit_shift_carry: [T; NUM_LIMBS],
51}
52
53/// RV32 shift AIR.
54/// Note: when the shift amount from operand is greater than the number of bits, only shift
55/// `shift_amount % num_bits` bits. This matches the RV32 specs for SLL/SRL/SRA.
56#[derive(Copy, Clone, Debug, derive_new::new)]
57pub struct ShiftCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
58    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
59    pub range_bus: VariableRangeCheckerBus,
60    pub offset: usize,
61}
62
63impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
64    for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
65{
66    fn width(&self) -> usize {
67        ShiftCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
68    }
69}
70impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
71    for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
72{
73}
74
75impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
76    for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
77where
78    AB: InteractionBuilder,
79    I: VmAdapterInterface<AB::Expr>,
80    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
81    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
82    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
83{
84    fn eval(
85        &self,
86        builder: &mut AB,
87        local_core: &[AB::Var],
88        _from_pc: AB::Var,
89    ) -> AdapterAirContext<AB::Expr, I> {
90        let cols: &ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
91        let flags = [
92            cols.opcode_sll_flag,
93            cols.opcode_srl_flag,
94            cols.opcode_sra_flag,
95        ];
96
97        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
98            builder.assert_bool(flag);
99            acc + flag.into()
100        });
101        builder.assert_bool(is_valid.clone());
102
103        let a = &cols.a;
104        let b = &cols.b;
105        let c = &cols.c;
106        let right_shift = cols.opcode_srl_flag + cols.opcode_sra_flag;
107
108        // Constrain that bit_shift, bit_multiplier are correct, i.e. that bit_multiplier =
109        // 1 << bit_shift. Because the sum of all bit_shift_marker[i] is constrained to be
110        // 1, bit_shift is guaranteed to be in range.
111        let mut bit_marker_sum = AB::Expr::ZERO;
112        let mut bit_shift = AB::Expr::ZERO;
113
114        for i in 0..LIMB_BITS {
115            builder.assert_bool(cols.bit_shift_marker[i]);
116            bit_marker_sum += cols.bit_shift_marker[i].into();
117            bit_shift += AB::Expr::from_canonical_usize(i) * cols.bit_shift_marker[i];
118
119            let mut when_bit_shift = builder.when(cols.bit_shift_marker[i]);
120            when_bit_shift.assert_eq(
121                cols.bit_multiplier_left,
122                AB::Expr::from_canonical_usize(1 << i) * cols.opcode_sll_flag,
123            );
124            when_bit_shift.assert_eq(
125                cols.bit_multiplier_right,
126                AB::Expr::from_canonical_usize(1 << i) * right_shift.clone(),
127            );
128        }
129        builder.when(is_valid.clone()).assert_one(bit_marker_sum);
130
131        // Check that a[i] = b[i] <</>> c[i] both on the bit and limb shift level if c <
132        // NUM_LIMBS * LIMB_BITS.
133        let mut limb_marker_sum = AB::Expr::ZERO;
134        let mut limb_shift = AB::Expr::ZERO;
135        for i in 0..NUM_LIMBS {
136            builder.assert_bool(cols.limb_shift_marker[i]);
137            limb_marker_sum += cols.limb_shift_marker[i].into();
138            limb_shift += AB::Expr::from_canonical_usize(i) * cols.limb_shift_marker[i];
139
140            let mut when_limb_shift = builder.when(cols.limb_shift_marker[i]);
141
142            for j in 0..NUM_LIMBS {
143                // SLL constraints
144                if j < i {
145                    when_limb_shift.assert_zero(a[j] * cols.opcode_sll_flag);
146                } else {
147                    let expected_a_left = if j - i == 0 {
148                        AB::Expr::ZERO
149                    } else {
150                        cols.bit_shift_carry[j - i - 1].into() * cols.opcode_sll_flag
151                    } + b[j - i] * cols.bit_multiplier_left
152                        - AB::Expr::from_canonical_usize(1 << LIMB_BITS)
153                            * cols.bit_shift_carry[j - i]
154                            * cols.opcode_sll_flag;
155                    when_limb_shift.assert_eq(a[j] * cols.opcode_sll_flag, expected_a_left);
156                }
157
158                // SRL and SRA constraints. Combining with above would require an additional column.
159                if j + i > NUM_LIMBS - 1 {
160                    when_limb_shift.assert_eq(
161                        a[j] * right_shift.clone(),
162                        cols.b_sign * AB::F::from_canonical_usize((1 << LIMB_BITS) - 1),
163                    );
164                } else {
165                    let expected_a_right = if j + i == NUM_LIMBS - 1 {
166                        cols.b_sign * (cols.bit_multiplier_right - AB::F::ONE)
167                    } else {
168                        cols.bit_shift_carry[j + i + 1].into() * right_shift.clone()
169                    } * AB::F::from_canonical_usize(1 << LIMB_BITS)
170                        + right_shift.clone() * (b[j + i] - cols.bit_shift_carry[j + i]);
171                    when_limb_shift.assert_eq(a[j] * cols.bit_multiplier_right, expected_a_right);
172                }
173            }
174        }
175        builder.when(is_valid.clone()).assert_one(limb_marker_sum);
176
177        // Check that bit_shift and limb_shift are correct.
178        let num_bits = AB::F::from_canonical_usize(NUM_LIMBS * LIMB_BITS);
179        self.range_bus
180            .range_check(
181                (c[0] - limb_shift * AB::F::from_canonical_usize(LIMB_BITS) - bit_shift.clone())
182                    * num_bits.inverse(),
183                LIMB_BITS - ((NUM_LIMBS * LIMB_BITS) as u32).ilog2() as usize,
184            )
185            .eval(builder, is_valid.clone());
186
187        // Check b_sign & b[NUM_LIMBS - 1] == b_sign using XOR
188        builder.assert_bool(cols.b_sign);
189        builder
190            .when(not(cols.opcode_sra_flag))
191            .assert_zero(cols.b_sign);
192
193        let mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
194        let b_sign_shifted = cols.b_sign * mask;
195        self.bitwise_lookup_bus
196            .send_xor(
197                b[NUM_LIMBS - 1],
198                mask,
199                b[NUM_LIMBS - 1] + mask - (AB::Expr::from_canonical_u32(2) * b_sign_shifted),
200            )
201            .eval(builder, cols.opcode_sra_flag);
202
203        for i in 0..(NUM_LIMBS / 2) {
204            self.bitwise_lookup_bus
205                .send_range(a[i * 2], a[i * 2 + 1])
206                .eval(builder, is_valid.clone());
207        }
208
209        for carry in cols.bit_shift_carry {
210            self.range_bus
211                .send(carry, bit_shift.clone())
212                .eval(builder, is_valid.clone());
213        }
214
215        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
216            self,
217            flags
218                .iter()
219                .zip(ShiftOpcode::iter())
220                .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
221                    acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
222                }),
223        );
224
225        AdapterAirContext {
226            to_pc: None,
227            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
228            writes: [cols.a.map(Into::into)].into(),
229            instruction: MinimalInstruction {
230                is_valid,
231                opcode: expected_opcode,
232            }
233            .into(),
234        }
235    }
236
237    fn start_offset(&self) -> usize {
238        self.offset
239    }
240}
241
242#[repr(C)]
243#[derive(AlignedBytesBorrow, Debug)]
244pub struct ShiftCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
245    pub b: [u8; NUM_LIMBS],
246    pub c: [u8; NUM_LIMBS],
247    pub local_opcode: u8,
248}
249
250#[derive(Clone, Copy)]
251pub struct ShiftExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
252    adapter: A,
253    pub offset: usize,
254}
255
256#[derive(Clone)]
257pub struct ShiftFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
258    adapter: A,
259    pub offset: usize,
260    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
261    pub range_checker_chip: SharedVariableRangeCheckerChip,
262}
263
264impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftExecutor<A, NUM_LIMBS, LIMB_BITS> {
265    pub fn new(adapter: A, offset: usize) -> Self {
266        assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2");
267        Self { adapter, offset }
268    }
269}
270
271impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftFiller<A, NUM_LIMBS, LIMB_BITS> {
272    pub fn new(
273        adapter: A,
274        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
275        range_checker_chip: SharedVariableRangeCheckerChip,
276        offset: usize,
277    ) -> Self {
278        assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2");
279        Self {
280            adapter,
281            offset,
282            bitwise_lookup_chip,
283            range_checker_chip,
284        }
285    }
286}
287
288impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
289    for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
290where
291    F: PrimeField32,
292    A: 'static
293        + AdapterTraceExecutor<
294            F,
295            ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
296            WriteData: From<[[u8; NUM_LIMBS]; 1]>,
297        >,
298    for<'buf> RA: RecordArena<
299        'buf,
300        EmptyAdapterCoreLayout<F, A>,
301        (
302            A::RecordMut<'buf>,
303            &'buf mut ShiftCoreRecord<NUM_LIMBS, LIMB_BITS>,
304        ),
305    >,
306{
307    fn get_opcode_name(&self, opcode: usize) -> String {
308        format!("{:?}", ShiftOpcode::from_usize(opcode - self.offset))
309    }
310
311    fn execute(
312        &self,
313        state: VmStateMut<F, TracingMemory, RA>,
314        instruction: &Instruction<F>,
315    ) -> Result<(), ExecutionError> {
316        let Instruction { opcode, .. } = instruction;
317
318        let local_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset));
319
320        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
321
322        A::start(*state.pc, state.memory, &mut adapter_record);
323
324        let [rs1, rs2] = self
325            .adapter
326            .read(state.memory, instruction, &mut adapter_record)
327            .into();
328
329        let (output, _, _) = run_shift::<NUM_LIMBS, LIMB_BITS>(local_opcode, &rs1, &rs2);
330
331        core_record.b = rs1;
332        core_record.c = rs2;
333        core_record.local_opcode = local_opcode as u8;
334
335        self.adapter.write(
336            state.memory,
337            instruction,
338            [output].into(),
339            &mut adapter_record,
340        );
341        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
342
343        Ok(())
344    }
345}
346
347impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
348    for ShiftFiller<A, NUM_LIMBS, LIMB_BITS>
349where
350    F: PrimeField32,
351    A: 'static + AdapterTraceFiller<F>,
352{
353    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
354        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
355        // ShiftCoreCols::width() elements
356        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
357        self.adapter.fill_trace_row(mem_helper, adapter_row);
358        // SAFETY: core_row contains a valid ShiftCoreRecord written by the executor
359        // during trace generation
360        let record: &ShiftCoreRecord<NUM_LIMBS, LIMB_BITS> =
361            unsafe { get_record_from_slice(&mut core_row, ()) };
362
363        let core_row: &mut ShiftCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
364
365        let opcode = ShiftOpcode::from_usize(record.local_opcode as usize);
366        let (a, limb_shift, bit_shift) =
367            run_shift::<NUM_LIMBS, LIMB_BITS>(opcode, &record.b, &record.c);
368
369        for pair in a.chunks_exact(2) {
370            self.bitwise_lookup_chip
371                .request_range(pair[0] as u32, pair[1] as u32);
372        }
373
374        let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2();
375        self.range_checker_chip.add_count(
376            ((record.c[0] as usize - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u32,
377            LIMB_BITS - num_bits_log as usize,
378        );
379
380        core_row.bit_shift_carry = if bit_shift == 0 {
381            for _ in 0..NUM_LIMBS {
382                self.range_checker_chip.add_count(0, 0);
383            }
384            [F::ZERO; NUM_LIMBS]
385        } else {
386            array::from_fn(|i| {
387                let carry = match opcode {
388                    ShiftOpcode::SLL => record.b[i] >> (LIMB_BITS - bit_shift),
389                    _ => record.b[i] % (1 << bit_shift),
390                };
391                self.range_checker_chip.add_count(carry as u32, bit_shift);
392                F::from_canonical_u8(carry)
393            })
394        };
395
396        core_row.limb_shift_marker = [F::ZERO; NUM_LIMBS];
397        core_row.limb_shift_marker[limb_shift] = F::ONE;
398        core_row.bit_shift_marker = [F::ZERO; LIMB_BITS];
399        core_row.bit_shift_marker[bit_shift] = F::ONE;
400
401        core_row.b_sign = F::ZERO;
402        if opcode == ShiftOpcode::SRA {
403            core_row.b_sign = F::from_canonical_u8(record.b[NUM_LIMBS - 1] >> (LIMB_BITS - 1));
404            self.bitwise_lookup_chip
405                .request_xor(record.b[NUM_LIMBS - 1] as u32, 1 << (LIMB_BITS - 1));
406        }
407
408        core_row.bit_multiplier_right = match opcode {
409            ShiftOpcode::SLL => F::ZERO,
410            _ => F::from_canonical_usize(1 << bit_shift),
411        };
412        core_row.bit_multiplier_left = match opcode {
413            ShiftOpcode::SLL => F::from_canonical_usize(1 << bit_shift),
414            _ => F::ZERO,
415        };
416
417        core_row.opcode_sra_flag = F::from_bool(opcode == ShiftOpcode::SRA);
418        core_row.opcode_srl_flag = F::from_bool(opcode == ShiftOpcode::SRL);
419        core_row.opcode_sll_flag = F::from_bool(opcode == ShiftOpcode::SLL);
420
421        core_row.c = record.c.map(F::from_canonical_u8);
422        core_row.b = record.b.map(F::from_canonical_u8);
423        core_row.a = a.map(F::from_canonical_u8);
424    }
425}
426
427// Returns (result, limb_shift, bit_shift)
428#[inline(always)]
429pub(super) fn run_shift<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
430    opcode: ShiftOpcode,
431    x: &[u8; NUM_LIMBS],
432    y: &[u8; NUM_LIMBS],
433) -> ([u8; NUM_LIMBS], usize, usize) {
434    match opcode {
435        ShiftOpcode::SLL => run_shift_left::<NUM_LIMBS, LIMB_BITS>(x, y),
436        ShiftOpcode::SRL => run_shift_right::<NUM_LIMBS, LIMB_BITS>(x, y, true),
437        ShiftOpcode::SRA => run_shift_right::<NUM_LIMBS, LIMB_BITS>(x, y, false),
438    }
439}
440
441#[inline(always)]
442fn run_shift_left<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
443    x: &[u8; NUM_LIMBS],
444    y: &[u8; NUM_LIMBS],
445) -> ([u8; NUM_LIMBS], usize, usize) {
446    let mut result = [0u8; NUM_LIMBS];
447
448    let (limb_shift, bit_shift) = get_shift::<NUM_LIMBS, LIMB_BITS>(y);
449
450    for i in limb_shift..NUM_LIMBS {
451        result[i] = if i > limb_shift {
452            (((x[i - limb_shift] as u16) << bit_shift)
453                | ((x[i - limb_shift - 1] as u16) >> (LIMB_BITS - bit_shift)))
454                % (1u16 << LIMB_BITS)
455        } else {
456            ((x[i - limb_shift] as u16) << bit_shift) % (1u16 << LIMB_BITS)
457        } as u8;
458    }
459    (result, limb_shift, bit_shift)
460}
461
462#[inline(always)]
463fn run_shift_right<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
464    x: &[u8; NUM_LIMBS],
465    y: &[u8; NUM_LIMBS],
466    logical: bool,
467) -> ([u8; NUM_LIMBS], usize, usize) {
468    let fill = if logical {
469        0
470    } else {
471        (((1u16 << LIMB_BITS) - 1) as u8) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1))
472    };
473    let mut result = [fill; NUM_LIMBS];
474
475    let (limb_shift, bit_shift) = get_shift::<NUM_LIMBS, LIMB_BITS>(y);
476
477    for i in 0..(NUM_LIMBS - limb_shift) {
478        let res = if i + limb_shift + 1 < NUM_LIMBS {
479            (((x[i + limb_shift] >> bit_shift) as u16)
480                | ((x[i + limb_shift + 1] as u16) << (LIMB_BITS - bit_shift)))
481                % (1u16 << LIMB_BITS)
482        } else {
483            (((x[i + limb_shift] >> bit_shift) as u16) | ((fill as u16) << (LIMB_BITS - bit_shift)))
484                % (1u16 << LIMB_BITS)
485        };
486        result[i] = res as u8;
487    }
488    (result, limb_shift, bit_shift)
489}
490
491#[inline(always)]
492fn get_shift<const NUM_LIMBS: usize, const LIMB_BITS: usize>(y: &[u8]) -> (usize, usize) {
493    debug_assert!(NUM_LIMBS * LIMB_BITS <= (1 << LIMB_BITS));
494    // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so the shift is defined
495    // entirely in y[0].
496    let shift = (y[0] as usize) % (NUM_LIMBS * LIMB_BITS);
497    (shift / LIMB_BITS, shift % LIMB_BITS)
498}