aws_smithy_async/future/
pagination_stream.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Provides types to support stream-like operations for paginators.
7
8use crate::future::pagination_stream::collect::sealed::Collectable;
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13pub mod collect;
14pub mod fn_stream;
15use fn_stream::FnStream;
16
17/// Stream specifically made to support paginators.
18///
19/// `PaginationStream` provides two primary mechanisms for accessing stream of data.
20/// 1. With [`.next()`](PaginationStream::next) (or [`try_next()`](PaginationStream::try_next)):
21///
22/// ```no_run
23/// # async fn docs() {
24/// # use aws_smithy_async::future::pagination_stream::PaginationStream;
25/// # fn operation_to_yield_paginator<T>() -> PaginationStream<T> {
26/// #     todo!()
27/// # }
28/// # struct Page;
29/// let mut stream: PaginationStream<Page> = operation_to_yield_paginator();
30/// while let Some(page) = stream.next().await {
31///     // process `page`
32/// }
33/// # }
34/// ```
35/// 2. With [`.collect()`](PaginationStream::collect) (or [`try_collect()`](PaginationStream::try_collect)):
36///
37/// ```no_run
38/// # async fn docs() {
39/// # use aws_smithy_async::future::pagination_stream::PaginationStream;
40/// # fn operation_to_yield_paginator<T>() -> PaginationStream<T> {
41/// #     todo!()
42/// # }
43/// # struct Page;
44/// let mut stream: PaginationStream<Page> = operation_to_yield_paginator();
45/// let result = stream.collect::<Vec<Page>>().await;
46/// # }
47/// ```
48///
49/// [`PaginationStream`] is implemented in terms of [`FnStream`], but the latter is meant to be
50/// used internally and not by external users.
51#[derive(Debug)]
52pub struct PaginationStream<Item>(FnStream<Item>);
53
54impl<Item> PaginationStream<Item> {
55    /// Creates a `PaginationStream` from the given [`FnStream`].
56    pub fn new(stream: FnStream<Item>) -> Self {
57        Self(stream)
58    }
59
60    /// Consumes and returns the next `Item` from this stream.
61    pub async fn next(&mut self) -> Option<Item> {
62        self.0.next().await
63    }
64
65    /// Poll an item from the stream
66    pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Item>> {
67        Pin::new(&mut self.0).poll_next(cx)
68    }
69
70    /// Consumes this stream and gathers elements into a collection.
71    pub async fn collect<T: Collectable<Item>>(self) -> T {
72        self.0.collect().await
73    }
74}
75
76impl<T, E> PaginationStream<Result<T, E>> {
77    /// Yields the next item in the stream or returns an error if an error is encountered.
78    pub async fn try_next(&mut self) -> Result<Option<T>, E> {
79        self.next().await.transpose()
80    }
81
82    /// Convenience method for `.collect::<Result<Vec<_>, _>()`.
83    pub async fn try_collect(self) -> Result<Vec<T>, E> {
84        self.collect::<Result<Vec<T>, E>>().await
85    }
86}
87
88/// Utility wrapper to flatten paginated results
89///
90/// When flattening paginated results, it's most convenient to produce an iterator where the `Result`
91/// is present in each item. This provides `items()` which can wrap an stream of `Result<Page, Err>`
92/// and produce a stream of `Result<Item, Err>`.
93#[derive(Debug)]
94pub struct TryFlatMap<Page, Err>(PaginationStream<Result<Page, Err>>);
95
96impl<Page, Err> TryFlatMap<Page, Err> {
97    /// Creates a `TryFlatMap` that wraps the input.
98    pub fn new(stream: PaginationStream<Result<Page, Err>>) -> Self {
99        Self(stream)
100    }
101
102    /// Produces a new [`PaginationStream`] by mapping this stream with `map` then flattening the result.
103    pub fn flat_map<M, Item, Iter>(mut self, map: M) -> PaginationStream<Result<Item, Err>>
104    where
105        Page: Send + 'static,
106        Err: Send + 'static,
107        M: Fn(Page) -> Iter + Send + 'static,
108        Item: Send + 'static,
109        Iter: IntoIterator<Item = Item> + Send,
110        <Iter as IntoIterator>::IntoIter: Send,
111    {
112        PaginationStream::new(FnStream::new(|tx| {
113            Box::pin(async move {
114                while let Some(page) = self.0.next().await {
115                    match page {
116                        Ok(page) => {
117                            let mapped = map(page);
118                            for item in mapped.into_iter() {
119                                let _ = tx.send(Ok(item)).await;
120                            }
121                        }
122                        Err(e) => {
123                            let _ = tx.send(Err(e)).await;
124                            break;
125                        }
126                    }
127                }
128            }) as Pin<Box<dyn Future<Output = ()> + Send>>
129        }))
130    }
131}
132
133#[cfg(test)]
134mod test {
135    use crate::future::pagination_stream::{FnStream, PaginationStream, TryFlatMap};
136    use std::sync::{Arc, Mutex};
137    use std::time::Duration;
138
139    /// basic test of FnStream functionality
140    #[tokio::test]
141    async fn fn_stream_returns_results() {
142        tokio::time::pause();
143        let mut stream = FnStream::new(|tx| {
144            Box::pin(async move {
145                tx.send("1").await.expect("failed to send");
146                tokio::time::sleep(Duration::from_secs(1)).await;
147                tokio::time::sleep(Duration::from_secs(1)).await;
148                tx.send("2").await.expect("failed to send");
149                tokio::time::sleep(Duration::from_secs(1)).await;
150                tx.send("3").await.expect("failed to send");
151            })
152        });
153        let mut out = vec![];
154        while let Some(value) = stream.next().await {
155            out.push(value);
156        }
157        assert_eq!(vec!["1", "2", "3"], out);
158    }
159
160    #[tokio::test]
161    async fn fn_stream_try_next() {
162        tokio::time::pause();
163        let mut stream = FnStream::new(|tx| {
164            Box::pin(async move {
165                tx.send(Ok(1)).await.unwrap();
166                tx.send(Ok(2)).await.unwrap();
167                tx.send(Err("err")).await.unwrap();
168            })
169        });
170        let mut out = vec![];
171        while let Ok(value) = stream.try_next().await {
172            out.push(value);
173        }
174        assert_eq!(vec![Some(1), Some(2)], out);
175    }
176
177    // smithy-rs#1902: there was a bug where we could continue to poll the generator after it
178    // had returned Poll::Ready. This test case leaks the tx half so that the channel stays open
179    // but the send side generator completes. By calling `poll` multiple times on the resulting future,
180    // we can trigger the bug and validate the fix.
181    #[tokio::test]
182    async fn fn_stream_doesnt_poll_after_done() {
183        let mut stream = FnStream::new(|tx| {
184            Box::pin(async move {
185                assert!(tx.send("blah").await.is_ok());
186                Box::leak(Box::new(tx));
187            })
188        });
189        assert_eq!(Some("blah"), stream.next().await);
190        let mut test_stream = tokio_test::task::spawn(stream);
191        // `tokio_test::task::Spawn::poll_next` can only be invoked when the wrapped
192        // type implements the `Stream` trait. Here, `FnStream` does not implement it,
193        // so we work around it by using the `enter` method.
194        test_stream.enter(|ctx, pin| {
195            let polled = pin.poll_next(ctx);
196            assert!(polled.is_pending());
197        });
198        test_stream.enter(|ctx, pin| {
199            let polled = pin.poll_next(ctx);
200            assert!(polled.is_pending());
201        });
202    }
203
204    /// Tests that the generator will not advance until demand exists
205    #[tokio::test]
206    async fn waits_for_reader() {
207        let progress = Arc::new(Mutex::new(0));
208        let mut stream = FnStream::new(|tx| {
209            let progress = progress.clone();
210            Box::pin(async move {
211                *progress.lock().unwrap() = 1;
212                tx.send("1").await.expect("failed to send");
213                *progress.lock().unwrap() = 2;
214                tx.send("2").await.expect("failed to send");
215                *progress.lock().unwrap() = 3;
216                tx.send("3").await.expect("failed to send");
217                *progress.lock().unwrap() = 4;
218            })
219        });
220        assert_eq!(*progress.lock().unwrap(), 0);
221        stream.next().await.expect("ready");
222        assert_eq!(*progress.lock().unwrap(), 1);
223
224        assert_eq!("2", stream.next().await.expect("ready"));
225        assert_eq!(2, *progress.lock().unwrap());
226
227        let _ = stream.next().await.expect("ready");
228        assert_eq!(3, *progress.lock().unwrap());
229        assert_eq!(None, stream.next().await);
230        assert_eq!(4, *progress.lock().unwrap());
231    }
232
233    #[tokio::test]
234    async fn generator_with_errors() {
235        let mut stream = FnStream::new(|tx| {
236            Box::pin(async move {
237                for i in 0..5 {
238                    if i != 2 {
239                        if tx.send(Ok(i)).await.is_err() {
240                            return;
241                        }
242                    } else {
243                        tx.send(Err(i)).await.unwrap();
244                        return;
245                    }
246                }
247            })
248        });
249        let mut out = vec![];
250        while let Some(Ok(value)) = stream.next().await {
251            out.push(value);
252        }
253        assert_eq!(vec![0, 1], out);
254    }
255
256    #[tokio::test]
257    async fn flatten_items_ok() {
258        #[derive(Debug)]
259        struct Output {
260            items: Vec<u8>,
261        }
262        let stream: FnStream<Result<_, &str>> = FnStream::new(|tx| {
263            Box::pin(async move {
264                tx.send(Ok(Output {
265                    items: vec![1, 2, 3],
266                }))
267                .await
268                .unwrap();
269                tx.send(Ok(Output {
270                    items: vec![4, 5, 6],
271                }))
272                .await
273                .unwrap();
274            })
275        });
276        assert_eq!(
277            Ok(vec![1, 2, 3, 4, 5, 6]),
278            TryFlatMap::new(PaginationStream::new(stream))
279                .flat_map(|output| output.items.into_iter())
280                .try_collect()
281                .await,
282        );
283    }
284
285    #[tokio::test]
286    async fn flatten_items_error() {
287        #[derive(Debug)]
288        struct Output {
289            items: Vec<u8>,
290        }
291        let stream = FnStream::new(|tx| {
292            Box::pin(async move {
293                tx.send(Ok(Output {
294                    items: vec![1, 2, 3],
295                }))
296                .await
297                .unwrap();
298                tx.send(Err("bummer")).await.unwrap();
299            })
300        });
301        assert_eq!(
302            Err("bummer"),
303            TryFlatMap::new(PaginationStream::new(stream))
304                .flat_map(|output| output.items.into_iter())
305                .try_collect()
306                .await
307        )
308    }
309}