summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy/httptransport.go
diff options
context:
space:
mode:
authorFrancis Lavoie <lavofr@gmail.com>2022-08-12 15:11:13 -0400
committerGitHub <noreply@github.com>2022-08-12 13:11:13 -0600
commit922d9f5c251a27bc9f76c4b74bde151cc03cc9b3 (patch)
tree727bb655e958f1c67455e3752f86fef06b6114fd /modules/caddyhttp/reverseproxy/httptransport.go
parent91ab0e60669b84b7c09189a06de0d6a9771bf950 (diff)
reverseproxy: Fix H2C dialer using new stdlib `DialTLSContext` (#4951)
Diffstat (limited to 'modules/caddyhttp/reverseproxy/httptransport.go')
-rw-r--r--modules/caddyhttp/reverseproxy/httptransport.go96
1 files changed, 49 insertions, 47 deletions
diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go
index ef72b88..e9c7ddd 100644
--- a/modules/caddyhttp/reverseproxy/httptransport.go
+++ b/modules/caddyhttp/reverseproxy/httptransport.go
@@ -128,28 +128,6 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error {
}
h.Transport = rt
- // if h2c is enabled, configure its transport (std lib http.Transport
- // does not "HTTP/2 over cleartext TCP")
- if sliceContains(h.Versions, "h2c") {
- // crafting our own http2.Transport doesn't allow us to utilize
- // most of the customizations/preferences on the http.Transport,
- // because, for some reason, only http2.ConfigureTransport()
- // is allowed to set the unexported field that refers to a base
- // http.Transport config; oh well
- h2t := &http2.Transport{
- // kind of a hack, but for plaintext/H2C requests, pretend to dial TLS
- DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
- // TODO: no context, thus potentially wrong dial info
- return net.Dial(network, addr)
- },
- AllowHTTP: true,
- }
- if h.Compression != nil {
- h2t.DisableCompression = !*h.Compression
- }
- h.h2cTransport = h2t
- }
-
return nil
}
@@ -194,35 +172,38 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
}
}
- rt := &http.Transport{
- DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
- // the proper dialing information should be embedded into the request's context
- if dialInfo, ok := GetDialInfo(ctx); ok {
- network = dialInfo.Network
- address = dialInfo.Address
- }
+ // Set up the dialer to pull the correct information from the context
+ dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
+ // the proper dialing information should be embedded into the request's context
+ if dialInfo, ok := GetDialInfo(ctx); ok {
+ network = dialInfo.Network
+ address = dialInfo.Address
+ }
- conn, err := dialer.DialContext(ctx, network, address)
- if err != nil {
- // identify this error as one that occurred during
- // dialing, which can be important when trying to
- // decide whether to retry a request
- return nil, DialError{err}
- }
+ conn, err := dialer.DialContext(ctx, network, address)
+ if err != nil {
+ // identify this error as one that occurred during
+ // dialing, which can be important when trying to
+ // decide whether to retry a request
+ return nil, DialError{err}
+ }
- // if read/write timeouts are configured and this is a TCP connection, enforce the timeouts
- // by wrapping the connection with our own type
- if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) {
- conn = &tcpRWTimeoutConn{
- TCPConn: tcpConn,
- readTimeout: time.Duration(h.ReadTimeout),
- writeTimeout: time.Duration(h.WriteTimeout),
- logger: caddyCtx.Logger(h),
- }
+ // if read/write timeouts are configured and this is a TCP connection, enforce the timeouts
+ // by wrapping the connection with our own type
+ if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) {
+ conn = &tcpRWTimeoutConn{
+ TCPConn: tcpConn,
+ readTimeout: time.Duration(h.ReadTimeout),
+ writeTimeout: time.Duration(h.WriteTimeout),
+ logger: caddyCtx.Logger(h),
}
+ }
- return conn, nil
- },
+ return conn, nil
+ }
+
+ rt := &http.Transport{
+ DialContext: dialContext,
MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
@@ -260,6 +241,27 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
}
}
+ // if h2c is enabled, configure its transport (std lib http.Transport
+ // does not "HTTP/2 over cleartext TCP")
+ if sliceContains(h.Versions, "h2c") {
+ // crafting our own http2.Transport doesn't allow us to utilize
+ // most of the customizations/preferences on the http.Transport,
+ // because, for some reason, only http2.ConfigureTransport()
+ // is allowed to set the unexported field that refers to a base
+ // http.Transport config; oh well
+ h2t := &http2.Transport{
+ // kind of a hack, but for plaintext/H2C requests, pretend to dial TLS
+ DialTLSContext: func(ctx context.Context, network, address string, _ *tls.Config) (net.Conn, error) {
+ return dialContext(ctx, network, address)
+ },
+ AllowHTTP: true,
+ }
+ if h.Compression != nil {
+ h2t.DisableCompression = !*h.Compression
+ }
+ h.h2cTransport = h2t
+ }
+
return rt, nil
}