1use std::io;
2use std::ops::Range;
3
4use super::base::Payload;
5use super::codec::Codec;
6use super::message::PlainMessage;
7use crate::enums::{ContentType, ProtocolVersion};
8use crate::error::{Error, InvalidMessage, PeerMisbehaved};
9use crate::msgs::codec;
10use crate::msgs::message::{MessageError, OpaqueMessage};
11use crate::record_layer::{Decrypted, RecordLayer};
12
13#[derive(Default)]
18pub struct MessageDeframer {
19 last_error: Option<Error>,
23
24 buf: Vec<u8>,
28
29 joining_hs: Option<HandshakePayloadMeta>,
31
32 used: usize,
34}
35
36impl MessageDeframer {
37 pub fn pop(&mut self, record_layer: &mut RecordLayer) -> Result<Option<Deframed>, Error> {
43 if let Some(last_err) = self.last_error.clone() {
44 return Err(last_err);
45 } else if self.used == 0 {
46 return Ok(None);
47 }
48
49 let expected_len = loop {
53 let start = match &self.joining_hs {
54 Some(meta) => {
55 match meta.expected_len {
56 Some(len) if len <= meta.payload.len() => break len,
58 _ if meta.quic => return Ok(None),
60 _ => meta.message.end,
62 }
63 }
64 None => 0,
65 };
66
67 let mut rd = codec::Reader::init(&self.buf[start..self.used]);
71 let m = match OpaqueMessage::read(&mut rd) {
72 Ok(m) => m,
73 Err(msg_err) => {
74 let err_kind = match msg_err {
75 MessageError::TooShortForHeader | MessageError::TooShortForLength => {
76 return Ok(None)
77 }
78 MessageError::InvalidEmptyPayload => InvalidMessage::InvalidEmptyPayload,
79 MessageError::MessageTooLarge => InvalidMessage::MessageTooLarge,
80 MessageError::InvalidContentType => InvalidMessage::InvalidContentType,
81 MessageError::UnknownProtocolVersion => {
82 InvalidMessage::UnknownProtocolVersion
83 }
84 };
85
86 return Err(self.set_err(err_kind));
87 }
88 };
89
90 let end = start + rd.used();
93 if m.typ == ContentType::ChangeCipherSpec && self.joining_hs.is_none() {
94 self.discard(end);
96 return Ok(Some(Deframed {
97 want_close_before_decrypt: false,
98 aligned: true,
99 trial_decryption_finished: false,
100 message: m.into_plain_message(),
101 }));
102 }
103
104 let msg = match record_layer.decrypt_incoming(m) {
106 Ok(Some(decrypted)) => {
107 let Decrypted {
108 want_close_before_decrypt,
109 plaintext,
110 } = decrypted;
111 debug_assert!(!want_close_before_decrypt);
112 plaintext
113 }
114 Ok(None) if self.joining_hs.is_some() => {
117 return Err(self.set_err(
118 PeerMisbehaved::RejectedEarlyDataInterleavedWithHandshakeMessage,
119 ));
120 }
121 Ok(None) => {
122 self.discard(end);
123 continue;
124 }
125 Err(e) => return Err(e),
126 };
127
128 if self.joining_hs.is_some() && msg.typ != ContentType::Handshake {
129 return Err(self.set_err(PeerMisbehaved::MessageInterleavedWithHandshakeMessage));
134 }
135
136 if msg.typ != ContentType::Handshake {
138 let end = start + rd.used();
139 self.discard(end);
140 return Ok(Some(Deframed {
141 want_close_before_decrypt: false,
142 aligned: true,
143 trial_decryption_finished: false,
144 message: msg,
145 }));
146 }
147
148 match self.append_hs(msg.version, &msg.payload.0, end, false)? {
151 HandshakePayloadState::Blocked => return Ok(None),
152 HandshakePayloadState::Complete(len) => break len,
153 HandshakePayloadState::Continue => continue,
154 }
155 };
156
157 let meta = self.joining_hs.as_mut().unwrap(); let message = PlainMessage {
161 typ: ContentType::Handshake,
162 version: meta.version,
163 payload: Payload::new(&self.buf[meta.payload.start..meta.payload.start + expected_len]),
164 };
165
166 if meta.payload.len() > expected_len {
168 meta.payload.start += expected_len;
172 meta.expected_len = payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
173 } else {
174 let end = meta.message.end;
177 self.joining_hs = None;
178 self.discard(end);
179 }
180
181 Ok(Some(Deframed {
182 want_close_before_decrypt: false,
183 aligned: self.joining_hs.is_none(),
184 trial_decryption_finished: true,
185 message,
186 }))
187 }
188
189 fn set_err(&mut self, err: impl Into<Error>) -> Error {
193 let err = err.into();
194 self.last_error = Some(err.clone());
195 err
196 }
197
198 #[cfg(feature = "quic")]
200 pub fn push(&mut self, version: ProtocolVersion, payload: &[u8]) -> Result<(), Error> {
201 if self.used > 0 && self.joining_hs.is_none() {
202 return Err(Error::General(
203 "cannot push QUIC messages into unrelated connection".into(),
204 ));
205 } else if let Err(err) = self.prepare_read() {
206 return Err(Error::General(err.into()));
207 }
208
209 let end = self.used + payload.len();
210 self.append_hs(version, payload, end, true)?;
211 self.used = end;
212 Ok(())
213 }
214
215 fn append_hs(
219 &mut self,
220 version: ProtocolVersion,
221 payload: &[u8],
222 end: usize,
223 quic: bool,
224 ) -> Result<HandshakePayloadState, Error> {
225 let meta = match &mut self.joining_hs {
226 Some(meta) => {
227 debug_assert_eq!(meta.quic, quic);
228
229 let dst = &mut self.buf[meta.payload.end..meta.payload.end + payload.len()];
233 dst.copy_from_slice(payload);
234 meta.message.end = end;
235 meta.payload.end += payload.len();
236
237 if meta.expected_len.is_none() {
239 meta.expected_len =
240 payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
241 }
242
243 meta
244 }
245 None => {
246 let expected_len = payload_size(payload)?;
250 let dst = &mut self.buf[..payload.len()];
251 dst.copy_from_slice(payload);
252 self.joining_hs
253 .insert(HandshakePayloadMeta {
254 message: Range { start: 0, end },
255 payload: Range {
256 start: 0,
257 end: payload.len(),
258 },
259 version,
260 expected_len,
261 quic,
262 })
263 }
264 };
265
266 Ok(match meta.expected_len {
267 Some(len) if len <= meta.payload.len() => HandshakePayloadState::Complete(len),
268 _ => match self.used > meta.message.end {
269 true => HandshakePayloadState::Continue,
270 false => HandshakePayloadState::Blocked,
271 },
272 })
273 }
274
275 #[allow(clippy::comparison_chain)]
277 pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
278 if let Err(err) = self.prepare_read() {
279 return Err(io::Error::new(io::ErrorKind::InvalidData, err));
280 }
281
282 let new_bytes = rd.read(&mut self.buf[self.used..])?;
287 self.used += new_bytes;
288 Ok(new_bytes)
289 }
290
291 fn prepare_read(&mut self) -> Result<(), &'static str> {
293 let allow_max = match self.joining_hs {
301 Some(_) => MAX_HANDSHAKE_SIZE as usize,
302 None => OpaqueMessage::MAX_WIRE_SIZE,
303 };
304
305 if self.used >= allow_max {
306 return Err("message buffer full");
307 }
308
309 let need_capacity = Ord::min(allow_max, self.used + READ_SIZE);
315 if need_capacity > self.buf.len() {
316 self.buf.resize(need_capacity, 0);
317 } else if self.used == 0 || self.buf.len() > allow_max {
318 self.buf.resize(need_capacity, 0);
319 self.buf.shrink_to(need_capacity);
320 }
321
322 Ok(())
323 }
324
325 pub fn has_pending(&self) -> bool {
329 self.used > 0
330 }
331
332 fn discard(&mut self, taken: usize) {
334 #[allow(clippy::comparison_chain)]
335 if taken < self.used {
336 self.buf
350 .copy_within(taken..self.used, 0);
351 self.used -= taken;
352 } else if taken == self.used {
353 self.used = 0;
354 }
355 }
356}
357
358enum HandshakePayloadState {
359 Blocked,
361 Complete(usize),
363 Continue,
365}
366
367struct HandshakePayloadMeta {
368 message: Range<usize>,
372 payload: Range<usize>,
374 version: ProtocolVersion,
376 expected_len: Option<usize>,
381 quic: bool,
387}
388
389fn payload_size(buf: &[u8]) -> Result<Option<usize>, Error> {
395 if buf.len() < HEADER_SIZE {
396 return Ok(None);
397 }
398
399 let (header, _) = buf.split_at(HEADER_SIZE);
400 match codec::u24::read_bytes(&header[1..]) {
401 Ok(len) if len.0 > MAX_HANDSHAKE_SIZE => Err(Error::InvalidMessage(
402 InvalidMessage::HandshakePayloadTooLarge,
403 )),
404 Ok(len) => Ok(Some(HEADER_SIZE + usize::from(len))),
405 _ => Ok(None),
406 }
407}
408
409#[derive(Debug)]
410pub struct Deframed {
411 pub want_close_before_decrypt: bool,
412 pub aligned: bool,
413 pub trial_decryption_finished: bool,
414 pub message: PlainMessage,
415}
416
417#[derive(Debug)]
418pub enum DeframerError {
419 HandshakePayloadSizeTooLarge,
420}
421
422const HEADER_SIZE: usize = 1 + 3;
423
424const MAX_HANDSHAKE_SIZE: u32 = 0xffff;
428
429const READ_SIZE: usize = 4096;
430
431#[cfg(test)]
432mod tests {
433 use super::MessageDeframer;
434 use crate::msgs::message::{Message, OpaqueMessage};
435 use crate::record_layer::RecordLayer;
436 use crate::{ContentType, Error, InvalidMessage};
437
438 use std::io;
439
440 const FIRST_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.1.bin");
441 const SECOND_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.2.bin");
442
443 const EMPTY_APPLICATIONDATA_MESSAGE: &[u8] =
444 include_bytes!("../testdata/deframer-empty-applicationdata.bin");
445
446 const INVALID_EMPTY_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-empty.bin");
447 const INVALID_CONTENTTYPE_MESSAGE: &[u8] =
448 include_bytes!("../testdata/deframer-invalid-contenttype.bin");
449 const INVALID_VERSION_MESSAGE: &[u8] =
450 include_bytes!("../testdata/deframer-invalid-version.bin");
451 const INVALID_LENGTH_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-length.bin");
452
453 fn input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize> {
454 let mut rd = io::Cursor::new(bytes);
455 d.read(&mut rd)
456 }
457
458 fn input_bytes_concat(
459 d: &mut MessageDeframer,
460 bytes1: &[u8],
461 bytes2: &[u8],
462 ) -> io::Result<usize> {
463 let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
464 bytes[..bytes1.len()].clone_from_slice(bytes1);
465 bytes[bytes1.len()..].clone_from_slice(bytes2);
466 let mut rd = io::Cursor::new(&bytes);
467 d.read(&mut rd)
468 }
469
470 struct ErrorRead {
471 error: Option<io::Error>,
472 }
473
474 impl ErrorRead {
475 fn new(error: io::Error) -> Self {
476 Self { error: Some(error) }
477 }
478 }
479
480 impl io::Read for ErrorRead {
481 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
482 for (i, b) in buf.iter_mut().enumerate() {
483 *b = i as u8;
484 }
485
486 let error = self.error.take().unwrap();
487 Err(error)
488 }
489 }
490
491 fn input_error(d: &mut MessageDeframer) {
492 let error = io::Error::from(io::ErrorKind::TimedOut);
493 let mut rd = ErrorRead::new(error);
494 d.read(&mut rd)
495 .expect_err("error not propagated");
496 }
497
498 fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) {
499 let before = d.used;
500
501 for i in 0..bytes.len() {
502 assert_len(1, input_bytes(d, &bytes[i..i + 1]));
503 assert!(d.has_pending());
504 }
505
506 assert_eq!(before + bytes.len(), d.used);
507 }
508
509 fn assert_len(want: usize, got: io::Result<usize>) {
510 if let Ok(gotval) = got {
511 assert_eq!(gotval, want);
512 } else {
513 panic!("read failed, expected {:?} bytes", want);
514 }
515 }
516
517 fn pop_first(d: &mut MessageDeframer, rl: &mut RecordLayer) {
518 let m = d.pop(rl).unwrap().unwrap().message;
519 assert_eq!(m.typ, ContentType::Handshake);
520 Message::try_from(m).unwrap();
521 }
522
523 fn pop_second(d: &mut MessageDeframer, rl: &mut RecordLayer) {
524 let m = d.pop(rl).unwrap().unwrap().message;
525 assert_eq!(m.typ, ContentType::Alert);
526 Message::try_from(m).unwrap();
527 }
528
529 #[test]
530 fn check_incremental() {
531 let mut d = MessageDeframer::default();
532 assert!(!d.has_pending());
533 input_whole_incremental(&mut d, FIRST_MESSAGE);
534 assert!(d.has_pending());
535
536 let mut rl = RecordLayer::new();
537 pop_first(&mut d, &mut rl);
538 assert!(!d.has_pending());
539 assert!(d.last_error.is_none());
540 }
541
542 #[test]
543 fn check_incremental_2() {
544 let mut d = MessageDeframer::default();
545 assert!(!d.has_pending());
546 input_whole_incremental(&mut d, FIRST_MESSAGE);
547 assert!(d.has_pending());
548 input_whole_incremental(&mut d, SECOND_MESSAGE);
549 assert!(d.has_pending());
550
551 let mut rl = RecordLayer::new();
552 pop_first(&mut d, &mut rl);
553 assert!(d.has_pending());
554 pop_second(&mut d, &mut rl);
555 assert!(!d.has_pending());
556 assert!(d.last_error.is_none());
557 }
558
559 #[test]
560 fn check_whole() {
561 let mut d = MessageDeframer::default();
562 assert!(!d.has_pending());
563 assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
564 assert!(d.has_pending());
565
566 let mut rl = RecordLayer::new();
567 pop_first(&mut d, &mut rl);
568 assert!(!d.has_pending());
569 assert!(d.last_error.is_none());
570 }
571
572 #[test]
573 fn check_whole_2() {
574 let mut d = MessageDeframer::default();
575 assert!(!d.has_pending());
576 assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
577 assert_len(SECOND_MESSAGE.len(), input_bytes(&mut d, SECOND_MESSAGE));
578
579 let mut rl = RecordLayer::new();
580 pop_first(&mut d, &mut rl);
581 pop_second(&mut d, &mut rl);
582 assert!(!d.has_pending());
583 assert!(d.last_error.is_none());
584 }
585
586 #[test]
587 fn test_two_in_one_read() {
588 let mut d = MessageDeframer::default();
589 assert!(!d.has_pending());
590 assert_len(
591 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
592 input_bytes_concat(&mut d, FIRST_MESSAGE, SECOND_MESSAGE),
593 );
594
595 let mut rl = RecordLayer::new();
596 pop_first(&mut d, &mut rl);
597 pop_second(&mut d, &mut rl);
598 assert!(!d.has_pending());
599 assert!(d.last_error.is_none());
600 }
601
602 #[test]
603 fn test_two_in_one_read_shortest_first() {
604 let mut d = MessageDeframer::default();
605 assert!(!d.has_pending());
606 assert_len(
607 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
608 input_bytes_concat(&mut d, SECOND_MESSAGE, FIRST_MESSAGE),
609 );
610
611 let mut rl = RecordLayer::new();
612 pop_second(&mut d, &mut rl);
613 pop_first(&mut d, &mut rl);
614 assert!(!d.has_pending());
615 assert!(d.last_error.is_none());
616 }
617
618 #[test]
619 fn test_incremental_with_nonfatal_read_error() {
620 let mut d = MessageDeframer::default();
621 assert_len(3, input_bytes(&mut d, &FIRST_MESSAGE[..3]));
622 input_error(&mut d);
623 assert_len(
624 FIRST_MESSAGE.len() - 3,
625 input_bytes(&mut d, &FIRST_MESSAGE[3..]),
626 );
627
628 let mut rl = RecordLayer::new();
629 pop_first(&mut d, &mut rl);
630 assert!(!d.has_pending());
631 assert!(d.last_error.is_none());
632 }
633
634 #[test]
635 fn test_invalid_contenttype_errors() {
636 let mut d = MessageDeframer::default();
637 assert_len(
638 INVALID_CONTENTTYPE_MESSAGE.len(),
639 input_bytes(&mut d, INVALID_CONTENTTYPE_MESSAGE),
640 );
641
642 let mut rl = RecordLayer::new();
643 assert_eq!(
644 d.pop(&mut rl).unwrap_err(),
645 Error::InvalidMessage(InvalidMessage::InvalidContentType)
646 );
647 }
648
649 #[test]
650 fn test_invalid_version_errors() {
651 let mut d = MessageDeframer::default();
652 assert_len(
653 INVALID_VERSION_MESSAGE.len(),
654 input_bytes(&mut d, INVALID_VERSION_MESSAGE),
655 );
656
657 let mut rl = RecordLayer::new();
658 assert_eq!(
659 d.pop(&mut rl).unwrap_err(),
660 Error::InvalidMessage(InvalidMessage::UnknownProtocolVersion)
661 );
662 }
663
664 #[test]
665 fn test_invalid_length_errors() {
666 let mut d = MessageDeframer::default();
667 assert_len(
668 INVALID_LENGTH_MESSAGE.len(),
669 input_bytes(&mut d, INVALID_LENGTH_MESSAGE),
670 );
671
672 let mut rl = RecordLayer::new();
673 assert_eq!(
674 d.pop(&mut rl).unwrap_err(),
675 Error::InvalidMessage(InvalidMessage::MessageTooLarge)
676 );
677 }
678
679 #[test]
680 fn test_empty_applicationdata() {
681 let mut d = MessageDeframer::default();
682 assert_len(
683 EMPTY_APPLICATIONDATA_MESSAGE.len(),
684 input_bytes(&mut d, EMPTY_APPLICATIONDATA_MESSAGE),
685 );
686
687 let mut rl = RecordLayer::new();
688 let m = d.pop(&mut rl).unwrap().unwrap().message;
689 assert_eq!(m.typ, ContentType::ApplicationData);
690 assert_eq!(m.payload.0.len(), 0);
691 assert!(!d.has_pending());
692 assert!(d.last_error.is_none());
693 }
694
695 #[test]
696 fn test_invalid_empty_errors() {
697 let mut d = MessageDeframer::default();
698 assert_len(
699 INVALID_EMPTY_MESSAGE.len(),
700 input_bytes(&mut d, INVALID_EMPTY_MESSAGE),
701 );
702
703 let mut rl = RecordLayer::new();
704 assert_eq!(
705 d.pop(&mut rl).unwrap_err(),
706 Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
707 );
708 assert_eq!(
710 d.pop(&mut rl).unwrap_err(),
711 Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
712 );
713 }
714
715 #[test]
716 fn test_limited_buffer() {
717 const PAYLOAD_LEN: usize = 16_384;
718 let mut message = Vec::with_capacity(16_389);
719 message.push(0x17); message.extend(&[0x03, 0x04]); message.extend((PAYLOAD_LEN as u16).to_be_bytes()); message.extend(&[0; PAYLOAD_LEN]);
723
724 let mut d = MessageDeframer::default();
725 assert_len(4096, input_bytes(&mut d, &message));
726 assert_len(4096, input_bytes(&mut d, &message));
727 assert_len(4096, input_bytes(&mut d, &message));
728 assert_len(4096, input_bytes(&mut d, &message));
729 assert_len(
730 OpaqueMessage::MAX_WIRE_SIZE - 16_384,
731 input_bytes(&mut d, &message),
732 );
733 assert!(input_bytes(&mut d, &message).is_err());
734 }
735}