openvm_stark_sdk/config/
instrument.rs
1use std::{
2 any::type_name,
3 collections::HashMap,
4 sync::{Arc, Mutex},
5};
6
7use p3_symmetric::{
8 CryptographicHasher, CryptographicPermutation, Permutation, PseudoCompressionFunction,
9};
10use serde::{Deserialize, Serialize};
11
12use super::FriParameters;
13
14pub type InstrumentCounter = Arc<Mutex<HashMap<String, Vec<usize>>>>;
15
16#[derive(Clone, Debug)]
19pub struct Instrumented<T> {
20 pub is_on: bool,
21 pub inner: T,
22 pub input_lens_by_type: InstrumentCounter,
23}
24
25impl<T> Instrumented<T> {
26 pub fn new(inner: T) -> Self {
27 Self {
28 is_on: true,
29 inner,
30 input_lens_by_type: Arc::new(Mutex::new(HashMap::new())),
31 }
32 }
33
34 fn add_len_for_type<A>(&self, len: usize) {
35 if !self.is_on {
36 return;
37 }
38 self.input_lens_by_type
39 .lock()
40 .unwrap()
41 .entry(type_name::<A>().to_string())
42 .and_modify(|lens| lens.push(len))
43 .or_insert(vec![len]);
44 }
45}
46
47impl<T: Clone, P: Permutation<T>> Permutation<T> for Instrumented<P> {
48 fn permute_mut(&self, input: &mut T) {
49 self.add_len_for_type::<T>(1);
50 self.inner.permute_mut(input);
51 }
52 fn permute(&self, input: T) -> T {
53 self.add_len_for_type::<T>(1);
54 self.inner.permute(input)
55 }
56}
57
58impl<T: Clone, P: CryptographicPermutation<T>> CryptographicPermutation<T> for Instrumented<P> {}
59
60impl<T, const N: usize, C: PseudoCompressionFunction<T, N>> PseudoCompressionFunction<T, N>
63 for Instrumented<C>
64{
65 fn compress(&self, input: [T; N]) -> T {
66 self.add_len_for_type::<T>(N);
67 self.inner.compress(input)
68 }
69}
70
71impl<Item: Clone, Out, H: CryptographicHasher<Item, Out>> CryptographicHasher<Item, Out>
72 for Instrumented<H>
73{
74 fn hash_iter<I>(&self, input: I) -> Out
75 where
76 I: IntoIterator<Item = Item>,
77 {
78 if self.is_on {
79 let input = input.into_iter().collect::<Vec<_>>();
80 self.add_len_for_type::<(Item, Out)>(input.len());
81 self.inner.hash_iter(input)
82 } else {
83 self.inner.hash_iter(input)
84 }
85 }
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize)]
89pub struct HashStatistics {
90 pub permutations: usize,
93}
94
95#[derive(Clone, Debug, Serialize, Deserialize)]
96pub struct StarkHashStatistics<T> {
97 pub name: String,
99 pub stats: HashStatistics,
100 pub fri_params: FriParameters,
101 pub custom: T,
102}