diff --git a/adapter.go b/adapter.go index c311d07..b1c2e46 100644 --- a/adapter.go +++ b/adapter.go @@ -90,11 +90,7 @@ func Adapter(opts ...Option) (func(http.Handler) http.Handler, error) { writerPool.Put(gw) }() - if _, ok := w.(http.CloseNotifier); ok { - w = compressWriterWithCloseNotify{gw} - } else { - w = gw - } + w = extend(gw) h.ServeHTTP(w, r) }) diff --git a/optional.go b/optional.go new file mode 100644 index 0000000..39fb94c --- /dev/null +++ b/optional.go @@ -0,0 +1,133 @@ +package httpcompression + +import ( + "bufio" + "net" + "net/http" +) + +// extend returns a http.ResponseWriter that wraps the compressWriter and that +// dynamically exposes some optional interfaces of http.ResponseWriter. +// Currently the supported optional interfaces are http.Hijacker, http.Pusher, +// and http.CloseNotifier. +// This is obviously an horrible way of doing things, but it's really unavoidable +// without proper language support for interface extension; see +// https://blog.merovius.de/2017/07/30/the-trouble-with-optional-interfaces.html +// for details. +func extend(cw *compressWriter) http.ResponseWriter { + switch cw.ResponseWriter.(type) { + case interface { + http.Hijacker + http.Pusher + http.CloseNotifier + }: + return cwHijackPushCloseNotifier{cw} + case interface { + http.Pusher + http.CloseNotifier + }: + return cwPushCloseNotifier{cw} + case interface { + http.Hijacker + http.Pusher + }: + return cwHijackPusher{cw} + case interface { + http.Hijacker + http.CloseNotifier + }: + return cwHijackCloseNotifier{cw} + case http.CloseNotifier: + return cwCloseNotifier{cw} + case http.Hijacker: + return cwHijacker{cw} + case http.Pusher: + return cwPusher{cw} + default: + return cw + } +} + +type cwHijacker struct{ *compressWriter } + +func (cw cwHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return cw.ResponseWriter.(http.Hijacker).Hijack() +} + +var _ http.Hijacker = cwHijacker{} + +type cwCloseNotifier struct{ *compressWriter } + +func (cw cwCloseNotifier) CloseNotify() <-chan bool { + return cw.ResponseWriter.(http.CloseNotifier).CloseNotify() +} + +var _ http.CloseNotifier = cwCloseNotifier{} + +type cwPusher struct{ *compressWriter } + +func (cw cwPusher) Push(target string, opts *http.PushOptions) error { + return cw.ResponseWriter.(http.Pusher).Push(target, opts) +} + +var _ http.Pusher = cwPusher{} + +type cwHijackCloseNotifier struct{ *compressWriter } + +func (cw cwHijackCloseNotifier) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return cwHijacker{cw.compressWriter}.Hijack() +} +func (cw cwHijackCloseNotifier) CloseNotify() <-chan bool { + return cwCloseNotifier{cw.compressWriter}.CloseNotify() +} + +var ( + _ http.Hijacker = cwHijackCloseNotifier{} + _ http.CloseNotifier = cwHijackCloseNotifier{} +) + +type cwHijackPusher struct{ *compressWriter } + +func (cw cwHijackPusher) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return cwHijacker{cw.compressWriter}.Hijack() +} +func (cw cwHijackPusher) Push(target string, opts *http.PushOptions) error { + return cwPusher{cw.compressWriter}.Push(target, opts) +} + +var ( + _ http.Hijacker = cwHijackPusher{} + _ http.Pusher = cwHijackPusher{} +) + +type cwPushCloseNotifier struct{ *compressWriter } + +func (cw cwPushCloseNotifier) Push(target string, opts *http.PushOptions) error { + return cwPusher{cw.compressWriter}.Push(target, opts) +} +func (cw cwPushCloseNotifier) CloseNotify() <-chan bool { + return cwCloseNotifier{cw.compressWriter}.CloseNotify() +} + +var ( + _ http.Pusher = cwPushCloseNotifier{} + _ http.CloseNotifier = cwPushCloseNotifier{} +) + +type cwHijackPushCloseNotifier struct{ *compressWriter } + +func (cw cwHijackPushCloseNotifier) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return cwHijacker{cw.compressWriter}.Hijack() +} +func (cw cwHijackPushCloseNotifier) Push(target string, opts *http.PushOptions) error { + return cwPusher{cw.compressWriter}.Push(target, opts) +} +func (cw cwHijackPushCloseNotifier) CloseNotify() <-chan bool { + return cwCloseNotifier{cw.compressWriter}.CloseNotify() +} + +var ( + _ http.Hijacker = cwHijackPushCloseNotifier{} + _ http.Pusher = cwHijackPushCloseNotifier{} + _ http.CloseNotifier = cwHijackPushCloseNotifier{} +) diff --git a/response_writer.go b/response_writer.go index 253bc29..3234d1d 100644 --- a/response_writer.go +++ b/response_writer.go @@ -1,10 +1,8 @@ package httpcompression import ( - "bufio" "fmt" "io" - "net" "net/http" "strconv" "sync" @@ -31,21 +29,6 @@ type compressWriter struct { var ( _ io.WriteCloser = &compressWriter{} _ http.Flusher = &compressWriter{} - _ http.Hijacker = &compressWriter{} -) - -type compressWriterWithCloseNotify struct { - *compressWriter -} - -func (w compressWriterWithCloseNotify) CloseNotify() <-chan bool { - return w.ResponseWriter.(http.CloseNotifier).CloseNotify() -} - -var ( - _ io.WriteCloser = compressWriterWithCloseNotify{} - _ http.Flusher = compressWriterWithCloseNotify{} - _ http.Hijacker = compressWriterWithCloseNotify{} ) const maxBuf = 1 << 16 // maximum size of recycled buffer @@ -213,8 +196,13 @@ func (w *compressWriter) Flush() { return } - // Flush the compressor, if supported, - // note: http.ResponseWriter does not implement Flusher, so we need to call ResponseWriter.Flush anyway. + // Flush the compressor, if supported. + // note: http.ResponseWriter does not implement Flusher (http.Flusher does not return an error), + // so we need to later call ResponseWriter.Flush anyway: + // - in case we are bypassing compression, w.w is the parent ResponseWriter, and therefore we skip + // this as the parent ResponseWriter does not implement Flusher. + // - in case we are NOT bypassing compression, w.w is the compressor, and therefore we flush the + // compressor and then we flush the parent ResponseWriter. if fw, ok := w.w.(Flusher); ok { _ = fw.Flush() } @@ -225,15 +213,6 @@ func (w *compressWriter) Flush() { } } -// Hijack implements http.Hijacker. If the underlying ResponseWriter is a -// Hijacker, its Hijack method is returned. Otherwise an error is returned. -func (w *compressWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if hj, ok := w.ResponseWriter.(http.Hijacker); ok { - return hj.Hijack() - } - return nil, nil, fmt.Errorf("http.Hijacker interface is not supported") -} - func (w *compressWriter) recycleBuffer() { if cap(w.buf) > 0 && cap(w.buf) <= maxBuf { w.pool.Put(w.buf[:0])