openvm_keccak256_circuit/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
12    LocalOpcode,
13};
14use openvm_keccak256_transpiler::Rv32KeccakOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16use p3_keccak_air::NUM_ROUNDS;
17
18use super::{KeccakVmExecutor, KECCAK_WORD_SIZE};
19use crate::utils::{keccak256, num_keccak_f};
20
21#[derive(AlignedBytesBorrow, Clone)]
22#[repr(C)]
23struct KeccakPreCompute {
24    a: u8,
25    b: u8,
26    c: u8,
27}
28
29impl KeccakVmExecutor {
30    fn pre_compute_impl<F: PrimeField32>(
31        &self,
32        pc: u32,
33        inst: &Instruction<F>,
34        data: &mut KeccakPreCompute,
35    ) -> Result<(), StaticProgramError> {
36        let Instruction {
37            opcode,
38            a,
39            b,
40            c,
41            d,
42            e,
43            ..
44        } = inst;
45        let e_u32 = e.as_canonical_u32();
46        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
47            return Err(StaticProgramError::InvalidInstruction(pc));
48        }
49        *data = KeccakPreCompute {
50            a: a.as_canonical_u32() as u8,
51            b: b.as_canonical_u32() as u8,
52            c: c.as_canonical_u32() as u8,
53        };
54        assert_eq!(&Rv32KeccakOpcode::KECCAK256.global_opcode(), opcode);
55        Ok(())
56    }
57}
58
59impl<F: PrimeField32> InterpreterExecutor<F> for KeccakVmExecutor {
60    fn pre_compute_size(&self) -> usize {
61        size_of::<KeccakPreCompute>()
62    }
63
64    #[cfg(not(feature = "tco"))]
65    fn pre_compute<Ctx>(
66        &self,
67        pc: u32,
68        inst: &Instruction<F>,
69        data: &mut [u8],
70    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
71    where
72        Ctx: ExecutionCtxTrait,
73    {
74        let data: &mut KeccakPreCompute = data.borrow_mut();
75        self.pre_compute_impl(pc, inst, data)?;
76        Ok(execute_e1_impl::<_, _>)
77    }
78
79    #[cfg(feature = "tco")]
80    fn handler<Ctx>(
81        &self,
82        pc: u32,
83        inst: &Instruction<F>,
84        data: &mut [u8],
85    ) -> Result<Handler<F, Ctx>, StaticProgramError>
86    where
87        Ctx: ExecutionCtxTrait,
88    {
89        let data: &mut KeccakPreCompute = data.borrow_mut();
90        self.pre_compute_impl(pc, inst, data)?;
91        Ok(execute_e1_handler)
92    }
93}
94
95#[cfg(feature = "aot")]
96impl<F: PrimeField32> AotExecutor<F> for KeccakVmExecutor {}
97
98impl<F: PrimeField32> InterpreterMeteredExecutor<F> for KeccakVmExecutor {
99    fn metered_pre_compute_size(&self) -> usize {
100        size_of::<E2PreCompute<KeccakPreCompute>>()
101    }
102
103    #[cfg(not(feature = "tco"))]
104    fn metered_pre_compute<Ctx>(
105        &self,
106        chip_idx: usize,
107        pc: u32,
108        inst: &Instruction<F>,
109        data: &mut [u8],
110    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
111    where
112        Ctx: MeteredExecutionCtxTrait,
113    {
114        let data: &mut E2PreCompute<KeccakPreCompute> = data.borrow_mut();
115        data.chip_idx = chip_idx as u32;
116        self.pre_compute_impl(pc, inst, &mut data.data)?;
117        Ok(execute_e2_impl::<_, _>)
118    }
119
120    #[cfg(feature = "tco")]
121    fn metered_handler<Ctx>(
122        &self,
123        chip_idx: usize,
124        pc: u32,
125        inst: &Instruction<F>,
126        data: &mut [u8],
127    ) -> Result<Handler<F, Ctx>, StaticProgramError>
128    where
129        Ctx: MeteredExecutionCtxTrait,
130    {
131        let data: &mut E2PreCompute<KeccakPreCompute> = data.borrow_mut();
132        data.chip_idx = chip_idx as u32;
133        self.pre_compute_impl(pc, inst, &mut data.data)?;
134        Ok(execute_e2_handler::<_, _>)
135    }
136}
137#[cfg(feature = "aot")]
138impl<F: PrimeField32> AotMeteredExecutor<F> for KeccakVmExecutor {}
139
140#[inline(always)]
141unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_E1: bool>(
142    pre_compute: &KeccakPreCompute,
143    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
144) -> u32 {
145    let dst = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32);
146    let src = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
147    let len = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
148    let dst_u32 = u32::from_le_bytes(dst);
149    let src_u32 = u32::from_le_bytes(src);
150    let len_u32 = u32::from_le_bytes(len);
151
152    let (output, height) = if IS_E1 {
153        // SAFETY: RV32_MEMORY_AS is memory address space of type u8
154        let message = exec_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize);
155        let output = keccak256(message);
156        (output, 0)
157    } else {
158        let num_reads = (len_u32 as usize).div_ceil(KECCAK_WORD_SIZE);
159        let message: Vec<_> = (0..num_reads)
160            .flat_map(|i| {
161                exec_state.vm_read::<u8, KECCAK_WORD_SIZE>(
162                    RV32_MEMORY_AS,
163                    src_u32 + (i * KECCAK_WORD_SIZE) as u32,
164                )
165            })
166            .collect();
167        let output = keccak256(&message[..len_u32 as usize]);
168        let height = (num_keccak_f(len_u32 as usize) * NUM_ROUNDS) as u32;
169        (output, height)
170    };
171    exec_state.vm_write(RV32_MEMORY_AS, dst_u32, &output);
172
173    let pc = exec_state.pc();
174    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
175
176    height
177}
178
179#[create_handler]
180#[inline(always)]
181unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
182    pre_compute: *const u8,
183    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
184) {
185    let pre_compute: &KeccakPreCompute =
186        std::slice::from_raw_parts(pre_compute, size_of::<KeccakPreCompute>()).borrow();
187    execute_e12_impl::<F, CTX, true>(pre_compute, exec_state);
188}
189
190#[create_handler]
191#[inline(always)]
192unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
193    pre_compute: *const u8,
194    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
195) {
196    let pre_compute: &E2PreCompute<KeccakPreCompute> =
197        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<KeccakPreCompute>>())
198            .borrow();
199    let height = execute_e12_impl::<F, CTX, false>(&pre_compute.data, exec_state);
200    exec_state
201        .ctx
202        .on_height_change(pre_compute.chip_idx as usize, height);
203}