aws_smithy_async/future/
rendezvous.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Rendezvous channel implementation
7//!
8//! Rendezvous channels are equivalent to a channel with a 0-sized buffer: A sender cannot send
9//! until there is an active receiver waiting. This implementation uses a Semaphore to record demand
10//! and coordinate with the receiver.
11//!
12//! Rendezvous channels should be used with care—it's inherently easy to deadlock unless they're being
13//! used from separate tasks or an a coroutine setup (e.g. [`crate::future::pagination_stream::fn_stream::FnStream`])
14
15use std::future::poll_fn;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use tokio::sync::Semaphore;
19
20/// Create a new rendezvous channel
21///
22/// Rendezvous channels are equivalent to a channel with a 0-sized buffer: A sender cannot send
23/// until this is an active receiver waiting. This implementation uses a semaphore to record demand
24/// and coordinate with the receiver.
25pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
26    let (tx, rx) = tokio::sync::mpsc::channel(1);
27    let semaphore = Arc::new(Semaphore::new(0));
28    (
29        Sender {
30            semaphore: semaphore.clone(),
31            chan: tx,
32        },
33        Receiver {
34            semaphore,
35            chan: rx,
36            needs_permit: false,
37        },
38    )
39}
40
41/// Errors for rendezvous channel
42pub mod error {
43    use std::fmt;
44    use tokio::sync::mpsc::error::SendError as TokioSendError;
45
46    /// Error when [crate::future::rendezvous::Sender] fails to send a value to the associated `Receiver`
47    #[derive(Debug)]
48    pub struct SendError<T> {
49        source: TokioSendError<T>,
50    }
51
52    impl<T> SendError<T> {
53        pub(crate) fn tokio_send_error(source: TokioSendError<T>) -> Self {
54            Self { source }
55        }
56    }
57
58    impl<T> fmt::Display for SendError<T> {
59        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60            write!(f, "failed to send value to the receiver")
61        }
62    }
63
64    impl<T: fmt::Debug + 'static> std::error::Error for SendError<T> {
65        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
66            Some(&self.source)
67        }
68    }
69}
70
71#[derive(Debug)]
72/// Sender-half of a channel
73pub struct Sender<T> {
74    semaphore: Arc<Semaphore>,
75    chan: tokio::sync::mpsc::Sender<T>,
76}
77
78impl<T> Sender<T> {
79    /// Send `item` into the channel waiting until there is matching demand
80    ///
81    /// Unlike something like `tokio::sync::mpsc::Channel` where sending a value will be buffered until
82    /// demand exists, a rendezvous sender will wait until matching demand exists before this function will return.
83    pub async fn send(&self, item: T) -> Result<(), error::SendError<T>> {
84        let result = self.chan.send(item).await;
85        // If this is an error, the rx half has been dropped. We will never get demand.
86        if result.is_ok() {
87            // The key here is that we block _after_ the send until more demand exists
88            self.semaphore
89                .acquire()
90                .await
91                .expect("semaphore is never closed")
92                .forget();
93        }
94        result.map_err(error::SendError::tokio_send_error)
95    }
96}
97
98#[derive(Debug)]
99/// Receiver half of the rendezvous channel
100pub struct Receiver<T> {
101    semaphore: Arc<Semaphore>,
102    chan: tokio::sync::mpsc::Receiver<T>,
103    needs_permit: bool,
104}
105
106impl<T> Receiver<T> {
107    /// Polls to receive an item from the channel
108    pub async fn recv(&mut self) -> Option<T> {
109        poll_fn(|cx| self.poll_recv(cx)).await
110    }
111
112    pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
113        // This uses `needs_permit` to track whether this is the first poll since we last returned an item.
114        // If it is, we will grant a permit to the semaphore. Otherwise, we'll just forward the response through.
115        let resp = self.chan.poll_recv(cx);
116        // If there is no data on the channel, but we are reading, then give a permit so we can load data
117        if self.needs_permit && matches!(resp, Poll::Pending) {
118            self.needs_permit = false;
119            self.semaphore.add_permits(1);
120        }
121
122        if matches!(resp, Poll::Ready(_)) {
123            // we returned an item, no need to provide another permit until we fail to read from the channel again
124            self.needs_permit = true;
125        }
126        resp
127    }
128}
129
130#[cfg(test)]
131mod test {
132    use crate::future::rendezvous::channel;
133    use std::sync::{Arc, Mutex};
134
135    #[tokio::test]
136    async fn send_blocks_caller() {
137        let (tx, mut rx) = channel::<u8>();
138        let done = Arc::new(Mutex::new(0));
139        let idone = done.clone();
140        let send = tokio::spawn(async move {
141            *idone.lock().unwrap() = 1;
142            tx.send(0).await.unwrap();
143            *idone.lock().unwrap() = 2;
144            tx.send(1).await.unwrap();
145            *idone.lock().unwrap() = 3;
146        });
147        assert_eq!(*done.lock().unwrap(), 0);
148        assert_eq!(rx.recv().await, Some(0));
149        assert_eq!(*done.lock().unwrap(), 1);
150        assert_eq!(rx.recv().await, Some(1));
151        assert_eq!(*done.lock().unwrap(), 2);
152        assert_eq!(rx.recv().await, None);
153        assert_eq!(*done.lock().unwrap(), 3);
154        let _ = send.await;
155    }
156
157    #[tokio::test]
158    async fn send_errors_when_rx_dropped() {
159        let (tx, rx) = channel::<u8>();
160        drop(rx);
161        tx.send(0).await.expect_err("rx half dropped");
162    }
163}