diff options
author | Kevin Lin <masknu@users.noreply.github.com> | 2020-07-21 02:14:46 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-20 12:14:46 -0600 |
commit | e9b1d7dcb4cbf85da7fb4cf8c411a4f840a98cf1 (patch) | |
tree | c21f3f66b5e1c2a8cf60172f07004a4ee6c261b5 | |
parent | bd9d796e6ed64c713002e3503a8b0012bd4f1460 (diff) |
reverse_proxy: flush HTTP/2 response when ContentLength is unknown (#3561)
* reverse proxy: Support more h2 stream scenarios (#3556)
* reverse proxy: add integration test for better h2 stream (#3556)
* reverse proxy: adjust comments as francislavoie suggests
* link to issue #3556 in the comments
-rw-r--r-- | caddytest/integration/stream_test.go | 201 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/reverseproxy.go | 8 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/streaming.go | 5 |
3 files changed, 214 insertions, 0 deletions
diff --git a/caddytest/integration/stream_test.go b/caddytest/integration/stream_test.go new file mode 100644 index 0000000..c0ab32b --- /dev/null +++ b/caddytest/integration/stream_test.go @@ -0,0 +1,201 @@ +package integration + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "testing" + "time" + + "github.com/caddyserver/caddy/v2/caddytest" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +// (see https://github.com/caddyserver/caddy/issues/3556 for use case) +func TestH2ToH2CStream(t *testing.T) { + tester := caddytest.NewTester(t) + tester.InitServer(` + { + "apps": { + "http": { + "http_port": 9080, + "https_port": 9443, + "servers": { + "srv0": { + "listen": [ + ":9443" + ], + "routes": [ + { + "handle": [ + { + "handler": "reverse_proxy", + "transport": { + "protocol": "http", + "compression": false, + "versions": [ + "h2c", + "2" + ] + }, + "upstreams": [ + { + "dial": "localhost:54321" + } + ] + } + ], + "match": [ + { + "path": [ + "/tov2ray" + ] + } + ] + } + ], + "tls_connection_policies": [ + { + "certificate_selection": { + "any_tag": ["cert0"] + }, + "default_sni": "a.caddy.localhost" + } + ] + } + } + }, + "tls": { + "certificates": { + "load_files": [ + { + "certificate": "/a.caddy.localhost.crt", + "key": "/a.caddy.localhost.key", + "tags": [ + "cert0" + ] + } + ] + } + }, + "pki": { + "certificate_authorities" : { + "local" : { + "install_trust": false + } + } + } + } + } + `, "json") + + expectedBody := "some data to be echoed" + // start the server + server := testH2ToH2CStreamServeH2C(t) + go server.ListenAndServe() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + server.Shutdown(ctx) + }() + + r, w := io.Pipe() + req := &http.Request{ + Method: "PUT", + Body: ioutil.NopCloser(r), + URL: &url.URL{ + Scheme: "https", + Host: "127.0.0.1:9443", + Path: "/tov2ray", + }, + Proto: "HTTP/2", + ProtoMajor: 2, + ProtoMinor: 0, + Header: make(http.Header), + } + // Disable any compression method from server. + req.Header.Set("Accept-Encoding", "identity") + + resp := tester.AssertResponseCode(req, 200) + if 200 != resp.StatusCode { + return + } + go func() { + fmt.Fprint(w, expectedBody) + w.Close() + }() + + defer resp.Body.Close() + bytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read the response body %s", err) + } + + body := string(bytes) + + if !strings.Contains(body, expectedBody) { + t.Errorf("requesting \"%s\" expected response body \"%s\" but got \"%s\"", req.RequestURI, expectedBody, body) + } + return +} + +func testH2ToH2CStreamServeH2C(t *testing.T) *http.Server { + h2s := &http2.Server{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rstring, err := httputil.DumpRequest(r, false) + if err == nil { + t.Logf("h2c server received req: %s", rstring) + } + // We only accept HTTP/2! + if r.ProtoMajor != 2 { + t.Error("Not a HTTP/2 request, rejected!") + w.WriteHeader(http.StatusInternalServerError) + return + } + + if r.Host != "127.0.0.1:9443" { + t.Errorf("r.Host doesn't match, %v!", r.Host) + w.WriteHeader(http.StatusNotFound) + return + } + + if !strings.HasPrefix(r.URL.Path, "/tov2ray") { + w.WriteHeader(http.StatusNotFound) + return + } + + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(200) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + buf := make([]byte, 4*1024) + + for { + n, err := r.Body.Read(buf) + if n > 0 { + w.Write(buf[:n]) + } + + if err != nil { + if err == io.EOF { + r.Body.Close() + } + break + } + } + }) + + server := &http.Server{ + Addr: "127.0.0.1:54321", + Handler: h2c.NewHandler(handler, h2s), + } + return server +} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index bb1453a..0a53db4 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -611,6 +611,14 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, di Dia rw.WriteHeader(res.StatusCode) + // some apps need the response headers before starting to stream content with http2, + // so it's important to explicitly flush the headers to the client before streaming the data. + // (see https://github.com/caddyserver/caddy/issues/3556 for use case) + if req.ProtoMajor == 2 && res.ContentLength == -1 { + if wf, ok := rw.(http.Flusher); ok { + wf.Flush() + } + } err = h.copyResponse(rw, res.Body, h.flushInterval(req, res)) res.Body.Close() // close now, instead of defer, to populate res.Trailer if err != nil { diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 8a7c6f7..105ff32 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -96,6 +96,11 @@ func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Durat return -1 // negative means immediately } + // for h2 and h2c upstream streaming data to client (issue #3556) + if req.ProtoMajor == 2 && res.ContentLength == -1 { + return -1 + } + // TODO: more specific cases? e.g. res.ContentLength == -1? (this TODO is from the std lib) return time.Duration(h.FlushInterval) } |