summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorFrancis Lavoie <lavofr@gmail.com>2023-05-05 17:08:10 -0400
committerGitHub <noreply@github.com>2023-05-05 15:08:10 -0600
commit48598e1f2a370c2440b38f0b77e4d74748111b9a (patch)
treef6938e816d076fd7288e333e58eecddab92ef443 /modules
parentcdce452edc5e9cf9127789a868d01864d3276af5 (diff)
reverseproxy: Add `fallback` for some policies, instead of always random (#5488)
Diffstat (limited to 'modules')
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies.go184
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies_test.go92
2 files changed, 237 insertions, 39 deletions
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go
index 4184df5..a2985f1 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go
@@ -18,6 +18,7 @@ import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
+ "encoding/json"
"fmt"
"hash/fnv"
weakrand "math/rand"
@@ -29,6 +30,7 @@ import (
"time"
"github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/caddyconfig"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
@@ -372,6 +374,10 @@ func (r *URIHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
type QueryHashSelection struct {
// The query key whose value is to be hashed and used for upstream selection.
Key string `json:"key,omitempty"`
+
+ // The fallback policy to use if the query key is not present. Defaults to `random`.
+ FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"`
+ fallback Selector
}
// CaddyModule returns the Caddy module information.
@@ -382,12 +388,24 @@ func (QueryHashSelection) CaddyModule() caddy.ModuleInfo {
}
}
-// Select returns an available host, if any.
-func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
+// Provision sets up the module.
+func (s *QueryHashSelection) Provision(ctx caddy.Context) error {
if s.Key == "" {
- return nil
+ return fmt.Errorf("query key is required")
+ }
+ if s.FallbackRaw == nil {
+ s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil)
}
+ mod, err := ctx.LoadModule(s, "FallbackRaw")
+ if err != nil {
+ return fmt.Errorf("loading fallback selection policy: %s", err)
+ }
+ s.fallback = mod.(Selector)
+ return nil
+}
+// Select returns an available host, if any.
+func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
// Since the query may have multiple values for the same key,
// we'll join them to avoid a problem where the user can control
// the upstream that the request goes to by sending multiple values
@@ -397,7 +415,7 @@ func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.
// different request, because the order of the values is significant.
vals := strings.Join(req.URL.Query()[s.Key], ",")
if vals == "" {
- return RandomSelection{}.Select(pool, req, nil)
+ return s.fallback.Select(pool, req, nil)
}
return hostByHashing(pool, vals)
}
@@ -410,6 +428,24 @@ func (s *QueryHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
s.Key = d.Val()
}
+ for nesting := d.Nesting(); d.NextBlock(nesting); {
+ switch d.Val() {
+ case "fallback":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if s.FallbackRaw != nil {
+ return d.Err("fallback selection policy already specified")
+ }
+ mod, err := loadFallbackPolicy(d)
+ if err != nil {
+ return err
+ }
+ s.FallbackRaw = mod
+ default:
+ return d.Errf("unrecognized option '%s'", d.Val())
+ }
+ }
return nil
}
@@ -418,6 +454,10 @@ func (s *QueryHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
type HeaderHashSelection struct {
// The HTTP header field whose value is to be hashed and used for upstream selection.
Field string `json:"field,omitempty"`
+
+ // The fallback policy to use if the header is not present. Defaults to `random`.
+ FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"`
+ fallback Selector
}
// CaddyModule returns the Caddy module information.
@@ -428,12 +468,24 @@ func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo {
}
}
-// Select returns an available host, if any.
-func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
+// Provision sets up the module.
+func (s *HeaderHashSelection) Provision(ctx caddy.Context) error {
if s.Field == "" {
- return nil
+ return fmt.Errorf("header field is required")
+ }
+ if s.FallbackRaw == nil {
+ s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil)
}
+ mod, err := ctx.LoadModule(s, "FallbackRaw")
+ if err != nil {
+ return fmt.Errorf("loading fallback selection policy: %s", err)
+ }
+ s.fallback = mod.(Selector)
+ return nil
+}
+// Select returns an available host, if any.
+func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
// The Host header should be obtained from the req.Host field
// since net/http removes it from the header map.
if s.Field == "Host" && req.Host != "" {
@@ -442,7 +494,7 @@ func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http
val := req.Header.Get(s.Field)
if val == "" {
- return RandomSelection{}.Select(pool, req, nil)
+ return s.fallback.Select(pool, req, nil)
}
return hostByHashing(pool, val)
}
@@ -455,6 +507,24 @@ func (s *HeaderHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
s.Field = d.Val()
}
+ for nesting := d.Nesting(); d.NextBlock(nesting); {
+ switch d.Val() {
+ case "fallback":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if s.FallbackRaw != nil {
+ return d.Err("fallback selection policy already specified")
+ }
+ mod, err := loadFallbackPolicy(d)
+ if err != nil {
+ return err
+ }
+ s.FallbackRaw = mod
+ default:
+ return d.Errf("unrecognized option '%s'", d.Val())
+ }
+ }
return nil
}
@@ -465,6 +535,10 @@ type CookieHashSelection struct {
Name string `json:"name,omitempty"`
// Secret to hash (Hmac256) chosen upstream in cookie
Secret string `json:"secret,omitempty"`
+
+ // The fallback policy to use if the cookie is not present. Defaults to `random`.
+ FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"`
+ fallback Selector
}
// CaddyModule returns the Caddy module information.
@@ -475,15 +549,48 @@ func (CookieHashSelection) CaddyModule() caddy.ModuleInfo {
}
}
-// Select returns an available host, if any.
-func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream {
+// Provision sets up the module.
+func (s *CookieHashSelection) Provision(ctx caddy.Context) error {
if s.Name == "" {
s.Name = "lb"
}
+ if s.FallbackRaw == nil {
+ s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil)
+ }
+ mod, err := ctx.LoadModule(s, "FallbackRaw")
+ if err != nil {
+ return fmt.Errorf("loading fallback selection policy: %s", err)
+ }
+ s.fallback = mod.(Selector)
+ return nil
+}
+
+// Select returns an available host, if any.
+func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream {
+ // selects a new Host using the fallback policy (typically random)
+ // and write a sticky session cookie to the response.
+ selectNewHost := func() *Upstream {
+ upstream := s.fallback.Select(pool, req, w)
+ if upstream == nil {
+ return nil
+ }
+ sha, err := hashCookie(s.Secret, upstream.Dial)
+ if err != nil {
+ return upstream
+ }
+ http.SetCookie(w, &http.Cookie{
+ Name: s.Name,
+ Value: sha,
+ Path: "/",
+ Secure: false,
+ })
+ return upstream
+ }
+
cookie, err := req.Cookie(s.Name)
- // If there's no cookie, select new random host
+ // If there's no cookie, select a host using the fallback policy
if err != nil || cookie == nil {
- return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
+ return selectNewHost()
}
// If the cookie is present, loop over the available upstreams until we find a match
cookieValue := cookie.Value
@@ -496,13 +603,15 @@ func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http
return upstream
}
}
- // If there is no matching host, select new random host
- return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
+ // If there is no matching host, select a host using the fallback policy
+ return selectNewHost()
}
// UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax:
//
-// lb_policy cookie [<name> [<secret>]]
+// lb_policy cookie [<name> [<secret>]] {
+// fallback <policy>
+// }
//
// By default name is `lb`
func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
@@ -517,22 +626,25 @@ func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
default:
return d.ArgErr()
}
- return nil
-}
-
-// Select a new Host randomly and add a sticky session cookie
-func selectNewHostWithCookieHashSelection(pool []*Upstream, w http.ResponseWriter, cookieSecret string, cookieName string) *Upstream {
- randomHost := selectRandomHost(pool)
-
- if randomHost != nil {
- // Hash (HMAC with some key for privacy) the upstream.Dial string as the cookie value
- sha, err := hashCookie(cookieSecret, randomHost.Dial)
- if err == nil {
- // write the cookie.
- http.SetCookie(w, &http.Cookie{Name: cookieName, Value: sha, Path: "/", Secure: false})
+ for nesting := d.Nesting(); d.NextBlock(nesting); {
+ switch d.Val() {
+ case "fallback":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if s.FallbackRaw != nil {
+ return d.Err("fallback selection policy already specified")
+ }
+ mod, err := loadFallbackPolicy(d)
+ if err != nil {
+ return err
+ }
+ s.FallbackRaw = mod
+ default:
+ return d.Errf("unrecognized option '%s'", d.Val())
}
}
- return randomHost
+ return nil
}
// hashCookie hashes (HMAC 256) some data with the secret
@@ -627,6 +739,20 @@ func hash(s string) uint32 {
return h.Sum32()
}
+func loadFallbackPolicy(d *caddyfile.Dispenser) (json.RawMessage, error) {
+ name := d.Val()
+ modID := "http.reverse_proxy.selection_policies." + name
+ unm, err := caddyfile.UnmarshalModule(d, modID)
+ if err != nil {
+ return nil, err
+ }
+ sel, ok := unm.(Selector)
+ if !ok {
+ return nil, d.Errf("module %s (%T) is not a reverseproxy.Selector", modID, unm)
+ }
+ return caddyconfig.JSONModuleObject(sel, "policy", name, nil), nil
+}
+
// Interface guards
var (
_ Selector = (*RandomSelection)(nil)
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
index d2b7b3d..93dcb77 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
@@ -20,6 +20,8 @@ import (
"net/http/httptest"
"testing"
+ "github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/caddyconfig"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
@@ -33,7 +35,7 @@ func testPool() UpstreamPool {
func TestRoundRobinPolicy(t *testing.T) {
pool := testPool()
- rrPolicy := new(RoundRobinSelection)
+ rrPolicy := RoundRobinSelection{}
req, _ := http.NewRequest("GET", "/", nil)
h := rrPolicy.Select(pool, req, nil)
@@ -74,7 +76,7 @@ func TestRoundRobinPolicy(t *testing.T) {
func TestLeastConnPolicy(t *testing.T) {
pool := testPool()
- lcPolicy := new(LeastConnSelection)
+ lcPolicy := LeastConnSelection{}
req, _ := http.NewRequest("GET", "/", nil)
pool[0].countRequest(10)
@@ -92,7 +94,7 @@ func TestLeastConnPolicy(t *testing.T) {
func TestIPHashPolicy(t *testing.T) {
pool := testPool()
- ipHash := new(IPHashSelection)
+ ipHash := IPHashSelection{}
req, _ := http.NewRequest("GET", "/", nil)
// We should be able to predict where every request is routed.
@@ -234,7 +236,7 @@ func TestIPHashPolicy(t *testing.T) {
func TestClientIPHashPolicy(t *testing.T) {
pool := testPool()
- ipHash := new(ClientIPHashSelection)
+ ipHash := ClientIPHashSelection{}
req, _ := http.NewRequest("GET", "/", nil)
req = req.WithContext(context.WithValue(req.Context(), caddyhttp.VarsCtxKey, make(map[string]any)))
@@ -377,7 +379,7 @@ func TestClientIPHashPolicy(t *testing.T) {
func TestFirstPolicy(t *testing.T) {
pool := testPool()
- firstPolicy := new(FirstSelection)
+ firstPolicy := FirstSelection{}
req := httptest.NewRequest(http.MethodGet, "/", nil)
h := firstPolicy.Select(pool, req, nil)
@@ -393,8 +395,15 @@ func TestFirstPolicy(t *testing.T) {
}
func TestQueryHashPolicy(t *testing.T) {
- pool := testPool()
+ ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
+ defer cancel()
queryPolicy := QueryHashSelection{Key: "foo"}
+ if err := queryPolicy.Provision(ctx); err != nil {
+ t.Errorf("Provision error: %v", err)
+ t.FailNow()
+ }
+
+ pool := testPool()
request := httptest.NewRequest(http.MethodGet, "/?foo=1", nil)
h := queryPolicy.Select(pool, request, nil)
@@ -463,7 +472,7 @@ func TestQueryHashPolicy(t *testing.T) {
func TestURIHashPolicy(t *testing.T) {
pool := testPool()
- uriPolicy := new(URIHashSelection)
+ uriPolicy := URIHashSelection{}
request := httptest.NewRequest(http.MethodGet, "/test", nil)
h := uriPolicy.Select(pool, request, nil)
@@ -552,8 +561,7 @@ func TestRandomChoicePolicy(t *testing.T) {
pool[2].countRequest(30)
request := httptest.NewRequest(http.MethodGet, "/test", nil)
- randomChoicePolicy := new(RandomChoiceSelection)
- randomChoicePolicy.Choose = 2
+ randomChoicePolicy := RandomChoiceSelection{Choose: 2}
h := randomChoicePolicy.Select(pool, request, nil)
@@ -568,6 +576,14 @@ func TestRandomChoicePolicy(t *testing.T) {
}
func TestCookieHashPolicy(t *testing.T) {
+ ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
+ defer cancel()
+ cookieHashPolicy := CookieHashSelection{}
+ if err := cookieHashPolicy.Provision(ctx); err != nil {
+ t.Errorf("Provision error: %v", err)
+ t.FailNow()
+ }
+
pool := testPool()
pool[0].Dial = "localhost:8080"
pool[1].Dial = "localhost:8081"
@@ -577,7 +593,7 @@ func TestCookieHashPolicy(t *testing.T) {
pool[2].setHealthy(false)
request := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
- cookieHashPolicy := new(CookieHashSelection)
+
h := cookieHashPolicy.Select(pool, request, w)
cookieServer1 := w.Result().Cookies()[0]
if cookieServer1 == nil {
@@ -614,3 +630,59 @@ func TestCookieHashPolicy(t *testing.T) {
t.Error("Expected cookieHashPolicy to set a new cookie.")
}
}
+
+func TestCookieHashPolicyWithFirstFallback(t *testing.T) {
+ ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
+ defer cancel()
+ cookieHashPolicy := CookieHashSelection{
+ FallbackRaw: caddyconfig.JSONModuleObject(FirstSelection{}, "policy", "first", nil),
+ }
+ if err := cookieHashPolicy.Provision(ctx); err != nil {
+ t.Errorf("Provision error: %v", err)
+ t.FailNow()
+ }
+
+ pool := testPool()
+ pool[0].Dial = "localhost:8080"
+ pool[1].Dial = "localhost:8081"
+ pool[2].Dial = "localhost:8082"
+ pool[0].setHealthy(true)
+ pool[1].setHealthy(true)
+ pool[2].setHealthy(true)
+ request := httptest.NewRequest(http.MethodGet, "/test", nil)
+ w := httptest.NewRecorder()
+
+ h := cookieHashPolicy.Select(pool, request, w)
+ cookieServer1 := w.Result().Cookies()[0]
+ if cookieServer1 == nil {
+ t.Fatal("cookieHashPolicy should set a cookie")
+ }
+ if cookieServer1.Name != "lb" {
+ t.Error("cookieHashPolicy should set a cookie with name lb")
+ }
+ if h != pool[0] {
+ t.Errorf("Expected cookieHashPolicy host to be the first only available host, got %s", h)
+ }
+ request = httptest.NewRequest(http.MethodGet, "/test", nil)
+ w = httptest.NewRecorder()
+ request.AddCookie(cookieServer1)
+ h = cookieHashPolicy.Select(pool, request, w)
+ if h != pool[0] {
+ t.Errorf("Expected cookieHashPolicy host to stick to the first host (matching cookie), got %s", h)
+ }
+ s := w.Result().Cookies()
+ if len(s) != 0 {
+ t.Error("Expected cookieHashPolicy to not set a new cookie.")
+ }
+ pool[0].setHealthy(false)
+ request = httptest.NewRequest(http.MethodGet, "/test", nil)
+ w = httptest.NewRecorder()
+ request.AddCookie(cookieServer1)
+ h = cookieHashPolicy.Select(pool, request, w)
+ if h != pool[1] {
+ t.Errorf("Expected cookieHashPolicy to select the next first available host, got %s", h)
+ }
+ if w.Result().Cookies() == nil {
+ t.Error("Expected cookieHashPolicy to set a new cookie.")
+ }
+}