parity_scale_codec/
counted_input.rs

1// Copyright 2017-2024 Parity Technologies
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15/// A wrapper for `Input` which tracks the number fo bytes that are successfully read.
16///
17/// If inner `Input` fails to read, the counter is not incremented.
18///
19/// It can count until `u64::MAX - 1` accurately then saturate.
20pub struct CountedInput<'a, I: crate::Input> {
21	input: &'a mut I,
22	counter: u64,
23}
24
25impl<'a, I: crate::Input> CountedInput<'a, I> {
26	/// Create a new `CountedInput` with the given input.
27	pub fn new(input: &'a mut I) -> Self {
28		Self { input, counter: 0 }
29	}
30
31	/// Get the number of bytes successfully read.
32	pub fn count(&self) -> u64 {
33		self.counter
34	}
35}
36
37impl<I: crate::Input> crate::Input for CountedInput<'_, I> {
38	fn remaining_len(&mut self) -> Result<Option<usize>, crate::Error> {
39		self.input.remaining_len()
40	}
41
42	fn read(&mut self, into: &mut [u8]) -> Result<(), crate::Error> {
43		self.input.read(into).inspect(|_r| {
44			self.counter = self.counter.saturating_add(into.len().try_into().unwrap_or(u64::MAX));
45		})
46	}
47
48	fn read_byte(&mut self) -> Result<u8, crate::Error> {
49		self.input.read_byte().inspect(|_r| {
50			self.counter = self.counter.saturating_add(1);
51		})
52	}
53
54	fn ascend_ref(&mut self) {
55		self.input.ascend_ref()
56	}
57
58	fn descend_ref(&mut self) -> Result<(), crate::Error> {
59		self.input.descend_ref()
60	}
61
62	fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), crate::Error> {
63		self.input.on_before_alloc_mem(size)
64	}
65}
66
67#[cfg(test)]
68mod test {
69	use super::*;
70	use crate::Input;
71
72	#[test]
73	fn test_counted_input_input_impl() {
74		let mut input = &[1u8, 2, 3, 4, 5][..];
75		let mut counted_input = CountedInput::new(&mut input);
76
77		assert_eq!(counted_input.remaining_len().unwrap(), Some(5));
78		assert_eq!(counted_input.count(), 0);
79
80		counted_input.read_byte().unwrap();
81
82		assert_eq!(counted_input.remaining_len().unwrap(), Some(4));
83		assert_eq!(counted_input.count(), 1);
84
85		counted_input.read(&mut [0u8; 2][..]).unwrap();
86
87		assert_eq!(counted_input.remaining_len().unwrap(), Some(2));
88		assert_eq!(counted_input.count(), 3);
89
90		counted_input.ascend_ref();
91		counted_input.descend_ref().unwrap();
92
93		counted_input.read(&mut [0u8; 2][..]).unwrap();
94
95		assert_eq!(counted_input.remaining_len().unwrap(), Some(0));
96		assert_eq!(counted_input.count(), 5);
97
98		assert_eq!(counted_input.read_byte(), Err("Not enough data to fill buffer".into()));
99
100		assert_eq!(counted_input.remaining_len().unwrap(), Some(0));
101		assert_eq!(counted_input.count(), 5);
102
103		assert_eq!(
104			counted_input.read(&mut [0u8; 2][..]),
105			Err("Not enough data to fill buffer".into())
106		);
107
108		assert_eq!(counted_input.remaining_len().unwrap(), Some(0));
109		assert_eq!(counted_input.count(), 5);
110	}
111
112	#[test]
113	fn test_counted_input_max_count_read_byte() {
114		let max_exact_count = u64::MAX - 1;
115
116		let mut input = &[0u8; 1000][..];
117		let mut counted_input = CountedInput::new(&mut input);
118
119		counted_input.counter = max_exact_count;
120
121		assert_eq!(counted_input.count(), max_exact_count);
122
123		counted_input.read_byte().unwrap();
124
125		assert_eq!(counted_input.count(), u64::MAX);
126
127		counted_input.read_byte().unwrap();
128
129		assert_eq!(counted_input.count(), u64::MAX);
130	}
131
132	#[test]
133	fn test_counted_input_max_count_read() {
134		let max_exact_count = u64::MAX - 1;
135
136		let mut input = &[0u8; 1000][..];
137		let mut counted_input = CountedInput::new(&mut input);
138
139		counted_input.counter = max_exact_count;
140
141		assert_eq!(counted_input.count(), max_exact_count);
142
143		counted_input.read(&mut [0u8; 2][..]).unwrap();
144
145		assert_eq!(counted_input.count(), u64::MAX);
146
147		counted_input.read(&mut [0u8; 2][..]).unwrap();
148
149		assert_eq!(counted_input.count(), u64::MAX);
150	}
151}