From e02117cb8a2b0b6dbd3dbb1de4d1569ff63ca617 Mon Sep 17 00:00:00 2001 From: Matt Holt Date: Tue, 24 Mar 2020 10:53:53 -0600 Subject: reverse_proxy: Add support for SRV backends (#3180) * reverse_proxy: Begin SRV lookup support (WIP) * reverse_proxy: Finish adding support for SRV-based backends (#3179) --- modules/caddyhttp/reverseproxy/hosts.go | 73 ++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 23 deletions(-) (limited to 'modules/caddyhttp/reverseproxy/hosts.go') diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index 602aab2..a7709ee 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -17,6 +17,8 @@ package reverseproxy import ( "context" "fmt" + "net" + "net/http" "strconv" "sync/atomic" @@ -63,10 +65,10 @@ type UpstreamPool []*Upstream type Upstream struct { Host `json:"-"` - // The [network address](/docs/json/apps/http/#servers/listen) + // The [network address](/docs/conventions#network-addresses) // to dial to connect to the upstream. Must represent precisely // one socket (i.e. no port ranges). A valid network address - // either has a host and port, or is a unix socket address. + // either has a host and port or is a unix socket address. // // Placeholders may be used to make the upstream dynamic, but be // aware of the health check implications of this: a single @@ -75,6 +77,11 @@ type Upstream struct { // backends is down. Also be aware of open proxy vulnerabilities. Dial string `json:"dial,omitempty"` + // If DNS SRV records are used for service discovery with this + // upstream, specify the DNS name for which to look up SRV + // records here, instead of specifying a dial address. + LookupSRV string `json:"lookup_srv,omitempty"` + // The maximum number of simultaneous requests to allow to // this upstream. If set, overrides the global passive health // check UnhealthyRequestCount value. @@ -118,6 +125,47 @@ func (u *Upstream) Full() bool { return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests } +// fillDialInfo returns a filled DialInfo for upstream u, using the request +// context. If the upstream has a SRV lookup configured, that is done and a +// returned address is chosen; otherwise, the upstream's regular dial address +// field is used. Note that the returned value is not a pointer. +func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) { + repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + var addr caddy.ParsedAddress + + if u.LookupSRV != "" { + // perform DNS lookup for SRV records and choose one + srvName := repl.ReplaceAll(u.LookupSRV, "") + _, records, err := net.DefaultResolver.LookupSRV(r.Context(), "", "", srvName) + if err != nil { + return DialInfo{}, err + } + addr.Network = "tcp" + addr.Host = records[0].Target + addr.StartPort, addr.EndPort = uint(records[0].Port), uint(records[0].Port) + } else { + // use provided dial address + var err error + dial := repl.ReplaceAll(u.Dial, "") + addr, err = caddy.ParseNetworkAddress(dial) + if err != nil { + return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", u.Dial, dial, err) + } + if numPorts := addr.PortRangeSize(); numPorts != 1 { + return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d", + u.Dial, dial, numPorts) + } + } + + return DialInfo{ + Upstream: u, + Network: addr.Network, + Address: addr.JoinHostPort(0), + Host: addr.Host, + Port: strconv.Itoa(int(addr.StartPort)), + }, nil +} + // upstreamHost is the basic, in-memory representation // of the state of a remote host. It implements the // Host interface. @@ -204,27 +252,6 @@ func (di DialInfo) String() string { return caddy.JoinNetworkAddress(di.Network, di.Host, di.Port) } -// fillDialInfo returns a filled DialInfo for the given upstream, using -// the given Replacer. Note that the returned value is not a pointer. -func fillDialInfo(upstream *Upstream, repl *caddy.Replacer) (DialInfo, error) { - dial := repl.ReplaceAll(upstream.Dial, "") - addr, err := caddy.ParseNetworkAddress(dial) - if err != nil { - return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", upstream.Dial, dial, err) - } - if numPorts := addr.PortRangeSize(); numPorts != 1 { - return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d", - upstream.Dial, dial, numPorts) - } - return DialInfo{ - Upstream: upstream, - Network: addr.Network, - Address: addr.JoinHostPort(0), - Host: addr.Host, - Port: strconv.Itoa(int(addr.StartPort)), - }, nil -} - // GetDialInfo gets the upstream dialing info out of the context, // and returns true if there was a valid value; false otherwise. func GetDialInfo(ctx context.Context) (DialInfo, bool) { -- cgit v1.2.3