openvm_native_circuit/castf/
core.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::arch::{
4    AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
5    VmCoreAir, VmCoreChip,
6};
7use openvm_circuit_primitives::var_range::{
8    SharedVariableRangeCheckerChip, VariableRangeCheckerBus,
9};
10use openvm_circuit_primitives_derive::AlignedBorrow;
11use openvm_instructions::{instruction::Instruction, LocalOpcode};
12use openvm_native_compiler::CastfOpcode;
13use openvm_rv32im_circuit::adapters::RV32_REGISTER_NUM_LIMBS;
14use openvm_stark_backend::{
15    interaction::InteractionBuilder,
16    p3_air::BaseAir,
17    p3_field::{Field, FieldAlgebra, PrimeField32},
18    rap::BaseAirWithPublicValues,
19};
20use serde::{Deserialize, Serialize};
21
22// LIMB_BITS is the size of the limbs in bits.
23pub(crate) const LIMB_BITS: usize = 8;
24// the final limb has only 6 bits
25pub(crate) const FINAL_LIMB_BITS: usize = 6;
26
27#[repr(C)]
28#[derive(AlignedBorrow)]
29pub struct CastFCoreCols<T> {
30    pub in_val: T,
31    pub out_val: [T; RV32_REGISTER_NUM_LIMBS],
32    pub is_valid: T,
33}
34
35#[derive(Copy, Clone, Debug)]
36pub struct CastFCoreAir {
37    pub bus: VariableRangeCheckerBus, // to communicate with the range checker that checks that all limbs are < 2^LIMB_BITS
38}
39
40impl<F: Field> BaseAir<F> for CastFCoreAir {
41    fn width(&self) -> usize {
42        CastFCoreCols::<F>::width()
43    }
44}
45
46impl<F: Field> BaseAirWithPublicValues<F> for CastFCoreAir {}
47
48impl<AB, I> VmCoreAir<AB, I> for CastFCoreAir
49where
50    AB: InteractionBuilder,
51    I: VmAdapterInterface<AB::Expr>,
52    I::Reads: From<[[AB::Expr; 1]; 1]>,
53    I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
54    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
55{
56    fn eval(
57        &self,
58        builder: &mut AB,
59        local_core: &[AB::Var],
60        _from_pc: AB::Var,
61    ) -> AdapterAirContext<AB::Expr, I> {
62        let cols: &CastFCoreCols<_> = local_core.borrow();
63
64        builder.assert_bool(cols.is_valid);
65
66        let intermed_val = cols
67            .out_val
68            .iter()
69            .enumerate()
70            .fold(AB::Expr::ZERO, |acc, (i, &limb)| {
71                acc + limb * AB::Expr::from_canonical_u32(1 << (i * LIMB_BITS))
72            });
73
74        for i in 0..4 {
75            self.bus
76                .range_check(
77                    cols.out_val[i],
78                    match i {
79                        0..=2 => LIMB_BITS,
80                        3 => FINAL_LIMB_BITS,
81                        _ => unreachable!(),
82                    },
83                )
84                .eval(builder, cols.is_valid);
85        }
86
87        AdapterAirContext {
88            to_pc: None,
89            reads: [[intermed_val]].into(),
90            writes: [cols.out_val.map(Into::into)].into(),
91            instruction: MinimalInstruction {
92                is_valid: cols.is_valid.into(),
93                opcode: AB::Expr::from_canonical_usize(
94                    CastfOpcode::CASTF.global_opcode().as_usize(),
95                ),
96            }
97            .into(),
98        }
99    }
100
101    fn start_offset(&self) -> usize {
102        CastfOpcode::CLASS_OFFSET
103    }
104}
105
106#[repr(C)]
107#[derive(Debug, Serialize, Deserialize)]
108pub struct CastFRecord<F> {
109    pub in_val: F,
110    pub out_val: [u32; RV32_REGISTER_NUM_LIMBS],
111}
112
113pub struct CastFCoreChip {
114    pub air: CastFCoreAir,
115    pub range_checker_chip: SharedVariableRangeCheckerChip,
116}
117
118impl CastFCoreChip {
119    pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self {
120        Self {
121            air: CastFCoreAir {
122                bus: range_checker_chip.bus(),
123            },
124            range_checker_chip,
125        }
126    }
127}
128
129impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for CastFCoreChip
130where
131    I::Reads: Into<[[F; 1]; 1]>,
132    I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>,
133{
134    type Record = CastFRecord<F>;
135    type Air = CastFCoreAir;
136
137    #[allow(clippy::type_complexity)]
138    fn execute_instruction(
139        &self,
140        instruction: &Instruction<F>,
141        _from_pc: u32,
142        reads: I::Reads,
143    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
144        let Instruction { opcode, .. } = instruction;
145
146        assert_eq!(
147            opcode.local_opcode_idx(CastfOpcode::CLASS_OFFSET),
148            CastfOpcode::CASTF as usize
149        );
150
151        let y = reads.into()[0][0];
152        let x = CastF::solve(y.as_canonical_u32());
153
154        let output = AdapterRuntimeContext {
155            to_pc: None,
156            writes: [x.map(F::from_canonical_u32)].into(),
157        };
158
159        let record = CastFRecord {
160            in_val: y,
161            out_val: x,
162        };
163
164        Ok((output, record))
165    }
166
167    fn get_opcode_name(&self, _opcode: usize) -> String {
168        format!("{:?}", CastfOpcode::CASTF)
169    }
170
171    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
172        for (i, limb) in record.out_val.iter().enumerate() {
173            if i == 3 {
174                self.range_checker_chip.add_count(*limb, FINAL_LIMB_BITS);
175            } else {
176                self.range_checker_chip.add_count(*limb, LIMB_BITS);
177            }
178        }
179
180        let cols: &mut CastFCoreCols<F> = row_slice.borrow_mut();
181        cols.in_val = record.in_val;
182        cols.out_val = record.out_val.map(F::from_canonical_u32);
183        cols.is_valid = F::ONE;
184    }
185
186    fn air(&self) -> &Self::Air {
187        &self.air
188    }
189}
190
191pub struct CastF;
192impl CastF {
193    pub(super) fn solve(y: u32) -> [u32; RV32_REGISTER_NUM_LIMBS] {
194        let mut x = [0; 4];
195        for (i, limb) in x.iter_mut().enumerate() {
196            *limb = (y >> (8 * i)) & 0xFF;
197        }
198        x
199    }
200}