openvm_sha256_circuit/extension/
mod.rs

1use std::{result::Result, sync::Arc};
2
3use derive_more::derive::From;
4use openvm_circuit::{
5    arch::{
6        AirInventory, AirInventoryError, ChipInventory, ChipInventoryError,
7        ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
8        VmExecutionExtension, VmProverExtension,
9    },
10    system::memory::SharedMemoryHelper,
11};
12use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
13use openvm_circuit_primitives::bitwise_op_lookup::{
14    BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
15    SharedBitwiseOperationLookupChip,
16};
17use openvm_instructions::*;
18use openvm_sha256_transpiler::Rv32Sha256Opcode;
19use openvm_stark_backend::{
20    config::{StarkGenericConfig, Val},
21    p3_field::PrimeField32,
22    prover::cpu::{CpuBackend, CpuDevice},
23};
24use openvm_stark_sdk::engine::StarkEngine;
25use serde::{Deserialize, Serialize};
26use strum::IntoEnumIterator;
27
28use crate::*;
29
30cfg_if::cfg_if! {
31    if #[cfg(feature = "cuda")] {
32        mod cuda;
33        pub use self::cuda::*;
34        pub use self::cuda::Sha256GpuProverExt as Sha256ProverExt;
35    } else {
36        pub use self::Sha2CpuProverExt as Sha256ProverExt;
37    }
38}
39
40// =================================== VM Extension Implementation =================================
41#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
42pub struct Sha256;
43
44#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
45pub enum Sha256Executor {
46    Sha256(Sha256VmExecutor),
47}
48
49impl<F> VmExecutionExtension<F> for Sha256 {
50    type Executor = Sha256Executor;
51
52    fn extend_execution(
53        &self,
54        inventory: &mut ExecutorInventoryBuilder<F, Sha256Executor>,
55    ) -> Result<(), ExecutorInventoryError> {
56        let pointer_max_bits = inventory.pointer_max_bits();
57        let sha256_step = Sha256VmExecutor::new(Rv32Sha256Opcode::CLASS_OFFSET, pointer_max_bits);
58        inventory.add_executor(
59            sha256_step,
60            Rv32Sha256Opcode::iter().map(|x| x.global_opcode()),
61        )?;
62
63        Ok(())
64    }
65}
66
67impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for Sha256 {
68    fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
69        let pointer_max_bits = inventory.pointer_max_bits();
70
71        let bitwise_lu = {
72            let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
73            if let Some(air) = existing_air {
74                air.bus
75            } else {
76                let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
77                let air = BitwiseOperationLookupAir::<8>::new(bus);
78                inventory.add_air(air);
79                air.bus
80            }
81        };
82
83        let sha256 = Sha256VmAir::new(
84            inventory.system().port(),
85            bitwise_lu,
86            pointer_max_bits,
87            inventory.new_bus_idx(),
88        );
89        inventory.add_air(sha256);
90
91        Ok(())
92    }
93}
94
95pub struct Sha2CpuProverExt;
96// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker,
97// BitwiseOperationLookupChip) are specific to CpuBackend.
98impl<E, SC, RA> VmProverExtension<E, RA, Sha256> for Sha2CpuProverExt
99where
100    SC: StarkGenericConfig,
101    E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
102    RA: RowMajorMatrixArena<Val<SC>>,
103    Val<SC>: PrimeField32,
104{
105    fn extend_prover(
106        &self,
107        _: &Sha256,
108        inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
109    ) -> Result<(), ChipInventoryError> {
110        let range_checker = inventory.range_checker()?.clone();
111        let timestamp_max_bits = inventory.timestamp_max_bits();
112        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
113        let pointer_max_bits = inventory.airs().pointer_max_bits();
114
115        let bitwise_lu = {
116            let existing_chip = inventory
117                .find_chip::<SharedBitwiseOperationLookupChip<8>>()
118                .next();
119            if let Some(chip) = existing_chip {
120                chip.clone()
121            } else {
122                let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
123                let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
124                inventory.add_periphery_chip(chip.clone());
125                chip
126            }
127        };
128
129        inventory.next_air::<Sha256VmAir>()?;
130        let sha256 = Sha256VmChip::new(
131            Sha256VmFiller::new(bitwise_lu, pointer_max_bits),
132            mem_helper,
133        );
134        inventory.add_executor_chip(sha256);
135
136        Ok(())
137    }
138}