summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy/upstreams.go
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/reverseproxy/upstreams.go')
-rw-r--r--modules/caddyhttp/reverseproxy/upstreams.go97
1 files changed, 70 insertions, 27 deletions
diff --git a/modules/caddyhttp/reverseproxy/upstreams.go b/modules/caddyhttp/reverseproxy/upstreams.go
index 7a90016..2d21a5c 100644
--- a/modules/caddyhttp/reverseproxy/upstreams.go
+++ b/modules/caddyhttp/reverseproxy/upstreams.go
@@ -8,12 +8,12 @@ import (
"net"
"net/http"
"strconv"
- "strings"
"sync"
"time"
- "github.com/caddyserver/caddy/v2"
"go.uber.org/zap"
+
+ "github.com/caddyserver/caddy/v2"
)
func init() {
@@ -114,7 +114,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
cached := srvs[suAddr]
srvsMu.RUnlock()
if cached.isFresh() {
- return cached.upstreams, nil
+ return allNew(cached.upstreams), nil
}
// otherwise, obtain a write-lock to update the cached value
@@ -126,7 +126,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// have refreshed it in the meantime before we re-obtained our lock
cached = srvs[suAddr]
if cached.isFresh() {
- return cached.upstreams, nil
+ return allNew(cached.upstreams), nil
}
su.logger.Debug("refreshing SRV upstreams",
@@ -145,7 +145,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
su.logger.Warn("SRV records filtered", zap.Error(err))
}
- upstreams := make([]*Upstream, len(records))
+ upstreams := make([]Upstream, len(records))
for i, rec := range records {
su.logger.Debug("discovered SRV record",
zap.String("target", rec.Target),
@@ -153,7 +153,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
zap.Uint16("priority", rec.Priority),
zap.Uint16("weight", rec.Weight))
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
- upstreams[i] = &Upstream{Dial: addr}
+ upstreams[i] = Upstream{Dial: addr}
}
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
@@ -170,7 +170,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
upstreams: upstreams,
}
- return upstreams, nil
+ return allNew(upstreams), nil
}
func (su SRVUpstreams) String() string {
@@ -206,13 +206,18 @@ func (SRVUpstreams) formattedAddr(service, proto, name string) string {
type srvLookup struct {
srvUpstreams SRVUpstreams
freshness time.Time
- upstreams []*Upstream
+ upstreams []Upstream
}
func (sl srvLookup) isFresh() bool {
return time.Since(sl.freshness) < time.Duration(sl.srvUpstreams.Refresh)
}
+type IPVersions struct {
+ IPv4 *bool `json:"ipv4,omitempty"`
+ IPv6 *bool `json:"ipv6,omitempty"`
+}
+
// AUpstreams provides upstreams from A/AAAA lookups.
// Results are cached and refreshed at the configured
// refresh interval.
@@ -240,7 +245,14 @@ type AUpstreams struct {
// A negative value disables this.
FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
+ // The IP versions to resolve for. By default, both
+ // "ipv4" and "ipv6" will be enabled, which
+ // correspond to A and AAAA records respectively.
+ Versions *IPVersions `json:"versions,omitempty"`
+
resolver *net.Resolver
+
+ logger *zap.Logger
}
// CaddyModule returns the Caddy module information.
@@ -251,7 +263,8 @@ func (AUpstreams) CaddyModule() caddy.ModuleInfo {
}
}
-func (au *AUpstreams) Provision(_ caddy.Context) error {
+func (au *AUpstreams) Provision(ctx caddy.Context) error {
+ au.logger = ctx.Logger()
if au.Refresh == 0 {
au.Refresh = caddy.Duration(time.Minute)
}
@@ -286,14 +299,36 @@ func (au *AUpstreams) Provision(_ caddy.Context) error {
func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
- auStr := repl.ReplaceAll(au.String(), "")
+
+ resolveIpv4 := au.Versions == nil || au.Versions.IPv4 == nil || *au.Versions.IPv4
+ resolveIpv6 := au.Versions == nil || au.Versions.IPv6 == nil || *au.Versions.IPv6
+
+ // Map ipVersion early, so we can use it as part of the cache-key.
+ // This should be fairly inexpensive and comes and the upside of
+ // allowing the same dynamic upstream (name + port combination)
+ // to be used multiple times with different ip versions.
+ //
+ // It also forced a cache-miss if a previously cached dynamic
+ // upstream changes its ip version, e.g. after a config reload,
+ // while keeping the cache-invalidation as simple as it currently is.
+ var ipVersion string
+ switch {
+ case resolveIpv4 && !resolveIpv6:
+ ipVersion = "ip4"
+ case !resolveIpv4 && resolveIpv6:
+ ipVersion = "ip6"
+ default:
+ ipVersion = "ip"
+ }
+
+ auStr := repl.ReplaceAll(au.String()+ipVersion, "")
// first, use a cheap read-lock to return a cached result quickly
aAaaaMu.RLock()
cached := aAaaa[auStr]
aAaaaMu.RUnlock()
if cached.isFresh() {
- return cached.upstreams, nil
+ return allNew(cached.upstreams), nil
}
// otherwise, obtain a write-lock to update the cached value
@@ -305,26 +340,33 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// have refreshed it in the meantime before we re-obtained our lock
cached = aAaaa[auStr]
if cached.isFresh() {
- return cached.upstreams, nil
+ return allNew(cached.upstreams), nil
}
name := repl.ReplaceAll(au.Name, "")
port := repl.ReplaceAll(au.Port, "")
- ips, err := au.resolver.LookupIPAddr(r.Context(), name)
+ au.logger.Debug("refreshing A upstreams",
+ zap.String("version", ipVersion),
+ zap.String("name", name),
+ zap.String("port", port))
+
+ ips, err := au.resolver.LookupIP(r.Context(), ipVersion, name)
if err != nil {
return nil, err
}
- upstreams := make([]*Upstream, len(ips))
+ upstreams := make([]Upstream, len(ips))
for i, ip := range ips {
- upstreams[i] = &Upstream{
+ au.logger.Debug("discovered A record",
+ zap.String("ip", ip.String()))
+ upstreams[i] = Upstream{
Dial: net.JoinHostPort(ip.String(), port),
}
}
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
- if cached.freshness.IsZero() && len(srvs) >= 100 {
+ if cached.freshness.IsZero() && len(aAaaa) >= 100 {
for randomKey := range aAaaa {
delete(aAaaa, randomKey)
break
@@ -337,7 +379,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
upstreams: upstreams,
}
- return upstreams, nil
+ return allNew(upstreams), nil
}
func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port) }
@@ -345,7 +387,7 @@ func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port)
type aLookup struct {
aUpstreams AUpstreams
freshness time.Time
- upstreams []*Upstream
+ upstreams []Upstream
}
func (al aLookup) isFresh() bool {
@@ -439,16 +481,9 @@ type UpstreamResolver struct {
// and ensures they're ready to be used.
func (u *UpstreamResolver) ParseAddresses() error {
for _, v := range u.Addresses {
- addr, err := caddy.ParseNetworkAddress(v)
+ addr, err := caddy.ParseNetworkAddressWithDefaults(v, "udp", 53)
if err != nil {
- // If a port wasn't specified for the resolver,
- // try defaulting to 53 and parse again
- if strings.Contains(err.Error(), "missing port in address") {
- addr, err = caddy.ParseNetworkAddress(v + ":53")
- }
- if err != nil {
- return err
- }
+ return err
}
if addr.PortRangeSize() != 1 {
return fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr)
@@ -458,6 +493,14 @@ func (u *UpstreamResolver) ParseAddresses() error {
return nil
}
+func allNew(upstreams []Upstream) []*Upstream {
+ results := make([]*Upstream, len(upstreams))
+ for i := range upstreams {
+ results[i] = &Upstream{Dial: upstreams[i].Dial}
+ }
+ return results
+}
+
var (
srvs = make(map[string]srvLookup)
srvsMu sync.RWMutex