1use super::{DatabaseCommit, DatabaseRef, EmptyDB};
2use crate::primitives::{
3 hash_map::Entry, Account, AccountInfo, Address, Bytecode, HashMap, Log, B256, KECCAK_EMPTY,
4 U256,
5};
6use crate::Database;
7use core::convert::Infallible;
8use std::vec::Vec;
9
10pub type InMemoryDB = CacheDB<EmptyDB>;
12
13#[derive(Debug, Clone)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct CacheDB<ExtDB> {
23 pub accounts: HashMap<Address, DbAccount>,
26 pub contracts: HashMap<B256, Bytecode>,
28 pub logs: Vec<Log>,
30 pub block_hashes: HashMap<U256, B256>,
32 pub db: ExtDB,
36}
37
38impl<ExtDB: Default> Default for CacheDB<ExtDB> {
39 fn default() -> Self {
40 Self::new(ExtDB::default())
41 }
42}
43
44impl<ExtDB> CacheDB<ExtDB> {
45 pub fn new(db: ExtDB) -> Self {
46 let mut contracts = HashMap::default();
47 contracts.insert(KECCAK_EMPTY, Bytecode::default());
48 contracts.insert(B256::ZERO, Bytecode::default());
49 Self {
50 accounts: HashMap::default(),
51 contracts,
52 logs: Vec::default(),
53 block_hashes: HashMap::default(),
54 db,
55 }
56 }
57
58 pub fn insert_contract(&mut self, account: &mut AccountInfo) {
64 if let Some(code) = &account.code {
65 if !code.is_empty() {
66 if account.code_hash == KECCAK_EMPTY {
67 account.code_hash = code.hash_slow();
68 }
69 self.contracts
70 .entry(account.code_hash)
71 .or_insert_with(|| code.clone());
72 }
73 }
74 if account.code_hash.is_zero() {
75 account.code_hash = KECCAK_EMPTY;
76 }
77 }
78
79 pub fn insert_account_info(&mut self, address: Address, mut info: AccountInfo) {
81 self.insert_contract(&mut info);
82 self.accounts.entry(address).or_default().info = info;
83 }
84}
85
86impl<ExtDB: DatabaseRef> CacheDB<ExtDB> {
87 pub fn load_account(&mut self, address: Address) -> Result<&mut DbAccount, ExtDB::Error> {
91 let db = &self.db;
92 match self.accounts.entry(address) {
93 Entry::Occupied(entry) => Ok(entry.into_mut()),
94 Entry::Vacant(entry) => Ok(entry.insert(
95 db.basic_ref(address)?
96 .map(|info| DbAccount {
97 info,
98 ..Default::default()
99 })
100 .unwrap_or_else(DbAccount::new_not_existing),
101 )),
102 }
103 }
104
105 pub fn insert_account_storage(
107 &mut self,
108 address: Address,
109 slot: U256,
110 value: U256,
111 ) -> Result<(), ExtDB::Error> {
112 let account = self.load_account(address)?;
113 account.storage.insert(slot, value);
114 Ok(())
115 }
116
117 pub fn replace_account_storage(
119 &mut self,
120 address: Address,
121 storage: HashMap<U256, U256>,
122 ) -> Result<(), ExtDB::Error> {
123 let account = self.load_account(address)?;
124 account.account_state = AccountState::StorageCleared;
125 account.storage = storage.into_iter().collect();
126 Ok(())
127 }
128}
129
130impl<ExtDB> DatabaseCommit for CacheDB<ExtDB> {
131 fn commit(&mut self, changes: HashMap<Address, Account>) {
132 for (address, mut account) in changes {
133 if !account.is_touched() {
134 continue;
135 }
136 if account.is_selfdestructed() {
137 let db_account = self.accounts.entry(address).or_default();
138 db_account.storage.clear();
139 db_account.account_state = AccountState::NotExisting;
140 db_account.info = AccountInfo::default();
141 continue;
142 }
143 let is_newly_created = account.is_created();
144 self.insert_contract(&mut account.info);
145
146 let db_account = self.accounts.entry(address).or_default();
147 db_account.info = account.info;
148
149 db_account.account_state = if is_newly_created {
150 db_account.storage.clear();
151 AccountState::StorageCleared
152 } else if db_account.account_state.is_storage_cleared() {
153 AccountState::StorageCleared
155 } else {
156 AccountState::Touched
157 };
158 db_account.storage.extend(
159 account
160 .storage
161 .into_iter()
162 .map(|(key, value)| (key, value.present_value())),
163 );
164 }
165 }
166}
167
168impl<ExtDB: DatabaseRef> Database for CacheDB<ExtDB> {
169 type Error = ExtDB::Error;
170
171 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
172 let basic = match self.accounts.entry(address) {
173 Entry::Occupied(entry) => entry.into_mut(),
174 Entry::Vacant(entry) => entry.insert(
175 self.db
176 .basic_ref(address)?
177 .map(|info| DbAccount {
178 info,
179 ..Default::default()
180 })
181 .unwrap_or_else(DbAccount::new_not_existing),
182 ),
183 };
184 Ok(basic.info())
185 }
186
187 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
188 match self.contracts.entry(code_hash) {
189 Entry::Occupied(entry) => Ok(entry.get().clone()),
190 Entry::Vacant(entry) => {
191 Ok(entry.insert(self.db.code_by_hash_ref(code_hash)?).clone())
193 }
194 }
195 }
196
197 fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
201 match self.accounts.entry(address) {
202 Entry::Occupied(mut acc_entry) => {
203 let acc_entry = acc_entry.get_mut();
204 match acc_entry.storage.entry(index) {
205 Entry::Occupied(entry) => Ok(*entry.get()),
206 Entry::Vacant(entry) => {
207 if matches!(
208 acc_entry.account_state,
209 AccountState::StorageCleared | AccountState::NotExisting
210 ) {
211 Ok(U256::ZERO)
212 } else {
213 let slot = self.db.storage_ref(address, index)?;
214 entry.insert(slot);
215 Ok(slot)
216 }
217 }
218 }
219 }
220 Entry::Vacant(acc_entry) => {
221 let info = self.db.basic_ref(address)?;
223 let (account, value) = if info.is_some() {
224 let value = self.db.storage_ref(address, index)?;
225 let mut account: DbAccount = info.into();
226 account.storage.insert(index, value);
227 (account, value)
228 } else {
229 (info.into(), U256::ZERO)
230 };
231 acc_entry.insert(account);
232 Ok(value)
233 }
234 }
235 }
236
237 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
238 match self.block_hashes.entry(U256::from(number)) {
239 Entry::Occupied(entry) => Ok(*entry.get()),
240 Entry::Vacant(entry) => {
241 let hash = self.db.block_hash_ref(number)?;
242 entry.insert(hash);
243 Ok(hash)
244 }
245 }
246 }
247}
248
249impl<ExtDB: DatabaseRef> DatabaseRef for CacheDB<ExtDB> {
250 type Error = ExtDB::Error;
251
252 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
253 match self.accounts.get(&address) {
254 Some(acc) => Ok(acc.info()),
255 None => self.db.basic_ref(address),
256 }
257 }
258
259 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
260 match self.contracts.get(&code_hash) {
261 Some(entry) => Ok(entry.clone()),
262 None => self.db.code_by_hash_ref(code_hash),
263 }
264 }
265
266 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
267 match self.accounts.get(&address) {
268 Some(acc_entry) => match acc_entry.storage.get(&index) {
269 Some(entry) => Ok(*entry),
270 None => {
271 if matches!(
272 acc_entry.account_state,
273 AccountState::StorageCleared | AccountState::NotExisting
274 ) {
275 Ok(U256::ZERO)
276 } else {
277 self.db.storage_ref(address, index)
278 }
279 }
280 },
281 None => self.db.storage_ref(address, index),
282 }
283 }
284
285 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
286 match self.block_hashes.get(&U256::from(number)) {
287 Some(entry) => Ok(*entry),
288 None => self.db.block_hash_ref(number),
289 }
290 }
291}
292
293#[derive(Debug, Clone, Default)]
294#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
295pub struct DbAccount {
296 pub info: AccountInfo,
297 pub account_state: AccountState,
299 pub storage: HashMap<U256, U256>,
301}
302
303impl DbAccount {
304 pub fn new_not_existing() -> Self {
305 Self {
306 account_state: AccountState::NotExisting,
307 ..Default::default()
308 }
309 }
310
311 pub fn info(&self) -> Option<AccountInfo> {
312 if matches!(self.account_state, AccountState::NotExisting) {
313 None
314 } else {
315 Some(self.info.clone())
316 }
317 }
318}
319
320impl From<Option<AccountInfo>> for DbAccount {
321 fn from(from: Option<AccountInfo>) -> Self {
322 from.map(Self::from).unwrap_or_else(Self::new_not_existing)
323 }
324}
325
326impl From<AccountInfo> for DbAccount {
327 fn from(info: AccountInfo) -> Self {
328 Self {
329 info,
330 account_state: AccountState::None,
331 ..Default::default()
332 }
333 }
334}
335
336#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
337#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
338pub enum AccountState {
339 NotExisting,
342 Touched,
344 StorageCleared,
347 #[default]
349 None,
350}
351
352impl AccountState {
353 pub fn is_storage_cleared(&self) -> bool {
355 matches!(self, AccountState::StorageCleared)
356 }
357}
358
359#[derive(Debug, Default, Clone)]
363pub struct BenchmarkDB {
364 pub bytecode: Bytecode,
365 pub hash: B256,
366 pub target: Address,
367 pub caller: Address,
368}
369
370impl BenchmarkDB {
371 pub fn new_bytecode(bytecode: Bytecode) -> Self {
373 let hash = bytecode.hash_slow();
374 Self {
375 bytecode,
376 hash,
377 target: Address::ZERO,
378 caller: Address::with_last_byte(1),
379 }
380 }
381
382 pub fn with_caller(self, caller: Address) -> Self {
384 Self { caller, ..self }
385 }
386
387 pub fn with_target(self, target: Address) -> Self {
389 Self { target, ..self }
390 }
391}
392
393impl Database for BenchmarkDB {
394 type Error = Infallible;
395 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
397 if address == self.target {
398 return Ok(Some(AccountInfo {
399 nonce: 1,
400 balance: U256::from(10000000),
401 code: Some(self.bytecode.clone()),
402 code_hash: self.hash,
403 }));
404 }
405 if address == self.caller {
406 return Ok(Some(AccountInfo {
407 nonce: 0,
408 balance: U256::from(10000000),
409 code: None,
410 code_hash: KECCAK_EMPTY,
411 }));
412 }
413 Ok(None)
414 }
415
416 fn code_by_hash(&mut self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
418 Ok(Bytecode::default())
419 }
420
421 fn storage(&mut self, _address: Address, _index: U256) -> Result<U256, Self::Error> {
423 Ok(U256::default())
424 }
425
426 fn block_hash(&mut self, _number: u64) -> Result<B256, Self::Error> {
428 Ok(B256::default())
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::{CacheDB, EmptyDB};
435 use crate::primitives::{db::Database, AccountInfo, Address, HashMap, U256};
436
437 #[test]
438 fn test_insert_account_storage() {
439 let account = Address::with_last_byte(42);
440 let nonce = 42;
441 let mut init_state = CacheDB::new(EmptyDB::default());
442 init_state.insert_account_info(
443 account,
444 AccountInfo {
445 nonce,
446 ..Default::default()
447 },
448 );
449
450 let (key, value) = (U256::from(123), U256::from(456));
451 let mut new_state = CacheDB::new(init_state);
452 new_state
453 .insert_account_storage(account, key, value)
454 .unwrap();
455
456 assert_eq!(new_state.basic(account).unwrap().unwrap().nonce, nonce);
457 assert_eq!(new_state.storage(account, key), Ok(value));
458 }
459
460 #[test]
461 fn test_replace_account_storage() {
462 let account = Address::with_last_byte(42);
463 let nonce = 42;
464 let mut init_state = CacheDB::new(EmptyDB::default());
465 init_state.insert_account_info(
466 account,
467 AccountInfo {
468 nonce,
469 ..Default::default()
470 },
471 );
472
473 let (key0, value0) = (U256::from(123), U256::from(456));
474 let (key1, value1) = (U256::from(789), U256::from(999));
475 init_state
476 .insert_account_storage(account, key0, value0)
477 .unwrap();
478
479 let mut new_state = CacheDB::new(init_state);
480 new_state
481 .replace_account_storage(account, HashMap::from_iter([(key1, value1)]))
482 .unwrap();
483
484 assert_eq!(new_state.basic(account).unwrap().unwrap().nonce, nonce);
485 assert_eq!(new_state.storage(account, key0), Ok(U256::ZERO));
486 assert_eq!(new_state.storage(account, key1), Ok(value1));
487 }
488
489 #[cfg(feature = "serde-json")]
490 #[test]
491 fn test_serialize_deserialize_cachedb() {
492 let account = Address::with_last_byte(69);
493 let nonce = 420;
494 let mut init_state = CacheDB::new(EmptyDB::default());
495 init_state.insert_account_info(
496 account,
497 AccountInfo {
498 nonce,
499 ..Default::default()
500 },
501 );
502
503 let serialized = serde_json::to_string(&init_state).unwrap();
504 let deserialized: CacheDB<EmptyDB> = serde_json::from_str(&serialized).unwrap();
505
506 assert!(deserialized.accounts.contains_key(&account));
507 assert_eq!(
508 deserialized.accounts.get(&account).unwrap().info.nonce,
509 nonce
510 );
511 }
512}