h2/frame/
headers.rs

1use super::{util, StreamDependency, StreamId};
2use crate::ext::Protocol;
3use crate::frame::{Error, Frame, Head, Kind};
4use crate::hpack::{self, BytesStr};
5
6use http::header::{self, HeaderName, HeaderValue};
7use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9use bytes::{BufMut, Bytes, BytesMut};
10
11use std::fmt;
12use std::io::Cursor;
13
14type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
15
16/// Header frame
17///
18/// This could be either a request or a response.
19#[derive(Eq, PartialEq)]
20pub struct Headers {
21    /// The ID of the stream with which this frame is associated.
22    stream_id: StreamId,
23
24    /// The stream dependency information, if any.
25    stream_dep: Option<StreamDependency>,
26
27    /// The header block fragment
28    header_block: HeaderBlock,
29
30    /// The associated flags
31    flags: HeadersFlag,
32}
33
34#[derive(Copy, Clone, Eq, PartialEq)]
35pub struct HeadersFlag(u8);
36
37#[derive(Eq, PartialEq)]
38pub struct PushPromise {
39    /// The ID of the stream with which this frame is associated.
40    stream_id: StreamId,
41
42    /// The ID of the stream being reserved by this PushPromise.
43    promised_id: StreamId,
44
45    /// The header block fragment
46    header_block: HeaderBlock,
47
48    /// The associated flags
49    flags: PushPromiseFlag,
50}
51
52#[derive(Copy, Clone, Eq, PartialEq)]
53pub struct PushPromiseFlag(u8);
54
55#[derive(Debug)]
56pub struct Continuation {
57    /// Stream ID of continuation frame
58    stream_id: StreamId,
59
60    header_block: EncodingHeaderBlock,
61}
62
63// TODO: These fields shouldn't be `pub`
64#[derive(Debug, Default, Eq, PartialEq)]
65pub struct Pseudo {
66    // Request
67    pub method: Option<Method>,
68    pub scheme: Option<BytesStr>,
69    pub authority: Option<BytesStr>,
70    pub path: Option<BytesStr>,
71    pub protocol: Option<Protocol>,
72
73    // Response
74    pub status: Option<StatusCode>,
75}
76
77#[derive(Debug)]
78pub struct Iter {
79    /// Pseudo headers
80    pseudo: Option<Pseudo>,
81
82    /// Header fields
83    fields: header::IntoIter<HeaderValue>,
84}
85
86#[derive(Debug, PartialEq, Eq)]
87struct HeaderBlock {
88    /// The decoded header fields
89    fields: HeaderMap,
90
91    /// Precomputed size of all of our header fields, for perf reasons
92    field_size: usize,
93
94    /// Set to true if decoding went over the max header list size.
95    is_over_size: bool,
96
97    /// Pseudo headers, these are broken out as they must be sent as part of the
98    /// headers frame.
99    pseudo: Pseudo,
100}
101
102#[derive(Debug)]
103struct EncodingHeaderBlock {
104    hpack: Bytes,
105}
106
107const END_STREAM: u8 = 0x1;
108const END_HEADERS: u8 = 0x4;
109const PADDED: u8 = 0x8;
110const PRIORITY: u8 = 0x20;
111const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
112
113// ===== impl Headers =====
114
115impl Headers {
116    /// Create a new HEADERS frame
117    pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
118        Headers {
119            stream_id,
120            stream_dep: None,
121            header_block: HeaderBlock {
122                field_size: calculate_headermap_size(&fields),
123                fields,
124                is_over_size: false,
125                pseudo,
126            },
127            flags: HeadersFlag::default(),
128        }
129    }
130
131    pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
132        let mut flags = HeadersFlag::default();
133        flags.set_end_stream();
134
135        Headers {
136            stream_id,
137            stream_dep: None,
138            header_block: HeaderBlock {
139                field_size: calculate_headermap_size(&fields),
140                fields,
141                is_over_size: false,
142                pseudo: Pseudo::default(),
143            },
144            flags,
145        }
146    }
147
148    /// Loads the header frame but doesn't actually do HPACK decoding.
149    ///
150    /// HPACK decoding is done in the `load_hpack` step.
151    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
152        let flags = HeadersFlag(head.flag());
153        let mut pad = 0;
154
155        tracing::trace!("loading headers; flags={:?}", flags);
156
157        if head.stream_id().is_zero() {
158            return Err(Error::InvalidStreamId);
159        }
160
161        // Read the padding length
162        if flags.is_padded() {
163            if src.is_empty() {
164                return Err(Error::MalformedMessage);
165            }
166            pad = src[0] as usize;
167
168            // Drop the padding
169            let _ = src.split_to(1);
170        }
171
172        // Read the stream dependency
173        let stream_dep = if flags.is_priority() {
174            if src.len() < 5 {
175                return Err(Error::MalformedMessage);
176            }
177            let stream_dep = StreamDependency::load(&src[..5])?;
178
179            if stream_dep.dependency_id() == head.stream_id() {
180                return Err(Error::InvalidDependencyId);
181            }
182
183            // Drop the next 5 bytes
184            let _ = src.split_to(5);
185
186            Some(stream_dep)
187        } else {
188            None
189        };
190
191        if pad > 0 {
192            if pad > src.len() {
193                return Err(Error::TooMuchPadding);
194            }
195
196            let len = src.len() - pad;
197            src.truncate(len);
198        }
199
200        let headers = Headers {
201            stream_id: head.stream_id(),
202            stream_dep,
203            header_block: HeaderBlock {
204                fields: HeaderMap::new(),
205                field_size: 0,
206                is_over_size: false,
207                pseudo: Pseudo::default(),
208            },
209            flags,
210        };
211
212        Ok((headers, src))
213    }
214
215    pub fn load_hpack(
216        &mut self,
217        src: &mut BytesMut,
218        max_header_list_size: usize,
219        decoder: &mut hpack::Decoder,
220    ) -> Result<(), Error> {
221        self.header_block.load(src, max_header_list_size, decoder)
222    }
223
224    pub fn stream_id(&self) -> StreamId {
225        self.stream_id
226    }
227
228    pub fn is_end_headers(&self) -> bool {
229        self.flags.is_end_headers()
230    }
231
232    pub fn set_end_headers(&mut self) {
233        self.flags.set_end_headers();
234    }
235
236    pub fn is_end_stream(&self) -> bool {
237        self.flags.is_end_stream()
238    }
239
240    pub fn set_end_stream(&mut self) {
241        self.flags.set_end_stream()
242    }
243
244    pub fn is_over_size(&self) -> bool {
245        self.header_block.is_over_size
246    }
247
248    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
249        (self.header_block.pseudo, self.header_block.fields)
250    }
251
252    #[cfg(feature = "unstable")]
253    pub fn pseudo_mut(&mut self) -> &mut Pseudo {
254        &mut self.header_block.pseudo
255    }
256
257    /// Whether it has status 1xx
258    pub(crate) fn is_informational(&self) -> bool {
259        self.header_block.pseudo.is_informational()
260    }
261
262    pub fn fields(&self) -> &HeaderMap {
263        &self.header_block.fields
264    }
265
266    pub fn into_fields(self) -> HeaderMap {
267        self.header_block.fields
268    }
269
270    pub fn encode(
271        self,
272        encoder: &mut hpack::Encoder,
273        dst: &mut EncodeBuf<'_>,
274    ) -> Option<Continuation> {
275        // At this point, the `is_end_headers` flag should always be set
276        debug_assert!(self.flags.is_end_headers());
277
278        // Get the HEADERS frame head
279        let head = self.head();
280
281        self.header_block
282            .into_encoding(encoder)
283            .encode(&head, dst, |_| {})
284    }
285
286    fn head(&self) -> Head {
287        Head::new(Kind::Headers, self.flags.into(), self.stream_id)
288    }
289}
290
291impl<T> From<Headers> for Frame<T> {
292    fn from(src: Headers) -> Self {
293        Frame::Headers(src)
294    }
295}
296
297impl fmt::Debug for Headers {
298    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299        let mut builder = f.debug_struct("Headers");
300        builder
301            .field("stream_id", &self.stream_id)
302            .field("flags", &self.flags);
303
304        if let Some(ref protocol) = self.header_block.pseudo.protocol {
305            builder.field("protocol", protocol);
306        }
307
308        if let Some(ref dep) = self.stream_dep {
309            builder.field("stream_dep", dep);
310        }
311
312        // `fields` and `pseudo` purposefully not included
313        builder.finish()
314    }
315}
316
317// ===== util =====
318
319#[derive(Debug, PartialEq, Eq)]
320pub struct ParseU64Error;
321
322pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
323    if src.len() > 19 {
324        // At danger for overflow...
325        return Err(ParseU64Error);
326    }
327
328    let mut ret = 0;
329
330    for &d in src {
331        if d < b'0' || d > b'9' {
332            return Err(ParseU64Error);
333        }
334
335        ret *= 10;
336        ret += (d - b'0') as u64;
337    }
338
339    Ok(ret)
340}
341
342// ===== impl PushPromise =====
343
344#[derive(Debug)]
345pub enum PushPromiseHeaderError {
346    InvalidContentLength(Result<u64, ParseU64Error>),
347    NotSafeAndCacheable,
348}
349
350impl PushPromise {
351    pub fn new(
352        stream_id: StreamId,
353        promised_id: StreamId,
354        pseudo: Pseudo,
355        fields: HeaderMap,
356    ) -> Self {
357        PushPromise {
358            flags: PushPromiseFlag::default(),
359            header_block: HeaderBlock {
360                field_size: calculate_headermap_size(&fields),
361                fields,
362                is_over_size: false,
363                pseudo,
364            },
365            promised_id,
366            stream_id,
367        }
368    }
369
370    pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
371        use PushPromiseHeaderError::*;
372        // The spec has some requirements for promised request headers
373        // [https://httpwg.org/specs/rfc7540.html#PushRequests]
374
375        // A promised request "that indicates the presence of a request body
376        // MUST reset the promised stream with a stream error"
377        if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
378            let parsed_length = parse_u64(content_length.as_bytes());
379            if parsed_length != Ok(0) {
380                return Err(InvalidContentLength(parsed_length));
381            }
382        }
383        // "The server MUST include a method in the :method pseudo-header field
384        // that is safe and cacheable"
385        if !Self::safe_and_cacheable(req.method()) {
386            return Err(NotSafeAndCacheable);
387        }
388
389        Ok(())
390    }
391
392    fn safe_and_cacheable(method: &Method) -> bool {
393        // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
394        // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
395        method == Method::GET || method == Method::HEAD
396    }
397
398    pub fn fields(&self) -> &HeaderMap {
399        &self.header_block.fields
400    }
401
402    #[cfg(feature = "unstable")]
403    pub fn into_fields(self) -> HeaderMap {
404        self.header_block.fields
405    }
406
407    /// Loads the push promise frame but doesn't actually do HPACK decoding.
408    ///
409    /// HPACK decoding is done in the `load_hpack` step.
410    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
411        let flags = PushPromiseFlag(head.flag());
412        let mut pad = 0;
413
414        if head.stream_id().is_zero() {
415            return Err(Error::InvalidStreamId);
416        }
417
418        // Read the padding length
419        if flags.is_padded() {
420            if src.is_empty() {
421                return Err(Error::MalformedMessage);
422            }
423
424            // TODO: Ensure payload is sized correctly
425            pad = src[0] as usize;
426
427            // Drop the padding
428            let _ = src.split_to(1);
429        }
430
431        if src.len() < 5 {
432            return Err(Error::MalformedMessage);
433        }
434
435        let (promised_id, _) = StreamId::parse(&src[..4]);
436        // Drop promised_id bytes
437        let _ = src.split_to(4);
438
439        if pad > 0 {
440            if pad > src.len() {
441                return Err(Error::TooMuchPadding);
442            }
443
444            let len = src.len() - pad;
445            src.truncate(len);
446        }
447
448        let frame = PushPromise {
449            flags,
450            header_block: HeaderBlock {
451                fields: HeaderMap::new(),
452                field_size: 0,
453                is_over_size: false,
454                pseudo: Pseudo::default(),
455            },
456            promised_id,
457            stream_id: head.stream_id(),
458        };
459        Ok((frame, src))
460    }
461
462    pub fn load_hpack(
463        &mut self,
464        src: &mut BytesMut,
465        max_header_list_size: usize,
466        decoder: &mut hpack::Decoder,
467    ) -> Result<(), Error> {
468        self.header_block.load(src, max_header_list_size, decoder)
469    }
470
471    pub fn stream_id(&self) -> StreamId {
472        self.stream_id
473    }
474
475    pub fn promised_id(&self) -> StreamId {
476        self.promised_id
477    }
478
479    pub fn is_end_headers(&self) -> bool {
480        self.flags.is_end_headers()
481    }
482
483    pub fn set_end_headers(&mut self) {
484        self.flags.set_end_headers();
485    }
486
487    pub fn is_over_size(&self) -> bool {
488        self.header_block.is_over_size
489    }
490
491    pub fn encode(
492        self,
493        encoder: &mut hpack::Encoder,
494        dst: &mut EncodeBuf<'_>,
495    ) -> Option<Continuation> {
496        // At this point, the `is_end_headers` flag should always be set
497        debug_assert!(self.flags.is_end_headers());
498
499        let head = self.head();
500        let promised_id = self.promised_id;
501
502        self.header_block
503            .into_encoding(encoder)
504            .encode(&head, dst, |dst| {
505                dst.put_u32(promised_id.into());
506            })
507    }
508
509    fn head(&self) -> Head {
510        Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
511    }
512
513    /// Consume `self`, returning the parts of the frame
514    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
515        (self.header_block.pseudo, self.header_block.fields)
516    }
517}
518
519impl<T> From<PushPromise> for Frame<T> {
520    fn from(src: PushPromise) -> Self {
521        Frame::PushPromise(src)
522    }
523}
524
525impl fmt::Debug for PushPromise {
526    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
527        f.debug_struct("PushPromise")
528            .field("stream_id", &self.stream_id)
529            .field("promised_id", &self.promised_id)
530            .field("flags", &self.flags)
531            // `fields` and `pseudo` purposefully not included
532            .finish()
533    }
534}
535
536// ===== impl Continuation =====
537
538impl Continuation {
539    fn head(&self) -> Head {
540        Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
541    }
542
543    pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
544        // Get the CONTINUATION frame head
545        let head = self.head();
546
547        self.header_block.encode(&head, dst, |_| {})
548    }
549}
550
551// ===== impl Pseudo =====
552
553impl Pseudo {
554    pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
555        let parts = uri::Parts::from(uri);
556
557        let mut path = parts
558            .path_and_query
559            .map(|v| BytesStr::from(v.as_str()))
560            .unwrap_or(BytesStr::from_static(""));
561
562        match method {
563            Method::OPTIONS | Method::CONNECT => {}
564            _ if path.is_empty() => {
565                path = BytesStr::from_static("/");
566            }
567            _ => {}
568        }
569
570        let mut pseudo = Pseudo {
571            method: Some(method),
572            scheme: None,
573            authority: None,
574            path: Some(path).filter(|p| !p.is_empty()),
575            protocol,
576            status: None,
577        };
578
579        // If the URI includes a scheme component, add it to the pseudo headers
580        //
581        // TODO: Scheme must be set...
582        if let Some(scheme) = parts.scheme {
583            pseudo.set_scheme(scheme);
584        }
585
586        // If the URI includes an authority component, add it to the pseudo
587        // headers
588        if let Some(authority) = parts.authority {
589            pseudo.set_authority(BytesStr::from(authority.as_str()));
590        }
591
592        pseudo
593    }
594
595    pub fn response(status: StatusCode) -> Self {
596        Pseudo {
597            method: None,
598            scheme: None,
599            authority: None,
600            path: None,
601            protocol: None,
602            status: Some(status),
603        }
604    }
605
606    #[cfg(feature = "unstable")]
607    pub fn set_status(&mut self, value: StatusCode) {
608        self.status = Some(value);
609    }
610
611    pub fn set_scheme(&mut self, scheme: uri::Scheme) {
612        let bytes_str = match scheme.as_str() {
613            "http" => BytesStr::from_static("http"),
614            "https" => BytesStr::from_static("https"),
615            s => BytesStr::from(s),
616        };
617        self.scheme = Some(bytes_str);
618    }
619
620    #[cfg(feature = "unstable")]
621    pub fn set_protocol(&mut self, protocol: Protocol) {
622        self.protocol = Some(protocol);
623    }
624
625    pub fn set_authority(&mut self, authority: BytesStr) {
626        self.authority = Some(authority);
627    }
628
629    /// Whether it has status 1xx
630    pub(crate) fn is_informational(&self) -> bool {
631        self.status
632            .map_or(false, |status| status.is_informational())
633    }
634}
635
636// ===== impl EncodingHeaderBlock =====
637
638impl EncodingHeaderBlock {
639    fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
640    where
641        F: FnOnce(&mut EncodeBuf<'_>),
642    {
643        let head_pos = dst.get_ref().len();
644
645        // At this point, we don't know how big the h2 frame will be.
646        // So, we write the head with length 0, then write the body, and
647        // finally write the length once we know the size.
648        head.encode(0, dst);
649
650        let payload_pos = dst.get_ref().len();
651
652        f(dst);
653
654        // Now, encode the header payload
655        let continuation = if self.hpack.len() > dst.remaining_mut() {
656            dst.put_slice(&self.hpack.split_to(dst.remaining_mut()));
657
658            Some(Continuation {
659                stream_id: head.stream_id(),
660                header_block: self,
661            })
662        } else {
663            dst.put_slice(&self.hpack);
664
665            None
666        };
667
668        // Compute the header block length
669        let payload_len = (dst.get_ref().len() - payload_pos) as u64;
670
671        // Write the frame length
672        let payload_len_be = payload_len.to_be_bytes();
673        assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
674        (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
675
676        if continuation.is_some() {
677            // There will be continuation frames, so the `is_end_headers` flag
678            // must be unset
679            debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
680
681            dst.get_mut()[head_pos + 4] -= END_HEADERS;
682        }
683
684        continuation
685    }
686}
687
688// ===== impl Iter =====
689
690impl Iterator for Iter {
691    type Item = hpack::Header<Option<HeaderName>>;
692
693    fn next(&mut self) -> Option<Self::Item> {
694        use crate::hpack::Header::*;
695
696        if let Some(ref mut pseudo) = self.pseudo {
697            if let Some(method) = pseudo.method.take() {
698                return Some(Method(method));
699            }
700
701            if let Some(scheme) = pseudo.scheme.take() {
702                return Some(Scheme(scheme));
703            }
704
705            if let Some(authority) = pseudo.authority.take() {
706                return Some(Authority(authority));
707            }
708
709            if let Some(path) = pseudo.path.take() {
710                return Some(Path(path));
711            }
712
713            if let Some(protocol) = pseudo.protocol.take() {
714                return Some(Protocol(protocol));
715            }
716
717            if let Some(status) = pseudo.status.take() {
718                return Some(Status(status));
719            }
720        }
721
722        self.pseudo = None;
723
724        self.fields
725            .next()
726            .map(|(name, value)| Field { name, value })
727    }
728}
729
730// ===== impl HeadersFlag =====
731
732impl HeadersFlag {
733    pub fn empty() -> HeadersFlag {
734        HeadersFlag(0)
735    }
736
737    pub fn load(bits: u8) -> HeadersFlag {
738        HeadersFlag(bits & ALL)
739    }
740
741    pub fn is_end_stream(&self) -> bool {
742        self.0 & END_STREAM == END_STREAM
743    }
744
745    pub fn set_end_stream(&mut self) {
746        self.0 |= END_STREAM;
747    }
748
749    pub fn is_end_headers(&self) -> bool {
750        self.0 & END_HEADERS == END_HEADERS
751    }
752
753    pub fn set_end_headers(&mut self) {
754        self.0 |= END_HEADERS;
755    }
756
757    pub fn is_padded(&self) -> bool {
758        self.0 & PADDED == PADDED
759    }
760
761    pub fn is_priority(&self) -> bool {
762        self.0 & PRIORITY == PRIORITY
763    }
764}
765
766impl Default for HeadersFlag {
767    /// Returns a `HeadersFlag` value with `END_HEADERS` set.
768    fn default() -> Self {
769        HeadersFlag(END_HEADERS)
770    }
771}
772
773impl From<HeadersFlag> for u8 {
774    fn from(src: HeadersFlag) -> u8 {
775        src.0
776    }
777}
778
779impl fmt::Debug for HeadersFlag {
780    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
781        util::debug_flags(fmt, self.0)
782            .flag_if(self.is_end_headers(), "END_HEADERS")
783            .flag_if(self.is_end_stream(), "END_STREAM")
784            .flag_if(self.is_padded(), "PADDED")
785            .flag_if(self.is_priority(), "PRIORITY")
786            .finish()
787    }
788}
789
790// ===== impl PushPromiseFlag =====
791
792impl PushPromiseFlag {
793    pub fn empty() -> PushPromiseFlag {
794        PushPromiseFlag(0)
795    }
796
797    pub fn load(bits: u8) -> PushPromiseFlag {
798        PushPromiseFlag(bits & ALL)
799    }
800
801    pub fn is_end_headers(&self) -> bool {
802        self.0 & END_HEADERS == END_HEADERS
803    }
804
805    pub fn set_end_headers(&mut self) {
806        self.0 |= END_HEADERS;
807    }
808
809    pub fn is_padded(&self) -> bool {
810        self.0 & PADDED == PADDED
811    }
812}
813
814impl Default for PushPromiseFlag {
815    /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
816    fn default() -> Self {
817        PushPromiseFlag(END_HEADERS)
818    }
819}
820
821impl From<PushPromiseFlag> for u8 {
822    fn from(src: PushPromiseFlag) -> u8 {
823        src.0
824    }
825}
826
827impl fmt::Debug for PushPromiseFlag {
828    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
829        util::debug_flags(fmt, self.0)
830            .flag_if(self.is_end_headers(), "END_HEADERS")
831            .flag_if(self.is_padded(), "PADDED")
832            .finish()
833    }
834}
835
836// ===== HeaderBlock =====
837
838impl HeaderBlock {
839    fn load(
840        &mut self,
841        src: &mut BytesMut,
842        max_header_list_size: usize,
843        decoder: &mut hpack::Decoder,
844    ) -> Result<(), Error> {
845        let mut reg = !self.fields.is_empty();
846        let mut malformed = false;
847        let mut headers_size = self.calculate_header_list_size();
848
849        macro_rules! set_pseudo {
850            ($field:ident, $val:expr) => {{
851                if reg {
852                    tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
853                    malformed = true;
854                } else if self.pseudo.$field.is_some() {
855                    tracing::trace!("load_hpack; header malformed -- repeated pseudo");
856                    malformed = true;
857                } else {
858                    let __val = $val;
859                    headers_size +=
860                        decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
861                    if headers_size < max_header_list_size {
862                        self.pseudo.$field = Some(__val);
863                    } else if !self.is_over_size {
864                        tracing::trace!("load_hpack; header list size over max");
865                        self.is_over_size = true;
866                    }
867                }
868            }};
869        }
870
871        let mut cursor = Cursor::new(src);
872
873        // If the header frame is malformed, we still have to continue decoding
874        // the headers. A malformed header frame is a stream level error, but
875        // the hpack state is connection level. In order to maintain correct
876        // state for other streams, the hpack decoding process must complete.
877        let res = decoder.decode(&mut cursor, |header| {
878            use crate::hpack::Header::*;
879
880            match header {
881                Field { name, value } => {
882                    // Connection level header fields are not supported and must
883                    // result in a protocol error.
884
885                    if name == header::CONNECTION
886                        || name == header::TRANSFER_ENCODING
887                        || name == header::UPGRADE
888                        || name == "keep-alive"
889                        || name == "proxy-connection"
890                    {
891                        tracing::trace!("load_hpack; connection level header");
892                        malformed = true;
893                    } else if name == header::TE && value != "trailers" {
894                        tracing::trace!(
895                            "load_hpack; TE header not set to trailers; val={:?}",
896                            value
897                        );
898                        malformed = true;
899                    } else {
900                        reg = true;
901
902                        headers_size += decoded_header_size(name.as_str().len(), value.len());
903                        if headers_size < max_header_list_size {
904                            self.field_size +=
905                                decoded_header_size(name.as_str().len(), value.len());
906                            self.fields.append(name, value);
907                        } else if !self.is_over_size {
908                            tracing::trace!("load_hpack; header list size over max");
909                            self.is_over_size = true;
910                        }
911                    }
912                }
913                Authority(v) => set_pseudo!(authority, v),
914                Method(v) => set_pseudo!(method, v),
915                Scheme(v) => set_pseudo!(scheme, v),
916                Path(v) => set_pseudo!(path, v),
917                Protocol(v) => set_pseudo!(protocol, v),
918                Status(v) => set_pseudo!(status, v),
919            }
920        });
921
922        if let Err(e) = res {
923            tracing::trace!("hpack decoding error; err={:?}", e);
924            return Err(e.into());
925        }
926
927        if malformed {
928            tracing::trace!("malformed message");
929            return Err(Error::MalformedMessage);
930        }
931
932        Ok(())
933    }
934
935    fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
936        let mut hpack = BytesMut::new();
937        let headers = Iter {
938            pseudo: Some(self.pseudo),
939            fields: self.fields.into_iter(),
940        };
941
942        encoder.encode(headers, &mut hpack);
943
944        EncodingHeaderBlock {
945            hpack: hpack.freeze(),
946        }
947    }
948
949    /// Calculates the size of the currently decoded header list.
950    ///
951    /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
952    ///
953    /// > The value is based on the uncompressed size of header fields,
954    /// > including the length of the name and value in octets plus an
955    /// > overhead of 32 octets for each header field.
956    fn calculate_header_list_size(&self) -> usize {
957        macro_rules! pseudo_size {
958            ($name:ident) => {{
959                self.pseudo
960                    .$name
961                    .as_ref()
962                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
963                    .unwrap_or(0)
964            }};
965        }
966
967        pseudo_size!(method)
968            + pseudo_size!(scheme)
969            + pseudo_size!(status)
970            + pseudo_size!(authority)
971            + pseudo_size!(path)
972            + self.field_size
973    }
974}
975
976fn calculate_headermap_size(map: &HeaderMap) -> usize {
977    map.iter()
978        .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
979        .sum::<usize>()
980}
981
982fn decoded_header_size(name: usize, value: usize) -> usize {
983    name + value + 32
984}
985
986#[cfg(test)]
987mod test {
988    use std::iter::FromIterator;
989
990    use super::*;
991    use crate::frame;
992    use crate::hpack::{huffman, Encoder};
993
994    #[test]
995    fn test_nameless_header_at_resume() {
996        let mut encoder = Encoder::default();
997        let mut dst = BytesMut::new();
998
999        let headers = Headers::new(
1000            StreamId::ZERO,
1001            Default::default(),
1002            HeaderMap::from_iter(vec![
1003                (
1004                    HeaderName::from_static("hello"),
1005                    HeaderValue::from_static("world"),
1006                ),
1007                (
1008                    HeaderName::from_static("hello"),
1009                    HeaderValue::from_static("zomg"),
1010                ),
1011                (
1012                    HeaderName::from_static("hello"),
1013                    HeaderValue::from_static("sup"),
1014                ),
1015            ]),
1016        );
1017
1018        let continuation = headers
1019            .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1020            .unwrap();
1021
1022        assert_eq!(17, dst.len());
1023        assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1024        assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1025        assert_eq!("hello", huff_decode(&dst[11..15]));
1026        assert_eq!(0x80 | 4, dst[15]);
1027
1028        let mut world = dst[16..17].to_owned();
1029
1030        dst.clear();
1031
1032        assert!(continuation
1033            .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1034            .is_none());
1035
1036        world.extend_from_slice(&dst[9..12]);
1037        assert_eq!("world", huff_decode(&world));
1038
1039        assert_eq!(24, dst.len());
1040        assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1041
1042        // // Next is not indexed
1043        assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1044        assert_eq!("zomg", huff_decode(&dst[15..18]));
1045        assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1046        assert_eq!("sup", huff_decode(&dst[21..]));
1047    }
1048
1049    fn huff_decode(src: &[u8]) -> BytesMut {
1050        let mut buf = BytesMut::new();
1051        huffman::decode(src, &mut buf).unwrap()
1052    }
1053}