summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/rewrite
diff options
context:
space:
mode:
authorMatthew Holt <mholt@users.noreply.github.com>2020-01-11 11:40:03 -0700
committerMatthew Holt <mholt@users.noreply.github.com>2020-01-11 11:40:03 -0700
commitd876de61e512db7a31a7ae59723d5134048f283e (patch)
tree9f23f2946f82716d94a2de6f9eb6f715d9dbf85b /modules/caddyhttp/rewrite
parent8be1f0ea668492000cdefbd937e0359bdc24bfc1 (diff)
rewrite: Fix query string logic
Diffstat (limited to 'modules/caddyhttp/rewrite')
-rw-r--r--modules/caddyhttp/rewrite/rewrite.go48
-rw-r--r--modules/caddyhttp/rewrite/rewrite_test.go8
2 files changed, 34 insertions, 22 deletions
diff --git a/modules/caddyhttp/rewrite/rewrite.go b/modules/caddyhttp/rewrite/rewrite.go
index d946447..c069db9 100644
--- a/modules/caddyhttp/rewrite/rewrite.go
+++ b/modules/caddyhttp/rewrite/rewrite.go
@@ -171,23 +171,31 @@ func (rewr Rewrite) rewrite(r *http.Request, repl *caddy.Replacer, logger *zap.L
// buildQueryString takes an input query string and
// performs replacements on each component, returning
-// the resulting query string.
+// the resulting query string. This function appends
+// duplicate keys rather than replaces.
func buildQueryString(qs string, repl *caddy.Replacer) string {
var sb strings.Builder
- var wroteKey bool
+
+ // first component must be key, which is the same
+ // as if we just wrote a value in previous iteration
+ wroteVal := true
for len(qs) > 0 {
- // determine the end of this component
+ // determine the end of this component, which will be at
+ // the next equal sign or ampersand, whichever comes first
nextEq, nextAmp := strings.Index(qs, "="), strings.Index(qs, "&")
- end := min(nextEq, nextAmp)
- if end == -1 {
- end = len(qs) // if there is nothing left, go to end of string
+ ampIsNext := nextAmp >= 0 && (nextAmp < nextEq || nextEq < 0)
+ end := len(qs) // assume no delimiter remains...
+ if ampIsNext {
+ end = nextAmp // ...unless ampersand is first...
+ } else if nextEq >= 0 && (nextEq < nextAmp || nextAmp < 0) {
+ end = nextEq // ...or unless equal is first.
}
// consume the component and write the result
comp := qs[:end]
comp, _ = repl.ReplaceFunc(comp, func(name, val string) (string, error) {
- if name == "http.request.uri.query" {
+ if name == "http.request.uri.query" && wroteVal {
return val, nil // already escaped
}
return url.QueryEscape(val), nil
@@ -197,29 +205,25 @@ func buildQueryString(qs string, repl *caddy.Replacer) string {
}
qs = qs[end:]
- if wroteKey {
+ // if previous iteration wrote a value,
+ // that means we are writing a key
+ if wroteVal {
+ if sb.Len() > 0 {
+ sb.WriteRune('&')
+ }
+ } else {
sb.WriteRune('=')
- } else if sb.Len() > 0 {
- sb.WriteRune('&')
}
-
- // remember that we just wrote a key, which is if the next
- // delimiter is an equals sign or if there is no ampersand
- wroteKey = nextEq < nextAmp || nextAmp < 0
-
sb.WriteString(comp)
+
+ // remember for the next iteration that we just wrote a value,
+ // which means the next iteration MUST write a key
+ wroteVal = ampIsNext
}
return sb.String()
}
-func min(a, b int) int {
- if b < a {
- return b
- }
- return a
-}
-
// replacer describes a simple and fast substring replacement.
type replacer struct {
// The substring to find. Supports placeholders.
diff --git a/modules/caddyhttp/rewrite/rewrite_test.go b/modules/caddyhttp/rewrite/rewrite_test.go
index beff499..de82d8d 100644
--- a/modules/caddyhttp/rewrite/rewrite_test.go
+++ b/modules/caddyhttp/rewrite/rewrite_test.go
@@ -139,6 +139,11 @@ func TestRewrite(t *testing.T) {
expect: newRequest(t, "GET", "/foo/bar"),
},
{
+ rule: Rewrite{URI: "?qs={http.request.uri.query}"},
+ input: newRequest(t, "GET", "/foo?a=b&c=d"),
+ expect: newRequest(t, "GET", "/foo?qs=a%3Db%26c%3Dd"),
+ },
+ {
rule: Rewrite{URI: "/foo?{http.request.uri.query}#frag"},
input: newRequest(t, "GET", "/foo/bar?a=b"),
expect: newRequest(t, "GET", "/foo?a=b#frag"),
@@ -216,6 +221,9 @@ func TestRewrite(t *testing.T) {
if expected, actual := tc.expect.URL.RequestURI(), tc.input.URL.RequestURI(); expected != actual {
t.Errorf("Test %d: Expected URL.RequestURI()='%s' but got '%s'", i, expected, actual)
}
+ if expected, actual := tc.expect.URL.Fragment, tc.input.URL.Fragment; expected != actual {
+ t.Errorf("Test %d: Expected URL.Fragment='%s' but got '%s'", i, expected, actual)
+ }
}
}