openvm_sha256_circuit/sha256_chip/
execution.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
4use openvm_circuit_primitives::AlignedBytesBorrow;
5use openvm_instructions::{
6    instruction::Instruction,
7    program::DEFAULT_PC_STEP,
8    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
9    LocalOpcode,
10};
11use openvm_sha256_air::{get_sha256_num_blocks, SHA256_ROWS_PER_BLOCK};
12use openvm_sha256_transpiler::Rv32Sha256Opcode;
13use openvm_stark_backend::p3_field::PrimeField32;
14
15use super::{sha256_solve, Sha256VmExecutor, SHA256_NUM_READ_ROWS, SHA256_READ_SIZE};
16
17#[derive(AlignedBytesBorrow, Clone)]
18#[repr(C)]
19struct ShaPreCompute {
20    a: u8,
21    b: u8,
22    c: u8,
23}
24
25impl<F: PrimeField32> Executor<F> for Sha256VmExecutor {
26    #[cfg(feature = "tco")]
27    fn handler<Ctx>(
28        &self,
29        pc: u32,
30        inst: &Instruction<F>,
31        data: &mut [u8],
32    ) -> Result<Handler<F, Ctx>, StaticProgramError>
33    where
34        Ctx: ExecutionCtxTrait,
35    {
36        let data: &mut ShaPreCompute = data.borrow_mut();
37        self.pre_compute_impl(pc, inst, data)?;
38        Ok(execute_e1_tco_handler::<_, _>)
39    }
40
41    fn pre_compute_size(&self) -> usize {
42        size_of::<ShaPreCompute>()
43    }
44
45    fn pre_compute<Ctx>(
46        &self,
47        pc: u32,
48        inst: &Instruction<F>,
49        data: &mut [u8],
50    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
51    where
52        Ctx: ExecutionCtxTrait,
53    {
54        let data: &mut ShaPreCompute = data.borrow_mut();
55        self.pre_compute_impl(pc, inst, data)?;
56        Ok(execute_e1_impl::<_, _>)
57    }
58}
59impl<F: PrimeField32> MeteredExecutor<F> for Sha256VmExecutor {
60    fn metered_pre_compute_size(&self) -> usize {
61        size_of::<E2PreCompute<ShaPreCompute>>()
62    }
63
64    fn metered_pre_compute<Ctx>(
65        &self,
66        chip_idx: usize,
67        pc: u32,
68        inst: &Instruction<F>,
69        data: &mut [u8],
70    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
71    where
72        Ctx: MeteredExecutionCtxTrait,
73    {
74        let data: &mut E2PreCompute<ShaPreCompute> = data.borrow_mut();
75        data.chip_idx = chip_idx as u32;
76        self.pre_compute_impl(pc, inst, &mut data.data)?;
77        Ok(execute_e2_impl::<_, _>)
78    }
79
80    #[cfg(feature = "tco")]
81    fn metered_handler<Ctx>(
82        &self,
83        chip_idx: usize,
84        pc: u32,
85        inst: &Instruction<F>,
86        data: &mut [u8],
87    ) -> Result<Handler<F, Ctx>, StaticProgramError>
88    where
89        Ctx: MeteredExecutionCtxTrait,
90    {
91        let data: &mut E2PreCompute<ShaPreCompute> = data.borrow_mut();
92        data.chip_idx = chip_idx as u32;
93        self.pre_compute_impl(pc, inst, &mut data.data)?;
94        Ok(execute_e2_tco_handler::<_, _>)
95    }
96}
97
98unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_E1: bool>(
99    pre_compute: &ShaPreCompute,
100    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
101) -> u32 {
102    let dst = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32);
103    let src = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
104    let len = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
105    let dst_u32 = u32::from_le_bytes(dst);
106    let src_u32 = u32::from_le_bytes(src);
107    let len_u32 = u32::from_le_bytes(len);
108
109    let (output, height) = if IS_E1 {
110        // SAFETY: RV32_MEMORY_AS is memory address space of type u8
111        let message = vm_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize);
112        let output = sha256_solve(message);
113        (output, 0)
114    } else {
115        let num_blocks = get_sha256_num_blocks(len_u32);
116        let mut message = Vec::with_capacity(len_u32 as usize);
117        for block_idx in 0..num_blocks as usize {
118            // Reads happen on the first 4 rows of each block
119            for row in 0..SHA256_NUM_READ_ROWS {
120                let read_idx = block_idx * SHA256_NUM_READ_ROWS + row;
121                let row_input: [u8; SHA256_READ_SIZE] = vm_state.vm_read(
122                    RV32_MEMORY_AS,
123                    src_u32 + (read_idx * SHA256_READ_SIZE) as u32,
124                );
125                message.extend_from_slice(&row_input);
126            }
127        }
128        let output = sha256_solve(&message[..len_u32 as usize]);
129        let height = num_blocks * SHA256_ROWS_PER_BLOCK as u32;
130        (output, height)
131    };
132    vm_state.vm_write(RV32_MEMORY_AS, dst_u32, &output);
133
134    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
135    vm_state.instret += 1;
136
137    height
138}
139
140#[create_tco_handler]
141unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
142    pre_compute: &[u8],
143    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
144) {
145    let pre_compute: &ShaPreCompute = pre_compute.borrow();
146    execute_e12_impl::<F, CTX, true>(pre_compute, vm_state);
147}
148#[create_tco_handler]
149unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
150    pre_compute: &[u8],
151    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
152) {
153    let pre_compute: &E2PreCompute<ShaPreCompute> = pre_compute.borrow();
154    let height = execute_e12_impl::<F, CTX, false>(&pre_compute.data, vm_state);
155    vm_state
156        .ctx
157        .on_height_change(pre_compute.chip_idx as usize, height);
158}
159
160impl Sha256VmExecutor {
161    fn pre_compute_impl<F: PrimeField32>(
162        &self,
163        pc: u32,
164        inst: &Instruction<F>,
165        data: &mut ShaPreCompute,
166    ) -> Result<(), StaticProgramError> {
167        let Instruction {
168            opcode,
169            a,
170            b,
171            c,
172            d,
173            e,
174            ..
175        } = inst;
176        let e_u32 = e.as_canonical_u32();
177        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
178            return Err(StaticProgramError::InvalidInstruction(pc));
179        }
180        *data = ShaPreCompute {
181            a: a.as_canonical_u32() as u8,
182            b: b.as_canonical_u32() as u8,
183            c: c.as_canonical_u32() as u8,
184        };
185        assert_eq!(&Rv32Sha256Opcode::SHA256.global_opcode(), opcode);
186        Ok(())
187    }
188}