openvm_keccak256_circuit/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
12    LocalOpcode,
13};
14use openvm_keccak256_transpiler::Rv32KeccakOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16use p3_keccak_air::NUM_ROUNDS;
17
18use super::{KeccakVmExecutor, KECCAK_WORD_SIZE};
19use crate::utils::{keccak256, num_keccak_f};
20
21#[derive(AlignedBytesBorrow, Clone)]
22#[repr(C)]
23struct KeccakPreCompute {
24    a: u8,
25    b: u8,
26    c: u8,
27}
28
29impl KeccakVmExecutor {
30    fn pre_compute_impl<F: PrimeField32>(
31        &self,
32        pc: u32,
33        inst: &Instruction<F>,
34        data: &mut KeccakPreCompute,
35    ) -> Result<(), StaticProgramError> {
36        let Instruction {
37            opcode,
38            a,
39            b,
40            c,
41            d,
42            e,
43            ..
44        } = inst;
45        let e_u32 = e.as_canonical_u32();
46        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
47            return Err(StaticProgramError::InvalidInstruction(pc));
48        }
49        *data = KeccakPreCompute {
50            a: a.as_canonical_u32() as u8,
51            b: b.as_canonical_u32() as u8,
52            c: c.as_canonical_u32() as u8,
53        };
54        assert_eq!(&Rv32KeccakOpcode::KECCAK256.global_opcode(), opcode);
55        Ok(())
56    }
57}
58
59impl<F: PrimeField32> Executor<F> for KeccakVmExecutor {
60    fn pre_compute_size(&self) -> usize {
61        size_of::<KeccakPreCompute>()
62    }
63
64    fn pre_compute<Ctx>(
65        &self,
66        pc: u32,
67        inst: &Instruction<F>,
68        data: &mut [u8],
69    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
70    where
71        Ctx: ExecutionCtxTrait,
72    {
73        let data: &mut KeccakPreCompute = data.borrow_mut();
74        self.pre_compute_impl(pc, inst, data)?;
75        Ok(execute_e1_impl::<_, _>)
76    }
77
78    #[cfg(feature = "tco")]
79    fn handler<Ctx>(
80        &self,
81        pc: u32,
82        inst: &Instruction<F>,
83        data: &mut [u8],
84    ) -> Result<Handler<F, Ctx>, StaticProgramError>
85    where
86        Ctx: ExecutionCtxTrait,
87    {
88        let data: &mut KeccakPreCompute = data.borrow_mut();
89        self.pre_compute_impl(pc, inst, data)?;
90        Ok(execute_e1_tco_handler)
91    }
92}
93
94impl<F: PrimeField32> MeteredExecutor<F> for KeccakVmExecutor {
95    fn metered_pre_compute_size(&self) -> usize {
96        size_of::<E2PreCompute<KeccakPreCompute>>()
97    }
98
99    fn metered_pre_compute<Ctx>(
100        &self,
101        chip_idx: usize,
102        pc: u32,
103        inst: &Instruction<F>,
104        data: &mut [u8],
105    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
106    where
107        Ctx: MeteredExecutionCtxTrait,
108    {
109        let data: &mut E2PreCompute<KeccakPreCompute> = data.borrow_mut();
110        data.chip_idx = chip_idx as u32;
111        self.pre_compute_impl(pc, inst, &mut data.data)?;
112        Ok(execute_e2_impl::<_, _>)
113    }
114
115    #[cfg(feature = "tco")]
116    fn metered_handler<Ctx>(
117        &self,
118        chip_idx: usize,
119        pc: u32,
120        inst: &Instruction<F>,
121        data: &mut [u8],
122    ) -> Result<Handler<F, Ctx>, StaticProgramError>
123    where
124        Ctx: MeteredExecutionCtxTrait,
125    {
126        let data: &mut E2PreCompute<KeccakPreCompute> = data.borrow_mut();
127        data.chip_idx = chip_idx as u32;
128        self.pre_compute_impl(pc, inst, &mut data.data)?;
129        Ok(execute_e2_tco_handler::<_, _>)
130    }
131}
132
133#[inline(always)]
134unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_E1: bool>(
135    pre_compute: &KeccakPreCompute,
136    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
137) -> u32 {
138    let dst = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32);
139    let src = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
140    let len = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
141    let dst_u32 = u32::from_le_bytes(dst);
142    let src_u32 = u32::from_le_bytes(src);
143    let len_u32 = u32::from_le_bytes(len);
144
145    let (output, height) = if IS_E1 {
146        // SAFETY: RV32_MEMORY_AS is memory address space of type u8
147        let message = vm_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize);
148        let output = keccak256(message);
149        (output, 0)
150    } else {
151        let num_reads = (len_u32 as usize).div_ceil(KECCAK_WORD_SIZE);
152        let message: Vec<_> = (0..num_reads)
153            .flat_map(|i| {
154                vm_state.vm_read::<u8, KECCAK_WORD_SIZE>(
155                    RV32_MEMORY_AS,
156                    src_u32 + (i * KECCAK_WORD_SIZE) as u32,
157                )
158            })
159            .collect();
160        let output = keccak256(&message[..len_u32 as usize]);
161        let height = (num_keccak_f(len_u32 as usize) * NUM_ROUNDS) as u32;
162        (output, height)
163    };
164    vm_state.vm_write(RV32_MEMORY_AS, dst_u32, &output);
165
166    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
167    vm_state.instret += 1;
168
169    height
170}
171
172#[create_tco_handler]
173unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
174    pre_compute: &[u8],
175    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
176) {
177    let pre_compute: &KeccakPreCompute = pre_compute.borrow();
178    execute_e12_impl::<F, CTX, true>(pre_compute, vm_state);
179}
180
181#[create_tco_handler]
182unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
183    pre_compute: &[u8],
184    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
185) {
186    let pre_compute: &E2PreCompute<KeccakPreCompute> = pre_compute.borrow();
187    let height = execute_e12_impl::<F, CTX, false>(&pre_compute.data, vm_state);
188    vm_state
189        .ctx
190        .on_height_change(pre_compute.chip_idx as usize, height);
191}