openvm_sha256_circuit/sha256_chip/
trace.rs
1use std::{array, borrow::BorrowMut, sync::Arc};
2
3use openvm_circuit_primitives::utils::next_power_of_two_or_zero;
4use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
5use openvm_rv32im_circuit::adapters::compose;
6use openvm_sha256_air::{
7 get_flag_pt_array, limbs_into_u32, Sha256Air, SHA256_BLOCK_WORDS, SHA256_BUFFER_SIZE, SHA256_H,
8 SHA256_HASH_WORDS, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S,
9};
10use openvm_stark_backend::{
11 config::{StarkGenericConfig, Val},
12 p3_air::BaseAir,
13 p3_field::{FieldAlgebra, PrimeField32},
14 p3_matrix::dense::RowMajorMatrix,
15 p3_maybe_rayon::prelude::*,
16 prover::types::AirProofInput,
17 rap::get_air_name,
18 AirRef, Chip, ChipUsageGetter,
19};
20
21use super::{
22 Sha256VmChip, Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH,
23 SHA256VM_DIGEST_WIDTH, SHA256VM_ROUND_WIDTH,
24};
25use crate::{
26 sha256_chip::{PaddingFlags, SHA256_READ_SIZE},
27 SHA256_BLOCK_CELLS,
28};
29
30impl<SC: StarkGenericConfig> Chip<SC> for Sha256VmChip<Val<SC>>
31where
32 Val<SC>: PrimeField32,
33{
34 fn air(&self) -> AirRef<SC> {
35 Arc::new(self.air.clone())
36 }
37
38 fn generate_air_proof_input(self) -> AirProofInput<SC> {
39 let non_padded_height = self.current_trace_height();
40 let height = next_power_of_two_or_zero(non_padded_height);
41 let width = self.trace_width();
42 let mut values = Val::<SC>::zero_vec(height * width);
43 if height == 0 {
44 return AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width));
45 }
46 let records = self.records;
47 let offline_memory = self.offline_memory.lock().unwrap();
48 let memory_aux_cols_factory = offline_memory.aux_cols_factory();
49
50 let mem_ptr_shift: u32 =
51 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.air.ptr_max_bits);
52
53 let mut states = Vec::with_capacity(height.div_ceil(SHA256_ROWS_PER_BLOCK));
54 let mut global_block_idx = 0;
55 for (record_idx, record) in records.iter().enumerate() {
56 let dst_read = offline_memory.record_by_id(record.dst_read);
57 let src_read = offline_memory.record_by_id(record.src_read);
58 let len_read = offline_memory.record_by_id(record.len_read);
59
60 self.bitwise_lookup_chip.request_range(
61 dst_read
62 .data_at(RV32_REGISTER_NUM_LIMBS - 1)
63 .as_canonical_u32()
64 * mem_ptr_shift,
65 src_read
66 .data_at(RV32_REGISTER_NUM_LIMBS - 1)
67 .as_canonical_u32()
68 * mem_ptr_shift,
69 );
70 let len = compose(len_read.data_slice().try_into().unwrap());
71 let mut state = &None;
72 for (i, input_message) in record.input_message.iter().enumerate() {
73 let input_message = input_message
74 .iter()
75 .flatten()
76 .copied()
77 .collect::<Vec<_>>()
78 .try_into()
79 .unwrap();
80 states.push(Some(Self::generate_state(
81 state,
82 input_message,
83 record_idx,
84 len,
85 i == record.input_records.len() - 1,
86 )));
87 state = &states[global_block_idx];
88 global_block_idx += 1;
89 }
90 }
91 states.extend(std::iter::repeat_n(
92 None,
93 (height - non_padded_height).div_ceil(SHA256_ROWS_PER_BLOCK),
94 ));
95
96 values
100 .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK)
101 .zip(states.into_par_iter().enumerate())
102 .for_each(|(block, (global_block_idx, state))| {
103 if let Some(state) = state {
105 let mut has_padding_occurred =
106 state.local_block_idx * SHA256_BLOCK_CELLS > state.message_len as usize;
107 let message_left = if has_padding_occurred {
108 0
109 } else {
110 state.message_len as usize - state.local_block_idx * SHA256_BLOCK_CELLS
111 };
112 let is_last_block = state.is_last_block;
113 let buffer: [[Val<SC>; SHA256_BUFFER_SIZE]; 4] = array::from_fn(|j| {
114 array::from_fn(|k| {
115 Val::<SC>::from_canonical_u8(
116 state.block_input_message[j * SHA256_BUFFER_SIZE + k],
117 )
118 })
119 });
120
121 let padded_message: [u32; SHA256_BLOCK_WORDS] = array::from_fn(|j| {
122 limbs_into_u32::<RV32_REGISTER_NUM_LIMBS>(array::from_fn(|k| {
123 state.block_padded_message[(j + 1) * SHA256_WORD_U8S - k - 1] as u32
124 }))
125 });
126
127 self.air.sha256_subair.generate_block_trace::<Val<SC>>(
128 block,
129 width,
130 SHA256VM_CONTROL_WIDTH,
131 &padded_message,
132 self.bitwise_lookup_chip.clone(),
133 &state.hash,
134 is_last_block,
135 global_block_idx as u32 + 1,
136 state.local_block_idx as u32,
137 &buffer,
138 );
139
140 let block_reads = records[state.message_idx].input_records
141 [state.local_block_idx]
142 .map(|record_id| offline_memory.record_by_id(record_id));
143
144 let mut read_ptr = block_reads[0].pointer;
145 let mut cur_timestamp = Val::<SC>::from_canonical_u32(block_reads[0].timestamp);
146
147 let read_size = Val::<SC>::from_canonical_usize(SHA256_READ_SIZE);
148 for row in 0..SHA256_ROWS_PER_BLOCK {
149 let row_slice = &mut block[row * width..(row + 1) * width];
150 if row < 16 {
151 let cols: &mut Sha256VmRoundCols<Val<SC>> =
152 row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut();
153 cols.control.len = Val::<SC>::from_canonical_u32(state.message_len);
154 cols.control.read_ptr = read_ptr;
155 cols.control.cur_timestamp = cur_timestamp;
156 if row < 4 {
157 read_ptr += read_size;
158 cur_timestamp += Val::<SC>::ONE;
159 memory_aux_cols_factory
160 .generate_read_aux(block_reads[row], &mut cols.read_aux);
161
162 if (row + 1) * SHA256_READ_SIZE <= message_left {
163 cols.control.pad_flags = get_flag_pt_array(
164 &self.air.padding_encoder,
165 PaddingFlags::NotPadding as usize,
166 )
167 .map(Val::<SC>::from_canonical_u32);
168 } else if !has_padding_occurred {
169 has_padding_occurred = true;
170 let len = message_left - row * SHA256_READ_SIZE;
171 cols.control.pad_flags = get_flag_pt_array(
172 &self.air.padding_encoder,
173 if row == 3 && is_last_block {
174 PaddingFlags::FirstPadding0_LastRow
175 } else {
176 PaddingFlags::FirstPadding0
177 } as usize
178 + len,
179 )
180 .map(Val::<SC>::from_canonical_u32);
181 } else {
182 cols.control.pad_flags = get_flag_pt_array(
183 &self.air.padding_encoder,
184 if row == 3 && is_last_block {
185 PaddingFlags::EntirePaddingLastRow
186 } else {
187 PaddingFlags::EntirePadding
188 } as usize,
189 )
190 .map(Val::<SC>::from_canonical_u32);
191 }
192 } else {
193 cols.control.pad_flags = get_flag_pt_array(
194 &self.air.padding_encoder,
195 PaddingFlags::NotConsidered as usize,
196 )
197 .map(Val::<SC>::from_canonical_u32);
198 }
199 cols.control.padding_occurred =
200 Val::<SC>::from_bool(has_padding_occurred);
201 } else {
202 if is_last_block {
203 has_padding_occurred = false;
204 }
205 let cols: &mut Sha256VmDigestCols<Val<SC>> =
206 row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut();
207 cols.control.len = Val::<SC>::from_canonical_u32(state.message_len);
208 cols.control.read_ptr = read_ptr;
209 cols.control.cur_timestamp = cur_timestamp;
210 cols.control.pad_flags = get_flag_pt_array(
211 &self.air.padding_encoder,
212 PaddingFlags::NotConsidered as usize,
213 )
214 .map(Val::<SC>::from_canonical_u32);
215 if is_last_block {
216 let record = &records[state.message_idx];
217 let dst_read = offline_memory.record_by_id(record.dst_read);
218 let src_read = offline_memory.record_by_id(record.src_read);
219 let len_read = offline_memory.record_by_id(record.len_read);
220 let digest_write = offline_memory.record_by_id(record.digest_write);
221 cols.from_state = record.from_state;
222 cols.rd_ptr = dst_read.pointer;
223 cols.rs1_ptr = src_read.pointer;
224 cols.rs2_ptr = len_read.pointer;
225 cols.dst_ptr.copy_from_slice(dst_read.data_slice());
226 cols.src_ptr.copy_from_slice(src_read.data_slice());
227 cols.len_data.copy_from_slice(len_read.data_slice());
228 memory_aux_cols_factory
229 .generate_read_aux(dst_read, &mut cols.register_reads_aux[0]);
230 memory_aux_cols_factory
231 .generate_read_aux(src_read, &mut cols.register_reads_aux[1]);
232 memory_aux_cols_factory
233 .generate_read_aux(len_read, &mut cols.register_reads_aux[2]);
234 memory_aux_cols_factory
235 .generate_write_aux(digest_write, &mut cols.writes_aux);
236 }
237 cols.control.padding_occurred =
238 Val::<SC>::from_bool(has_padding_occurred);
239 }
240 }
241 }
242 else {
244 block.par_chunks_mut(width).for_each(|row| {
245 let cols: &mut Sha256VmRoundCols<Val<SC>> = row.borrow_mut();
246 self.air.sha256_subair.generate_default_row(&mut cols.inner);
247 })
248 }
249 });
250
251 values[width..]
254 .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK)
255 .take(non_padded_height / SHA256_ROWS_PER_BLOCK)
256 .for_each(|chunk| {
257 self.air
258 .sha256_subair
259 .generate_missing_cells(chunk, width, SHA256VM_CONTROL_WIDTH);
260 });
261
262 AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width))
263 }
264}
265
266impl<F: PrimeField32> ChipUsageGetter for Sha256VmChip<F> {
267 fn air_name(&self) -> String {
268 get_air_name(&self.air)
269 }
270 fn current_trace_height(&self) -> usize {
271 self.records.iter().fold(0, |acc, record| {
272 acc + record.input_records.len() * SHA256_ROWS_PER_BLOCK
273 })
274 }
275
276 fn trace_width(&self) -> usize {
277 BaseAir::<F>::width(&self.air)
278 }
279}
280
281#[derive(Debug, Clone)]
283struct Sha256State {
284 hash: [u32; SHA256_HASH_WORDS],
285 local_block_idx: usize,
286 message_len: u32,
287 block_input_message: [u8; SHA256_BLOCK_CELLS],
288 block_padded_message: [u8; SHA256_BLOCK_CELLS],
289 message_idx: usize,
290 is_last_block: bool,
291}
292
293impl<F: PrimeField32> Sha256VmChip<F> {
294 fn generate_state(
295 prev_state: &Option<Sha256State>,
296 block_input_message: [u8; SHA256_BLOCK_CELLS],
297 message_idx: usize,
298 message_len: u32,
299 is_last_block: bool,
300 ) -> Sha256State {
301 let local_block_idx = if let Some(prev_state) = prev_state {
302 prev_state.local_block_idx + 1
303 } else {
304 0
305 };
306 let has_padding_occurred = local_block_idx * SHA256_BLOCK_CELLS > message_len as usize;
307 let message_left = if has_padding_occurred {
308 0
309 } else {
310 message_len as usize - local_block_idx * SHA256_BLOCK_CELLS
311 };
312
313 let padded_message_bytes: [u8; SHA256_BLOCK_CELLS] = array::from_fn(|j| {
314 if j < message_left {
315 block_input_message[j]
316 } else if j == message_left && !has_padding_occurred {
317 1 << (RV32_CELL_BITS - 1)
318 } else if !is_last_block || j < SHA256_BLOCK_CELLS - 4 {
319 0u8
320 } else {
321 let shift_amount = (SHA256_BLOCK_CELLS - j - 1) * RV32_CELL_BITS;
322 ((message_len * RV32_CELL_BITS as u32)
323 .checked_shr(shift_amount as u32)
324 .unwrap_or(0)
325 & ((1 << RV32_CELL_BITS) - 1)) as u8
326 }
327 });
328
329 if let Some(prev_state) = prev_state {
330 Sha256State {
331 hash: Sha256Air::get_block_hash(&prev_state.hash, prev_state.block_padded_message),
332 local_block_idx,
333 message_len,
334 block_input_message,
335 block_padded_message: padded_message_bytes,
336 message_idx,
337 is_last_block,
338 }
339 } else {
340 Sha256State {
341 hash: SHA256_H,
342 local_block_idx: 0,
343 message_len,
344 block_input_message,
345 block_padded_message: padded_message_bytes,
346 message_idx,
347 is_last_block,
348 }
349 }
350 }
351}