#![allow(clippy::len_without_is_empty)]
use crate::{
gates::GateInstructions,
utils::bit_length,
AssignedValue, Context,
QuantumCell::{Constant, Existing},
};
use super::{SafeByte, SafeType, ScalarField};
use getset::Getters;
use itertools::Itertools;
#[derive(Debug, Clone, Getters)]
pub struct VarLenBytes<F: ScalarField, const MAX_LEN: usize> {
#[getset(get = "pub")]
bytes: [SafeByte<F>; MAX_LEN],
#[getset(get = "pub")]
len: AssignedValue<F>,
}
impl<F: ScalarField, const MAX_LEN: usize> VarLenBytes<F, MAX_LEN> {
pub fn new(bytes: [SafeByte<F>; MAX_LEN], len: AssignedValue<F>) -> Self {
assert!(
len.value().le(&F::from(MAX_LEN as u64)),
"Invalid length which exceeds MAX_LEN {MAX_LEN}",
);
Self { bytes, len }
}
pub fn max_len(&self) -> usize {
MAX_LEN
}
pub fn left_pad_to_fixed(
&self,
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
) -> FixLenBytes<F, MAX_LEN> {
let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, MAX_LEN);
FixLenBytes::new(
padded.into_iter().map(|b| SafeByte(b)).collect::<Vec<_>>().try_into().unwrap(),
)
}
pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
Self::new(bytes.try_into().unwrap(), self.len)
}
}
#[derive(Debug, Clone, Getters)]
pub struct VarLenBytesVec<F: ScalarField> {
#[getset(get = "pub")]
bytes: Vec<SafeByte<F>>,
#[getset(get = "pub")]
len: AssignedValue<F>,
}
impl<F: ScalarField> VarLenBytesVec<F> {
pub fn new(bytes: Vec<SafeByte<F>>, len: AssignedValue<F>, max_len: usize) -> Self {
assert!(
len.value().le(&F::from(max_len as u64)),
"Invalid length which exceeds MAX_LEN {}",
max_len
);
assert_eq!(bytes.len(), max_len, "bytes is not padded correctly");
Self { bytes, len }
}
pub fn max_len(&self) -> usize {
self.bytes.len()
}
pub fn left_pad_to_fixed(
&self,
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
) -> FixLenBytesVec<F> {
let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len());
FixLenBytesVec::new(padded.into_iter().map(|b| SafeByte(b)).collect_vec(), self.max_len())
}
pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
Self::new(bytes, self.len, self.max_len())
}
}
#[derive(Debug, Clone, Getters)]
pub struct FixLenBytes<F: ScalarField, const LEN: usize> {
#[getset(get = "pub")]
bytes: [SafeByte<F>; LEN],
}
impl<F: ScalarField, const LEN: usize> FixLenBytes<F, LEN> {
pub fn new(bytes: [SafeByte<F>; LEN]) -> Self {
Self { bytes }
}
pub fn len(&self) -> usize {
LEN
}
pub fn into_bytes(self) -> [SafeByte<F>; LEN] {
self.bytes
}
}
#[derive(Debug, Clone, Getters)]
pub struct FixLenBytesVec<F: ScalarField> {
#[getset(get = "pub")]
bytes: Vec<SafeByte<F>>,
}
impl<F: ScalarField> FixLenBytesVec<F> {
pub fn new(bytes: Vec<SafeByte<F>>, len: usize) -> Self {
assert_eq!(bytes.len(), len, "bytes length doesn't match");
Self { bytes }
}
pub fn len(&self) -> usize {
self.bytes.len()
}
pub fn into_bytes(self) -> Vec<SafeByte<F>> {
self.bytes
}
}
impl<F: ScalarField, const TOTAL_BITS: usize> From<SafeType<F, 1, TOTAL_BITS>>
for FixLenBytes<F, { SafeType::<F, 1, TOTAL_BITS>::VALUE_LENGTH }>
{
fn from(bytes: SafeType<F, 1, TOTAL_BITS>) -> Self {
let bytes = bytes.value.into_iter().map(|b| SafeByte(b)).collect::<Vec<_>>();
Self::new(bytes.try_into().unwrap())
}
}
impl<F: ScalarField, const TOTAL_BITS: usize>
From<FixLenBytes<F, { SafeType::<F, 1, TOTAL_BITS>::VALUE_LENGTH }>>
for SafeType<F, 1, TOTAL_BITS>
{
fn from(bytes: FixLenBytes<F, { SafeType::<F, 1, TOTAL_BITS>::VALUE_LENGTH }>) -> Self {
let bytes = bytes.bytes.into_iter().map(|b| b.0).collect::<Vec<_>>();
Self::new(bytes)
}
}
pub fn left_pad_var_array_to_fixed<F: ScalarField>(
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
arr: &[impl AsRef<AssignedValue<F>>],
len: AssignedValue<F>,
out_len: usize,
) -> Vec<AssignedValue<F>> {
debug_assert!(arr.len() <= out_len);
debug_assert!(bit_length(out_len as u64) < F::CAPACITY as usize);
let mut padded = arr.iter().map(|b| *b.as_ref()).collect_vec();
padded.resize(out_len, padded[0]);
let shift = gate.sub(ctx, Constant(F::from(out_len as u64)), len);
let shift_bits = gate.num_to_bits(ctx, shift, bit_length(out_len as u64));
for (i, shift_bit) in shift_bits.into_iter().enumerate() {
let shifted = (0..out_len)
.map(|j| if j >= (1 << i) { Existing(padded[j - (1 << i)]) } else { Constant(F::ZERO) })
.collect_vec();
padded = padded
.into_iter()
.zip(shifted)
.map(|(noshift, shift)| gate.select(ctx, shift, noshift, shift_bit))
.collect_vec();
}
padded
}
fn ensure_0_padding<F: ScalarField>(
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
bytes: &[SafeByte<F>],
len: AssignedValue<F>,
) -> Vec<SafeByte<F>> {
let max_len = bytes.len();
let idx = gate.dec(ctx, len);
let len_indicator = gate.idx_to_indicator(ctx, idx, max_len);
let mut mask = gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec();
mask.reverse();
bytes
.iter()
.zip(mask.iter())
.map(|(byte, mask)| SafeByte(gate.mul(ctx, byte.0, *mask)))
.collect_vec()
}