1#![allow(clippy::len_without_is_empty)]
2use crate::{
3 gates::GateInstructions,
4 utils::bit_length,
5 AssignedValue, Context,
6 QuantumCell::{Constant, Existing},
7};
8
9use super::{SafeByte, ScalarField};
10
11use getset::Getters;
12use itertools::Itertools;
13
14#[derive(Debug, Clone, Getters)]
20pub struct VarLenBytes<F: ScalarField, const MAX_LEN: usize> {
21 #[getset(get = "pub")]
23 bytes: [SafeByte<F>; MAX_LEN],
24 #[getset(get = "pub")]
26 len: AssignedValue<F>,
27}
28
29impl<F: ScalarField, const MAX_LEN: usize> VarLenBytes<F, MAX_LEN> {
30 pub fn new(bytes: [SafeByte<F>; MAX_LEN], len: AssignedValue<F>) -> Self {
32 assert!(
33 len.value().le(&F::from(MAX_LEN as u64)),
34 "Invalid length which exceeds MAX_LEN {MAX_LEN}",
35 );
36 Self { bytes, len }
37 }
38
39 pub fn max_len(&self) -> usize {
41 MAX_LEN
42 }
43
44 pub fn left_pad_to_fixed(
54 &self,
55 ctx: &mut Context<F>,
56 gate: &impl GateInstructions<F>,
57 ) -> FixLenBytes<F, MAX_LEN> {
58 let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, MAX_LEN);
59 FixLenBytes::new(
60 padded.into_iter().map(|b| SafeByte(b)).collect::<Vec<_>>().try_into().unwrap(),
61 )
62 }
63
64 pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
66 let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
67 Self::new(bytes.try_into().unwrap(), self.len)
68 }
69}
70
71#[derive(Debug, Clone, Getters)]
77pub struct VarLenBytesVec<F: ScalarField> {
78 #[getset(get = "pub")]
80 bytes: Vec<SafeByte<F>>,
81 #[getset(get = "pub")]
83 len: AssignedValue<F>,
84}
85
86impl<F: ScalarField> VarLenBytesVec<F> {
87 pub fn new(bytes: Vec<SafeByte<F>>, len: AssignedValue<F>, max_len: usize) -> Self {
89 assert!(
90 len.value().le(&F::from(max_len as u64)),
91 "Invalid length which exceeds MAX_LEN {}",
92 max_len
93 );
94 assert_eq!(bytes.len(), max_len, "bytes is not padded correctly");
95 Self { bytes, len }
96 }
97
98 pub fn max_len(&self) -> usize {
100 self.bytes.len()
101 }
102
103 pub fn left_pad_to_fixed(
105 &self,
106 ctx: &mut Context<F>,
107 gate: &impl GateInstructions<F>,
108 ) -> FixLenBytesVec<F> {
109 let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len());
110 FixLenBytesVec::new(padded.into_iter().map(|b| SafeByte(b)).collect_vec(), self.max_len())
111 }
112
113 pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
115 let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
116 Self::new(bytes, self.len, self.max_len())
117 }
118}
119
120#[derive(Debug, Clone, Getters)]
122pub struct FixLenBytes<F: ScalarField, const LEN: usize> {
123 #[getset(get = "pub")]
125 bytes: [SafeByte<F>; LEN],
126}
127
128impl<F: ScalarField, const LEN: usize> FixLenBytes<F, LEN> {
129 pub fn new(bytes: [SafeByte<F>; LEN]) -> Self {
131 Self { bytes }
132 }
133
134 pub fn len(&self) -> usize {
136 LEN
137 }
138
139 pub fn into_bytes(self) -> [SafeByte<F>; LEN] {
141 self.bytes
142 }
143}
144
145#[derive(Debug, Clone, Getters)]
147pub struct FixLenBytesVec<F: ScalarField> {
148 #[getset(get = "pub")]
150 bytes: Vec<SafeByte<F>>,
151}
152
153impl<F: ScalarField> FixLenBytesVec<F> {
154 pub fn new(bytes: Vec<SafeByte<F>>, len: usize) -> Self {
156 assert_eq!(bytes.len(), len, "bytes length doesn't match");
157 Self { bytes }
158 }
159
160 pub fn len(&self) -> usize {
162 self.bytes.len()
163 }
164
165 pub fn into_bytes(self) -> Vec<SafeByte<F>> {
167 self.bytes
168 }
169}
170
171pub fn left_pad_var_array_to_fixed<F: ScalarField>(
181 ctx: &mut Context<F>,
182 gate: &impl GateInstructions<F>,
183 arr: &[impl AsRef<AssignedValue<F>>],
184 len: AssignedValue<F>,
185 out_len: usize,
186) -> Vec<AssignedValue<F>> {
187 debug_assert!(arr.len() <= out_len);
188 debug_assert!(bit_length(out_len as u64) < F::CAPACITY as usize);
189
190 let mut padded = arr.iter().map(|b| *b.as_ref()).collect_vec();
191 padded.resize(out_len, padded[0]);
192 let shift = gate.sub(ctx, Constant(F::from(out_len as u64)), len);
194 let shift_bits = gate.num_to_bits(ctx, shift, bit_length(out_len as u64));
195 for (i, shift_bit) in shift_bits.into_iter().enumerate() {
196 let shifted = (0..out_len)
197 .map(|j| if j >= (1 << i) { Existing(padded[j - (1 << i)]) } else { Constant(F::ZERO) })
198 .collect_vec();
199 padded = padded
200 .into_iter()
201 .zip(shifted)
202 .map(|(noshift, shift)| gate.select(ctx, shift, noshift, shift_bit))
203 .collect_vec();
204 }
205 padded
206}
207
208fn ensure_0_padding<F: ScalarField>(
209 ctx: &mut Context<F>,
210 gate: &impl GateInstructions<F>,
211 bytes: &[SafeByte<F>],
212 len: AssignedValue<F>,
213) -> Vec<SafeByte<F>> {
214 let max_len = bytes.len();
215 let idx = gate.dec(ctx, len);
217 let len_indicator = gate.idx_to_indicator(ctx, idx, max_len);
218 let mut mask = gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec();
220 mask.reverse();
221
222 bytes
223 .iter()
224 .zip(mask.iter())
225 .map(|(byte, mask)| SafeByte(gate.mul(ctx, byte.0, *mask)))
226 .collect_vec()
227}