1use std::{
2 borrow::Borrow,
3 cmp::{max, min},
4};
5
6use crate::{
7 gates::{
8 flex_gate::GateInstructions,
9 range::{RangeChip, RangeInstructions},
10 },
11 utils::ScalarField,
12 AssignedValue, Context,
13 QuantumCell::Witness,
14};
15
16use itertools::Itertools;
17
18mod bytes;
19mod primitives;
20
21pub use bytes::*;
22pub use primitives::*;
23
24#[cfg(test)]
26pub mod tests;
27
28type RawAssignedValues<F> = Vec<AssignedValue<F>>;
29
30const BITS_PER_BYTE: usize = 8;
31
32#[derive(Clone, Debug)]
43pub struct SafeType<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize> {
44 value: RawAssignedValues<F>,
46}
47
48impl<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>
49 SafeType<F, BYTES_PER_ELE, TOTAL_BITS>
50{
51 pub const BYTES_PER_ELE: usize = BYTES_PER_ELE;
53 pub const TOTAL_BITS: usize = TOTAL_BITS;
55 pub const VALUE_LENGTH: usize = TOTAL_BITS.div_ceil(BYTES_PER_ELE * BITS_PER_BYTE);
57
58 pub fn bits_per_ele() -> usize {
60 min(TOTAL_BITS, BYTES_PER_ELE * BITS_PER_BYTE)
61 }
62
63 fn new(raw_values: RawAssignedValues<F>) -> Self {
65 assert!(raw_values.len() == Self::VALUE_LENGTH, "Invalid raw values length");
66 Self { value: raw_values }
67 }
68
69 pub fn value(&self) -> &[AssignedValue<F>] {
71 &self.value
72 }
73}
74
75impl<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize> AsRef<[AssignedValue<F>]>
76 for SafeType<F, BYTES_PER_ELE, TOTAL_BITS>
77{
78 fn as_ref(&self) -> &[AssignedValue<F>] {
79 self.value()
80 }
81}
82
83impl<F: ScalarField, const TOTAL_BITS: usize> TryFrom<Vec<SafeByte<F>>>
84 for SafeType<F, 1, TOTAL_BITS>
85{
86 type Error = String;
87
88 fn try_from(value: Vec<SafeByte<F>>) -> Result<Self, Self::Error> {
89 if value.len() * 8 != TOTAL_BITS {
90 return Err("Invalid length".to_owned());
91 }
92 Ok(Self::new(value.into_iter().map(|b| b.0).collect::<Vec<_>>()))
93 }
94}
95
96pub type SafeAddress<F> = SafeType<F, 1, 160>;
98pub type SafeBytes32<F> = SafeType<F, 1, 256>;
100
101pub struct SafeTypeChip<'a, F: ScalarField> {
103 range_chip: &'a RangeChip<F>,
104}
105
106impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
107 pub fn new(range_chip: &'a RangeChip<F>) -> Self {
109 Self { range_chip }
110 }
111
112 pub fn raw_bytes_to<const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>(
116 &self,
117 ctx: &mut Context<F>,
118 inputs: RawAssignedValues<F>,
119 ) -> SafeType<F, BYTES_PER_ELE, TOTAL_BITS> {
120 let element_bits = SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::bits_per_ele();
121 let bits = TOTAL_BITS;
122 assert!(
123 inputs.len() * BITS_PER_BYTE == max(bits, BITS_PER_BYTE),
124 "number of bits doesn't match"
125 );
126 self.add_bytes_constraints(ctx, &inputs, bits);
127 if bits == 1 || element_bits == BITS_PER_BYTE {
129 return SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::new(inputs);
130 };
131
132 let byte_base = (0..BYTES_PER_ELE)
133 .map(|i| Witness(self.range_chip.gate.pow_of_two[i * BITS_PER_BYTE]))
134 .collect::<Vec<_>>();
135 let value = inputs
136 .chunks(BYTES_PER_ELE)
137 .map(|chunk| {
138 self.range_chip.gate.inner_product(
139 ctx,
140 chunk.to_vec(),
141 byte_base[..chunk.len()].to_vec(),
142 )
143 })
144 .collect::<Vec<_>>();
145 SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::new(value)
146 }
147
148 pub fn unsafe_to_safe_type<const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>(
151 inputs: RawAssignedValues<F>,
152 ) -> SafeType<F, BYTES_PER_ELE, TOTAL_BITS> {
153 assert_eq!(inputs.len(), SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::VALUE_LENGTH);
154 SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::new(inputs)
155 }
156
157 pub fn assert_bool(&self, ctx: &mut Context<F>, input: AssignedValue<F>) -> SafeBool<F> {
159 self.range_chip.gate().assert_bit(ctx, input);
160 SafeBool(input)
161 }
162
163 pub fn load_bool(&self, ctx: &mut Context<F>, input: bool) -> SafeBool<F> {
165 let input = ctx.load_witness(F::from(input));
166 self.assert_bool(ctx, input)
167 }
168
169 pub fn unsafe_to_bool(input: AssignedValue<F>) -> SafeBool<F> {
172 SafeBool(input)
173 }
174
175 pub fn assert_byte(&self, ctx: &mut Context<F>, input: AssignedValue<F>) -> SafeByte<F> {
177 self.range_chip.range_check(ctx, input, BITS_PER_BYTE);
178 SafeByte(input)
179 }
180
181 pub fn load_byte(&self, ctx: &mut Context<F>, input: u8) -> SafeByte<F> {
183 let input = ctx.load_witness(F::from(input as u64));
184 self.assert_byte(ctx, input)
185 }
186
187 pub fn unsafe_to_byte(input: AssignedValue<F>) -> SafeByte<F> {
190 SafeByte(input)
191 }
192
193 pub fn unsafe_to_var_len_bytes<const MAX_LEN: usize>(
196 inputs: [AssignedValue<F>; MAX_LEN],
197 len: AssignedValue<F>,
198 ) -> VarLenBytes<F, MAX_LEN> {
199 VarLenBytes::<F, MAX_LEN>::new(inputs.map(|input| Self::unsafe_to_byte(input)), len)
200 }
201
202 pub fn unsafe_to_var_len_bytes_vec(
205 inputs: RawAssignedValues<F>,
206 len: AssignedValue<F>,
207 max_len: usize,
208 ) -> VarLenBytesVec<F> {
209 VarLenBytesVec::<F>::new(
210 inputs.iter().map(|input| Self::unsafe_to_byte(*input)).collect_vec(),
211 len,
212 max_len,
213 )
214 }
215
216 pub fn unsafe_to_fix_len_bytes<const MAX_LEN: usize>(
219 inputs: [AssignedValue<F>; MAX_LEN],
220 ) -> FixLenBytes<F, MAX_LEN> {
221 FixLenBytes::<F, MAX_LEN>::new(inputs.map(|input| Self::unsafe_to_byte(input)))
222 }
223
224 pub fn unsafe_to_fix_len_bytes_vec(
227 inputs: RawAssignedValues<F>,
228 len: usize,
229 ) -> FixLenBytesVec<F> {
230 FixLenBytesVec::<F>::new(
231 inputs.into_iter().map(|input| Self::unsafe_to_byte(input)).collect_vec(),
232 len,
233 )
234 }
235
236 pub fn raw_to_var_len_bytes<const MAX_LEN: usize>(
246 &self,
247 ctx: &mut Context<F>,
248 inputs: [AssignedValue<F>; MAX_LEN],
249 len: AssignedValue<F>,
250 ) -> VarLenBytes<F, MAX_LEN> {
251 self.range_chip.check_less_than_safe(ctx, len, MAX_LEN as u64 + 1);
252 VarLenBytes::<F, MAX_LEN>::new(inputs.map(|input| self.assert_byte(ctx, input)), len)
253 }
254
255 pub fn raw_to_var_len_bytes_vec(
265 &self,
266 ctx: &mut Context<F>,
267 inputs: RawAssignedValues<F>,
268 len: AssignedValue<F>,
269 max_len: usize,
270 ) -> VarLenBytesVec<F> {
271 self.range_chip.check_less_than_safe(ctx, len, max_len as u64 + 1);
272 VarLenBytesVec::<F>::new(
273 inputs.iter().map(|input| self.assert_byte(ctx, *input)).collect_vec(),
274 len,
275 max_len,
276 )
277 }
278
279 pub fn raw_to_fix_len_bytes<const LEN: usize>(
284 &self,
285 ctx: &mut Context<F>,
286 inputs: [AssignedValue<F>; LEN],
287 ) -> FixLenBytes<F, LEN> {
288 FixLenBytes::<F, LEN>::new(inputs.map(|input| self.assert_byte(ctx, input)))
289 }
290
291 pub fn raw_to_fix_len_bytes_vec(
296 &self,
297 ctx: &mut Context<F>,
298 inputs: RawAssignedValues<F>,
299 len: usize,
300 ) -> FixLenBytesVec<F> {
301 FixLenBytesVec::<F>::new(
302 inputs.into_iter().map(|input| self.assert_byte(ctx, input)).collect_vec(),
303 len,
304 )
305 }
306
307 fn add_bytes_constraints(
309 &self,
310 ctx: &mut Context<F>,
311 inputs: &RawAssignedValues<F>,
312 bits: usize,
313 ) {
314 let mut bits_left = bits;
315 for input in inputs {
316 let num_bit = min(bits_left, BITS_PER_BYTE);
317 self.range_chip.range_check(ctx, *input, num_bit);
318 bits_left -= num_bit;
319 }
320 }
321
322 }