openvm_sha256_circuit/sha256_chip/
trace.rs

1use std::{array, borrow::BorrowMut, sync::Arc};
2
3use openvm_circuit_primitives::utils::next_power_of_two_or_zero;
4use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
5use openvm_rv32im_circuit::adapters::compose;
6use openvm_sha256_air::{
7    get_flag_pt_array, limbs_into_u32, Sha256Air, SHA256_BLOCK_WORDS, SHA256_BUFFER_SIZE, SHA256_H,
8    SHA256_HASH_WORDS, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S,
9};
10use openvm_stark_backend::{
11    config::{StarkGenericConfig, Val},
12    p3_air::BaseAir,
13    p3_field::{FieldAlgebra, PrimeField32},
14    p3_matrix::dense::RowMajorMatrix,
15    p3_maybe_rayon::prelude::*,
16    prover::types::AirProofInput,
17    rap::get_air_name,
18    AirRef, Chip, ChipUsageGetter,
19};
20
21use super::{
22    Sha256VmChip, Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH,
23    SHA256VM_DIGEST_WIDTH, SHA256VM_ROUND_WIDTH,
24};
25use crate::{
26    sha256_chip::{PaddingFlags, SHA256_READ_SIZE},
27    SHA256_BLOCK_CELLS,
28};
29
30impl<SC: StarkGenericConfig> Chip<SC> for Sha256VmChip<Val<SC>>
31where
32    Val<SC>: PrimeField32,
33{
34    fn air(&self) -> AirRef<SC> {
35        Arc::new(self.air.clone())
36    }
37
38    fn generate_air_proof_input(self) -> AirProofInput<SC> {
39        let non_padded_height = self.current_trace_height();
40        let height = next_power_of_two_or_zero(non_padded_height);
41        let width = self.trace_width();
42        let mut values = Val::<SC>::zero_vec(height * width);
43        if height == 0 {
44            return AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width));
45        }
46        let records = self.records;
47        let offline_memory = self.offline_memory.lock().unwrap();
48        let memory_aux_cols_factory = offline_memory.aux_cols_factory();
49
50        let mem_ptr_shift: u32 =
51            1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.air.ptr_max_bits);
52
53        let mut states = Vec::with_capacity(height.div_ceil(SHA256_ROWS_PER_BLOCK));
54        let mut global_block_idx = 0;
55        for (record_idx, record) in records.iter().enumerate() {
56            let dst_read = offline_memory.record_by_id(record.dst_read);
57            let src_read = offline_memory.record_by_id(record.src_read);
58            let len_read = offline_memory.record_by_id(record.len_read);
59
60            self.bitwise_lookup_chip.request_range(
61                dst_read
62                    .data_at(RV32_REGISTER_NUM_LIMBS - 1)
63                    .as_canonical_u32()
64                    * mem_ptr_shift,
65                src_read
66                    .data_at(RV32_REGISTER_NUM_LIMBS - 1)
67                    .as_canonical_u32()
68                    * mem_ptr_shift,
69            );
70            let len = compose(len_read.data_slice().try_into().unwrap());
71            let mut state = &None;
72            for (i, input_message) in record.input_message.iter().enumerate() {
73                let input_message = input_message
74                    .iter()
75                    .flatten()
76                    .copied()
77                    .collect::<Vec<_>>()
78                    .try_into()
79                    .unwrap();
80                states.push(Some(Self::generate_state(
81                    state,
82                    input_message,
83                    record_idx,
84                    len,
85                    i == record.input_records.len() - 1,
86                )));
87                state = &states[global_block_idx];
88                global_block_idx += 1;
89            }
90        }
91        states.extend(std::iter::repeat_n(
92            None,
93            (height - non_padded_height).div_ceil(SHA256_ROWS_PER_BLOCK),
94        ));
95
96        // During the first pass we will fill out most of the matrix
97        // But there are some cells that can't be generated by the first pass so we will do a second
98        // pass over the matrix
99        values
100            .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK)
101            .zip(states.into_par_iter().enumerate())
102            .for_each(|(block, (global_block_idx, state))| {
103                // Fill in a valid block
104                if let Some(state) = state {
105                    let mut has_padding_occurred =
106                        state.local_block_idx * SHA256_BLOCK_CELLS > state.message_len as usize;
107                    let message_left = if has_padding_occurred {
108                        0
109                    } else {
110                        state.message_len as usize - state.local_block_idx * SHA256_BLOCK_CELLS
111                    };
112                    let is_last_block = state.is_last_block;
113                    let buffer: [[Val<SC>; SHA256_BUFFER_SIZE]; 4] = array::from_fn(|j| {
114                        array::from_fn(|k| {
115                            Val::<SC>::from_canonical_u8(
116                                state.block_input_message[j * SHA256_BUFFER_SIZE + k],
117                            )
118                        })
119                    });
120
121                    let padded_message: [u32; SHA256_BLOCK_WORDS] = array::from_fn(|j| {
122                        limbs_into_u32::<RV32_REGISTER_NUM_LIMBS>(array::from_fn(|k| {
123                            state.block_padded_message[(j + 1) * SHA256_WORD_U8S - k - 1] as u32
124                        }))
125                    });
126
127                    self.air.sha256_subair.generate_block_trace::<Val<SC>>(
128                        block,
129                        width,
130                        SHA256VM_CONTROL_WIDTH,
131                        &padded_message,
132                        self.bitwise_lookup_chip.clone(),
133                        &state.hash,
134                        is_last_block,
135                        global_block_idx as u32 + 1,
136                        state.local_block_idx as u32,
137                        &buffer,
138                    );
139
140                    let block_reads = records[state.message_idx].input_records
141                        [state.local_block_idx]
142                        .map(|record_id| offline_memory.record_by_id(record_id));
143
144                    let mut read_ptr = block_reads[0].pointer;
145                    let mut cur_timestamp = Val::<SC>::from_canonical_u32(block_reads[0].timestamp);
146
147                    let read_size = Val::<SC>::from_canonical_usize(SHA256_READ_SIZE);
148                    for row in 0..SHA256_ROWS_PER_BLOCK {
149                        let row_slice = &mut block[row * width..(row + 1) * width];
150                        if row < 16 {
151                            let cols: &mut Sha256VmRoundCols<Val<SC>> =
152                                row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut();
153                            cols.control.len = Val::<SC>::from_canonical_u32(state.message_len);
154                            cols.control.read_ptr = read_ptr;
155                            cols.control.cur_timestamp = cur_timestamp;
156                            if row < 4 {
157                                read_ptr += read_size;
158                                cur_timestamp += Val::<SC>::ONE;
159                                memory_aux_cols_factory
160                                    .generate_read_aux(block_reads[row], &mut cols.read_aux);
161
162                                if (row + 1) * SHA256_READ_SIZE <= message_left {
163                                    cols.control.pad_flags = get_flag_pt_array(
164                                        &self.air.padding_encoder,
165                                        PaddingFlags::NotPadding as usize,
166                                    )
167                                    .map(Val::<SC>::from_canonical_u32);
168                                } else if !has_padding_occurred {
169                                    has_padding_occurred = true;
170                                    let len = message_left - row * SHA256_READ_SIZE;
171                                    cols.control.pad_flags = get_flag_pt_array(
172                                        &self.air.padding_encoder,
173                                        if row == 3 && is_last_block {
174                                            PaddingFlags::FirstPadding0_LastRow
175                                        } else {
176                                            PaddingFlags::FirstPadding0
177                                        } as usize
178                                            + len,
179                                    )
180                                    .map(Val::<SC>::from_canonical_u32);
181                                } else {
182                                    cols.control.pad_flags = get_flag_pt_array(
183                                        &self.air.padding_encoder,
184                                        if row == 3 && is_last_block {
185                                            PaddingFlags::EntirePaddingLastRow
186                                        } else {
187                                            PaddingFlags::EntirePadding
188                                        } as usize,
189                                    )
190                                    .map(Val::<SC>::from_canonical_u32);
191                                }
192                            } else {
193                                cols.control.pad_flags = get_flag_pt_array(
194                                    &self.air.padding_encoder,
195                                    PaddingFlags::NotConsidered as usize,
196                                )
197                                .map(Val::<SC>::from_canonical_u32);
198                            }
199                            cols.control.padding_occurred =
200                                Val::<SC>::from_bool(has_padding_occurred);
201                        } else {
202                            if is_last_block {
203                                has_padding_occurred = false;
204                            }
205                            let cols: &mut Sha256VmDigestCols<Val<SC>> =
206                                row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut();
207                            cols.control.len = Val::<SC>::from_canonical_u32(state.message_len);
208                            cols.control.read_ptr = read_ptr;
209                            cols.control.cur_timestamp = cur_timestamp;
210                            cols.control.pad_flags = get_flag_pt_array(
211                                &self.air.padding_encoder,
212                                PaddingFlags::NotConsidered as usize,
213                            )
214                            .map(Val::<SC>::from_canonical_u32);
215                            if is_last_block {
216                                let record = &records[state.message_idx];
217                                let dst_read = offline_memory.record_by_id(record.dst_read);
218                                let src_read = offline_memory.record_by_id(record.src_read);
219                                let len_read = offline_memory.record_by_id(record.len_read);
220                                let digest_write = offline_memory.record_by_id(record.digest_write);
221                                cols.from_state = record.from_state;
222                                cols.rd_ptr = dst_read.pointer;
223                                cols.rs1_ptr = src_read.pointer;
224                                cols.rs2_ptr = len_read.pointer;
225                                cols.dst_ptr.copy_from_slice(dst_read.data_slice());
226                                cols.src_ptr.copy_from_slice(src_read.data_slice());
227                                cols.len_data.copy_from_slice(len_read.data_slice());
228                                memory_aux_cols_factory
229                                    .generate_read_aux(dst_read, &mut cols.register_reads_aux[0]);
230                                memory_aux_cols_factory
231                                    .generate_read_aux(src_read, &mut cols.register_reads_aux[1]);
232                                memory_aux_cols_factory
233                                    .generate_read_aux(len_read, &mut cols.register_reads_aux[2]);
234                                memory_aux_cols_factory
235                                    .generate_write_aux(digest_write, &mut cols.writes_aux);
236                            }
237                            cols.control.padding_occurred =
238                                Val::<SC>::from_bool(has_padding_occurred);
239                        }
240                    }
241                }
242                // Fill in the invalid rows
243                else {
244                    block.par_chunks_mut(width).for_each(|row| {
245                        let cols: &mut Sha256VmRoundCols<Val<SC>> = row.borrow_mut();
246                        self.air.sha256_subair.generate_default_row(&mut cols.inner);
247                    })
248                }
249            });
250
251        // Do a second pass over the trace to fill in the missing values
252        // Note, we need to skip the very first row
253        values[width..]
254            .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK)
255            .take(non_padded_height / SHA256_ROWS_PER_BLOCK)
256            .for_each(|chunk| {
257                self.air
258                    .sha256_subair
259                    .generate_missing_cells(chunk, width, SHA256VM_CONTROL_WIDTH);
260            });
261
262        AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width))
263    }
264}
265
266impl<F: PrimeField32> ChipUsageGetter for Sha256VmChip<F> {
267    fn air_name(&self) -> String {
268        get_air_name(&self.air)
269    }
270    fn current_trace_height(&self) -> usize {
271        self.records.iter().fold(0, |acc, record| {
272            acc + record.input_records.len() * SHA256_ROWS_PER_BLOCK
273        })
274    }
275
276    fn trace_width(&self) -> usize {
277        BaseAir::<F>::width(&self.air)
278    }
279}
280
281/// This is the state information that a block will use to generate its trace
282#[derive(Debug, Clone)]
283struct Sha256State {
284    hash: [u32; SHA256_HASH_WORDS],
285    local_block_idx: usize,
286    message_len: u32,
287    block_input_message: [u8; SHA256_BLOCK_CELLS],
288    block_padded_message: [u8; SHA256_BLOCK_CELLS],
289    message_idx: usize,
290    is_last_block: bool,
291}
292
293impl<F: PrimeField32> Sha256VmChip<F> {
294    fn generate_state(
295        prev_state: &Option<Sha256State>,
296        block_input_message: [u8; SHA256_BLOCK_CELLS],
297        message_idx: usize,
298        message_len: u32,
299        is_last_block: bool,
300    ) -> Sha256State {
301        let local_block_idx = if let Some(prev_state) = prev_state {
302            prev_state.local_block_idx + 1
303        } else {
304            0
305        };
306        let has_padding_occurred = local_block_idx * SHA256_BLOCK_CELLS > message_len as usize;
307        let message_left = if has_padding_occurred {
308            0
309        } else {
310            message_len as usize - local_block_idx * SHA256_BLOCK_CELLS
311        };
312
313        let padded_message_bytes: [u8; SHA256_BLOCK_CELLS] = array::from_fn(|j| {
314            if j < message_left {
315                block_input_message[j]
316            } else if j == message_left && !has_padding_occurred {
317                1 << (RV32_CELL_BITS - 1)
318            } else if !is_last_block || j < SHA256_BLOCK_CELLS - 4 {
319                0u8
320            } else {
321                let shift_amount = (SHA256_BLOCK_CELLS - j - 1) * RV32_CELL_BITS;
322                ((message_len * RV32_CELL_BITS as u32)
323                    .checked_shr(shift_amount as u32)
324                    .unwrap_or(0)
325                    & ((1 << RV32_CELL_BITS) - 1)) as u8
326            }
327        });
328
329        if let Some(prev_state) = prev_state {
330            Sha256State {
331                hash: Sha256Air::get_block_hash(&prev_state.hash, prev_state.block_padded_message),
332                local_block_idx,
333                message_len,
334                block_input_message,
335                block_padded_message: padded_message_bytes,
336                message_idx,
337                is_last_block,
338            }
339        } else {
340            Sha256State {
341                hash: SHA256_H,
342                local_block_idx: 0,
343                message_len,
344                block_input_message,
345                block_padded_message: padded_message_bytes,
346                message_idx,
347                is_last_block,
348            }
349        }
350    }
351}