diff --git a/src/main/java/io/netty/incubator/codec/http3/Http3Headers.java b/src/main/java/io/netty/incubator/codec/http3/Http3Headers.java index 0a5711d..4546971 100644 --- a/src/main/java/io/netty/incubator/codec/http3/Http3Headers.java +++ b/src/main/java/io/netty/incubator/codec/http3/Http3Headers.java @@ -31,33 +31,35 @@ enum PseudoHeaderName { /** * {@code :method}. */ - METHOD(":method", true), + METHOD(":method", true, 0x1), /** * {@code :scheme}. */ - SCHEME(":scheme", true), + SCHEME(":scheme", true, 0x2), /** * {@code :authority}. */ - AUTHORITY(":authority", true), + AUTHORITY(":authority", true, 0x4), /** * {@code :path}. */ - PATH(":path", true), + PATH(":path", true, 0x8), /** * {@code :status}. */ - STATUS(":status", false); + STATUS(":status", false, 0x10); private static final char PSEUDO_HEADER_PREFIX = ':'; private static final byte PSEUDO_HEADER_PREFIX_BYTE = (byte) PSEUDO_HEADER_PREFIX; private final AsciiString value; private final boolean requestOnly; + // The position of the bit in the flag indicates the type of the header field + private final int flag; private static final CharSequenceMap PSEUDO_HEADERS = new CharSequenceMap(); static { @@ -66,9 +68,10 @@ enum PseudoHeaderName { } } - PseudoHeaderName(String value, boolean requestOnly) { + PseudoHeaderName(String value, boolean requestOnly, int flag) { this.value = AsciiString.cached(value); this.requestOnly = requestOnly; + this.flag = flag; } public AsciiString value() { @@ -120,6 +123,10 @@ public static PseudoHeaderName getPseudoHeader(CharSequence name) { public boolean isRequestOnly() { return requestOnly; } + + public int getFlag() { + return flag; + } } /** diff --git a/src/main/java/io/netty/incubator/codec/http3/Http3HeadersSink.java b/src/main/java/io/netty/incubator/codec/http3/Http3HeadersSink.java index b875373..1446e1b 100644 --- a/src/main/java/io/netty/incubator/codec/http3/Http3HeadersSink.java +++ b/src/main/java/io/netty/incubator/codec/http3/Http3HeadersSink.java @@ -15,10 +15,16 @@ */ package io.netty.incubator.codec.http3; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpMethod; import java.util.function.BiConsumer; +import static io.netty.incubator.codec.http3.Http3Headers.PseudoHeaderName.AUTHORITY; +import static io.netty.incubator.codec.http3.Http3Headers.PseudoHeaderName.METHOD; +import static io.netty.incubator.codec.http3.Http3Headers.PseudoHeaderName.PATH; +import static io.netty.incubator.codec.http3.Http3Headers.PseudoHeaderName.SCHEME; +import static io.netty.incubator.codec.http3.Http3Headers.PseudoHeaderName.STATUS; import static io.netty.incubator.codec.http3.Http3Headers.PseudoHeaderName.getPseudoHeader; import static io.netty.incubator.codec.http3.Http3Headers.PseudoHeaderName.hasPseudoHeaderFormat; @@ -36,7 +42,7 @@ final class Http3HeadersSink implements BiConsumer { private Http3HeadersValidationException validationException; private HeaderType previousType; private boolean request; - private int pseudoHeadersCount; + private int receivedPseudoHeaders; Http3HeadersSink(Http3Headers headers, long maxHeaderListSize, boolean validate, boolean trailer) { this.headers = headers; @@ -58,7 +64,7 @@ void finish() throws Http3HeadersValidationException, Http3Exception { } if (validate) { if (trailer) { - if (pseudoHeadersCount != 0) { + if (receivedPseudoHeaders != 0) { // Trailers must not have pseudo headers. throw new Http3HeadersValidationException("Pseudo-header(s) included in trailers."); } @@ -69,16 +75,12 @@ void finish() throws Http3HeadersValidationException, Http3Exception { if (request) { CharSequence method = headers.method(); // fast-path - if (pseudoHeadersCount < 2) { - // There can't be any duplicates for pseudy header names. - throw new Http3HeadersValidationException("Not all mandatory pseudo-headers included."); - } if (HttpMethod.CONNECT.asciiName().contentEqualsIgnoreCase(method)) { // For CONNECT we must only include: // - :method // - :authority - if (pseudoHeadersCount != 2 || headers.authority() == null) { - // There can't be any duplicates for pseudy header names. + final int requiredPseudoHeaders = METHOD.getFlag() | AUTHORITY.getFlag(); + if (receivedPseudoHeaders != requiredPseudoHeaders) { throw new Http3HeadersValidationException("Not all mandatory pseudo-headers included."); } } else if (HttpMethod.OPTIONS.asciiName().contentEqualsIgnoreCase(method)) { @@ -90,36 +92,43 @@ void finish() throws Http3HeadersValidationException, Http3Exception { // - :scheme // - :authority // - :path - if (pseudoHeadersCount != 4 && - // - :method - // - :scheme - // - :path - !(pseudoHeadersCount == 3 && headers.authority() == null && - "*".contentEquals(headers.path()))) { + final int requiredPseudoHeaders = METHOD.getFlag() | SCHEME.getFlag() | PATH.getFlag(); + if ((receivedPseudoHeaders & requiredPseudoHeaders) != requiredPseudoHeaders || + (!authorityOrHostHeaderReceived() && !"*".contentEquals(headers.path()))) { throw new Http3HeadersValidationException("Not all mandatory pseudo-headers included."); } } else { - // For requests we must include: + // For other requests we must include: // - :method // - :scheme // - :authority // - :path - if (pseudoHeadersCount != 4) { - // There can't be any duplicates for pseudy header names. + final int requiredPseudoHeaders = METHOD.getFlag() | SCHEME.getFlag() | PATH.getFlag(); + if ((receivedPseudoHeaders & requiredPseudoHeaders) != requiredPseudoHeaders || + !authorityOrHostHeaderReceived()) { throw new Http3HeadersValidationException("Not all mandatory pseudo-headers included."); } } } else { // For responses we must include: // - :status - if (pseudoHeadersCount != 1) { - // There can't be any duplicates for pseudy header names. + if (receivedPseudoHeaders != STATUS.getFlag()) { throw new Http3HeadersValidationException("Not all mandatory pseudo-headers included."); } } } } + /** + * Find host header field in case the :authority pseudo header is not specified. + * See: + * https://www.rfc-editor.org/rfc/rfc9110#section-7.2 + */ + private boolean authorityOrHostHeaderReceived() { + return (receivedPseudoHeaders & AUTHORITY.getFlag()) == AUTHORITY.getFlag() || + headers.contains(HttpHeaderNames.HOST); + } + @Override public void accept(CharSequence name, CharSequence value) { headersLength += QpackHeaderField.sizeOf(name, value); @@ -154,19 +163,15 @@ private void validate(Http3Headers headers, CharSequence name) { throw new Http3HeadersValidationException( String.format("Invalid HTTP/3 pseudo-header '%s' encountered.", name)); } - - final HeaderType currentHeaderType = pseudoHeader.isRequestOnly() ? - HeaderType.REQUEST_PSEUDO_HEADER : HeaderType.RESPONSE_PSEUDO_HEADER; - if (previousType != null && currentHeaderType != previousType) { - throw new Http3HeadersValidationException("Mix of request and response pseudo-headers."); - } - - if (headers.contains(name)) { + if ((receivedPseudoHeaders & pseudoHeader.getFlag()) != 0) { // There can't be any duplicates for pseudy header names. throw new Http3HeadersValidationException( String.format("Pseudo-header field '%s' exists already.", name)); } - pseudoHeadersCount++; + receivedPseudoHeaders |= pseudoHeader.getFlag(); + + final HeaderType currentHeaderType = pseudoHeader.isRequestOnly() ? + HeaderType.REQUEST_PSEUDO_HEADER : HeaderType.RESPONSE_PSEUDO_HEADER; request = pseudoHeader.isRequestOnly(); previousType = currentHeaderType; } else { diff --git a/src/test/java/io/netty/incubator/codec/http3/Http3HeadersSinkTest.java b/src/test/java/io/netty/incubator/codec/http3/Http3HeadersSinkTest.java index 80969cf..56753bd 100644 --- a/src/test/java/io/netty/incubator/codec/http3/Http3HeadersSinkTest.java +++ b/src/test/java/io/netty/incubator/codec/http3/Http3HeadersSinkTest.java @@ -16,6 +16,8 @@ package io.netty.incubator.codec.http3; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.util.AsciiString; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -143,7 +145,27 @@ public void testAuthorityNotRequiredForOptionsWildcard() throws Http3Exception { } @Test - public void testAuthorityRequiredForOptionsNonWildcard() throws Http3Exception { + public void testOptionsNonWildcardWithAuthority() throws Http3Exception { + Http3HeadersSink sink = new Http3HeadersSink(new DefaultHttp3Headers(), 512, true, false); + sink.accept(Http3Headers.PseudoHeaderName.METHOD.value(), "OPTIONS"); + sink.accept(Http3Headers.PseudoHeaderName.PATH.value(), "/something"); + sink.accept(Http3Headers.PseudoHeaderName.SCHEME.value(), "https"); + sink.accept(Http3Headers.PseudoHeaderName.AUTHORITY.value(), "example.com:4433"); + sink.finish(); + } + + @Test + public void testOptionsNonWildcardWithHost() throws Http3Exception { + Http3HeadersSink sink = new Http3HeadersSink(new DefaultHttp3Headers(), 512, true, false); + sink.accept(Http3Headers.PseudoHeaderName.METHOD.value(), "OPTIONS"); + sink.accept(Http3Headers.PseudoHeaderName.PATH.value(), "/something"); + sink.accept(Http3Headers.PseudoHeaderName.SCHEME.value(), "https"); + sink.accept(new AsciiString(HttpHeaderNames.HOST), "example.com:4433"); + sink.finish(); + } + + @Test + public void testAuthorityOrHostRequiredForOptionsNonWildcard() throws Http3Exception { Http3HeadersSink sink = new Http3HeadersSink(new DefaultHttp3Headers(), 512, true, false); sink.accept(Http3Headers.PseudoHeaderName.METHOD.value(), "OPTIONS"); sink.accept(Http3Headers.PseudoHeaderName.PATH.value(), "/something"); @@ -151,6 +173,16 @@ public void testAuthorityRequiredForOptionsNonWildcard() throws Http3Exception { assertThrows(Http3HeadersValidationException.class, () -> sink.finish()); } + @Test + public void testHostExistsInsteadOfAuthority() throws Http3Exception { + Http3HeadersSink sink = new Http3HeadersSink(new DefaultHttp3Headers(), 512, true, false); + sink.accept(Http3Headers.PseudoHeaderName.METHOD.value(), "GET"); + sink.accept(Http3Headers.PseudoHeaderName.PATH.value(), "/"); + sink.accept(Http3Headers.PseudoHeaderName.SCHEME.value(), "https"); + sink.accept(new AsciiString(HttpHeaderNames.HOST), "example.com:4433"); + sink.finish(); + } + private static void addMandatoryPseudoHeaders(Http3HeadersSink sink, boolean req) { if (req) { sink.accept(Http3Headers.PseudoHeaderName.METHOD.value(), "GET");