openvm_sha256_circuit/sha256_chip/
execution.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
4use openvm_circuit_primitives::AlignedBytesBorrow;
5use openvm_instructions::{
6    instruction::Instruction,
7    program::DEFAULT_PC_STEP,
8    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
9    LocalOpcode,
10};
11use openvm_sha256_air::{get_sha256_num_blocks, SHA256_ROWS_PER_BLOCK};
12use openvm_sha256_transpiler::Rv32Sha256Opcode;
13use openvm_stark_backend::p3_field::PrimeField32;
14
15use super::{sha256_solve, Sha256VmExecutor, SHA256_NUM_READ_ROWS, SHA256_READ_SIZE};
16
17#[derive(AlignedBytesBorrow, Clone)]
18#[repr(C)]
19struct ShaPreCompute {
20    a: u8,
21    b: u8,
22    c: u8,
23}
24
25impl<F: PrimeField32> InterpreterExecutor<F> for Sha256VmExecutor {
26    #[cfg(feature = "tco")]
27    fn handler<Ctx>(
28        &self,
29        pc: u32,
30        inst: &Instruction<F>,
31        data: &mut [u8],
32    ) -> Result<Handler<F, Ctx>, StaticProgramError>
33    where
34        Ctx: ExecutionCtxTrait,
35    {
36        let data: &mut ShaPreCompute = data.borrow_mut();
37        self.pre_compute_impl(pc, inst, data)?;
38        Ok(execute_e1_handler::<_, _>)
39    }
40
41    fn pre_compute_size(&self) -> usize {
42        size_of::<ShaPreCompute>()
43    }
44
45    #[cfg(not(feature = "tco"))]
46    fn pre_compute<Ctx>(
47        &self,
48        pc: u32,
49        inst: &Instruction<F>,
50        data: &mut [u8],
51    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
52    where
53        Ctx: ExecutionCtxTrait,
54    {
55        let data: &mut ShaPreCompute = data.borrow_mut();
56        self.pre_compute_impl(pc, inst, data)?;
57        Ok(execute_e1_impl::<_, _>)
58    }
59}
60
61#[cfg(feature = "aot")]
62impl<F: PrimeField32> AotExecutor<F> for Sha256VmExecutor {}
63
64impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Sha256VmExecutor {
65    fn metered_pre_compute_size(&self) -> usize {
66        size_of::<E2PreCompute<ShaPreCompute>>()
67    }
68
69    #[cfg(not(feature = "tco"))]
70    fn metered_pre_compute<Ctx>(
71        &self,
72        chip_idx: usize,
73        pc: u32,
74        inst: &Instruction<F>,
75        data: &mut [u8],
76    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
77    where
78        Ctx: MeteredExecutionCtxTrait,
79    {
80        let data: &mut E2PreCompute<ShaPreCompute> = data.borrow_mut();
81        data.chip_idx = chip_idx as u32;
82        self.pre_compute_impl(pc, inst, &mut data.data)?;
83        Ok(execute_e2_impl::<_, _>)
84    }
85
86    #[cfg(feature = "tco")]
87    fn metered_handler<Ctx>(
88        &self,
89        chip_idx: usize,
90        pc: u32,
91        inst: &Instruction<F>,
92        data: &mut [u8],
93    ) -> Result<Handler<F, Ctx>, StaticProgramError>
94    where
95        Ctx: MeteredExecutionCtxTrait,
96    {
97        let data: &mut E2PreCompute<ShaPreCompute> = data.borrow_mut();
98        data.chip_idx = chip_idx as u32;
99        self.pre_compute_impl(pc, inst, &mut data.data)?;
100        Ok(execute_e2_handler::<_, _>)
101    }
102}
103
104#[cfg(feature = "aot")]
105impl<F: PrimeField32> AotMeteredExecutor<F> for Sha256VmExecutor {}
106
107#[inline(always)]
108unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_E1: bool>(
109    pre_compute: &ShaPreCompute,
110    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
111) -> u32 {
112    let dst = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32);
113    let src = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
114    let len = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
115    let dst_u32 = u32::from_le_bytes(dst);
116    let src_u32 = u32::from_le_bytes(src);
117    let len_u32 = u32::from_le_bytes(len);
118
119    let (output, height) = if IS_E1 {
120        // SAFETY: RV32_MEMORY_AS is memory address space of type u8
121        let message = exec_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize);
122        let output = sha256_solve(message);
123        (output, 0)
124    } else {
125        let num_blocks = get_sha256_num_blocks(len_u32);
126        let mut message = Vec::with_capacity(len_u32 as usize);
127        for block_idx in 0..num_blocks as usize {
128            // Reads happen on the first 4 rows of each block
129            for row in 0..SHA256_NUM_READ_ROWS {
130                let read_idx = block_idx * SHA256_NUM_READ_ROWS + row;
131                let row_input: [u8; SHA256_READ_SIZE] = exec_state.vm_read(
132                    RV32_MEMORY_AS,
133                    src_u32 + (read_idx * SHA256_READ_SIZE) as u32,
134                );
135                message.extend_from_slice(&row_input);
136            }
137        }
138        let output = sha256_solve(&message[..len_u32 as usize]);
139        let height = num_blocks * SHA256_ROWS_PER_BLOCK as u32;
140        (output, height)
141    };
142    exec_state.vm_write(RV32_MEMORY_AS, dst_u32, &output);
143
144    let pc = exec_state.pc();
145    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
146
147    height
148}
149
150#[create_handler]
151#[inline(always)]
152unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
153    pre_compute: *const u8,
154    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
155) {
156    let pre_compute: &ShaPreCompute =
157        std::slice::from_raw_parts(pre_compute, size_of::<ShaPreCompute>()).borrow();
158    execute_e12_impl::<F, CTX, true>(pre_compute, exec_state);
159}
160
161#[create_handler]
162#[inline(always)]
163unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
164    pre_compute: *const u8,
165    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
166) {
167    let pre_compute: &E2PreCompute<ShaPreCompute> =
168        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<ShaPreCompute>>()).borrow();
169    let height = execute_e12_impl::<F, CTX, false>(&pre_compute.data, exec_state);
170    exec_state
171        .ctx
172        .on_height_change(pre_compute.chip_idx as usize, height);
173}
174
175impl Sha256VmExecutor {
176    fn pre_compute_impl<F: PrimeField32>(
177        &self,
178        pc: u32,
179        inst: &Instruction<F>,
180        data: &mut ShaPreCompute,
181    ) -> Result<(), StaticProgramError> {
182        let Instruction {
183            opcode,
184            a,
185            b,
186            c,
187            d,
188            e,
189            ..
190        } = inst;
191        let e_u32 = e.as_canonical_u32();
192        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
193            return Err(StaticProgramError::InvalidInstruction(pc));
194        }
195        *data = ShaPreCompute {
196            a: a.as_canonical_u32() as u8,
197            b: b.as_canonical_u32() as u8,
198            c: c.as_canonical_u32() as u8,
199        };
200        assert_eq!(&Rv32Sha256Opcode::SHA256.global_opcode(), opcode);
201        Ok(())
202    }
203}