use std::{
array,
borrow::{Borrow, BorrowMut},
sync::Arc,
};
use openvm_circuit::arch::{
AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
VmCoreAir, VmCoreChip,
};
use openvm_circuit_primitives::{
bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip},
utils::not,
};
use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_instructions::{instruction::Instruction, UsizeOpcode};
use openvm_rv32im_transpiler::BaseAluOpcode;
use openvm_stark_backend::{
interaction::InteractionBuilder,
p3_air::{AirBuilder, BaseAir},
p3_field::{AbstractField, Field, PrimeField32},
rap::BaseAirWithPublicValues,
};
use strum::IntoEnumIterator;
#[repr(C)]
#[derive(AlignedBorrow)]
pub struct BaseAluCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub a: [T; NUM_LIMBS],
pub b: [T; NUM_LIMBS],
pub c: [T; NUM_LIMBS],
pub opcode_add_flag: T,
pub opcode_sub_flag: T,
pub opcode_xor_flag: T,
pub opcode_or_flag: T,
pub opcode_and_flag: T,
}
#[derive(Copy, Clone, Debug)]
pub struct BaseAluCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub bus: BitwiseOperationLookupBus,
offset: usize,
}
impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
{
fn width(&self) -> usize {
BaseAluCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
}
}
impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
{
}
impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
where
AB: InteractionBuilder,
I: VmAdapterInterface<AB::Expr>,
I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
{
fn eval(
&self,
builder: &mut AB,
local_core: &[AB::Var],
_from_pc: AB::Var,
) -> AdapterAirContext<AB::Expr, I> {
let cols: &BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
let flags = [
cols.opcode_add_flag,
cols.opcode_sub_flag,
cols.opcode_xor_flag,
cols.opcode_or_flag,
cols.opcode_and_flag,
];
let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
builder.assert_bool(flag);
acc + flag.into()
});
builder.assert_bool(is_valid.clone());
let a = &cols.a;
let b = &cols.b;
let c = &cols.c;
let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse();
for i in 0..NUM_LIMBS {
carry_add[i] = AB::Expr::from(carry_divide)
* (b[i] + c[i] - a[i]
+ if i > 0 {
carry_add[i - 1].clone()
} else {
AB::Expr::ZERO
});
builder
.when(cols.opcode_add_flag)
.assert_bool(carry_add[i].clone());
carry_sub[i] = AB::Expr::from(carry_divide)
* (a[i] + c[i] - b[i]
+ if i > 0 {
carry_sub[i - 1].clone()
} else {
AB::Expr::ZERO
});
builder
.when(cols.opcode_sub_flag)
.assert_bool(carry_sub[i].clone());
}
let bitwise = cols.opcode_xor_flag + cols.opcode_or_flag + cols.opcode_and_flag;
for i in 0..NUM_LIMBS {
let x = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * b[i];
let y = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * c[i];
let x_xor_y = cols.opcode_xor_flag * a[i]
+ cols.opcode_or_flag * ((AB::Expr::from_canonical_u32(2) * a[i]) - b[i] - c[i])
+ cols.opcode_and_flag * (b[i] + c[i] - (AB::Expr::from_canonical_u32(2) * a[i]));
self.bus
.send_xor(x, y, x_xor_y)
.eval(builder, is_valid.clone());
}
let expected_opcode = flags.iter().zip(BaseAluOpcode::iter()).fold(
AB::Expr::ZERO,
|acc, (flag, local_opcode)| {
acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
},
) + AB::Expr::from_canonical_usize(self.offset);
AdapterAirContext {
to_pc: None,
reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
writes: [cols.a.map(Into::into)].into(),
instruction: MinimalInstruction {
is_valid,
opcode: expected_opcode,
}
.into(),
}
}
}
#[derive(Clone, Debug)]
pub struct BaseAluCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub opcode: BaseAluOpcode,
pub a: [T; NUM_LIMBS],
pub b: [T; NUM_LIMBS],
pub c: [T; NUM_LIMBS],
}
#[derive(Debug)]
pub struct BaseAluCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub air: BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>,
pub bitwise_lookup_chip: Arc<BitwiseOperationLookupChip<LIMB_BITS>>,
}
impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAluCoreChip<NUM_LIMBS, LIMB_BITS> {
pub fn new(
bitwise_lookup_chip: Arc<BitwiseOperationLookupChip<LIMB_BITS>>,
offset: usize,
) -> Self {
Self {
air: BaseAluCoreAir {
bus: bitwise_lookup_chip.bus(),
offset,
},
bitwise_lookup_chip,
}
}
}
impl<F, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreChip<F, I>
for BaseAluCoreChip<NUM_LIMBS, LIMB_BITS>
where
F: PrimeField32,
I: VmAdapterInterface<F>,
I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
I::Writes: From<[[F; NUM_LIMBS]; 1]>,
{
type Record = BaseAluCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
type Air = BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>;
#[allow(clippy::type_complexity)]
fn execute_instruction(
&self,
instruction: &Instruction<F>,
_from_pc: u32,
reads: I::Reads,
) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
let Instruction { opcode, .. } = instruction;
let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
let data: [[F; NUM_LIMBS]; 2] = reads.into();
let b = data[0].map(|x| x.as_canonical_u32());
let c = data[1].map(|y| y.as_canonical_u32());
let a = run_alu::<NUM_LIMBS, LIMB_BITS>(local_opcode, &b, &c);
let output = AdapterRuntimeContext {
to_pc: None,
writes: [a.map(F::from_canonical_u32)].into(),
};
if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB {
for a_val in a {
self.bitwise_lookup_chip.request_xor(a_val, a_val);
}
} else {
for (b_val, c_val) in b.iter().zip(c.iter()) {
self.bitwise_lookup_chip.request_xor(*b_val, *c_val);
}
}
let record = Self::Record {
opcode: local_opcode,
a: a.map(F::from_canonical_u32),
b: data[0],
c: data[1],
};
Ok((output, record))
}
fn get_opcode_name(&self, opcode: usize) -> String {
format!("{:?}", BaseAluOpcode::from_usize(opcode - self.air.offset))
}
fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
let row_slice: &mut BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
row_slice.a = record.a;
row_slice.b = record.b;
row_slice.c = record.c;
row_slice.opcode_add_flag = F::from_bool(record.opcode == BaseAluOpcode::ADD);
row_slice.opcode_sub_flag = F::from_bool(record.opcode == BaseAluOpcode::SUB);
row_slice.opcode_xor_flag = F::from_bool(record.opcode == BaseAluOpcode::XOR);
row_slice.opcode_or_flag = F::from_bool(record.opcode == BaseAluOpcode::OR);
row_slice.opcode_and_flag = F::from_bool(record.opcode == BaseAluOpcode::AND);
}
fn air(&self) -> &Self::Air {
&self.air
}
}
pub(super) fn run_alu<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
opcode: BaseAluOpcode,
x: &[u32; NUM_LIMBS],
y: &[u32; NUM_LIMBS],
) -> [u32; NUM_LIMBS] {
match opcode {
BaseAluOpcode::ADD => run_add::<NUM_LIMBS, LIMB_BITS>(x, y),
BaseAluOpcode::SUB => run_subtract::<NUM_LIMBS, LIMB_BITS>(x, y),
BaseAluOpcode::XOR => run_xor::<NUM_LIMBS, LIMB_BITS>(x, y),
BaseAluOpcode::OR => run_or::<NUM_LIMBS, LIMB_BITS>(x, y),
BaseAluOpcode::AND => run_and::<NUM_LIMBS, LIMB_BITS>(x, y),
}
}
fn run_add<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
x: &[u32; NUM_LIMBS],
y: &[u32; NUM_LIMBS],
) -> [u32; NUM_LIMBS] {
let mut z = [0u32; NUM_LIMBS];
let mut carry = [0u32; NUM_LIMBS];
for i in 0..NUM_LIMBS {
z[i] = x[i] + y[i] + if i > 0 { carry[i - 1] } else { 0 };
carry[i] = z[i] >> LIMB_BITS;
z[i] &= (1 << LIMB_BITS) - 1;
}
z
}
fn run_subtract<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
x: &[u32; NUM_LIMBS],
y: &[u32; NUM_LIMBS],
) -> [u32; NUM_LIMBS] {
let mut z = [0u32; NUM_LIMBS];
let mut carry = [0u32; NUM_LIMBS];
for i in 0..NUM_LIMBS {
let rhs = y[i] + if i > 0 { carry[i - 1] } else { 0 };
if x[i] >= rhs {
z[i] = x[i] - rhs;
carry[i] = 0;
} else {
z[i] = x[i] + (1 << LIMB_BITS) - rhs;
carry[i] = 1;
}
}
z
}
fn run_xor<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
x: &[u32; NUM_LIMBS],
y: &[u32; NUM_LIMBS],
) -> [u32; NUM_LIMBS] {
array::from_fn(|i| x[i] ^ y[i])
}
fn run_or<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
x: &[u32; NUM_LIMBS],
y: &[u32; NUM_LIMBS],
) -> [u32; NUM_LIMBS] {
array::from_fn(|i| x[i] | y[i])
}
fn run_and<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
x: &[u32; NUM_LIMBS],
y: &[u32; NUM_LIMBS],
) -> [u32; NUM_LIMBS] {
array::from_fn(|i| x[i] & y[i])
}