From 48598e1f2a370c2440b38f0b77e4d74748111b9a Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Fri, 5 May 2023 17:08:10 -0400 Subject: reverseproxy: Add `fallback` for some policies, instead of always random (#5488) --- .../caddyhttp/reverseproxy/selectionpolicies.go | 184 +++++++++++++++++---- .../reverseproxy/selectionpolicies_test.go | 92 +++++++++-- 2 files changed, 237 insertions(+), 39 deletions(-) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go index 4184df5..a2985f1 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -18,6 +18,7 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" "hash/fnv" weakrand "math/rand" @@ -29,6 +30,7 @@ import ( "time" "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) @@ -372,6 +374,10 @@ func (r *URIHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { type QueryHashSelection struct { // The query key whose value is to be hashed and used for upstream selection. Key string `json:"key,omitempty"` + + // The fallback policy to use if the query key is not present. Defaults to `random`. + FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"` + fallback Selector } // CaddyModule returns the Caddy module information. @@ -382,12 +388,24 @@ func (QueryHashSelection) CaddyModule() caddy.ModuleInfo { } } -// Select returns an available host, if any. -func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { +// Provision sets up the module. +func (s *QueryHashSelection) Provision(ctx caddy.Context) error { if s.Key == "" { - return nil + return fmt.Errorf("query key is required") + } + if s.FallbackRaw == nil { + s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil) } + mod, err := ctx.LoadModule(s, "FallbackRaw") + if err != nil { + return fmt.Errorf("loading fallback selection policy: %s", err) + } + s.fallback = mod.(Selector) + return nil +} +// Select returns an available host, if any. +func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { // Since the query may have multiple values for the same key, // we'll join them to avoid a problem where the user can control // the upstream that the request goes to by sending multiple values @@ -397,7 +415,7 @@ func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http. // different request, because the order of the values is significant. vals := strings.Join(req.URL.Query()[s.Key], ",") if vals == "" { - return RandomSelection{}.Select(pool, req, nil) + return s.fallback.Select(pool, req, nil) } return hostByHashing(pool, vals) } @@ -410,6 +428,24 @@ func (s *QueryHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } s.Key = d.Val() } + for nesting := d.Nesting(); d.NextBlock(nesting); { + switch d.Val() { + case "fallback": + if !d.NextArg() { + return d.ArgErr() + } + if s.FallbackRaw != nil { + return d.Err("fallback selection policy already specified") + } + mod, err := loadFallbackPolicy(d) + if err != nil { + return err + } + s.FallbackRaw = mod + default: + return d.Errf("unrecognized option '%s'", d.Val()) + } + } return nil } @@ -418,6 +454,10 @@ func (s *QueryHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { type HeaderHashSelection struct { // The HTTP header field whose value is to be hashed and used for upstream selection. Field string `json:"field,omitempty"` + + // The fallback policy to use if the header is not present. Defaults to `random`. + FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"` + fallback Selector } // CaddyModule returns the Caddy module information. @@ -428,12 +468,24 @@ func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo { } } -// Select returns an available host, if any. -func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { +// Provision sets up the module. +func (s *HeaderHashSelection) Provision(ctx caddy.Context) error { if s.Field == "" { - return nil + return fmt.Errorf("header field is required") + } + if s.FallbackRaw == nil { + s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil) } + mod, err := ctx.LoadModule(s, "FallbackRaw") + if err != nil { + return fmt.Errorf("loading fallback selection policy: %s", err) + } + s.fallback = mod.(Selector) + return nil +} +// Select returns an available host, if any. +func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { // The Host header should be obtained from the req.Host field // since net/http removes it from the header map. if s.Field == "Host" && req.Host != "" { @@ -442,7 +494,7 @@ func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http val := req.Header.Get(s.Field) if val == "" { - return RandomSelection{}.Select(pool, req, nil) + return s.fallback.Select(pool, req, nil) } return hostByHashing(pool, val) } @@ -455,6 +507,24 @@ func (s *HeaderHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } s.Field = d.Val() } + for nesting := d.Nesting(); d.NextBlock(nesting); { + switch d.Val() { + case "fallback": + if !d.NextArg() { + return d.ArgErr() + } + if s.FallbackRaw != nil { + return d.Err("fallback selection policy already specified") + } + mod, err := loadFallbackPolicy(d) + if err != nil { + return err + } + s.FallbackRaw = mod + default: + return d.Errf("unrecognized option '%s'", d.Val()) + } + } return nil } @@ -465,6 +535,10 @@ type CookieHashSelection struct { Name string `json:"name,omitempty"` // Secret to hash (Hmac256) chosen upstream in cookie Secret string `json:"secret,omitempty"` + + // The fallback policy to use if the cookie is not present. Defaults to `random`. + FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"` + fallback Selector } // CaddyModule returns the Caddy module information. @@ -475,15 +549,48 @@ func (CookieHashSelection) CaddyModule() caddy.ModuleInfo { } } -// Select returns an available host, if any. -func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream { +// Provision sets up the module. +func (s *CookieHashSelection) Provision(ctx caddy.Context) error { if s.Name == "" { s.Name = "lb" } + if s.FallbackRaw == nil { + s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil) + } + mod, err := ctx.LoadModule(s, "FallbackRaw") + if err != nil { + return fmt.Errorf("loading fallback selection policy: %s", err) + } + s.fallback = mod.(Selector) + return nil +} + +// Select returns an available host, if any. +func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream { + // selects a new Host using the fallback policy (typically random) + // and write a sticky session cookie to the response. + selectNewHost := func() *Upstream { + upstream := s.fallback.Select(pool, req, w) + if upstream == nil { + return nil + } + sha, err := hashCookie(s.Secret, upstream.Dial) + if err != nil { + return upstream + } + http.SetCookie(w, &http.Cookie{ + Name: s.Name, + Value: sha, + Path: "/", + Secure: false, + }) + return upstream + } + cookie, err := req.Cookie(s.Name) - // If there's no cookie, select new random host + // If there's no cookie, select a host using the fallback policy if err != nil || cookie == nil { - return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name) + return selectNewHost() } // If the cookie is present, loop over the available upstreams until we find a match cookieValue := cookie.Value @@ -496,13 +603,15 @@ func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http return upstream } } - // If there is no matching host, select new random host - return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name) + // If there is no matching host, select a host using the fallback policy + return selectNewHost() } // UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax: // -// lb_policy cookie [ []] +// lb_policy cookie [ []] { +// fallback +// } // // By default name is `lb` func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { @@ -517,22 +626,25 @@ func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { default: return d.ArgErr() } - return nil -} - -// Select a new Host randomly and add a sticky session cookie -func selectNewHostWithCookieHashSelection(pool []*Upstream, w http.ResponseWriter, cookieSecret string, cookieName string) *Upstream { - randomHost := selectRandomHost(pool) - - if randomHost != nil { - // Hash (HMAC with some key for privacy) the upstream.Dial string as the cookie value - sha, err := hashCookie(cookieSecret, randomHost.Dial) - if err == nil { - // write the cookie. - http.SetCookie(w, &http.Cookie{Name: cookieName, Value: sha, Path: "/", Secure: false}) + for nesting := d.Nesting(); d.NextBlock(nesting); { + switch d.Val() { + case "fallback": + if !d.NextArg() { + return d.ArgErr() + } + if s.FallbackRaw != nil { + return d.Err("fallback selection policy already specified") + } + mod, err := loadFallbackPolicy(d) + if err != nil { + return err + } + s.FallbackRaw = mod + default: + return d.Errf("unrecognized option '%s'", d.Val()) } } - return randomHost + return nil } // hashCookie hashes (HMAC 256) some data with the secret @@ -627,6 +739,20 @@ func hash(s string) uint32 { return h.Sum32() } +func loadFallbackPolicy(d *caddyfile.Dispenser) (json.RawMessage, error) { + name := d.Val() + modID := "http.reverse_proxy.selection_policies." + name + unm, err := caddyfile.UnmarshalModule(d, modID) + if err != nil { + return nil, err + } + sel, ok := unm.(Selector) + if !ok { + return nil, d.Errf("module %s (%T) is not a reverseproxy.Selector", modID, unm) + } + return caddyconfig.JSONModuleObject(sel, "policy", name, nil), nil +} + // Interface guards var ( _ Selector = (*RandomSelection)(nil) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go index d2b7b3d..93dcb77 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -20,6 +20,8 @@ import ( "net/http/httptest" "testing" + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig" "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) @@ -33,7 +35,7 @@ func testPool() UpstreamPool { func TestRoundRobinPolicy(t *testing.T) { pool := testPool() - rrPolicy := new(RoundRobinSelection) + rrPolicy := RoundRobinSelection{} req, _ := http.NewRequest("GET", "/", nil) h := rrPolicy.Select(pool, req, nil) @@ -74,7 +76,7 @@ func TestRoundRobinPolicy(t *testing.T) { func TestLeastConnPolicy(t *testing.T) { pool := testPool() - lcPolicy := new(LeastConnSelection) + lcPolicy := LeastConnSelection{} req, _ := http.NewRequest("GET", "/", nil) pool[0].countRequest(10) @@ -92,7 +94,7 @@ func TestLeastConnPolicy(t *testing.T) { func TestIPHashPolicy(t *testing.T) { pool := testPool() - ipHash := new(IPHashSelection) + ipHash := IPHashSelection{} req, _ := http.NewRequest("GET", "/", nil) // We should be able to predict where every request is routed. @@ -234,7 +236,7 @@ func TestIPHashPolicy(t *testing.T) { func TestClientIPHashPolicy(t *testing.T) { pool := testPool() - ipHash := new(ClientIPHashSelection) + ipHash := ClientIPHashSelection{} req, _ := http.NewRequest("GET", "/", nil) req = req.WithContext(context.WithValue(req.Context(), caddyhttp.VarsCtxKey, make(map[string]any))) @@ -377,7 +379,7 @@ func TestClientIPHashPolicy(t *testing.T) { func TestFirstPolicy(t *testing.T) { pool := testPool() - firstPolicy := new(FirstSelection) + firstPolicy := FirstSelection{} req := httptest.NewRequest(http.MethodGet, "/", nil) h := firstPolicy.Select(pool, req, nil) @@ -393,8 +395,15 @@ func TestFirstPolicy(t *testing.T) { } func TestQueryHashPolicy(t *testing.T) { - pool := testPool() + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() queryPolicy := QueryHashSelection{Key: "foo"} + if err := queryPolicy.Provision(ctx); err != nil { + t.Errorf("Provision error: %v", err) + t.FailNow() + } + + pool := testPool() request := httptest.NewRequest(http.MethodGet, "/?foo=1", nil) h := queryPolicy.Select(pool, request, nil) @@ -463,7 +472,7 @@ func TestQueryHashPolicy(t *testing.T) { func TestURIHashPolicy(t *testing.T) { pool := testPool() - uriPolicy := new(URIHashSelection) + uriPolicy := URIHashSelection{} request := httptest.NewRequest(http.MethodGet, "/test", nil) h := uriPolicy.Select(pool, request, nil) @@ -552,8 +561,7 @@ func TestRandomChoicePolicy(t *testing.T) { pool[2].countRequest(30) request := httptest.NewRequest(http.MethodGet, "/test", nil) - randomChoicePolicy := new(RandomChoiceSelection) - randomChoicePolicy.Choose = 2 + randomChoicePolicy := RandomChoiceSelection{Choose: 2} h := randomChoicePolicy.Select(pool, request, nil) @@ -568,6 +576,14 @@ func TestRandomChoicePolicy(t *testing.T) { } func TestCookieHashPolicy(t *testing.T) { + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + cookieHashPolicy := CookieHashSelection{} + if err := cookieHashPolicy.Provision(ctx); err != nil { + t.Errorf("Provision error: %v", err) + t.FailNow() + } + pool := testPool() pool[0].Dial = "localhost:8080" pool[1].Dial = "localhost:8081" @@ -577,7 +593,7 @@ func TestCookieHashPolicy(t *testing.T) { pool[2].setHealthy(false) request := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() - cookieHashPolicy := new(CookieHashSelection) + h := cookieHashPolicy.Select(pool, request, w) cookieServer1 := w.Result().Cookies()[0] if cookieServer1 == nil { @@ -614,3 +630,59 @@ func TestCookieHashPolicy(t *testing.T) { t.Error("Expected cookieHashPolicy to set a new cookie.") } } + +func TestCookieHashPolicyWithFirstFallback(t *testing.T) { + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + cookieHashPolicy := CookieHashSelection{ + FallbackRaw: caddyconfig.JSONModuleObject(FirstSelection{}, "policy", "first", nil), + } + if err := cookieHashPolicy.Provision(ctx); err != nil { + t.Errorf("Provision error: %v", err) + t.FailNow() + } + + pool := testPool() + pool[0].Dial = "localhost:8080" + pool[1].Dial = "localhost:8081" + pool[2].Dial = "localhost:8082" + pool[0].setHealthy(true) + pool[1].setHealthy(true) + pool[2].setHealthy(true) + request := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + + h := cookieHashPolicy.Select(pool, request, w) + cookieServer1 := w.Result().Cookies()[0] + if cookieServer1 == nil { + t.Fatal("cookieHashPolicy should set a cookie") + } + if cookieServer1.Name != "lb" { + t.Error("cookieHashPolicy should set a cookie with name lb") + } + if h != pool[0] { + t.Errorf("Expected cookieHashPolicy host to be the first only available host, got %s", h) + } + request = httptest.NewRequest(http.MethodGet, "/test", nil) + w = httptest.NewRecorder() + request.AddCookie(cookieServer1) + h = cookieHashPolicy.Select(pool, request, w) + if h != pool[0] { + t.Errorf("Expected cookieHashPolicy host to stick to the first host (matching cookie), got %s", h) + } + s := w.Result().Cookies() + if len(s) != 0 { + t.Error("Expected cookieHashPolicy to not set a new cookie.") + } + pool[0].setHealthy(false) + request = httptest.NewRequest(http.MethodGet, "/test", nil) + w = httptest.NewRecorder() + request.AddCookie(cookieServer1) + h = cookieHashPolicy.Select(pool, request, w) + if h != pool[1] { + t.Errorf("Expected cookieHashPolicy to select the next first available host, got %s", h) + } + if w.Result().Cookies() == nil { + t.Error("Expected cookieHashPolicy to set a new cookie.") + } +} -- cgit v1.2.3