openvm_stark_sdk/config/
instrument.rs

1use std::{
2    any::type_name,
3    collections::HashMap,
4    sync::{Arc, Mutex},
5};
6
7use p3_symmetric::{
8    CryptographicHasher, CryptographicPermutation, Permutation, PseudoCompressionFunction,
9};
10use serde::{Deserialize, Serialize};
11
12use super::FriParameters;
13
14pub type InstrumentCounter = Arc<Mutex<HashMap<String, Vec<usize>>>>;
15
16/// Wrapper to instrument a type to count function calls.
17/// CAUTION: Performance may be impacted.
18#[derive(Clone, Debug)]
19pub struct Instrumented<T> {
20    pub is_on: bool,
21    pub inner: T,
22    pub input_lens_by_type: InstrumentCounter,
23}
24
25impl<T> Instrumented<T> {
26    pub fn new(inner: T) -> Self {
27        Self {
28            is_on: true,
29            inner,
30            input_lens_by_type: Arc::new(Mutex::new(HashMap::new())),
31        }
32    }
33
34    fn add_len_for_type<A>(&self, len: usize) {
35        if !self.is_on {
36            return;
37        }
38        self.input_lens_by_type
39            .lock()
40            .unwrap()
41            .entry(type_name::<A>().to_string())
42            .and_modify(|lens| lens.push(len))
43            .or_insert(vec![len]);
44    }
45}
46
47impl<T: Clone, P: Permutation<T>> Permutation<T> for Instrumented<P> {
48    fn permute_mut(&self, input: &mut T) {
49        self.add_len_for_type::<T>(1);
50        self.inner.permute_mut(input);
51    }
52    fn permute(&self, input: T) -> T {
53        self.add_len_for_type::<T>(1);
54        self.inner.permute(input)
55    }
56}
57
58impl<T: Clone, P: CryptographicPermutation<T>> CryptographicPermutation<T> for Instrumented<P> {}
59
60// Note: this does not currently need to be used if the implemeation is derived from a CryptographicPermutation:
61// we can instrument the permutation itself
62impl<T, const N: usize, C: PseudoCompressionFunction<T, N>> PseudoCompressionFunction<T, N>
63    for Instrumented<C>
64{
65    fn compress(&self, input: [T; N]) -> T {
66        self.add_len_for_type::<T>(N);
67        self.inner.compress(input)
68    }
69}
70
71impl<Item: Clone, Out, H: CryptographicHasher<Item, Out>> CryptographicHasher<Item, Out>
72    for Instrumented<H>
73{
74    fn hash_iter<I>(&self, input: I) -> Out
75    where
76        I: IntoIterator<Item = Item>,
77    {
78        if self.is_on {
79            let input = input.into_iter().collect::<Vec<_>>();
80            self.add_len_for_type::<(Item, Out)>(input.len());
81            self.inner.hash_iter(input)
82        } else {
83            self.inner.hash_iter(input)
84        }
85    }
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize)]
89pub struct HashStatistics {
90    // pub cryptographic_hasher: usize,
91    // pub pseudo_compression_function: usize,
92    pub permutations: usize,
93}
94
95#[derive(Clone, Debug, Serialize, Deserialize)]
96pub struct StarkHashStatistics<T> {
97    /// Identifier for the hash permutation
98    pub name: String,
99    pub stats: HashStatistics,
100    pub fri_params: FriParameters,
101    pub custom: T,
102}