aws_smithy_async/future/
pagination_stream.rs
1use crate::future::pagination_stream::collect::sealed::Collectable;
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13pub mod collect;
14pub mod fn_stream;
15use fn_stream::FnStream;
16
17#[derive(Debug)]
52pub struct PaginationStream<Item>(FnStream<Item>);
53
54impl<Item> PaginationStream<Item> {
55 pub fn new(stream: FnStream<Item>) -> Self {
57 Self(stream)
58 }
59
60 pub async fn next(&mut self) -> Option<Item> {
62 self.0.next().await
63 }
64
65 pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Item>> {
67 Pin::new(&mut self.0).poll_next(cx)
68 }
69
70 pub async fn collect<T: Collectable<Item>>(self) -> T {
72 self.0.collect().await
73 }
74}
75
76impl<T, E> PaginationStream<Result<T, E>> {
77 pub async fn try_next(&mut self) -> Result<Option<T>, E> {
79 self.next().await.transpose()
80 }
81
82 pub async fn try_collect(self) -> Result<Vec<T>, E> {
84 self.collect::<Result<Vec<T>, E>>().await
85 }
86}
87
88#[derive(Debug)]
94pub struct TryFlatMap<Page, Err>(PaginationStream<Result<Page, Err>>);
95
96impl<Page, Err> TryFlatMap<Page, Err> {
97 pub fn new(stream: PaginationStream<Result<Page, Err>>) -> Self {
99 Self(stream)
100 }
101
102 pub fn flat_map<M, Item, Iter>(mut self, map: M) -> PaginationStream<Result<Item, Err>>
104 where
105 Page: Send + 'static,
106 Err: Send + 'static,
107 M: Fn(Page) -> Iter + Send + 'static,
108 Item: Send + 'static,
109 Iter: IntoIterator<Item = Item> + Send,
110 <Iter as IntoIterator>::IntoIter: Send,
111 {
112 PaginationStream::new(FnStream::new(|tx| {
113 Box::pin(async move {
114 while let Some(page) = self.0.next().await {
115 match page {
116 Ok(page) => {
117 let mapped = map(page);
118 for item in mapped.into_iter() {
119 let _ = tx.send(Ok(item)).await;
120 }
121 }
122 Err(e) => {
123 let _ = tx.send(Err(e)).await;
124 break;
125 }
126 }
127 }
128 }) as Pin<Box<dyn Future<Output = ()> + Send>>
129 }))
130 }
131}
132
133#[cfg(test)]
134mod test {
135 use crate::future::pagination_stream::{FnStream, PaginationStream, TryFlatMap};
136 use std::sync::{Arc, Mutex};
137 use std::time::Duration;
138
139 #[tokio::test]
141 async fn fn_stream_returns_results() {
142 tokio::time::pause();
143 let mut stream = FnStream::new(|tx| {
144 Box::pin(async move {
145 tx.send("1").await.expect("failed to send");
146 tokio::time::sleep(Duration::from_secs(1)).await;
147 tokio::time::sleep(Duration::from_secs(1)).await;
148 tx.send("2").await.expect("failed to send");
149 tokio::time::sleep(Duration::from_secs(1)).await;
150 tx.send("3").await.expect("failed to send");
151 })
152 });
153 let mut out = vec![];
154 while let Some(value) = stream.next().await {
155 out.push(value);
156 }
157 assert_eq!(vec!["1", "2", "3"], out);
158 }
159
160 #[tokio::test]
161 async fn fn_stream_try_next() {
162 tokio::time::pause();
163 let mut stream = FnStream::new(|tx| {
164 Box::pin(async move {
165 tx.send(Ok(1)).await.unwrap();
166 tx.send(Ok(2)).await.unwrap();
167 tx.send(Err("err")).await.unwrap();
168 })
169 });
170 let mut out = vec![];
171 while let Ok(value) = stream.try_next().await {
172 out.push(value);
173 }
174 assert_eq!(vec![Some(1), Some(2)], out);
175 }
176
177 #[tokio::test]
182 async fn fn_stream_doesnt_poll_after_done() {
183 let mut stream = FnStream::new(|tx| {
184 Box::pin(async move {
185 assert!(tx.send("blah").await.is_ok());
186 Box::leak(Box::new(tx));
187 })
188 });
189 assert_eq!(Some("blah"), stream.next().await);
190 let mut test_stream = tokio_test::task::spawn(stream);
191 test_stream.enter(|ctx, pin| {
195 let polled = pin.poll_next(ctx);
196 assert!(polled.is_pending());
197 });
198 test_stream.enter(|ctx, pin| {
199 let polled = pin.poll_next(ctx);
200 assert!(polled.is_pending());
201 });
202 }
203
204 #[tokio::test]
206 async fn waits_for_reader() {
207 let progress = Arc::new(Mutex::new(0));
208 let mut stream = FnStream::new(|tx| {
209 let progress = progress.clone();
210 Box::pin(async move {
211 *progress.lock().unwrap() = 1;
212 tx.send("1").await.expect("failed to send");
213 *progress.lock().unwrap() = 2;
214 tx.send("2").await.expect("failed to send");
215 *progress.lock().unwrap() = 3;
216 tx.send("3").await.expect("failed to send");
217 *progress.lock().unwrap() = 4;
218 })
219 });
220 assert_eq!(*progress.lock().unwrap(), 0);
221 stream.next().await.expect("ready");
222 assert_eq!(*progress.lock().unwrap(), 1);
223
224 assert_eq!("2", stream.next().await.expect("ready"));
225 assert_eq!(2, *progress.lock().unwrap());
226
227 let _ = stream.next().await.expect("ready");
228 assert_eq!(3, *progress.lock().unwrap());
229 assert_eq!(None, stream.next().await);
230 assert_eq!(4, *progress.lock().unwrap());
231 }
232
233 #[tokio::test]
234 async fn generator_with_errors() {
235 let mut stream = FnStream::new(|tx| {
236 Box::pin(async move {
237 for i in 0..5 {
238 if i != 2 {
239 if tx.send(Ok(i)).await.is_err() {
240 return;
241 }
242 } else {
243 tx.send(Err(i)).await.unwrap();
244 return;
245 }
246 }
247 })
248 });
249 let mut out = vec![];
250 while let Some(Ok(value)) = stream.next().await {
251 out.push(value);
252 }
253 assert_eq!(vec![0, 1], out);
254 }
255
256 #[tokio::test]
257 async fn flatten_items_ok() {
258 #[derive(Debug)]
259 struct Output {
260 items: Vec<u8>,
261 }
262 let stream: FnStream<Result<_, &str>> = FnStream::new(|tx| {
263 Box::pin(async move {
264 tx.send(Ok(Output {
265 items: vec![1, 2, 3],
266 }))
267 .await
268 .unwrap();
269 tx.send(Ok(Output {
270 items: vec![4, 5, 6],
271 }))
272 .await
273 .unwrap();
274 })
275 });
276 assert_eq!(
277 Ok(vec![1, 2, 3, 4, 5, 6]),
278 TryFlatMap::new(PaginationStream::new(stream))
279 .flat_map(|output| output.items.into_iter())
280 .try_collect()
281 .await,
282 );
283 }
284
285 #[tokio::test]
286 async fn flatten_items_error() {
287 #[derive(Debug)]
288 struct Output {
289 items: Vec<u8>,
290 }
291 let stream = FnStream::new(|tx| {
292 Box::pin(async move {
293 tx.send(Ok(Output {
294 items: vec![1, 2, 3],
295 }))
296 .await
297 .unwrap();
298 tx.send(Err("bummer")).await.unwrap();
299 })
300 });
301 assert_eq!(
302 Err("bummer"),
303 TryFlatMap::new(PaginationStream::new(stream))
304 .flat_map(|output| output.items.into_iter())
305 .try_collect()
306 .await
307 )
308 }
309}