diff options
-rw-r--r-- | modules/caddyhttp/reverseproxy/httptransport.go | 41 |
1 files changed, 39 insertions, 2 deletions
diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index 8bce580..eefc04a 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "reflect" + "strings" "time" "github.com/caddyserver/caddy/v2" @@ -242,9 +243,45 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error) return rt, nil } +// replaceTLSServername checks TLS servername to see if it needs replacing +// if it does need replacing, it creates a new cloned HTTPTransport object to avoid any races +// and does the replacing of the TLS servername on that and returns the new object +// if no replacement is necessary it returns the original +func (h *HTTPTransport) replaceTLSServername(repl *caddy.Replacer) *HTTPTransport { + // check whether we have TLS and need to replace the servername in the TLSClientConfig + if h.TLSEnabled() && strings.Contains(h.TLS.ServerName, "{") { + // make a new h, "copy" the parts we don't need to touch, add a new *tls.Config and replace servername + newtransport := &HTTPTransport{ + Resolver: h.Resolver, + TLS: h.TLS, + KeepAlive: h.KeepAlive, + Compression: h.Compression, + MaxConnsPerHost: h.MaxConnsPerHost, + DialTimeout: h.DialTimeout, + FallbackDelay: h.FallbackDelay, + ResponseHeaderTimeout: h.ResponseHeaderTimeout, + ExpectContinueTimeout: h.ExpectContinueTimeout, + MaxResponseHeaderSize: h.MaxResponseHeaderSize, + WriteBufferSize: h.WriteBufferSize, + ReadBufferSize: h.ReadBufferSize, + Versions: h.Versions, + Transport: h.Transport.Clone(), + h2cTransport: h.h2cTransport, + } + newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "") + return newtransport + } + + return h +} + // RoundTrip implements http.RoundTripper. func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - h.SetScheme(req) + // Try to replace TLS servername if needed + repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + transport := h.replaceTLSServername(repl) + + transport.SetScheme(req) // if H2C ("HTTP/2 over cleartext") is enabled and the upstream request is // HTTP without TLS, use the alternate H2C-capable transport instead @@ -252,7 +289,7 @@ func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { return h.h2cTransport.RoundTrip(req) } - return h.Transport.RoundTrip(req) + return transport.Transport.RoundTrip(req) } // SetScheme ensures that the outbound request req |