summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy/streaming.go
diff options
context:
space:
mode:
authorMatthew Holt <mholt@users.noreply.github.com>2019-09-03 16:56:09 -0600
committerMatthew Holt <mholt@users.noreply.github.com>2019-09-03 16:56:09 -0600
commit652460e03e11a037d9f86b09b3546c9e42733d2d (patch)
tree9b1042f3fc880f87ca25949618fea45e6960080a /modules/caddyhttp/reverseproxy/streaming.go
parent4a1e1649bc985e9658d326ed433a101d7d79ae30 (diff)
Some cleanup and godoc
Diffstat (limited to 'modules/caddyhttp/reverseproxy/streaming.go')
-rw-r--r--modules/caddyhttp/reverseproxy/streaming.go223
1 files changed, 223 insertions, 0 deletions
diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go
new file mode 100644
index 0000000..a3b711b
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/streaming.go
@@ -0,0 +1,223 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Most of the code in this file was initially borrowed from the Go
+// standard library, which has this copyright notice:
+// Copyright 2011 The Go Authors.
+
+package reverseproxy
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "sync"
+ "time"
+)
+
+func (h Handler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
+ reqUpType := upgradeType(req.Header)
+ resUpType := upgradeType(res.Header)
+ if reqUpType != resUpType {
+ // TODO: figure out our own error handling
+ // p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
+ return
+ }
+
+ copyHeader(res.Header, rw.Header())
+
+ hj, ok := rw.(http.Hijacker)
+ if !ok {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
+ return
+ }
+ backConn, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
+ return
+ }
+ defer backConn.Close()
+ conn, brw, err := hj.Hijack()
+ if err != nil {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
+ return
+ }
+ defer conn.Close()
+ res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
+ if err := res.Write(brw); err != nil {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
+ return
+ }
+ if err := brw.Flush(); err != nil {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
+ return
+ }
+ errc := make(chan error, 1)
+ spc := switchProtocolCopier{user: conn, backend: backConn}
+ go spc.copyToBackend(errc)
+ go spc.copyFromBackend(errc)
+ <-errc
+ return
+}
+
+// flushInterval returns the p.FlushInterval value, conditionally
+// overriding its value for a specific request/response.
+func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Duration {
+ resCT := res.Header.Get("Content-Type")
+
+ // For Server-Sent Events responses, flush immediately.
+ // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
+ if resCT == "text/event-stream" {
+ return -1 // negative means immediately
+ }
+
+ // TODO: more specific cases? e.g. res.ContentLength == -1? (this TODO is from the std lib)
+ return time.Duration(h.FlushInterval)
+}
+
+func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
+ if flushInterval != 0 {
+ if wf, ok := dst.(writeFlusher); ok {
+ mlw := &maxLatencyWriter{
+ dst: wf,
+ latency: flushInterval,
+ }
+ defer mlw.stop()
+ dst = mlw
+ }
+ }
+
+ // TODO: Figure out how we want to do this... using custom buffer pool type seems unnecessary
+ // or maybe it is, depending on how we want to handle errors,
+ // see: https://github.com/golang/go/issues/21814
+ // buf := bufPool.Get().(*bytes.Buffer)
+ // buf.Reset()
+ // defer bufPool.Put(buf)
+ // _, err := io.CopyBuffer(dst, src, )
+ var buf []byte
+ // if h.BufferPool != nil {
+ // buf = h.BufferPool.Get()
+ // defer h.BufferPool.Put(buf)
+ // }
+ _, err := h.copyBuffer(dst, src, buf)
+ return err
+}
+
+// copyBuffer returns any write errors or non-EOF read errors, and the amount
+// of bytes written.
+func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+ if len(buf) == 0 {
+ buf = make([]byte, 32*1024)
+ }
+ var written int64
+ for {
+ nr, rerr := src.Read(buf)
+ if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
+ // TODO: this could be useful to know (indeed, it revealed an error in our
+ // fastcgi PoC earlier; but it's this single error report here that necessitates
+ // a function separate from io.CopyBuffer, since io.CopyBuffer does not distinguish
+ // between read or write errors; in a reverse proxy situation, write errors are not
+ // something we need to report to the client, but read errors are a problem on our
+ // end for sure. so we need to decide what we want.)
+ // p.logf("copyBuffer: ReverseProxy read error during body copy: %v", rerr)
+ }
+ if nr > 0 {
+ nw, werr := dst.Write(buf[:nr])
+ if nw > 0 {
+ written += int64(nw)
+ }
+ if werr != nil {
+ return written, werr
+ }
+ if nr != nw {
+ return written, io.ErrShortWrite
+ }
+ }
+ if rerr != nil {
+ if rerr == io.EOF {
+ rerr = nil
+ }
+ return written, rerr
+ }
+ }
+}
+
+type writeFlusher interface {
+ io.Writer
+ http.Flusher
+}
+
+type maxLatencyWriter struct {
+ dst writeFlusher
+ latency time.Duration // non-zero; negative means to flush immediately
+
+ mu sync.Mutex // protects t, flushPending, and dst.Flush
+ t *time.Timer
+ flushPending bool
+}
+
+func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ n, err = m.dst.Write(p)
+ if m.latency < 0 {
+ m.dst.Flush()
+ return
+ }
+ if m.flushPending {
+ return
+ }
+ if m.t == nil {
+ m.t = time.AfterFunc(m.latency, m.delayedFlush)
+ } else {
+ m.t.Reset(m.latency)
+ }
+ m.flushPending = true
+ return
+}
+
+func (m *maxLatencyWriter) delayedFlush() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
+ return
+ }
+ m.dst.Flush()
+ m.flushPending = false
+}
+
+func (m *maxLatencyWriter) stop() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.flushPending = false
+ if m.t != nil {
+ m.t.Stop()
+ }
+}
+
+// switchProtocolCopier exists so goroutines proxying data back and
+// forth have nice names in stacks.
+type switchProtocolCopier struct {
+ user, backend io.ReadWriter
+}
+
+func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
+ _, err := io.Copy(c.user, c.backend)
+ errc <- err
+}
+
+func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
+ _, err := io.Copy(c.backend, c.user)
+ errc <- err
+}