summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorDimitri Masson <30894448+d-masson@users.noreply.github.com>2020-11-16 20:47:15 +0100
committerGitHub <noreply@github.com>2020-11-16 12:47:15 -0700
commit99b8f44486b766f220a33906d84ac05af942f260 (patch)
tree42fc7d997937e940e444e94dd8ac0e4cf1384757 /modules
parent670b723e3802ac37942dad07dc194539bccce9ff (diff)
reverse_proxy: Fix random_choose selection policy (#3811)
Diffstat (limited to 'modules')
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies.go9
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies_test.go51
2 files changed, 58 insertions, 2 deletions
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go
index 343140f..2aef63d 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go
@@ -397,13 +397,18 @@ func leastRequests(upstreams []*Upstream) *Upstream {
return nil
}
var best []*Upstream
- var bestReqs int
+ var bestReqs int = -1
for _, upstream := range upstreams {
+ if upstream == nil {
+ continue
+ }
reqs := upstream.NumRequests()
if reqs == 0 {
return upstream
}
- if reqs <= bestReqs {
+ // If bestReqs was just initialized to -1
+ // we need to append upstream also
+ if reqs <= bestReqs || bestReqs == -1 {
bestReqs = reqs
best = append(best, upstream)
}
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
index e9939d6..49585da 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
@@ -271,3 +271,54 @@ func TestURIHashPolicy(t *testing.T) {
t.Error("Expected uri policy policy host to be nil.")
}
}
+
+func TestLeastRequests(t *testing.T) {
+ pool := testPool()
+ pool[0].Dial = "localhost:8080"
+ pool[1].Dial = "localhost:8081"
+ pool[2].Dial = "localhost:8082"
+ pool[0].SetHealthy(true)
+ pool[1].SetHealthy(true)
+ pool[2].SetHealthy(true)
+ pool[0].CountRequest(10)
+ pool[1].CountRequest(20)
+ pool[2].CountRequest(30)
+
+ result := leastRequests(pool)
+
+ if result == nil {
+ t.Error("Least request should not return nil")
+ }
+
+ if result != pool[0] {
+ t.Error("Least request should return pool[0]")
+ }
+}
+
+func TestRandomChoicePolicy(t *testing.T) {
+ pool := testPool()
+ pool[0].Dial = "localhost:8080"
+ pool[1].Dial = "localhost:8081"
+ pool[2].Dial = "localhost:8082"
+ pool[0].SetHealthy(false)
+ pool[1].SetHealthy(true)
+ pool[2].SetHealthy(true)
+ pool[0].CountRequest(10)
+ pool[1].CountRequest(20)
+ pool[2].CountRequest(30)
+
+ request := httptest.NewRequest(http.MethodGet, "/test", nil)
+ randomChoicePolicy := new(RandomChoiceSelection)
+ randomChoicePolicy.Choose = 2
+
+ h := randomChoicePolicy.Select(pool, request)
+
+ if h == nil {
+ t.Error("RandomChoicePolicy should not return nil")
+ }
+
+ if h == pool[0] {
+ t.Error("RandomChoicePolicy should not choose pool[0]")
+ }
+
+}