1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4 cmp::min,
5};
6
7use openvm_circuit::{
8 arch::*,
9 system::memory::{
10 offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord},
11 online::TracingMemory,
12 MemoryAuxColsFactory,
13 },
14};
15use openvm_circuit_primitives::AlignedBytesBorrow;
16use openvm_instructions::{
17 instruction::Instruction,
18 program::DEFAULT_PC_STEP,
19 riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
20 LocalOpcode,
21};
22use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write};
23use openvm_sha256_air::{
24 get_flag_pt_array, get_sha256_num_blocks, Sha256FillerHelper, SHA256_BLOCK_BITS, SHA256_H,
25 SHA256_ROWS_PER_BLOCK,
26};
27use openvm_sha256_transpiler::Rv32Sha256Opcode;
28use openvm_stark_backend::{
29 p3_field::PrimeField32,
30 p3_matrix::{dense::RowMajorMatrix, Matrix},
31 p3_maybe_rayon::prelude::*,
32};
33
34use super::{
35 Sha256VmDigestCols, Sha256VmExecutor, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH,
36 SHA256VM_DIGEST_WIDTH,
37};
38use crate::{
39 sha256_chip::{PaddingFlags, SHA256_READ_SIZE, SHA256_REGISTER_READS, SHA256_WRITE_SIZE},
40 sha256_solve, Sha256VmControlCols, Sha256VmFiller, SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH,
41 SHA256_BLOCK_CELLS, SHA256_MAX_MESSAGE_LEN, SHA256_NUM_READ_ROWS,
42};
43
44#[derive(Clone, Copy)]
45pub struct Sha256VmMetadata {
46 pub num_blocks: u32,
47}
48
49impl MultiRowMetadata for Sha256VmMetadata {
50 #[inline(always)]
51 fn get_num_rows(&self) -> usize {
52 self.num_blocks as usize * SHA256_ROWS_PER_BLOCK
53 }
54}
55
56pub(crate) type Sha256VmRecordLayout = MultiRowLayout<Sha256VmMetadata>;
57
58#[repr(C)]
59#[derive(AlignedBytesBorrow, Debug, Clone)]
60pub struct Sha256VmRecordHeader {
61 pub from_pc: u32,
62 pub timestamp: u32,
63 pub rd_ptr: u32,
64 pub rs1_ptr: u32,
65 pub rs2_ptr: u32,
66 pub dst_ptr: u32,
67 pub src_ptr: u32,
68 pub len: u32,
69
70 pub register_reads_aux: [MemoryReadAuxRecord; SHA256_REGISTER_READS],
71 pub write_aux: MemoryWriteBytesAuxRecord<SHA256_WRITE_SIZE>,
72}
73
74pub struct Sha256VmRecordMut<'a> {
75 pub inner: &'a mut Sha256VmRecordHeader,
76 pub input: &'a mut [u8],
78 pub read_aux: &'a mut [MemoryReadAuxRecord],
79}
80
81impl<'a> CustomBorrow<'a, Sha256VmRecordMut<'a>, Sha256VmRecordLayout> for [u8] {
88 fn custom_borrow(&'a mut self, layout: Sha256VmRecordLayout) -> Sha256VmRecordMut<'a> {
89 let (header_buf, rest) =
93 unsafe { self.split_at_mut_unchecked(size_of::<Sha256VmRecordHeader>()) };
94
95 let (input, rest) = unsafe {
101 rest.split_at_mut_unchecked((layout.metadata.num_blocks as usize) * SHA256_BLOCK_CELLS)
102 };
103
104 let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::<MemoryReadAuxRecord>() };
111 Sha256VmRecordMut {
112 inner: header_buf.borrow_mut(),
113 input,
114 read_aux: &mut read_aux_buf
115 [..(layout.metadata.num_blocks as usize) * SHA256_NUM_READ_ROWS],
116 }
117 }
118
119 unsafe fn extract_layout(&self) -> Sha256VmRecordLayout {
120 let header: &Sha256VmRecordHeader = self.borrow();
121 Sha256VmRecordLayout {
122 metadata: Sha256VmMetadata {
123 num_blocks: get_sha256_num_blocks(header.len),
124 },
125 }
126 }
127}
128
129impl SizedRecord<Sha256VmRecordLayout> for Sha256VmRecordMut<'_> {
130 fn size(layout: &Sha256VmRecordLayout) -> usize {
131 let mut total_len = size_of::<Sha256VmRecordHeader>();
132 total_len += layout.metadata.num_blocks as usize * SHA256_BLOCK_CELLS;
133 total_len = total_len.next_multiple_of(align_of::<MemoryReadAuxRecord>());
135 total_len += layout.metadata.num_blocks as usize
136 * SHA256_NUM_READ_ROWS
137 * size_of::<MemoryReadAuxRecord>();
138 total_len
139 }
140
141 fn alignment(_layout: &Sha256VmRecordLayout) -> usize {
142 align_of::<Sha256VmRecordHeader>()
143 }
144}
145
146impl<F, RA> PreflightExecutor<F, RA> for Sha256VmExecutor
147where
148 F: PrimeField32,
149 for<'buf> RA: RecordArena<'buf, Sha256VmRecordLayout, Sha256VmRecordMut<'buf>>,
150{
151 fn get_opcode_name(&self, _: usize) -> String {
152 format!("{:?}", Rv32Sha256Opcode::SHA256)
153 }
154
155 fn execute(
156 &self,
157 state: VmStateMut<F, TracingMemory, RA>,
158 instruction: &Instruction<F>,
159 ) -> Result<(), ExecutionError> {
160 let Instruction {
161 opcode,
162 a,
163 b,
164 c,
165 d,
166 e,
167 ..
168 } = instruction;
169 debug_assert_eq!(*opcode, Rv32Sha256Opcode::SHA256.global_opcode());
170 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
171 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
172
173 let len = read_rv32_register(state.memory.data(), c.as_canonical_u32());
175
176 let num_blocks = get_sha256_num_blocks(len);
177 let record = state.ctx.alloc(MultiRowLayout {
178 metadata: Sha256VmMetadata { num_blocks },
179 });
180
181 record.inner.from_pc = *state.pc;
182 record.inner.timestamp = state.memory.timestamp();
183 record.inner.rd_ptr = a.as_canonical_u32();
184 record.inner.rs1_ptr = b.as_canonical_u32();
185 record.inner.rs2_ptr = c.as_canonical_u32();
186
187 record.inner.dst_ptr = u32::from_le_bytes(tracing_read(
188 state.memory,
189 RV32_REGISTER_AS,
190 record.inner.rd_ptr,
191 &mut record.inner.register_reads_aux[0].prev_timestamp,
192 ));
193 record.inner.src_ptr = u32::from_le_bytes(tracing_read(
194 state.memory,
195 RV32_REGISTER_AS,
196 record.inner.rs1_ptr,
197 &mut record.inner.register_reads_aux[1].prev_timestamp,
198 ));
199 record.inner.len = u32::from_le_bytes(tracing_read(
200 state.memory,
201 RV32_REGISTER_AS,
202 record.inner.rs2_ptr,
203 &mut record.inner.register_reads_aux[2].prev_timestamp,
204 ));
205
206 debug_assert!(
208 record.inner.src_ptr as usize + num_blocks as usize * SHA256_BLOCK_CELLS
209 <= (1 << self.pointer_max_bits)
210 );
211 debug_assert!(
212 record.inner.dst_ptr as usize + SHA256_WRITE_SIZE <= (1 << self.pointer_max_bits)
213 );
214 debug_assert!(record.inner.len < SHA256_MAX_MESSAGE_LEN as u32);
216
217 for block_idx in 0..num_blocks as usize {
218 for row in 0..SHA256_NUM_READ_ROWS {
220 let read_idx = block_idx * SHA256_NUM_READ_ROWS + row;
221 let row_input: [u8; SHA256_READ_SIZE] = tracing_read(
222 state.memory,
223 RV32_MEMORY_AS,
224 record.inner.src_ptr + (read_idx * SHA256_READ_SIZE) as u32,
225 &mut record.read_aux[read_idx].prev_timestamp,
226 );
227 record.input[read_idx * SHA256_READ_SIZE..(read_idx + 1) * SHA256_READ_SIZE]
228 .copy_from_slice(&row_input);
229 }
230 }
231
232 let output = sha256_solve(&record.input[..len as usize]);
233 tracing_write(
234 state.memory,
235 RV32_MEMORY_AS,
236 record.inner.dst_ptr,
237 output,
238 &mut record.inner.write_aux.prev_timestamp,
239 &mut record.inner.write_aux.prev_data,
240 );
241
242 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
243
244 Ok(())
245 }
246}
247
248impl<F: PrimeField32> TraceFiller<F> for Sha256VmFiller {
249 fn fill_trace(
250 &self,
251 mem_helper: &MemoryAuxColsFactory<F>,
252 trace_matrix: &mut RowMajorMatrix<F>,
253 rows_used: usize,
254 ) {
255 if rows_used == 0 {
256 return;
257 }
258
259 let mut chunks = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK);
260 let mut sizes = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK);
261 let mut trace = &mut trace_matrix.values[..];
262 let mut num_blocks_so_far = 0;
263
264 loop {
267 if num_blocks_so_far * SHA256_ROWS_PER_BLOCK >= rows_used {
268 chunks.push(trace);
270 sizes.push((0, num_blocks_so_far));
271 break;
272 } else {
273 let record: &Sha256VmRecordHeader =
278 unsafe { get_record_from_slice(&mut trace, ()) };
279 let num_blocks = ((record.len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS);
280 let (chunk, rest) =
281 trace.split_at_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK * num_blocks);
282 chunks.push(chunk);
283 sizes.push((num_blocks, num_blocks_so_far));
284 num_blocks_so_far += num_blocks;
285 trace = rest;
286 }
287 }
288
289 chunks.par_iter_mut().zip(sizes.par_iter()).for_each(
293 |(slice, (num_blocks, global_block_offset))| {
294 if global_block_offset * SHA256_ROWS_PER_BLOCK >= rows_used {
295 slice.par_chunks_mut(SHA256VM_WIDTH).for_each(|row| {
297 unsafe {
307 std::ptr::write_bytes(
308 row.as_mut_ptr() as *mut u8,
309 0,
310 SHA256VM_WIDTH * size_of::<F>(),
311 );
312 }
313 let cols: &mut Sha256VmRoundCols<F> =
314 row[..SHA256VM_ROUND_WIDTH].borrow_mut();
315 self.inner.generate_default_row(&mut cols.inner);
316 });
317 return;
318 }
319
320 let record: Sha256VmRecordMut = unsafe {
327 get_record_from_slice(
328 slice,
329 Sha256VmRecordLayout {
330 metadata: Sha256VmMetadata {
331 num_blocks: *num_blocks as u32,
332 },
333 },
334 )
335 };
336
337 let mut input: Vec<u8> = Vec::with_capacity(SHA256_BLOCK_CELLS * num_blocks);
338 input.extend_from_slice(record.input);
339 let mut padded_input = input.clone();
340 let len = record.inner.len as usize;
341 let padded_input_len = padded_input.len();
342 padded_input[len] = 1 << (RV32_CELL_BITS - 1);
343 padded_input[len + 1..padded_input_len - 4].fill(0);
344 padded_input[padded_input_len - 4..]
345 .copy_from_slice(&((len as u32) << 3).to_be_bytes());
346
347 let mut prev_hashes = Vec::with_capacity(*num_blocks);
348 prev_hashes.push(SHA256_H);
349 for i in 0..*num_blocks - 1 {
350 prev_hashes.push(Sha256FillerHelper::get_block_hash(
351 &prev_hashes[i],
352 padded_input[i * SHA256_BLOCK_CELLS..(i + 1) * SHA256_BLOCK_CELLS]
353 .try_into()
354 .unwrap(),
355 ));
356 }
357 let mut read_aux_records = Vec::with_capacity(SHA256_NUM_READ_ROWS * num_blocks);
360 read_aux_records.extend_from_slice(record.read_aux);
361 let vm_record = record.inner.clone();
362
363 slice
364 .par_chunks_exact_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK)
365 .enumerate()
366 .for_each(|(block_idx, block_slice)| {
367 unsafe {
375 std::ptr::write_bytes(
376 block_slice.as_mut_ptr() as *mut u8,
377 0,
378 SHA256_ROWS_PER_BLOCK * SHA256VM_WIDTH * size_of::<F>(),
379 );
380 }
381 self.fill_block_trace::<F>(
382 block_slice,
383 &vm_record,
384 &read_aux_records[block_idx * SHA256_NUM_READ_ROWS
385 ..(block_idx + 1) * SHA256_NUM_READ_ROWS],
386 &input[block_idx * SHA256_BLOCK_CELLS
387 ..(block_idx + 1) * SHA256_BLOCK_CELLS],
388 &padded_input[block_idx * SHA256_BLOCK_CELLS
389 ..(block_idx + 1) * SHA256_BLOCK_CELLS],
390 block_idx == *num_blocks - 1,
391 *global_block_offset + block_idx,
392 block_idx,
393 prev_hashes[block_idx],
394 mem_helper,
395 );
396 });
397 },
398 );
399
400 trace_matrix.values[SHA256VM_WIDTH..]
403 .par_chunks_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK)
404 .take(rows_used / SHA256_ROWS_PER_BLOCK)
405 .for_each(|chunk| {
406 self.inner
407 .generate_missing_cells(chunk, SHA256VM_WIDTH, SHA256VM_CONTROL_WIDTH);
408 });
409 }
410}
411
412impl Sha256VmFiller {
413 #[allow(clippy::too_many_arguments)]
414 fn fill_block_trace<F: PrimeField32>(
415 &self,
416 block_slice: &mut [F],
417 record: &Sha256VmRecordHeader,
418 read_aux_records: &[MemoryReadAuxRecord],
419 input: &[u8],
420 padded_input: &[u8],
421 is_last_block: bool,
422 global_block_idx: usize,
423 local_block_idx: usize,
424 prev_hash: [u32; 8],
425 mem_helper: &MemoryAuxColsFactory<F>,
426 ) {
427 debug_assert_eq!(input.len(), SHA256_BLOCK_CELLS);
428 debug_assert_eq!(padded_input.len(), SHA256_BLOCK_CELLS);
429 debug_assert_eq!(read_aux_records.len(), SHA256_NUM_READ_ROWS);
430
431 let padded_input = array::from_fn(|i| {
432 u32::from_be_bytes(padded_input[i * 4..(i + 1) * 4].try_into().unwrap())
433 });
434
435 let block_start_timestamp = record.timestamp
436 + (SHA256_REGISTER_READS + SHA256_NUM_READ_ROWS * local_block_idx) as u32;
437
438 let read_cells = (SHA256_BLOCK_CELLS * local_block_idx) as u32;
439 let block_start_read_ptr = record.src_ptr + read_cells;
440
441 let message_left = if record.len <= read_cells {
442 0
443 } else {
444 (record.len - read_cells) as usize
445 };
446
447 let first_padding_row = if record.len < read_cells {
450 -1
451 } else if message_left < SHA256_BLOCK_CELLS {
452 (message_left / SHA256_READ_SIZE) as i32
453 } else {
454 18
455 };
456
457 block_slice
459 .par_chunks_exact_mut(SHA256VM_WIDTH)
460 .enumerate()
461 .for_each(|(row_idx, row_slice)| {
462 if row_idx == SHA256_ROWS_PER_BLOCK - 1 {
464 let digest_cols: &mut Sha256VmDigestCols<F> =
466 row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut();
467 digest_cols.from_state.timestamp = F::from_canonical_u32(record.timestamp);
468 digest_cols.from_state.pc = F::from_canonical_u32(record.from_pc);
469 digest_cols.rd_ptr = F::from_canonical_u32(record.rd_ptr);
470 digest_cols.rs1_ptr = F::from_canonical_u32(record.rs1_ptr);
471 digest_cols.rs2_ptr = F::from_canonical_u32(record.rs2_ptr);
472 digest_cols.dst_ptr = record.dst_ptr.to_le_bytes().map(F::from_canonical_u8);
473 digest_cols.src_ptr = record.src_ptr.to_le_bytes().map(F::from_canonical_u8);
474 digest_cols.len_data = record.len.to_le_bytes().map(F::from_canonical_u8);
475 if is_last_block {
476 digest_cols
477 .register_reads_aux
478 .iter_mut()
479 .zip(record.register_reads_aux.iter())
480 .enumerate()
481 .for_each(|(idx, (cols_read, record_read))| {
482 mem_helper.fill(
483 record_read.prev_timestamp,
484 record.timestamp + idx as u32,
485 cols_read.as_mut(),
486 );
487 });
488 digest_cols
489 .writes_aux
490 .set_prev_data(record.write_aux.prev_data.map(F::from_canonical_u8));
491 mem_helper.fill(
495 record.write_aux.prev_timestamp,
496 block_start_timestamp + SHA256_NUM_READ_ROWS as u32,
497 digest_cols.writes_aux.as_mut(),
498 );
499 let msl_rshift: u32 =
501 ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32;
502 let msl_lshift: u32 = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS
503 - self.pointer_max_bits)
504 as u32;
505 self.bitwise_lookup_chip.request_range(
506 (record.dst_ptr >> msl_rshift) << msl_lshift,
507 (record.src_ptr >> msl_rshift) << msl_lshift,
508 );
509 } else {
510 digest_cols.register_reads_aux.iter_mut().for_each(|aux| {
513 mem_helper.fill_zero(aux.as_mut());
514 });
515 digest_cols
516 .writes_aux
517 .set_prev_data([F::ZERO; SHA256_WRITE_SIZE]);
518 mem_helper.fill_zero(digest_cols.writes_aux.as_mut());
519 }
520 digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block);
521 digest_cols.inner.flags.is_digest_row = F::from_bool(true);
522 } else {
523 let round_cols: &mut Sha256VmRoundCols<F> =
525 row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut();
526 if row_idx < SHA256_NUM_READ_ROWS {
528 round_cols
529 .inner
530 .message_schedule
531 .carry_or_buffer
532 .as_flattened_mut()
533 .iter_mut()
534 .zip(
535 input[row_idx * SHA256_READ_SIZE..(row_idx + 1) * SHA256_READ_SIZE]
536 .iter(),
537 )
538 .for_each(|(cell, data)| {
539 *cell = F::from_canonical_u8(*data);
540 });
541 mem_helper.fill(
542 read_aux_records[row_idx].prev_timestamp,
543 block_start_timestamp + row_idx as u32,
544 round_cols.read_aux.as_mut(),
545 );
546 } else {
547 mem_helper.fill_zero(round_cols.read_aux.as_mut());
548 }
549 }
550 let control_cols: &mut Sha256VmControlCols<F> =
552 row_slice[..SHA256VM_CONTROL_WIDTH].borrow_mut();
553 control_cols.len = F::from_canonical_u32(record.len);
554 control_cols.cur_timestamp = F::from_canonical_u32(
556 block_start_timestamp + min(row_idx, SHA256_NUM_READ_ROWS) as u32,
557 );
558 control_cols.read_ptr = F::from_canonical_u32(
559 block_start_read_ptr
560 + (SHA256_READ_SIZE * min(row_idx, SHA256_NUM_READ_ROWS)) as u32,
561 );
562
563 if row_idx < SHA256_NUM_READ_ROWS {
565 #[allow(clippy::comparison_chain)]
566 if (row_idx as i32) < first_padding_row {
567 control_cols.pad_flags = get_flag_pt_array(
568 &self.padding_encoder,
569 PaddingFlags::NotPadding as usize,
570 )
571 .map(F::from_canonical_u32);
572 } else if row_idx as i32 == first_padding_row {
573 let len = message_left - row_idx * SHA256_READ_SIZE;
574 control_cols.pad_flags = get_flag_pt_array(
575 &self.padding_encoder,
576 if row_idx == 3 && is_last_block {
577 PaddingFlags::FirstPadding0_LastRow
578 } else {
579 PaddingFlags::FirstPadding0
580 } as usize
581 + len,
582 )
583 .map(F::from_canonical_u32);
584 } else {
585 control_cols.pad_flags = get_flag_pt_array(
586 &self.padding_encoder,
587 if row_idx == 3 && is_last_block {
588 PaddingFlags::EntirePaddingLastRow
589 } else {
590 PaddingFlags::EntirePadding
591 } as usize,
592 )
593 .map(F::from_canonical_u32);
594 }
595 } else {
596 control_cols.pad_flags = get_flag_pt_array(
597 &self.padding_encoder,
598 PaddingFlags::NotConsidered as usize,
599 )
600 .map(F::from_canonical_u32);
601 }
602 if is_last_block && row_idx == SHA256_ROWS_PER_BLOCK - 1 {
603 control_cols.padding_occurred = F::ZERO;
605 } else {
606 control_cols.padding_occurred =
607 F::from_bool((row_idx as i32) >= first_padding_row);
608 }
609 });
610
611 self.inner.generate_block_trace::<F>(
613 block_slice,
614 SHA256VM_WIDTH,
615 SHA256VM_CONTROL_WIDTH,
616 &padded_input,
617 self.bitwise_lookup_chip.as_ref(),
618 &prev_hash,
619 is_last_block,
620 global_block_idx as u32 + 1, local_block_idx as u32,
622 );
623 }
624}