openvm_native_circuit/castf/
core.rs
1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::arch::{
4 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
5 VmCoreAir, VmCoreChip,
6};
7use openvm_circuit_primitives::var_range::{
8 SharedVariableRangeCheckerChip, VariableRangeCheckerBus,
9};
10use openvm_circuit_primitives_derive::AlignedBorrow;
11use openvm_instructions::{instruction::Instruction, LocalOpcode};
12use openvm_native_compiler::CastfOpcode;
13use openvm_rv32im_circuit::adapters::RV32_REGISTER_NUM_LIMBS;
14use openvm_stark_backend::{
15 interaction::InteractionBuilder,
16 p3_air::BaseAir,
17 p3_field::{Field, FieldAlgebra, PrimeField32},
18 rap::BaseAirWithPublicValues,
19};
20use serde::{Deserialize, Serialize};
21
22pub(crate) const LIMB_BITS: usize = 8;
24pub(crate) const FINAL_LIMB_BITS: usize = 6;
26
27#[repr(C)]
28#[derive(AlignedBorrow)]
29pub struct CastFCoreCols<T> {
30 pub in_val: T,
31 pub out_val: [T; RV32_REGISTER_NUM_LIMBS],
32 pub is_valid: T,
33}
34
35#[derive(Copy, Clone, Debug)]
36pub struct CastFCoreAir {
37 pub bus: VariableRangeCheckerBus, }
39
40impl<F: Field> BaseAir<F> for CastFCoreAir {
41 fn width(&self) -> usize {
42 CastFCoreCols::<F>::width()
43 }
44}
45
46impl<F: Field> BaseAirWithPublicValues<F> for CastFCoreAir {}
47
48impl<AB, I> VmCoreAir<AB, I> for CastFCoreAir
49where
50 AB: InteractionBuilder,
51 I: VmAdapterInterface<AB::Expr>,
52 I::Reads: From<[[AB::Expr; 1]; 1]>,
53 I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
54 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
55{
56 fn eval(
57 &self,
58 builder: &mut AB,
59 local_core: &[AB::Var],
60 _from_pc: AB::Var,
61 ) -> AdapterAirContext<AB::Expr, I> {
62 let cols: &CastFCoreCols<_> = local_core.borrow();
63
64 builder.assert_bool(cols.is_valid);
65
66 let intermed_val = cols
67 .out_val
68 .iter()
69 .enumerate()
70 .fold(AB::Expr::ZERO, |acc, (i, &limb)| {
71 acc + limb * AB::Expr::from_canonical_u32(1 << (i * LIMB_BITS))
72 });
73
74 for i in 0..4 {
75 self.bus
76 .range_check(
77 cols.out_val[i],
78 match i {
79 0..=2 => LIMB_BITS,
80 3 => FINAL_LIMB_BITS,
81 _ => unreachable!(),
82 },
83 )
84 .eval(builder, cols.is_valid);
85 }
86
87 AdapterAirContext {
88 to_pc: None,
89 reads: [[intermed_val]].into(),
90 writes: [cols.out_val.map(Into::into)].into(),
91 instruction: MinimalInstruction {
92 is_valid: cols.is_valid.into(),
93 opcode: AB::Expr::from_canonical_usize(
94 CastfOpcode::CASTF.global_opcode().as_usize(),
95 ),
96 }
97 .into(),
98 }
99 }
100
101 fn start_offset(&self) -> usize {
102 CastfOpcode::CLASS_OFFSET
103 }
104}
105
106#[repr(C)]
107#[derive(Debug, Serialize, Deserialize)]
108pub struct CastFRecord<F> {
109 pub in_val: F,
110 pub out_val: [u32; RV32_REGISTER_NUM_LIMBS],
111}
112
113pub struct CastFCoreChip {
114 pub air: CastFCoreAir,
115 pub range_checker_chip: SharedVariableRangeCheckerChip,
116}
117
118impl CastFCoreChip {
119 pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self {
120 Self {
121 air: CastFCoreAir {
122 bus: range_checker_chip.bus(),
123 },
124 range_checker_chip,
125 }
126 }
127}
128
129impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for CastFCoreChip
130where
131 I::Reads: Into<[[F; 1]; 1]>,
132 I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>,
133{
134 type Record = CastFRecord<F>;
135 type Air = CastFCoreAir;
136
137 #[allow(clippy::type_complexity)]
138 fn execute_instruction(
139 &self,
140 instruction: &Instruction<F>,
141 _from_pc: u32,
142 reads: I::Reads,
143 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
144 let Instruction { opcode, .. } = instruction;
145
146 assert_eq!(
147 opcode.local_opcode_idx(CastfOpcode::CLASS_OFFSET),
148 CastfOpcode::CASTF as usize
149 );
150
151 let y = reads.into()[0][0];
152 let x = CastF::solve(y.as_canonical_u32());
153
154 let output = AdapterRuntimeContext {
155 to_pc: None,
156 writes: [x.map(F::from_canonical_u32)].into(),
157 };
158
159 let record = CastFRecord {
160 in_val: y,
161 out_val: x,
162 };
163
164 Ok((output, record))
165 }
166
167 fn get_opcode_name(&self, _opcode: usize) -> String {
168 format!("{:?}", CastfOpcode::CASTF)
169 }
170
171 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
172 for (i, limb) in record.out_val.iter().enumerate() {
173 if i == 3 {
174 self.range_checker_chip.add_count(*limb, FINAL_LIMB_BITS);
175 } else {
176 self.range_checker_chip.add_count(*limb, LIMB_BITS);
177 }
178 }
179
180 let cols: &mut CastFCoreCols<F> = row_slice.borrow_mut();
181 cols.in_val = record.in_val;
182 cols.out_val = record.out_val.map(F::from_canonical_u32);
183 cols.is_valid = F::ONE;
184 }
185
186 fn air(&self) -> &Self::Air {
187 &self.air
188 }
189}
190
191pub struct CastF;
192impl CastF {
193 pub(super) fn solve(y: u32) -> [u32; RV32_REGISTER_NUM_LIMBS] {
194 let mut x = [0; 4];
195 for (i, limb) in x.iter_mut().enumerate() {
196 *limb = (y >> (8 * i)) & 0xFF;
197 }
198 x
199 }
200}