halo2_base/safe_types/
mod.rs

1use std::{
2    borrow::Borrow,
3    cmp::{max, min},
4};
5
6use crate::{
7    gates::{
8        flex_gate::GateInstructions,
9        range::{RangeChip, RangeInstructions},
10    },
11    utils::ScalarField,
12    AssignedValue, Context,
13    QuantumCell::Witness,
14};
15
16use itertools::Itertools;
17
18mod bytes;
19mod primitives;
20
21pub use bytes::*;
22pub use primitives::*;
23
24/// Unit Tests
25#[cfg(test)]
26pub mod tests;
27
28type RawAssignedValues<F> = Vec<AssignedValue<F>>;
29
30const BITS_PER_BYTE: usize = 8;
31
32/// [`SafeType`]'s goal is to avoid out-of-range undefined behavior.
33/// When building circuits, it's common to use multiple [`AssignedValue<F>`]s to represent
34/// a logical variable. For example, we might want to represent a hash with 32 [`AssignedValue<F>`]
35/// where each [`AssignedValue`] represents 1 byte. However, the range of [`AssignedValue<F>`] is much
36/// larger than 1 byte(0~255). If a circuit takes 32 [`AssignedValue<F>`] as inputs and some of them
37/// are actually greater than 255, there could be some undefined behaviors.
38/// [`SafeType`] gurantees the value range of its owned [`AssignedValue<F>`]. So circuits don't need to
39/// do any extra value checking if they take SafeType as inputs.
40/// - `TOTAL_BITS` is the number of total bits of this type.
41/// - `BYTES_PER_ELE` is the number of bytes of each element.
42#[derive(Clone, Debug)]
43pub struct SafeType<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize> {
44    // value is stored in little-endian.
45    value: RawAssignedValues<F>,
46}
47
48impl<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>
49    SafeType<F, BYTES_PER_ELE, TOTAL_BITS>
50{
51    /// Number of bytes of each element.
52    pub const BYTES_PER_ELE: usize = BYTES_PER_ELE;
53    /// Total bits of this type.
54    pub const TOTAL_BITS: usize = TOTAL_BITS;
55    /// Number of elements of this type.
56    pub const VALUE_LENGTH: usize = TOTAL_BITS.div_ceil(BYTES_PER_ELE * BITS_PER_BYTE);
57
58    /// Number of bits of each element.
59    pub fn bits_per_ele() -> usize {
60        min(TOTAL_BITS, BYTES_PER_ELE * BITS_PER_BYTE)
61    }
62
63    // new is private so Safetype can only be constructed by this crate.
64    fn new(raw_values: RawAssignedValues<F>) -> Self {
65        assert!(raw_values.len() == Self::VALUE_LENGTH, "Invalid raw values length");
66        Self { value: raw_values }
67    }
68
69    /// Return values in little-endian.
70    pub fn value(&self) -> &[AssignedValue<F>] {
71        &self.value
72    }
73}
74
75impl<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize> AsRef<[AssignedValue<F>]>
76    for SafeType<F, BYTES_PER_ELE, TOTAL_BITS>
77{
78    fn as_ref(&self) -> &[AssignedValue<F>] {
79        self.value()
80    }
81}
82
83impl<F: ScalarField, const TOTAL_BITS: usize> TryFrom<Vec<SafeByte<F>>>
84    for SafeType<F, 1, TOTAL_BITS>
85{
86    type Error = String;
87
88    fn try_from(value: Vec<SafeByte<F>>) -> Result<Self, Self::Error> {
89        if value.len() * 8 != TOTAL_BITS {
90            return Err("Invalid length".to_owned());
91        }
92        Ok(Self::new(value.into_iter().map(|b| b.0).collect::<Vec<_>>()))
93    }
94}
95
96/// SafeType for Address.
97pub type SafeAddress<F> = SafeType<F, 1, 160>;
98/// SafeType for bytes32.
99pub type SafeBytes32<F> = SafeType<F, 1, 256>;
100
101/// Chip for SafeType
102pub struct SafeTypeChip<'a, F: ScalarField> {
103    range_chip: &'a RangeChip<F>,
104}
105
106impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
107    /// Construct a SafeTypeChip.
108    pub fn new(range_chip: &'a RangeChip<F>) -> Self {
109        Self { range_chip }
110    }
111
112    /// Convert a vector of AssignedValue (treated as little-endian) to a SafeType.
113    /// The number of bytes of inputs must equal to the number of bytes of outputs.
114    /// This function also add contraints that a AssignedValue in inputs must be in the range of a byte.
115    pub fn raw_bytes_to<const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>(
116        &self,
117        ctx: &mut Context<F>,
118        inputs: RawAssignedValues<F>,
119    ) -> SafeType<F, BYTES_PER_ELE, TOTAL_BITS> {
120        let element_bits = SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::bits_per_ele();
121        let bits = TOTAL_BITS;
122        assert!(
123            inputs.len() * BITS_PER_BYTE == max(bits, BITS_PER_BYTE),
124            "number of bits doesn't match"
125        );
126        self.add_bytes_constraints(ctx, &inputs, bits);
127        // inputs is a bool or uint8.
128        if bits == 1 || element_bits == BITS_PER_BYTE {
129            return SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::new(inputs);
130        };
131
132        let byte_base = (0..BYTES_PER_ELE)
133            .map(|i| Witness(self.range_chip.gate.pow_of_two[i * BITS_PER_BYTE]))
134            .collect::<Vec<_>>();
135        let value = inputs
136            .chunks(BYTES_PER_ELE)
137            .map(|chunk| {
138                self.range_chip.gate.inner_product(
139                    ctx,
140                    chunk.to_vec(),
141                    byte_base[..chunk.len()].to_vec(),
142                )
143            })
144            .collect::<Vec<_>>();
145        SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::new(value)
146    }
147
148    /// Unsafe method that directly converts `input` to [`SafeType`] **without any checks**.
149    /// This should **only** be used if an external library needs to convert their types to [`SafeType`].
150    pub fn unsafe_to_safe_type<const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>(
151        inputs: RawAssignedValues<F>,
152    ) -> SafeType<F, BYTES_PER_ELE, TOTAL_BITS> {
153        assert_eq!(inputs.len(), SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::VALUE_LENGTH);
154        SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::new(inputs)
155    }
156
157    /// Constrains that the `input` is a boolean value (either 0 or 1) and wraps it in [`SafeBool`].
158    pub fn assert_bool(&self, ctx: &mut Context<F>, input: AssignedValue<F>) -> SafeBool<F> {
159        self.range_chip.gate().assert_bit(ctx, input);
160        SafeBool(input)
161    }
162
163    /// Load a boolean value as witness and constrain it is either 0 or 1.
164    pub fn load_bool(&self, ctx: &mut Context<F>, input: bool) -> SafeBool<F> {
165        let input = ctx.load_witness(F::from(input));
166        self.assert_bool(ctx, input)
167    }
168
169    /// Unsafe method that directly converts `input` to [`SafeBool`] **without any checks**.
170    /// This should **only** be used if an external library needs to convert their types to [`SafeBool`].
171    pub fn unsafe_to_bool(input: AssignedValue<F>) -> SafeBool<F> {
172        SafeBool(input)
173    }
174
175    /// Constrains that the `input` is a byte value and wraps it in [`SafeByte`].
176    pub fn assert_byte(&self, ctx: &mut Context<F>, input: AssignedValue<F>) -> SafeByte<F> {
177        self.range_chip.range_check(ctx, input, BITS_PER_BYTE);
178        SafeByte(input)
179    }
180
181    /// Load a boolean value as witness and constrain it is either 0 or 1.
182    pub fn load_byte(&self, ctx: &mut Context<F>, input: u8) -> SafeByte<F> {
183        let input = ctx.load_witness(F::from(input as u64));
184        self.assert_byte(ctx, input)
185    }
186
187    /// Unsafe method that directly converts `input` to [`SafeByte`] **without any checks**.
188    /// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
189    pub fn unsafe_to_byte(input: AssignedValue<F>) -> SafeByte<F> {
190        SafeByte(input)
191    }
192
193    /// Unsafe method that directly converts `inputs` to [`VarLenBytes`] **without any checks**.
194    /// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
195    pub fn unsafe_to_var_len_bytes<const MAX_LEN: usize>(
196        inputs: [AssignedValue<F>; MAX_LEN],
197        len: AssignedValue<F>,
198    ) -> VarLenBytes<F, MAX_LEN> {
199        VarLenBytes::<F, MAX_LEN>::new(inputs.map(|input| Self::unsafe_to_byte(input)), len)
200    }
201
202    /// Unsafe method that directly converts `inputs` to [`VarLenBytesVec`] **without any checks**.
203    /// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
204    pub fn unsafe_to_var_len_bytes_vec(
205        inputs: RawAssignedValues<F>,
206        len: AssignedValue<F>,
207        max_len: usize,
208    ) -> VarLenBytesVec<F> {
209        VarLenBytesVec::<F>::new(
210            inputs.iter().map(|input| Self::unsafe_to_byte(*input)).collect_vec(),
211            len,
212            max_len,
213        )
214    }
215
216    /// Unsafe method that directly converts `inputs` to [`FixLenBytes`] **without any checks**.
217    /// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
218    pub fn unsafe_to_fix_len_bytes<const MAX_LEN: usize>(
219        inputs: [AssignedValue<F>; MAX_LEN],
220    ) -> FixLenBytes<F, MAX_LEN> {
221        FixLenBytes::<F, MAX_LEN>::new(inputs.map(|input| Self::unsafe_to_byte(input)))
222    }
223
224    /// Unsafe method that directly converts `inputs` to [`FixLenBytesVec`] **without any checks**.
225    /// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
226    pub fn unsafe_to_fix_len_bytes_vec(
227        inputs: RawAssignedValues<F>,
228        len: usize,
229    ) -> FixLenBytesVec<F> {
230        FixLenBytesVec::<F>::new(
231            inputs.into_iter().map(|input| Self::unsafe_to_byte(input)).collect_vec(),
232            len,
233        )
234    }
235
236    /// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes.
237    ///
238    /// * inputs: Slice representing the byte array.
239    /// * len: [`AssignedValue<F>`] witness representing the variable length of the byte array. Constrained to be `<= MAX_LEN`.
240    /// * MAX_LEN: [usize] representing the maximum length of the byte array and the number of elements it must contain.
241    ///
242    /// ## Assumptions
243    /// * `MAX_LEN < u64::MAX` to prevent overflow (but you should never make an array this large)
244    /// * `ceil((MAX_LEN + 1).bits() / lookup_bits) * lookup_bits <= F::CAPACITY` where `lookup_bits = self.range_chip.lookup_bits`
245    pub fn raw_to_var_len_bytes<const MAX_LEN: usize>(
246        &self,
247        ctx: &mut Context<F>,
248        inputs: [AssignedValue<F>; MAX_LEN],
249        len: AssignedValue<F>,
250    ) -> VarLenBytes<F, MAX_LEN> {
251        self.range_chip.check_less_than_safe(ctx, len, MAX_LEN as u64 + 1);
252        VarLenBytes::<F, MAX_LEN>::new(inputs.map(|input| self.assert_byte(ctx, input)), len)
253    }
254
255    /// Converts a vector of AssignedValue to [VarLenBytesVec]. Not encouraged to use because `MAX_LEN` cannot be verified at compile time.
256    ///
257    /// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding.
258    /// * len: [`AssignedValue<F>`] witness representing the variable length of the byte array. Constrained to be `<= max_len`.
259    /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. We enforce this to be provided explictly to make sure length of `inputs` is determinstic.
260    ///
261    /// ## Assumptions
262    /// * `max_len < u64::MAX` to prevent overflow (but you should never make an array this large)
263    /// * `ceil((max_len + 1).bits() / lookup_bits) * lookup_bits <= F::CAPACITY` where `lookup_bits = self.range_chip.lookup_bits`
264    pub fn raw_to_var_len_bytes_vec(
265        &self,
266        ctx: &mut Context<F>,
267        inputs: RawAssignedValues<F>,
268        len: AssignedValue<F>,
269        max_len: usize,
270    ) -> VarLenBytesVec<F> {
271        self.range_chip.check_less_than_safe(ctx, len, max_len as u64 + 1);
272        VarLenBytesVec::<F>::new(
273            inputs.iter().map(|input| self.assert_byte(ctx, *input)).collect_vec(),
274            len,
275            max_len,
276        )
277    }
278
279    /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytes.
280    ///
281    /// * inputs: Slice representing the byte array.
282    /// * LEN: length of the byte array.
283    pub fn raw_to_fix_len_bytes<const LEN: usize>(
284        &self,
285        ctx: &mut Context<F>,
286        inputs: [AssignedValue<F>; LEN],
287    ) -> FixLenBytes<F, LEN> {
288        FixLenBytes::<F, LEN>::new(inputs.map(|input| self.assert_byte(ctx, input)))
289    }
290
291    /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytesVec.
292    ///
293    /// * inputs: Slice representing the byte array.
294    /// * len: length of the byte array. We enforce this to be provided explictly to make sure length of `inputs` is determinstic.
295    pub fn raw_to_fix_len_bytes_vec(
296        &self,
297        ctx: &mut Context<F>,
298        inputs: RawAssignedValues<F>,
299        len: usize,
300    ) -> FixLenBytesVec<F> {
301        FixLenBytesVec::<F>::new(
302            inputs.into_iter().map(|input| self.assert_byte(ctx, input)).collect_vec(),
303            len,
304        )
305    }
306
307    /// Assumes that `bits <= inputs.len() * 8`.
308    fn add_bytes_constraints(
309        &self,
310        ctx: &mut Context<F>,
311        inputs: &RawAssignedValues<F>,
312        bits: usize,
313    ) {
314        let mut bits_left = bits;
315        for input in inputs {
316            let num_bit = min(bits_left, BITS_PER_BYTE);
317            self.range_chip.range_check(ctx, *input, num_bit);
318            bits_left -= num_bit;
319        }
320    }
321
322    // TODO: Add comparison. e.g. is_less_than(SafeUint8, SafeUint8) -> SafeBool
323    // TODO: Add type castings. e.g. uint256 -> bytes32/uint32 -> uint64
324}