openvm_rv32im_circuit/shift/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
8    VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    utils::not,
13    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{instruction::Instruction, LocalOpcode};
17use openvm_rv32im_transpiler::ShiftOpcode;
18use openvm_stark_backend::{
19    interaction::InteractionBuilder,
20    p3_air::{AirBuilder, BaseAir},
21    p3_field::{Field, FieldAlgebra, PrimeField32},
22    rap::BaseAirWithPublicValues,
23};
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use serde_big_array::BigArray;
26use strum::IntoEnumIterator;
27
28#[repr(C)]
29#[derive(AlignedBorrow, Clone, Copy, Debug)]
30pub struct ShiftCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
31    pub a: [T; NUM_LIMBS],
32    pub b: [T; NUM_LIMBS],
33    pub c: [T; NUM_LIMBS],
34
35    pub opcode_sll_flag: T,
36    pub opcode_srl_flag: T,
37    pub opcode_sra_flag: T,
38
39    // bit_multiplier = 2^bit_shift
40    pub bit_multiplier_left: T,
41    pub bit_multiplier_right: T,
42
43    // Sign of x for SRA
44    pub b_sign: T,
45
46    // Boolean columns that are 1 exactly at the index of the bit/limb shift amount
47    pub bit_shift_marker: [T; LIMB_BITS],
48    pub limb_shift_marker: [T; NUM_LIMBS],
49
50    // Part of each x[i] that gets bit shifted to the next limb
51    pub bit_shift_carry: [T; NUM_LIMBS],
52}
53
54#[derive(Copy, Clone, Debug)]
55pub struct ShiftCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
56    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
57    pub range_bus: VariableRangeCheckerBus,
58    pub offset: usize,
59}
60
61impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
62    for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
63{
64    fn width(&self) -> usize {
65        ShiftCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
66    }
67}
68impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
69    for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
70{
71}
72
73impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
74    for ShiftCoreAir<NUM_LIMBS, LIMB_BITS>
75where
76    AB: InteractionBuilder,
77    I: VmAdapterInterface<AB::Expr>,
78    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
79    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
80    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
81{
82    fn eval(
83        &self,
84        builder: &mut AB,
85        local_core: &[AB::Var],
86        _from_pc: AB::Var,
87    ) -> AdapterAirContext<AB::Expr, I> {
88        let cols: &ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
89        let flags = [
90            cols.opcode_sll_flag,
91            cols.opcode_srl_flag,
92            cols.opcode_sra_flag,
93        ];
94
95        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
96            builder.assert_bool(flag);
97            acc + flag.into()
98        });
99        builder.assert_bool(is_valid.clone());
100
101        let a = &cols.a;
102        let b = &cols.b;
103        let c = &cols.c;
104        let right_shift = cols.opcode_srl_flag + cols.opcode_sra_flag;
105
106        // Constrain that bit_shift, bit_multiplier are correct, i.e. that bit_multiplier =
107        // 1 << bit_shift. Because the sum of all bit_shift_marker[i] is constrained to be
108        // 1, bit_shift is guaranteed to be in range.
109        let mut bit_marker_sum = AB::Expr::ZERO;
110        let mut bit_shift = AB::Expr::ZERO;
111
112        for i in 0..LIMB_BITS {
113            builder.assert_bool(cols.bit_shift_marker[i]);
114            bit_marker_sum += cols.bit_shift_marker[i].into();
115            bit_shift += AB::Expr::from_canonical_usize(i) * cols.bit_shift_marker[i];
116
117            let mut when_bit_shift = builder.when(cols.bit_shift_marker[i]);
118            when_bit_shift.assert_eq(
119                cols.bit_multiplier_left,
120                AB::Expr::from_canonical_usize(1 << i) * cols.opcode_sll_flag,
121            );
122            when_bit_shift.assert_eq(
123                cols.bit_multiplier_right,
124                AB::Expr::from_canonical_usize(1 << i) * right_shift.clone(),
125            );
126        }
127        builder.when(is_valid.clone()).assert_one(bit_marker_sum);
128
129        // Check that a[i] = b[i] <</>> c[i] both on the bit and limb shift level if c <
130        // NUM_LIMBS * LIMB_BITS.
131        let mut limb_marker_sum = AB::Expr::ZERO;
132        let mut limb_shift = AB::Expr::ZERO;
133        for i in 0..NUM_LIMBS {
134            builder.assert_bool(cols.limb_shift_marker[i]);
135            limb_marker_sum += cols.limb_shift_marker[i].into();
136            limb_shift += AB::Expr::from_canonical_usize(i) * cols.limb_shift_marker[i];
137
138            let mut when_limb_shift = builder.when(cols.limb_shift_marker[i]);
139
140            for j in 0..NUM_LIMBS {
141                // SLL constraints
142                if j < i {
143                    when_limb_shift.assert_zero(a[j] * cols.opcode_sll_flag);
144                } else {
145                    let expected_a_left = if j - i == 0 {
146                        AB::Expr::ZERO
147                    } else {
148                        cols.bit_shift_carry[j - i - 1].into() * cols.opcode_sll_flag
149                    } + b[j - i] * cols.bit_multiplier_left
150                        - AB::Expr::from_canonical_usize(1 << LIMB_BITS)
151                            * cols.bit_shift_carry[j - i]
152                            * cols.opcode_sll_flag;
153                    when_limb_shift.assert_eq(a[j] * cols.opcode_sll_flag, expected_a_left);
154                }
155
156                // SRL and SRA constraints. Combining with above would require an additional column.
157                if j + i > NUM_LIMBS - 1 {
158                    when_limb_shift.assert_eq(
159                        a[j] * right_shift.clone(),
160                        cols.b_sign * AB::F::from_canonical_usize((1 << LIMB_BITS) - 1),
161                    );
162                } else {
163                    let expected_a_right = if j + i == NUM_LIMBS - 1 {
164                        cols.b_sign * (cols.bit_multiplier_right - AB::F::ONE)
165                    } else {
166                        cols.bit_shift_carry[j + i + 1].into() * right_shift.clone()
167                    } * AB::F::from_canonical_usize(1 << LIMB_BITS)
168                        + right_shift.clone() * (b[j + i] - cols.bit_shift_carry[j + i]);
169                    when_limb_shift.assert_eq(a[j] * cols.bit_multiplier_right, expected_a_right);
170                }
171            }
172        }
173        builder.when(is_valid.clone()).assert_one(limb_marker_sum);
174
175        // Check that bit_shift and limb_shift are correct.
176        let num_bits = AB::F::from_canonical_usize(NUM_LIMBS * LIMB_BITS);
177        self.range_bus
178            .range_check(
179                (c[0] - limb_shift * AB::F::from_canonical_usize(LIMB_BITS) - bit_shift.clone())
180                    * num_bits.inverse(),
181                LIMB_BITS - ((NUM_LIMBS * LIMB_BITS) as u32).ilog2() as usize,
182            )
183            .eval(builder, is_valid.clone());
184
185        // Check b_sign & b[NUM_LIMBS - 1] == b_sign using XOR
186        builder.assert_bool(cols.b_sign);
187        builder
188            .when(not(cols.opcode_sra_flag))
189            .assert_zero(cols.b_sign);
190
191        let mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
192        let b_sign_shifted = cols.b_sign * mask;
193        self.bitwise_lookup_bus
194            .send_xor(
195                b[NUM_LIMBS - 1],
196                mask,
197                b[NUM_LIMBS - 1] + mask - (AB::Expr::from_canonical_u32(2) * b_sign_shifted),
198            )
199            .eval(builder, cols.opcode_sra_flag);
200
201        for i in 0..(NUM_LIMBS / 2) {
202            self.bitwise_lookup_bus
203                .send_range(a[i * 2], a[i * 2 + 1])
204                .eval(builder, is_valid.clone());
205        }
206
207        for carry in cols.bit_shift_carry {
208            self.range_bus
209                .send(carry, bit_shift.clone())
210                .eval(builder, is_valid.clone());
211        }
212
213        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
214            self,
215            flags
216                .iter()
217                .zip(ShiftOpcode::iter())
218                .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
219                    acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
220                }),
221        );
222
223        AdapterAirContext {
224            to_pc: None,
225            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
226            writes: [cols.a.map(Into::into)].into(),
227            instruction: MinimalInstruction {
228                is_valid,
229                opcode: expected_opcode,
230            }
231            .into(),
232        }
233    }
234
235    fn start_offset(&self) -> usize {
236        self.offset
237    }
238}
239
240#[repr(C)]
241#[derive(Clone, Debug, Serialize, Deserialize)]
242#[serde(bound = "T: Serialize + DeserializeOwned")]
243pub struct ShiftCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
244    #[serde(with = "BigArray")]
245    pub a: [T; NUM_LIMBS],
246    #[serde(with = "BigArray")]
247    pub b: [T; NUM_LIMBS],
248    #[serde(with = "BigArray")]
249    pub c: [T; NUM_LIMBS],
250    pub b_sign: T,
251    #[serde(with = "BigArray")]
252    pub bit_shift_carry: [u32; NUM_LIMBS],
253    pub bit_shift: usize,
254    pub limb_shift: usize,
255    pub opcode: ShiftOpcode,
256}
257
258pub struct ShiftCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
259    pub air: ShiftCoreAir<NUM_LIMBS, LIMB_BITS>,
260    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
261    pub range_checker_chip: SharedVariableRangeCheckerChip,
262}
263
264impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftCoreChip<NUM_LIMBS, LIMB_BITS> {
265    pub fn new(
266        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
267        range_checker_chip: SharedVariableRangeCheckerChip,
268        offset: usize,
269    ) -> Self {
270        assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2");
271        Self {
272            air: ShiftCoreAir {
273                bitwise_lookup_bus: bitwise_lookup_chip.bus(),
274                range_bus: range_checker_chip.bus(),
275                offset,
276            },
277            bitwise_lookup_chip,
278            range_checker_chip,
279        }
280    }
281}
282
283impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
284    VmCoreChip<F, I> for ShiftCoreChip<NUM_LIMBS, LIMB_BITS>
285where
286    I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
287    I::Writes: From<[[F; NUM_LIMBS]; 1]>,
288{
289    type Record = ShiftCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
290    type Air = ShiftCoreAir<NUM_LIMBS, LIMB_BITS>;
291
292    #[allow(clippy::type_complexity)]
293    fn execute_instruction(
294        &self,
295        instruction: &Instruction<F>,
296        _from_pc: u32,
297        reads: I::Reads,
298    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
299        let Instruction { opcode, .. } = instruction;
300        let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
301
302        let data: [[F; NUM_LIMBS]; 2] = reads.into();
303        let b = data[0].map(|x| x.as_canonical_u32());
304        let c = data[1].map(|y| y.as_canonical_u32());
305        let (a, limb_shift, bit_shift) = run_shift::<NUM_LIMBS, LIMB_BITS>(shift_opcode, &b, &c);
306
307        let bit_shift_carry = array::from_fn(|i| match shift_opcode {
308            ShiftOpcode::SLL => b[i] >> (LIMB_BITS - bit_shift),
309            _ => b[i] % (1 << bit_shift),
310        });
311
312        let mut b_sign = 0;
313        if shift_opcode == ShiftOpcode::SRA {
314            b_sign = b[NUM_LIMBS - 1] >> (LIMB_BITS - 1);
315            self.bitwise_lookup_chip
316                .request_xor(b[NUM_LIMBS - 1], 1 << (LIMB_BITS - 1));
317        }
318
319        for i in 0..(NUM_LIMBS / 2) {
320            self.bitwise_lookup_chip
321                .request_range(a[i * 2], a[i * 2 + 1]);
322        }
323
324        let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]);
325        let record = ShiftCoreRecord {
326            opcode: shift_opcode,
327            a: a.map(F::from_canonical_u32),
328            b: data[0],
329            c: data[1],
330            bit_shift_carry,
331            bit_shift,
332            limb_shift,
333            b_sign: F::from_canonical_u32(b_sign),
334        };
335
336        Ok((output, record))
337    }
338
339    fn get_opcode_name(&self, opcode: usize) -> String {
340        format!("{:?}", ShiftOpcode::from_usize(opcode - self.air.offset))
341    }
342
343    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
344        for carry_val in record.bit_shift_carry {
345            self.range_checker_chip
346                .add_count(carry_val, record.bit_shift);
347        }
348
349        let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2();
350        self.range_checker_chip.add_count(
351            (((record.c[0].as_canonical_u32() as usize)
352                - record.bit_shift
353                - record.limb_shift * LIMB_BITS)
354                >> num_bits_log) as u32,
355            LIMB_BITS - num_bits_log as usize,
356        );
357
358        let row_slice: &mut ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
359        row_slice.a = record.a;
360        row_slice.b = record.b;
361        row_slice.c = record.c;
362        row_slice.bit_multiplier_left = match record.opcode {
363            ShiftOpcode::SLL => F::from_canonical_usize(1 << record.bit_shift),
364            _ => F::ZERO,
365        };
366        row_slice.bit_multiplier_right = match record.opcode {
367            ShiftOpcode::SLL => F::ZERO,
368            _ => F::from_canonical_usize(1 << record.bit_shift),
369        };
370        row_slice.b_sign = record.b_sign;
371        row_slice.bit_shift_marker = array::from_fn(|i| F::from_bool(i == record.bit_shift));
372        row_slice.limb_shift_marker = array::from_fn(|i| F::from_bool(i == record.limb_shift));
373        row_slice.bit_shift_carry = record.bit_shift_carry.map(F::from_canonical_u32);
374        row_slice.opcode_sll_flag = F::from_bool(record.opcode == ShiftOpcode::SLL);
375        row_slice.opcode_srl_flag = F::from_bool(record.opcode == ShiftOpcode::SRL);
376        row_slice.opcode_sra_flag = F::from_bool(record.opcode == ShiftOpcode::SRA);
377    }
378
379    fn air(&self) -> &Self::Air {
380        &self.air
381    }
382}
383
384pub(super) fn run_shift<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
385    opcode: ShiftOpcode,
386    x: &[u32; NUM_LIMBS],
387    y: &[u32; NUM_LIMBS],
388) -> ([u32; NUM_LIMBS], usize, usize) {
389    match opcode {
390        ShiftOpcode::SLL => run_shift_left::<NUM_LIMBS, LIMB_BITS>(x, y),
391        ShiftOpcode::SRL => run_shift_right::<NUM_LIMBS, LIMB_BITS>(x, y, true),
392        ShiftOpcode::SRA => run_shift_right::<NUM_LIMBS, LIMB_BITS>(x, y, false),
393    }
394}
395
396fn run_shift_left<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
397    x: &[u32; NUM_LIMBS],
398    y: &[u32; NUM_LIMBS],
399) -> ([u32; NUM_LIMBS], usize, usize) {
400    let mut result = [0u32; NUM_LIMBS];
401
402    let (limb_shift, bit_shift) = get_shift::<NUM_LIMBS, LIMB_BITS>(y);
403
404    for i in limb_shift..NUM_LIMBS {
405        result[i] = if i > limb_shift {
406            ((x[i - limb_shift] << bit_shift) + (x[i - limb_shift - 1] >> (LIMB_BITS - bit_shift)))
407                % (1 << LIMB_BITS)
408        } else {
409            (x[i - limb_shift] << bit_shift) % (1 << LIMB_BITS)
410        };
411    }
412    (result, limb_shift, bit_shift)
413}
414
415fn run_shift_right<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
416    x: &[u32; NUM_LIMBS],
417    y: &[u32; NUM_LIMBS],
418    logical: bool,
419) -> ([u32; NUM_LIMBS], usize, usize) {
420    let fill = if logical {
421        0
422    } else {
423        ((1 << LIMB_BITS) - 1) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1))
424    };
425    let mut result = [fill; NUM_LIMBS];
426
427    let (limb_shift, bit_shift) = get_shift::<NUM_LIMBS, LIMB_BITS>(y);
428
429    for i in 0..(NUM_LIMBS - limb_shift) {
430        result[i] = if i + limb_shift + 1 < NUM_LIMBS {
431            ((x[i + limb_shift] >> bit_shift) + (x[i + limb_shift + 1] << (LIMB_BITS - bit_shift)))
432                % (1 << LIMB_BITS)
433        } else {
434            ((x[i + limb_shift] >> bit_shift) + (fill << (LIMB_BITS - bit_shift)))
435                % (1 << LIMB_BITS)
436        }
437    }
438    (result, limb_shift, bit_shift)
439}
440
441fn get_shift<const NUM_LIMBS: usize, const LIMB_BITS: usize>(y: &[u32]) -> (usize, usize) {
442    // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so so the shift is defined
443    // entirely in y[0].
444    let shift = (y[0] as usize) % (NUM_LIMBS * LIMB_BITS);
445    (shift / LIMB_BITS, shift % LIMB_BITS)
446}