diff options
author | Francis Lavoie <lavofr@gmail.com> | 2022-08-12 15:11:13 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-12 13:11:13 -0600 |
commit | 922d9f5c251a27bc9f76c4b74bde151cc03cc9b3 (patch) | |
tree | 727bb655e958f1c67455e3752f86fef06b6114fd /modules | |
parent | 91ab0e60669b84b7c09189a06de0d6a9771bf950 (diff) |
reverseproxy: Fix H2C dialer using new stdlib `DialTLSContext` (#4951)
Diffstat (limited to 'modules')
-rw-r--r-- | modules/caddyhttp/reverseproxy/httptransport.go | 96 |
1 files changed, 49 insertions, 47 deletions
diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index ef72b88..e9c7ddd 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -128,28 +128,6 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error { } h.Transport = rt - // if h2c is enabled, configure its transport (std lib http.Transport - // does not "HTTP/2 over cleartext TCP") - if sliceContains(h.Versions, "h2c") { - // crafting our own http2.Transport doesn't allow us to utilize - // most of the customizations/preferences on the http.Transport, - // because, for some reason, only http2.ConfigureTransport() - // is allowed to set the unexported field that refers to a base - // http.Transport config; oh well - h2t := &http2.Transport{ - // kind of a hack, but for plaintext/H2C requests, pretend to dial TLS - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // TODO: no context, thus potentially wrong dial info - return net.Dial(network, addr) - }, - AllowHTTP: true, - } - if h.Compression != nil { - h2t.DisableCompression = !*h.Compression - } - h.h2cTransport = h2t - } - return nil } @@ -194,35 +172,38 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e } } - rt := &http.Transport{ - DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - // the proper dialing information should be embedded into the request's context - if dialInfo, ok := GetDialInfo(ctx); ok { - network = dialInfo.Network - address = dialInfo.Address - } + // Set up the dialer to pull the correct information from the context + dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { + // the proper dialing information should be embedded into the request's context + if dialInfo, ok := GetDialInfo(ctx); ok { + network = dialInfo.Network + address = dialInfo.Address + } - conn, err := dialer.DialContext(ctx, network, address) - if err != nil { - // identify this error as one that occurred during - // dialing, which can be important when trying to - // decide whether to retry a request - return nil, DialError{err} - } + conn, err := dialer.DialContext(ctx, network, address) + if err != nil { + // identify this error as one that occurred during + // dialing, which can be important when trying to + // decide whether to retry a request + return nil, DialError{err} + } - // if read/write timeouts are configured and this is a TCP connection, enforce the timeouts - // by wrapping the connection with our own type - if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) { - conn = &tcpRWTimeoutConn{ - TCPConn: tcpConn, - readTimeout: time.Duration(h.ReadTimeout), - writeTimeout: time.Duration(h.WriteTimeout), - logger: caddyCtx.Logger(h), - } + // if read/write timeouts are configured and this is a TCP connection, enforce the timeouts + // by wrapping the connection with our own type + if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) { + conn = &tcpRWTimeoutConn{ + TCPConn: tcpConn, + readTimeout: time.Duration(h.ReadTimeout), + writeTimeout: time.Duration(h.WriteTimeout), + logger: caddyCtx.Logger(h), } + } - return conn, nil - }, + return conn, nil + } + + rt := &http.Transport{ + DialContext: dialContext, MaxConnsPerHost: h.MaxConnsPerHost, ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout), @@ -260,6 +241,27 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e } } + // if h2c is enabled, configure its transport (std lib http.Transport + // does not "HTTP/2 over cleartext TCP") + if sliceContains(h.Versions, "h2c") { + // crafting our own http2.Transport doesn't allow us to utilize + // most of the customizations/preferences on the http.Transport, + // because, for some reason, only http2.ConfigureTransport() + // is allowed to set the unexported field that refers to a base + // http.Transport config; oh well + h2t := &http2.Transport{ + // kind of a hack, but for plaintext/H2C requests, pretend to dial TLS + DialTLSContext: func(ctx context.Context, network, address string, _ *tls.Config) (net.Conn, error) { + return dialContext(ctx, network, address) + }, + AllowHTTP: true, + } + if h.Compression != nil { + h2t.DisableCompression = !*h.Compression + } + h.h2cTransport = h2t + } + return rt, nil } |