From 758269124ef5d6d83cb4e8d8eeb095f5b025250e Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Wed, 11 Sep 2019 18:53:44 -0600 Subject: reverseproxy: Fix host and port on requests; fix Caddyfile parser --- modules/caddyhttp/reverseproxy/caddyfile.go | 188 +++++++++++++------------ modules/caddyhttp/reverseproxy/reverseproxy.go | 34 ++++- 2 files changed, 128 insertions(+), 94 deletions(-) (limited to 'modules/caddyhttp') diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index 3c02bf1..d8c63b4 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -376,108 +376,110 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { // } // func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { - for d.NextBlock(0) { - switch d.Val() { - case "read_buffer": - if !d.NextArg() { - return d.ArgErr() - } - size, err := humanize.ParseBytes(d.Val()) - if err != nil { - return d.Errf("invalid read buffer size '%s': %v", d.Val(), err) - } - h.ReadBufferSize = int(size) + for d.Next() { + for d.NextBlock(0) { + switch d.Val() { + case "read_buffer": + if !d.NextArg() { + return d.ArgErr() + } + size, err := humanize.ParseBytes(d.Val()) + if err != nil { + return d.Errf("invalid read buffer size '%s': %v", d.Val(), err) + } + h.ReadBufferSize = int(size) - case "write_buffer": - if !d.NextArg() { - return d.ArgErr() - } - size, err := humanize.ParseBytes(d.Val()) - if err != nil { - return d.Errf("invalid write buffer size '%s': %v", d.Val(), err) - } - h.WriteBufferSize = int(size) + case "write_buffer": + if !d.NextArg() { + return d.ArgErr() + } + size, err := humanize.ParseBytes(d.Val()) + if err != nil { + return d.Errf("invalid write buffer size '%s': %v", d.Val(), err) + } + h.WriteBufferSize = int(size) - case "dial_timeout": - if !d.NextArg() { - return d.ArgErr() - } - dur, err := time.ParseDuration(d.Val()) - if err != nil { - return d.Errf("bad timeout value '%s': %v", d.Val(), err) - } - h.DialTimeout = caddy.Duration(dur) + case "dial_timeout": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value '%s': %v", d.Val(), err) + } + h.DialTimeout = caddy.Duration(dur) - case "tls_client_auth": - args := d.RemainingArgs() - if len(args) != 2 { - return d.ArgErr() - } - if h.TLS == nil { - h.TLS = new(TLSConfig) - } - h.TLS.ClientCertificateFile = args[0] - h.TLS.ClientCertificateKeyFile = args[1] + case "tls_client_auth": + args := d.RemainingArgs() + if len(args) != 2 { + return d.ArgErr() + } + if h.TLS == nil { + h.TLS = new(TLSConfig) + } + h.TLS.ClientCertificateFile = args[0] + h.TLS.ClientCertificateKeyFile = args[1] - case "tls": - if h.TLS == nil { - h.TLS = new(TLSConfig) - } + case "tls": + if h.TLS == nil { + h.TLS = new(TLSConfig) + } - case "tls_insecure_skip_verify": - if d.NextArg() { - return d.ArgErr() - } - if h.TLS == nil { - h.TLS = new(TLSConfig) - } - h.TLS.InsecureSkipVerify = true + case "tls_insecure_skip_verify": + if d.NextArg() { + return d.ArgErr() + } + if h.TLS == nil { + h.TLS = new(TLSConfig) + } + h.TLS.InsecureSkipVerify = true - case "tls_timeout": - if !d.NextArg() { - return d.ArgErr() - } - dur, err := time.ParseDuration(d.Val()) - if err != nil { - return d.Errf("bad timeout value '%s': %v", d.Val(), err) - } - if h.TLS == nil { - h.TLS = new(TLSConfig) - } - h.TLS.HandshakeTimeout = caddy.Duration(dur) + case "tls_timeout": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value '%s': %v", d.Val(), err) + } + if h.TLS == nil { + h.TLS = new(TLSConfig) + } + h.TLS.HandshakeTimeout = caddy.Duration(dur) - case "keepalive": - if !d.NextArg() { - return d.ArgErr() - } - if h.KeepAlive == nil { - h.KeepAlive = new(KeepAlive) - } - if d.Val() == "off" { - var disable bool - h.KeepAlive.Enabled = &disable - } - dur, err := time.ParseDuration(d.Val()) - if err != nil { - return d.Errf("bad duration value '%s': %v", d.Val(), err) - } - h.KeepAlive.IdleConnTimeout = caddy.Duration(dur) + case "keepalive": + if !d.NextArg() { + return d.ArgErr() + } + if h.KeepAlive == nil { + h.KeepAlive = new(KeepAlive) + } + if d.Val() == "off" { + var disable bool + h.KeepAlive.Enabled = &disable + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value '%s': %v", d.Val(), err) + } + h.KeepAlive.IdleConnTimeout = caddy.Duration(dur) - case "keepalive_idle_conns": - if !d.NextArg() { - return d.ArgErr() - } - num, err := strconv.Atoi(d.Val()) - if err != nil { - return d.Errf("bad integer value '%s': %v", d.Val(), err) - } - if h.KeepAlive == nil { - h.KeepAlive = new(KeepAlive) - } - h.KeepAlive.MaxIdleConns = num + case "keepalive_idle_conns": + if !d.NextArg() { + return d.ArgErr() + } + num, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("bad integer value '%s': %v", d.Val(), err) + } + if h.KeepAlive == nil { + h.KeepAlive = new(KeepAlive) + } + h.KeepAlive.MaxIdleConns = num - default: - return d.Errf("unrecognized subdirective %s", d.Val()) + default: + return d.Errf("unrecognized subdirective %s", d.Val()) + } } } return nil diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 5a37613..a82c4e0 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -145,6 +145,26 @@ func (h *Handler) Provision(ctx caddy.Context) error { var allUpstreams []*Upstream for _, upstream := range h.Upstreams { + // if a port was not specified (and the network type uses + // ports), then maybe we can figure out the default port + netw, host, port, err := caddy.SplitNetworkAddress(upstream.Dial) + if err != nil && port == "" && !strings.Contains(netw, "unix") { + if host == "" { + // assume all that was given was the host, no port + host = upstream.Dial + } + // a port was not specified, but we may be able to + // infer it if we know the standard ports on which + // the transport protocol operates + if ht, ok := h.Transport.(*HTTPTransport); ok { + defaultPort := "80" + if ht.TLS != nil { + defaultPort = "443" + } + upstream.Dial = caddy.JoinNetworkAddress(netw, host, defaultPort) + } + } + // upstreams are allowed to map to only a single host, // but an upstream's address may semantically represent // multiple addresses, so make sure to handle each @@ -474,7 +494,19 @@ func (h Handler) tryAgain(start time.Time, proxyErr error) bool { // given upstream host. It must modify ONLY the request URL. func (h Handler) directRequest(req *http.Request, upstream *Upstream) { if req.URL.Host == "" { - req.URL.Host = upstream.dialInfo.Address + // we need a host, so set the upstream's host address + fullHost := upstream.dialInfo.Address + + // but if the port matches the scheme, strip the port because + // it's weird to make a request like http://example.com:80/. + host, port, err := net.SplitHostPort(fullHost) + if err == nil && + (req.URL.Scheme == "http" && port == "80") || + (req.URL.Scheme == "https" && port == "443") { + fullHost = host + } + + req.URL.Host = fullHost } } -- cgit v1.2.3