1use crate::buf::count::CountBuf;
9use crate::buf::crc::{CrcBuf, CrcBufMut};
10use crate::error::{Error, ErrorKind};
11use aws_smithy_types::config_bag::{Storable, StoreReplace};
12use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
13use aws_smithy_types::str_bytes::StrBytes;
14use aws_smithy_types::DateTime;
15use bytes::{Buf, BufMut};
16use std::error::Error as StdError;
17use std::fmt;
18use std::mem::size_of;
19use std::sync::{mpsc, Mutex};
20
21const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::<u32>() as u32;
22const PRELUDE_LENGTH_BYTES_USIZE: usize = PRELUDE_LENGTH_BYTES as usize;
23const MESSAGE_CRC_LENGTH_BYTES: u32 = size_of::<u32>() as u32;
24const MAX_HEADER_NAME_LEN: usize = 255;
25const MIN_HEADER_LEN: usize = 2;
26
27pub(crate) const TYPE_TRUE: u8 = 0;
28pub(crate) const TYPE_FALSE: u8 = 1;
29pub(crate) const TYPE_BYTE: u8 = 2;
30pub(crate) const TYPE_INT16: u8 = 3;
31pub(crate) const TYPE_INT32: u8 = 4;
32pub(crate) const TYPE_INT64: u8 = 5;
33pub(crate) const TYPE_BYTE_ARRAY: u8 = 6;
34pub(crate) const TYPE_STRING: u8 = 7;
35pub(crate) const TYPE_TIMESTAMP: u8 = 8;
36pub(crate) const TYPE_UUID: u8 = 9;
37
38pub type SignMessageError = Box<dyn StdError + Send + Sync + 'static>;
39
40pub trait SignMessage: fmt::Debug {
42 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError>;
43
44 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>>;
49}
50
51#[derive(Debug)]
53#[non_exhaustive]
54pub struct DeferredSignerSender(Mutex<mpsc::Sender<Box<dyn SignMessage + Send + Sync>>>);
55
56impl DeferredSignerSender {
57 fn new(tx: mpsc::Sender<Box<dyn SignMessage + Send + Sync>>) -> Self {
59 Self(Mutex::new(tx))
60 }
61
62 pub fn send(
64 &self,
65 signer: Box<dyn SignMessage + Send + Sync>,
66 ) -> Result<(), mpsc::SendError<Box<dyn SignMessage + Send + Sync>>> {
67 self.0.lock().unwrap().send(signer)
68 }
69}
70
71impl Storable for DeferredSignerSender {
72 type Storer = StoreReplace<Self>;
73}
74
75#[derive(Debug)]
91pub struct DeferredSigner {
92 rx: Option<Mutex<mpsc::Receiver<Box<dyn SignMessage + Send + Sync>>>>,
93 signer: Option<Box<dyn SignMessage + Send + Sync>>,
94}
95
96impl DeferredSigner {
97 pub fn new() -> (Self, DeferredSignerSender) {
98 let (tx, rx) = mpsc::channel();
99 (
100 Self {
101 rx: Some(Mutex::new(rx)),
102 signer: None,
103 },
104 DeferredSignerSender::new(tx),
105 )
106 }
107
108 fn acquire(&mut self) -> &mut (dyn SignMessage + Send + Sync) {
109 if self.signer.is_some() {
111 return self.signer.as_mut().unwrap().as_mut();
112 } else {
113 self.signer = Some(
114 self.rx
115 .take()
116 .expect("only taken once")
117 .lock()
118 .unwrap()
119 .try_recv()
120 .ok()
121 .unwrap_or_else(|| Box::new(NoOpSigner {}) as _),
128 );
129 self.acquire()
130 }
131 }
132}
133
134impl SignMessage for DeferredSigner {
135 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
136 self.acquire().sign(message)
137 }
138
139 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
140 self.acquire().sign_empty()
141 }
142}
143
144#[derive(Debug)]
145pub struct NoOpSigner {}
146impl SignMessage for NoOpSigner {
147 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
148 Ok(message)
149 }
150
151 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
152 None
153 }
154}
155
156pub trait MarshallMessage: fmt::Debug {
158 type Input;
160
161 fn marshall(&self, input: Self::Input) -> Result<Message, Error>;
162}
163
164#[derive(Debug)]
166pub enum UnmarshalledMessage<T, E> {
167 Event(T),
168 Error(E),
169}
170
171pub trait UnmarshallMessage: fmt::Debug {
173 type Output;
175 type Error;
177
178 fn unmarshall(
179 &self,
180 message: &Message,
181 ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, Error>;
182}
183
184macro_rules! read_value {
185 ($buf:ident, $typ:ident, $size_typ:ident, $read_fn:ident) => {
186 if $buf.remaining() >= size_of::<$size_typ>() {
187 Ok(HeaderValue::$typ($buf.$read_fn()))
188 } else {
189 Err(ErrorKind::InvalidHeaderValue.into())
190 }
191 };
192}
193
194fn read_header_value_from<B: Buf>(mut buffer: B) -> Result<HeaderValue, Error> {
195 let value_type = buffer.get_u8();
196 match value_type {
197 TYPE_TRUE => Ok(HeaderValue::Bool(true)),
198 TYPE_FALSE => Ok(HeaderValue::Bool(false)),
199 TYPE_BYTE => read_value!(buffer, Byte, i8, get_i8),
200 TYPE_INT16 => read_value!(buffer, Int16, i16, get_i16),
201 TYPE_INT32 => read_value!(buffer, Int32, i32, get_i32),
202 TYPE_INT64 => read_value!(buffer, Int64, i64, get_i64),
203 TYPE_BYTE_ARRAY | TYPE_STRING => {
204 if buffer.remaining() > size_of::<u16>() {
205 let len = buffer.get_u16() as usize;
206 if buffer.remaining() < len {
207 return Err(ErrorKind::InvalidHeaderValue.into());
208 }
209 let bytes = buffer.copy_to_bytes(len);
210 if value_type == TYPE_STRING {
211 Ok(HeaderValue::String(
212 bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?,
213 ))
214 } else {
215 Ok(HeaderValue::ByteArray(bytes))
216 }
217 } else {
218 Err(ErrorKind::InvalidHeaderValue.into())
219 }
220 }
221 TYPE_TIMESTAMP => {
222 if buffer.remaining() >= size_of::<i64>() {
223 let epoch_millis = buffer.get_i64();
224 Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis)))
225 } else {
226 Err(ErrorKind::InvalidHeaderValue.into())
227 }
228 }
229 TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128),
230 _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()),
231 }
232}
233
234fn write_header_value_to<B: BufMut>(value: &HeaderValue, mut buffer: B) -> Result<(), Error> {
235 use HeaderValue::*;
236 match value {
237 Bool(val) => buffer.put_u8(if *val { TYPE_TRUE } else { TYPE_FALSE }),
238 Byte(val) => {
239 buffer.put_u8(TYPE_BYTE);
240 buffer.put_i8(*val);
241 }
242 Int16(val) => {
243 buffer.put_u8(TYPE_INT16);
244 buffer.put_i16(*val);
245 }
246 Int32(val) => {
247 buffer.put_u8(TYPE_INT32);
248 buffer.put_i32(*val);
249 }
250 Int64(val) => {
251 buffer.put_u8(TYPE_INT64);
252 buffer.put_i64(*val);
253 }
254 ByteArray(val) => {
255 buffer.put_u8(TYPE_BYTE_ARRAY);
256 buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?);
257 buffer.put_slice(&val[..]);
258 }
259 String(val) => {
260 buffer.put_u8(TYPE_STRING);
261 buffer.put_u16(checked(
262 val.as_bytes().len(),
263 ErrorKind::HeaderValueTooLong.into(),
264 )?);
265 buffer.put_slice(&val.as_bytes()[..]);
266 }
267 Timestamp(time) => {
268 buffer.put_u8(TYPE_TIMESTAMP);
269 buffer.put_i64(
270 time.to_millis()
271 .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?,
272 );
273 }
274 Uuid(val) => {
275 buffer.put_u8(TYPE_UUID);
276 buffer.put_u128(*val);
277 }
278 _ => {
279 panic!("matched on unexpected variant in `aws_smithy_types::event_stream::HeaderValue`")
280 }
281 }
282 Ok(())
283}
284
285fn read_header_from<B: Buf>(mut buffer: B) -> Result<(Header, usize), Error> {
287 if buffer.remaining() < MIN_HEADER_LEN {
288 return Err(ErrorKind::InvalidHeadersLength.into());
289 }
290
291 let mut counting_buf = CountBuf::new(&mut buffer);
292 let name_len = counting_buf.get_u8();
293 if name_len as usize >= counting_buf.remaining() {
294 return Err(ErrorKind::InvalidHeaderNameLength.into());
295 }
296
297 let name: StrBytes = counting_buf
298 .copy_to_bytes(name_len as usize)
299 .try_into()
300 .map_err(|_| ErrorKind::InvalidUtf8String)?;
301 let value = read_header_value_from(&mut counting_buf)?;
302 Ok((Header::new(name, value), counting_buf.into_count()))
303}
304
305fn write_header_to<B: BufMut>(header: &Header, mut buffer: B) -> Result<(), Error> {
307 if header.name().as_bytes().len() > MAX_HEADER_NAME_LEN {
308 return Err(ErrorKind::InvalidHeaderNameLength.into());
309 }
310
311 buffer.put_u8(u8::try_from(header.name().as_bytes().len()).expect("bounds check above"));
312 buffer.put_slice(&header.name().as_bytes()[..]);
313 write_header_value_to(header.value(), buffer)
314}
315
316pub fn write_headers_to<B: BufMut>(headers: &[Header], mut buffer: B) -> Result<(), Error> {
318 for header in headers {
319 write_header_to(header, &mut buffer)?;
320 }
321 Ok(())
322}
323
324fn read_prelude_from<B: Buf>(mut buffer: B) -> Result<(u32, u32), Error> {
326 let mut crc_buffer = CrcBuf::new(&mut buffer);
327
328 let total_len = crc_buffer.get_u32();
330 if crc_buffer.remaining() + size_of::<u32>() < total_len as usize {
331 return Err(ErrorKind::InvalidMessageLength.into());
332 }
333
334 let header_len = crc_buffer.get_u32();
336 let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32());
337 if expected_crc != prelude_crc {
338 return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into());
339 }
340 if header_len == 1 || header_len > max_header_len(total_len)? {
342 return Err(ErrorKind::InvalidHeadersLength.into());
343 }
344 Ok((total_len, header_len))
345}
346
347pub fn read_message_from<B: Buf>(mut buffer: B) -> Result<Message, Error> {
350 if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE {
351 return Err(ErrorKind::InvalidMessageLength.into());
352 }
353
354 let mut crc_buffer = CrcBuf::new(&mut buffer);
356 let (total_len, header_len) = read_prelude_from(&mut crc_buffer)?;
357
358 let remaining_len = total_len
360 .checked_sub(PRELUDE_LENGTH_BYTES)
361 .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
362 if crc_buffer.remaining() < remaining_len as usize {
363 return Err(ErrorKind::InvalidMessageLength.into());
364 }
365
366 let mut header_bytes_read = 0;
368 let mut headers = Vec::new();
369 while header_bytes_read < header_len as usize {
370 let (header, bytes_read) = read_header_from(&mut crc_buffer)?;
371 header_bytes_read += bytes_read;
372 if header_bytes_read > header_len as usize {
373 return Err(ErrorKind::InvalidHeaderValue.into());
374 }
375 headers.push(header);
376 }
377
378 let payload_len = payload_len(total_len, header_len)?;
380 let payload = crc_buffer.copy_to_bytes(payload_len as usize);
381
382 let expected_crc = crc_buffer.into_crc();
383 let message_crc = buffer.get_u32();
384 if expected_crc != message_crc {
385 return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into());
386 }
387
388 Ok(Message::new_from_parts(headers, payload))
389}
390
391pub fn write_message_to(message: &Message, buffer: &mut dyn BufMut) -> Result<(), Error> {
393 let mut headers = Vec::new();
394 for header in message.headers() {
395 write_header_to(header, &mut headers)?;
396 }
397
398 let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?;
399 let payload_len = checked(message.payload().len(), ErrorKind::PayloadTooLong.into())?;
400 let message_len = [
401 PRELUDE_LENGTH_BYTES,
402 headers_len,
403 payload_len,
404 MESSAGE_CRC_LENGTH_BYTES,
405 ]
406 .iter()
407 .try_fold(0u32, |acc, v| {
408 acc.checked_add(*v)
409 .ok_or_else(|| Error::from(ErrorKind::MessageTooLong))
410 })?;
411
412 let mut crc_buffer = CrcBufMut::new(buffer);
413 crc_buffer.put_u32(message_len);
414 crc_buffer.put_u32(headers_len);
415 crc_buffer.put_crc();
416 crc_buffer.put(&headers[..]);
417 crc_buffer.put(&message.payload()[..]);
418 crc_buffer.put_crc();
419 Ok(())
420}
421
422fn checked<T: TryFrom<U>, U>(from: U, err: Error) -> Result<T, Error> {
423 T::try_from(from).map_err(|_| err)
424}
425
426fn max_header_len(total_len: u32) -> Result<u32, Error> {
427 total_len
428 .checked_sub(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
429 .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
430}
431
432fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> {
433 total_len
434 .checked_sub(
435 header_len
436 .checked_add(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
437 .ok_or_else(|| Error::from(ErrorKind::InvalidHeadersLength))?,
438 )
439 .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
440}
441
442#[cfg(test)]
443mod message_tests {
444 use super::read_message_from;
445 use crate::error::ErrorKind;
446 use crate::frame::{write_message_to, Header, HeaderValue, Message};
447 use aws_smithy_types::DateTime;
448 use bytes::Bytes;
449
450 macro_rules! read_message_expect_err {
451 ($bytes:expr, $err:pat) => {
452 let result = read_message_from(&mut Bytes::from_static($bytes));
453 let result = result.as_ref();
454 assert!(result.is_err(), "Expected error, got {:?}", result);
455 assert!(
456 matches!(result.err().unwrap().kind(), $err),
457 "Expected {}, got {:?}",
458 stringify!($err),
459 result
460 );
461 };
462 }
463
464 #[test]
465 fn invalid_messages() {
466 read_message_expect_err!(
467 include_bytes!("../test_data/invalid_header_string_value_length"),
468 ErrorKind::InvalidHeaderValue
469 );
470 read_message_expect_err!(
471 include_bytes!("../test_data/invalid_header_string_length_cut_off"),
472 ErrorKind::InvalidHeaderValue
473 );
474 read_message_expect_err!(
475 include_bytes!("../test_data/invalid_header_value_type"),
476 ErrorKind::InvalidHeaderValueType(0x60)
477 );
478 read_message_expect_err!(
479 include_bytes!("../test_data/invalid_header_name_length"),
480 ErrorKind::InvalidHeaderNameLength
481 );
482 read_message_expect_err!(
483 include_bytes!("../test_data/invalid_headers_length"),
484 ErrorKind::InvalidHeadersLength
485 );
486 read_message_expect_err!(
487 include_bytes!("../test_data/invalid_prelude_checksum"),
488 ErrorKind::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF)
489 );
490 read_message_expect_err!(
491 include_bytes!("../test_data/invalid_message_checksum"),
492 ErrorKind::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF)
493 );
494 read_message_expect_err!(
495 include_bytes!("../test_data/invalid_header_name_length_too_long"),
496 ErrorKind::InvalidUtf8String
497 );
498 }
499
500 #[test]
501 fn read_message_no_headers() {
502 let data: &'static [u8] = &[
505 0x00, 0x00, 0x00, 0x1D, 0x00, 0x00, 0x00, 0x00, 0xfd, 0x52, 0x8c, 0x5a, 0x7b, 0x27,
506 0x66, 0x6f, 0x6f, 0x27, 0x3a, 0x27, 0x62, 0x61, 0x72, 0x27, 0x7d, 0xc3, 0x65, 0x39,
507 0x36,
508 ];
509
510 let result = read_message_from(&mut Bytes::from_static(data)).unwrap();
511 assert_eq!(result.headers(), Vec::new());
512
513 let expected_payload = b"{'foo':'bar'}";
514 assert_eq!(expected_payload, result.payload().as_ref());
515 }
516
517 #[test]
518 fn read_message_one_header() {
519 let data: &'static [u8] = &[
522 0x00, 0x00, 0x00, 0x3D, 0x00, 0x00, 0x00, 0x20, 0x07, 0xFD, 0x83, 0x96, 0x0C, b'c',
523 b'o', b'n', b't', b'e', b'n', b't', b'-', b't', b'y', b'p', b'e', 0x07, 0x00, 0x10,
524 b'a', b'p', b'p', b'l', b'i', b'c', b'a', b't', b'i', b'o', b'n', b'/', b'j', b's',
525 b'o', b'n', 0x7b, 0x27, 0x66, 0x6f, 0x6f, 0x27, 0x3a, 0x27, 0x62, 0x61, 0x72, 0x27,
526 0x7d, 0x8D, 0x9C, 0x08, 0xB1,
527 ];
528
529 let result = read_message_from(&mut Bytes::from_static(data)).unwrap();
530 assert_eq!(
531 result.headers(),
532 vec![Header::new(
533 "content-type",
534 HeaderValue::String("application/json".into())
535 )]
536 );
537
538 let expected_payload = b"{'foo':'bar'}";
539 assert_eq!(expected_payload, result.payload().as_ref());
540 }
541
542 #[test]
543 fn read_all_headers_and_payload() {
544 let message = include_bytes!("../test_data/valid_with_all_headers_and_payload");
545 let result = read_message_from(&mut Bytes::from_static(message)).unwrap();
546 assert_eq!(
547 result.headers(),
548 vec![
549 Header::new("true", HeaderValue::Bool(true)),
550 Header::new("false", HeaderValue::Bool(false)),
551 Header::new("byte", HeaderValue::Byte(50)),
552 Header::new("short", HeaderValue::Int16(20_000)),
553 Header::new("int", HeaderValue::Int32(500_000)),
554 Header::new("long", HeaderValue::Int64(50_000_000_000)),
555 Header::new(
556 "bytes",
557 HeaderValue::ByteArray(Bytes::from(&b"some bytes"[..]))
558 ),
559 Header::new("str", HeaderValue::String("some str".into())),
560 Header::new(
561 "time",
562 HeaderValue::Timestamp(DateTime::from_secs(5_000_000))
563 ),
564 Header::new(
565 "uuid",
566 HeaderValue::Uuid(0xb79bc914_de21_4e13_b8b2_bc47e85b7f0b)
567 ),
568 ]
569 );
570
571 assert_eq!(b"some payload", result.payload().as_ref());
572 }
573
574 #[test]
575 fn round_trip_all_headers_payload() {
576 let message = Message::new(&b"some payload"[..])
577 .add_header(Header::new("true", HeaderValue::Bool(true)))
578 .add_header(Header::new("false", HeaderValue::Bool(false)))
579 .add_header(Header::new("byte", HeaderValue::Byte(50)))
580 .add_header(Header::new("short", HeaderValue::Int16(20_000)))
581 .add_header(Header::new("int", HeaderValue::Int32(500_000)))
582 .add_header(Header::new("long", HeaderValue::Int64(50_000_000_000)))
583 .add_header(Header::new(
584 "bytes",
585 HeaderValue::ByteArray((&b"some bytes"[..]).into()),
586 ))
587 .add_header(Header::new("str", HeaderValue::String("some str".into())))
588 .add_header(Header::new(
589 "time",
590 HeaderValue::Timestamp(DateTime::from_secs(5_000_000)),
591 ))
592 .add_header(Header::new(
593 "uuid",
594 HeaderValue::Uuid(0xb79bc914_de21_4e13_b8b2_bc47e85b7f0b),
595 ));
596
597 let mut actual = Vec::new();
598 write_message_to(&message, &mut actual).unwrap();
599
600 let expected = include_bytes!("../test_data/valid_with_all_headers_and_payload").to_vec();
601 assert_eq!(expected, actual);
602
603 let result = read_message_from(&mut Bytes::from(actual)).unwrap();
604 assert_eq!(message.headers(), result.headers());
605 assert_eq!(message.payload().as_ref(), result.payload().as_ref());
606 }
607}
608
609#[derive(Debug)]
611pub enum DecodedFrame {
612 Incomplete,
614 Complete(Message),
616}
617
618#[non_exhaustive]
620#[derive(Default, Debug)]
621pub struct MessageFrameDecoder {
622 prelude: [u8; PRELUDE_LENGTH_BYTES_USIZE],
623 prelude_read: bool,
624}
625
626impl MessageFrameDecoder {
627 pub fn new() -> Self {
629 Default::default()
630 }
631
632 fn remaining_bytes_if_frame_available<B: Buf>(
637 &self,
638 buffer: &B,
639 ) -> Result<Option<usize>, Error> {
640 if self.prelude_read {
641 let remaining_len = (&self.prelude[..])
642 .get_u32()
643 .checked_sub(PRELUDE_LENGTH_BYTES)
644 .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
645 if buffer.remaining() >= remaining_len as usize {
646 return Ok(Some(remaining_len as usize));
647 }
648 }
649 Ok(None)
650 }
651
652 fn reset(&mut self) {
654 self.prelude_read = false;
655 self.prelude = [0u8; PRELUDE_LENGTH_BYTES_USIZE];
656 }
657
658 pub fn decode_frame<B: Buf>(&mut self, mut buffer: B) -> Result<DecodedFrame, Error> {
667 if !self.prelude_read && buffer.remaining() >= PRELUDE_LENGTH_BYTES_USIZE {
668 buffer.copy_to_slice(&mut self.prelude);
669 self.prelude_read = true;
670 }
671
672 if let Some(remaining_len) = self.remaining_bytes_if_frame_available(&buffer)? {
673 let mut message_buf = (&self.prelude[..]).chain(buffer.take(remaining_len));
674 let result = read_message_from(&mut message_buf).map(DecodedFrame::Complete);
675 self.reset();
676 return result;
677 }
678
679 Ok(DecodedFrame::Incomplete)
680 }
681}
682
683#[cfg(test)]
684mod message_frame_decoder_tests {
685 use super::{DecodedFrame, MessageFrameDecoder};
686 use crate::frame::read_message_from;
687 use bytes::Bytes;
688 use bytes_utils::SegmentedBuf;
689
690 #[test]
691 fn single_streaming_message() {
692 let message = include_bytes!("../test_data/valid_with_all_headers_and_payload");
693
694 let mut decoder = MessageFrameDecoder::new();
695 let mut segmented = SegmentedBuf::new();
696 for i in 0..(message.len() - 1) {
697 segmented.push(&message[i..(i + 1)]);
698 if let DecodedFrame::Complete(_) = decoder.decode_frame(&mut segmented).unwrap() {
699 panic!("incomplete frame shouldn't result in message");
700 }
701 }
702
703 segmented.push(&message[(message.len() - 1)..]);
704 match decoder.decode_frame(&mut segmented).unwrap() {
705 DecodedFrame::Incomplete => panic!("frame should be complete now"),
706 DecodedFrame::Complete(actual) => {
707 let expected = read_message_from(&mut Bytes::from_static(message)).unwrap();
708 assert_eq!(expected, actual);
709 }
710 }
711 }
712
713 fn multiple_streaming_messages_chunk_size(chunk_size: usize) {
714 let message1 = include_bytes!("../test_data/valid_with_all_headers_and_payload");
715 let message2 = include_bytes!("../test_data/valid_empty_payload");
716 let message3 = include_bytes!("../test_data/valid_no_headers");
717 let mut repeated = message1.to_vec();
718 repeated.extend_from_slice(message2);
719 repeated.extend_from_slice(message3);
720
721 let mut decoder = MessageFrameDecoder::new();
722 let mut segmented = SegmentedBuf::new();
723 let mut decoded = Vec::new();
724 for window in repeated.chunks(chunk_size) {
725 segmented.push(window);
726 match dbg!(decoder.decode_frame(&mut segmented)).unwrap() {
727 DecodedFrame::Incomplete => {}
728 DecodedFrame::Complete(message) => {
729 decoded.push(message);
730 }
731 }
732 }
733
734 let expected1 = read_message_from(&mut Bytes::from_static(message1)).unwrap();
735 let expected2 = read_message_from(&mut Bytes::from_static(message2)).unwrap();
736 let expected3 = read_message_from(&mut Bytes::from_static(message3)).unwrap();
737 assert_eq!(3, decoded.len());
738 assert_eq!(expected1, decoded[0]);
739 assert_eq!(expected2, decoded[1]);
740 assert_eq!(expected3, decoded[2]);
741 }
742
743 #[test]
744 fn multiple_streaming_messages() {
745 for chunk_size in 1..=11 {
746 println!("chunk size: {}", chunk_size);
747 multiple_streaming_messages_chunk_size(chunk_size);
748 }
749 }
750}
751
752#[cfg(test)]
753mod deferred_signer_tests {
754 use crate::frame::{DeferredSigner, Header, HeaderValue, Message, SignMessage};
755 use bytes::Bytes;
756
757 fn check_send_sync<T: Send + Sync>(value: T) -> T {
758 value
759 }
760
761 #[test]
762 fn deferred_signer() {
763 #[derive(Default, Debug)]
764 struct TestSigner {
765 call_num: i32,
766 }
767 impl SignMessage for TestSigner {
768 fn sign(
769 &mut self,
770 message: Message,
771 ) -> Result<Message, crate::frame::SignMessageError> {
772 self.call_num += 1;
773 Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num))))
774 }
775
776 fn sign_empty(&mut self) -> Option<Result<Message, crate::frame::SignMessageError>> {
777 None
778 }
779 }
780
781 let (mut signer, sender) = check_send_sync(DeferredSigner::new());
782
783 sender.send(Box::<TestSigner>::default()).expect("success");
784
785 let message = signer.sign(Message::new(Bytes::new())).expect("success");
786 assert_eq!(1, message.headers()[0].value().as_int32().unwrap());
787
788 let message = signer.sign(Message::new(Bytes::new())).expect("success");
789 assert_eq!(2, message.headers()[0].value().as_int32().unwrap());
790
791 assert!(signer.sign_empty().is_none());
792 }
793
794 #[test]
795 fn deferred_signer_defaults_to_noop_signer() {
796 let (mut signer, _sender) = DeferredSigner::new();
797 assert_eq!(
798 Message::new(Bytes::new()),
799 signer.sign(Message::new(Bytes::new())).unwrap()
800 );
801 assert!(signer.sign_empty().is_none());
802 }
803}