openvm_sha256_circuit/sha256_chip/
mod.rs

1//! Sha256 hasher. Handles full sha256 hashing with padding.
2//! variable length inputs read from VM memory.
3use std::{
4    array,
5    cmp::{max, min},
6    sync::{Arc, Mutex},
7};
8
9use openvm_circuit::arch::{
10    ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, SystemPort,
11};
12use openvm_circuit_primitives::{
13    bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder,
14};
15use openvm_instructions::{
16    instruction::Instruction,
17    program::DEFAULT_PC_STEP,
18    riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS},
19    LocalOpcode,
20};
21use openvm_rv32im_circuit::adapters::read_rv32_register;
22use openvm_sha256_air::{Sha256Air, SHA256_BLOCK_BITS};
23use openvm_sha256_transpiler::Rv32Sha256Opcode;
24use openvm_stark_backend::{interaction::BusIndex, p3_field::PrimeField32};
25use serde::{Deserialize, Serialize};
26use sha2::{Digest, Sha256};
27
28mod air;
29mod columns;
30mod trace;
31
32pub use air::*;
33pub use columns::*;
34use openvm_circuit::system::memory::{MemoryController, OfflineMemory, RecordId};
35
36#[cfg(test)]
37mod tests;
38
39// ==== Constants for register/memory adapter ====
40/// Register reads to get dst, src, len
41const SHA256_REGISTER_READS: usize = 3;
42/// Number of cells to read in a single memory access
43const SHA256_READ_SIZE: usize = 16;
44/// Number of cells to write in a single memory access
45const SHA256_WRITE_SIZE: usize = 32;
46/// Number of rv32 cells read in a SHA256 block
47pub const SHA256_BLOCK_CELLS: usize = SHA256_BLOCK_BITS / RV32_CELL_BITS;
48/// Number of rows we will do a read on for each SHA256 block
49pub const SHA256_NUM_READ_ROWS: usize = SHA256_BLOCK_CELLS / SHA256_READ_SIZE;
50pub struct Sha256VmChip<F: PrimeField32> {
51    pub air: Sha256VmAir,
52    /// IO and memory data necessary for each opcode call
53    pub records: Vec<Sha256Record<F>>,
54    pub offline_memory: Arc<Mutex<OfflineMemory<F>>>,
55    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>,
56
57    offset: usize,
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
61pub struct Sha256Record<F> {
62    pub from_state: ExecutionState<F>,
63    pub dst_read: RecordId,
64    pub src_read: RecordId,
65    pub len_read: RecordId,
66    pub input_records: Vec<[RecordId; SHA256_NUM_READ_ROWS]>,
67    pub input_message: Vec<[[u8; SHA256_READ_SIZE]; SHA256_NUM_READ_ROWS]>,
68    pub digest_write: RecordId,
69}
70
71impl<F: PrimeField32> Sha256VmChip<F> {
72    pub fn new(
73        SystemPort {
74            execution_bus,
75            program_bus,
76            memory_bridge,
77        }: SystemPort,
78        address_bits: usize,
79        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>,
80        self_bus_idx: BusIndex,
81        offset: usize,
82        offline_memory: Arc<Mutex<OfflineMemory<F>>>,
83    ) -> Self {
84        Self {
85            air: Sha256VmAir::new(
86                ExecutionBridge::new(execution_bus, program_bus),
87                memory_bridge,
88                bitwise_lookup_chip.bus(),
89                address_bits,
90                Sha256Air::new(bitwise_lookup_chip.bus(), self_bus_idx),
91                Encoder::new(PaddingFlags::COUNT, 2, false),
92            ),
93            bitwise_lookup_chip,
94            records: Vec::new(),
95            offset,
96            offline_memory,
97        }
98    }
99}
100
101impl<F: PrimeField32> InstructionExecutor<F> for Sha256VmChip<F> {
102    fn execute(
103        &mut self,
104        memory: &mut MemoryController<F>,
105        instruction: &Instruction<F>,
106        from_state: ExecutionState<u32>,
107    ) -> Result<ExecutionState<u32>, ExecutionError> {
108        let &Instruction {
109            opcode,
110            a,
111            b,
112            c,
113            d,
114            e,
115            ..
116        } = instruction;
117        let local_opcode = opcode.local_opcode_idx(self.offset);
118        debug_assert_eq!(local_opcode, Rv32Sha256Opcode::SHA256.local_usize());
119        debug_assert_eq!(d, F::from_canonical_u32(RV32_REGISTER_AS));
120        debug_assert_eq!(e, F::from_canonical_u32(RV32_MEMORY_AS));
121
122        debug_assert_eq!(from_state.timestamp, memory.timestamp());
123
124        let (dst_read, dst) = read_rv32_register(memory, d, a);
125        let (src_read, src) = read_rv32_register(memory, d, b);
126        let (len_read, len) = read_rv32_register(memory, d, c);
127
128        #[cfg(debug_assertions)]
129        {
130            assert!(dst < (1 << self.air.ptr_max_bits));
131            assert!(src < (1 << self.air.ptr_max_bits));
132            assert!(len < (1 << self.air.ptr_max_bits));
133        }
134
135        // need to pad with one 1 bit, 64 bits for the message length and then pad until the length
136        // is divisible by [SHA256_BLOCK_BITS]
137        let num_blocks = ((len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS);
138
139        // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used
140        debug_assert!(
141            src as usize + num_blocks * SHA256_BLOCK_CELLS <= (1 << self.air.ptr_max_bits)
142        );
143        let mut hasher = Sha256::new();
144        let mut input_records = Vec::with_capacity(num_blocks * SHA256_NUM_READ_ROWS);
145        let mut input_message = Vec::with_capacity(num_blocks * SHA256_NUM_READ_ROWS);
146        let mut read_ptr = src;
147        for _ in 0..num_blocks {
148            let block_reads_records = array::from_fn(|i| {
149                memory.read(
150                    e,
151                    F::from_canonical_u32(read_ptr + (i * SHA256_READ_SIZE) as u32),
152                )
153            });
154            let block_reads_bytes = array::from_fn(|i| {
155                // we add to the hasher only the bytes that are part of the message
156                let num_reads = min(
157                    SHA256_READ_SIZE,
158                    (max(read_ptr, src + len) - read_ptr) as usize,
159                );
160                let row_input = block_reads_records[i]
161                    .1
162                    .map(|x| x.as_canonical_u32().try_into().unwrap());
163                hasher.update(&row_input[..num_reads]);
164                read_ptr += SHA256_READ_SIZE as u32;
165                row_input
166            });
167            input_records.push(block_reads_records.map(|x| x.0));
168            input_message.push(block_reads_bytes);
169        }
170
171        let mut digest = [0u8; SHA256_WRITE_SIZE];
172        digest.copy_from_slice(hasher.finalize().as_ref());
173        let (digest_write, _) = memory.write(
174            e,
175            F::from_canonical_u32(dst),
176            digest.map(|b| F::from_canonical_u8(b)),
177        );
178
179        self.records.push(Sha256Record {
180            from_state: from_state.map(F::from_canonical_u32),
181            dst_read,
182            src_read,
183            len_read,
184            input_records,
185            input_message,
186            digest_write,
187        });
188
189        Ok(ExecutionState {
190            pc: from_state.pc + DEFAULT_PC_STEP,
191            timestamp: memory.timestamp(),
192        })
193    }
194
195    fn get_opcode_name(&self, _: usize) -> String {
196        "SHA256".to_string()
197    }
198}
199
200pub fn sha256_solve(input_message: &[u8]) -> [u8; SHA256_WRITE_SIZE] {
201    let mut hasher = Sha256::new();
202    hasher.update(input_message);
203    let mut output = [0u8; SHA256_WRITE_SIZE];
204    output.copy_from_slice(hasher.finalize().as_ref());
205    output
206}