1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
56//! Rendezvous channel implementation
7//!
8//! Rendezvous channels are equivalent to a channel with a 0-sized buffer: A sender cannot send
9//! until there is an active receiver waiting. This implementation uses a Semaphore to record demand
10//! and coordinate with the receiver.
11//!
12//! Rendezvous channels should be used with care—it's inherently easy to deadlock unless they're being
13//! used from separate tasks or an a coroutine setup (e.g. [`crate::future::pagination_stream::fn_stream::FnStream`])
1415use std::future::poll_fn;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use tokio::sync::Semaphore;
1920/// Create a new rendezvous channel
21///
22/// Rendezvous channels are equivalent to a channel with a 0-sized buffer: A sender cannot send
23/// until this is an active receiver waiting. This implementation uses a semaphore to record demand
24/// and coordinate with the receiver.
25pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
26let (tx, rx) = tokio::sync::mpsc::channel(1);
27let semaphore = Arc::new(Semaphore::new(0));
28 (
29 Sender {
30 semaphore: semaphore.clone(),
31 chan: tx,
32 },
33 Receiver {
34 semaphore,
35 chan: rx,
36 needs_permit: false,
37 },
38 )
39}
4041/// Errors for rendezvous channel
42pub mod error {
43use std::fmt;
44use tokio::sync::mpsc::error::SendError as TokioSendError;
4546/// Error when [crate::future::rendezvous::Sender] fails to send a value to the associated `Receiver`
47#[derive(Debug)]
48pub struct SendError<T> {
49 source: TokioSendError<T>,
50 }
5152impl<T> SendError<T> {
53pub(crate) fn tokio_send_error(source: TokioSendError<T>) -> Self {
54Self { source }
55 }
56 }
5758impl<T> fmt::Display for SendError<T> {
59fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60write!(f, "failed to send value to the receiver")
61 }
62 }
6364impl<T: fmt::Debug + 'static> std::error::Error for SendError<T> {
65fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
66Some(&self.source)
67 }
68 }
69}
7071#[derive(Debug)]
72/// Sender-half of a channel
73pub struct Sender<T> {
74 semaphore: Arc<Semaphore>,
75 chan: tokio::sync::mpsc::Sender<T>,
76}
7778impl<T> Sender<T> {
79/// Send `item` into the channel waiting until there is matching demand
80 ///
81 /// Unlike something like `tokio::sync::mpsc::Channel` where sending a value will be buffered until
82 /// demand exists, a rendezvous sender will wait until matching demand exists before this function will return.
83pub async fn send(&self, item: T) -> Result<(), error::SendError<T>> {
84let result = self.chan.send(item).await;
85// If this is an error, the rx half has been dropped. We will never get demand.
86if result.is_ok() {
87// The key here is that we block _after_ the send until more demand exists
88self.semaphore
89 .acquire()
90 .await
91.expect("semaphore is never closed")
92 .forget();
93 }
94 result.map_err(error::SendError::tokio_send_error)
95 }
96}
9798#[derive(Debug)]
99/// Receiver half of the rendezvous channel
100pub struct Receiver<T> {
101 semaphore: Arc<Semaphore>,
102 chan: tokio::sync::mpsc::Receiver<T>,
103 needs_permit: bool,
104}
105106impl<T> Receiver<T> {
107/// Polls to receive an item from the channel
108pub async fn recv(&mut self) -> Option<T> {
109 poll_fn(|cx| self.poll_recv(cx)).await
110}
111112pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
113// This uses `needs_permit` to track whether this is the first poll since we last returned an item.
114 // If it is, we will grant a permit to the semaphore. Otherwise, we'll just forward the response through.
115let resp = self.chan.poll_recv(cx);
116// If there is no data on the channel, but we are reading, then give a permit so we can load data
117if self.needs_permit && matches!(resp, Poll::Pending) {
118self.needs_permit = false;
119self.semaphore.add_permits(1);
120 }
121122if matches!(resp, Poll::Ready(_)) {
123// we returned an item, no need to provide another permit until we fail to read from the channel again
124self.needs_permit = true;
125 }
126 resp
127 }
128}
129130#[cfg(test)]
131mod test {
132use crate::future::rendezvous::channel;
133use std::sync::{Arc, Mutex};
134135#[tokio::test]
136async fn send_blocks_caller() {
137let (tx, mut rx) = channel::<u8>();
138let done = Arc::new(Mutex::new(0));
139let idone = done.clone();
140let send = tokio::spawn(async move {
141*idone.lock().unwrap() = 1;
142 tx.send(0).await.unwrap();
143*idone.lock().unwrap() = 2;
144 tx.send(1).await.unwrap();
145*idone.lock().unwrap() = 3;
146 });
147assert_eq!(*done.lock().unwrap(), 0);
148assert_eq!(rx.recv().await, Some(0));
149assert_eq!(*done.lock().unwrap(), 1);
150assert_eq!(rx.recv().await, Some(1));
151assert_eq!(*done.lock().unwrap(), 2);
152assert_eq!(rx.recv().await, None);
153assert_eq!(*done.lock().unwrap(), 3);
154let _ = send.await;
155 }
156157#[tokio::test]
158async fn send_errors_when_rx_dropped() {
159let (tx, rx) = channel::<u8>();
160 drop(rx);
161 tx.send(0).await.expect_err("rx half dropped");
162 }
163}