halo2_base/safe_types/
bytes.rs

1#![allow(clippy::len_without_is_empty)]
2use crate::{
3    gates::GateInstructions,
4    utils::bit_length,
5    AssignedValue, Context,
6    QuantumCell::{Constant, Existing},
7};
8
9use super::{SafeByte, ScalarField};
10
11use getset::Getters;
12use itertools::Itertools;
13
14/// Represents a variable length byte array in circuit.
15///
16/// Each element is guaranteed to be a byte, given by type [`SafeByte`].
17/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is some additional context the user must provide.
18/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s).
19#[derive(Debug, Clone, Getters)]
20pub struct VarLenBytes<F: ScalarField, const MAX_LEN: usize> {
21    /// The byte array, right padded
22    #[getset(get = "pub")]
23    bytes: [SafeByte<F>; MAX_LEN],
24    /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN`
25    #[getset(get = "pub")]
26    len: AssignedValue<F>,
27}
28
29impl<F: ScalarField, const MAX_LEN: usize> VarLenBytes<F, MAX_LEN> {
30    /// Slightly unsafe constructor: it is not constrained that `len <= MAX_LEN`.
31    pub fn new(bytes: [SafeByte<F>; MAX_LEN], len: AssignedValue<F>) -> Self {
32        assert!(
33            len.value().le(&F::from(MAX_LEN as u64)),
34            "Invalid length which exceeds MAX_LEN {MAX_LEN}",
35        );
36        Self { bytes, len }
37    }
38
39    /// Returns the maximum length of the byte array.
40    pub fn max_len(&self) -> usize {
41        MAX_LEN
42    }
43
44    /// Left pads the variable length byte array with 0s to the `MAX_LEN`.
45    /// Takes a fixed length array `self.bytes` and returns a length `MAX_LEN` array equal to
46    /// `[[0; MAX_LEN - len], self.bytes[..len]].concat()`, i.e., we take `self.bytes[..len]` and
47    /// zero pad it on the left, where `len = self.len`
48    ///
49    /// Assumes `0 < self.len <= MAX_LEN`.
50    ///
51    /// ## Panics
52    /// If `self.len` is not in the range `(0, MAX_LEN]`.
53    pub fn left_pad_to_fixed(
54        &self,
55        ctx: &mut Context<F>,
56        gate: &impl GateInstructions<F>,
57    ) -> FixLenBytes<F, MAX_LEN> {
58        let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, MAX_LEN);
59        FixLenBytes::new(
60            padded.into_iter().map(|b| SafeByte(b)).collect::<Vec<_>>().try_into().unwrap(),
61        )
62    }
63
64    /// Return a copy of the byte array with 0 padding ensured.
65    pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
66        let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
67        Self::new(bytes.try_into().unwrap(), self.len)
68    }
69}
70
71/// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time.
72///
73/// Each element is guaranteed to be a byte, given by type [`SafeByte`].
74/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is provided when constructing and `bytes.len()` == `MAX_LEN` is enforced.
75/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s).
76#[derive(Debug, Clone, Getters)]
77pub struct VarLenBytesVec<F: ScalarField> {
78    /// The byte array, right padded
79    #[getset(get = "pub")]
80    bytes: Vec<SafeByte<F>>,
81    /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN`
82    #[getset(get = "pub")]
83    len: AssignedValue<F>,
84}
85
86impl<F: ScalarField> VarLenBytesVec<F> {
87    /// Slightly unsafe constructor: it is not constrained that `len <= max_len`.
88    pub fn new(bytes: Vec<SafeByte<F>>, len: AssignedValue<F>, max_len: usize) -> Self {
89        assert!(
90            len.value().le(&F::from(max_len as u64)),
91            "Invalid length which exceeds MAX_LEN {}",
92            max_len
93        );
94        assert_eq!(bytes.len(), max_len, "bytes is not padded correctly");
95        Self { bytes, len }
96    }
97
98    /// Returns the maximum length of the byte array.
99    pub fn max_len(&self) -> usize {
100        self.bytes.len()
101    }
102
103    /// Left pads the variable length byte array with 0s to the MAX_LEN
104    pub fn left_pad_to_fixed(
105        &self,
106        ctx: &mut Context<F>,
107        gate: &impl GateInstructions<F>,
108    ) -> FixLenBytesVec<F> {
109        let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len());
110        FixLenBytesVec::new(padded.into_iter().map(|b| SafeByte(b)).collect_vec(), self.max_len())
111    }
112
113    /// Return a copy of the byte array with 0 padding ensured.
114    pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
115        let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
116        Self::new(bytes, self.len, self.max_len())
117    }
118}
119
120/// Represents a fixed length byte array in circuit.
121#[derive(Debug, Clone, Getters)]
122pub struct FixLenBytes<F: ScalarField, const LEN: usize> {
123    /// The byte array
124    #[getset(get = "pub")]
125    bytes: [SafeByte<F>; LEN],
126}
127
128impl<F: ScalarField, const LEN: usize> FixLenBytes<F, LEN> {
129    /// Constructor
130    pub fn new(bytes: [SafeByte<F>; LEN]) -> Self {
131        Self { bytes }
132    }
133
134    /// Returns the length of the byte array.
135    pub fn len(&self) -> usize {
136        LEN
137    }
138
139    /// Returns inner array of [SafeByte]s.
140    pub fn into_bytes(self) -> [SafeByte<F>; LEN] {
141        self.bytes
142    }
143}
144
145/// Represents a fixed length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time.
146#[derive(Debug, Clone, Getters)]
147pub struct FixLenBytesVec<F: ScalarField> {
148    /// The byte array
149    #[getset(get = "pub")]
150    bytes: Vec<SafeByte<F>>,
151}
152
153impl<F: ScalarField> FixLenBytesVec<F> {
154    /// Constructor
155    pub fn new(bytes: Vec<SafeByte<F>>, len: usize) -> Self {
156        assert_eq!(bytes.len(), len, "bytes length doesn't match");
157        Self { bytes }
158    }
159
160    /// Returns the length of the byte array.
161    pub fn len(&self) -> usize {
162        self.bytes.len()
163    }
164
165    /// Returns inner array of [SafeByte]s.
166    pub fn into_bytes(self) -> Vec<SafeByte<F>> {
167        self.bytes
168    }
169}
170
171// Represents a fixed length byte array in circuit as a vector, where length must be fixed.
172// Not encouraged to use because `LEN` cannot be verified at compile time.
173// pub type FixLenBytesVec<F> = Vec<SafeByte<F>>;
174
175/// Takes a fixed length array `arr` and returns a length `out_len` array equal to
176/// `[[0; out_len - len], arr[..len]].concat()`, i.e., we take `arr[..len]` and
177/// zero pad it on the left.
178///
179/// Assumes `0 < len <= max_len <= out_len`.
180pub fn left_pad_var_array_to_fixed<F: ScalarField>(
181    ctx: &mut Context<F>,
182    gate: &impl GateInstructions<F>,
183    arr: &[impl AsRef<AssignedValue<F>>],
184    len: AssignedValue<F>,
185    out_len: usize,
186) -> Vec<AssignedValue<F>> {
187    debug_assert!(arr.len() <= out_len);
188    debug_assert!(bit_length(out_len as u64) < F::CAPACITY as usize);
189
190    let mut padded = arr.iter().map(|b| *b.as_ref()).collect_vec();
191    padded.resize(out_len, padded[0]);
192    // We use a barrel shifter to shift `arr` to the right by `out_len - len` bits.
193    let shift = gate.sub(ctx, Constant(F::from(out_len as u64)), len);
194    let shift_bits = gate.num_to_bits(ctx, shift, bit_length(out_len as u64));
195    for (i, shift_bit) in shift_bits.into_iter().enumerate() {
196        let shifted = (0..out_len)
197            .map(|j| if j >= (1 << i) { Existing(padded[j - (1 << i)]) } else { Constant(F::ZERO) })
198            .collect_vec();
199        padded = padded
200            .into_iter()
201            .zip(shifted)
202            .map(|(noshift, shift)| gate.select(ctx, shift, noshift, shift_bit))
203            .collect_vec();
204    }
205    padded
206}
207
208fn ensure_0_padding<F: ScalarField>(
209    ctx: &mut Context<F>,
210    gate: &impl GateInstructions<F>,
211    bytes: &[SafeByte<F>],
212    len: AssignedValue<F>,
213) -> Vec<SafeByte<F>> {
214    let max_len = bytes.len();
215    // Generate a mask array where a[i] = i < len for i = 0..max_len.
216    let idx = gate.dec(ctx, len);
217    let len_indicator = gate.idx_to_indicator(ctx, idx, max_len);
218    // inputs_mask[i] = sum(len_indicator[i..])
219    let mut mask = gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec();
220    mask.reverse();
221
222    bytes
223        .iter()
224        .zip(mask.iter())
225        .map(|(byte, mask)| SafeByte(gate.mul(ctx, byte.0, *mask)))
226        .collect_vec()
227}