diff options
-rw-r--r-- | modules/caddyhttp/celmatcher.go | 32 | ||||
-rw-r--r-- | modules/caddyhttp/replacer.go | 54 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/reverseproxy.go | 7 | ||||
-rw-r--r-- | modules/caddyhttp/rewrite/rewrite.go | 17 | ||||
-rw-r--r-- | modules/caddyhttp/server.go | 11 | ||||
-rw-r--r-- | replacer.go | 85 | ||||
-rw-r--r-- | replacer_test.go | 87 |
7 files changed, 184 insertions, 109 deletions
diff --git a/modules/caddyhttp/celmatcher.go b/modules/caddyhttp/celmatcher.go index a78bd9c..84565e4 100644 --- a/modules/caddyhttp/celmatcher.go +++ b/modules/caddyhttp/celmatcher.go @@ -86,7 +86,7 @@ func (m *MatchExpression) Provision(_ caddy.Context) error { decls.NewFunction(placeholderFuncName, decls.NewOverload(placeholderFuncName+"_httpRequest_string", []*exprpb.Type{httpRequestObjectType, decls.String}, - decls.String)), + decls.Any)), ), cel.CustomTypeAdapter(celHTTPRequestTypeAdapter{}), ext.Strings(), @@ -210,7 +210,35 @@ func caddyPlaceholderFunc(lhs, rhs ref.Val) ref.Val { repl := celReq.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) val, _ := repl.Get(string(phStr)) - return types.String(val) + // TODO: this is... kinda awful and underwhelming, how can we expand CEL's type system more easily? + switch v := val.(type) { + case string: + return types.String(v) + case fmt.Stringer: + return types.String(v.String()) + case error: + return types.NewErr(v.Error()) + case int: + return types.Int(v) + case int32: + return types.Int(v) + case int64: + return types.Int(v) + case uint: + return types.Int(v) + case uint32: + return types.Int(v) + case uint64: + return types.Int(v) + case float32: + return types.Double(v) + case float64: + return types.Double(v) + case bool: + return types.Bool(v) + default: + return types.String(fmt.Sprintf("%+v", v)) + } } // Interface guards diff --git a/modules/caddyhttp/replacer.go b/modules/caddyhttp/replacer.go index c9c7522..f55bb0a 100644 --- a/modules/caddyhttp/replacer.go +++ b/modules/caddyhttp/replacer.go @@ -31,7 +31,7 @@ import ( ) func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) { - httpVars := func(key string) (string, bool) { + httpVars := func(key string) (interface{}, bool) { if req != nil { // query string parameters if strings.HasPrefix(key, reqURIQueryReplPrefix) { @@ -62,7 +62,7 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo } } - // http.request.tls. + // http.request.tls.* if strings.HasPrefix(key, reqTLSReplPrefix) { return getReqTLSReplacement(req, key) } @@ -182,21 +182,10 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo if strings.HasPrefix(key, varsReplPrefix) { varName := key[len(varsReplPrefix):] tbl := req.Context().Value(VarsCtxKey).(map[string]interface{}) - raw, ok := tbl[varName] - if !ok { - // variables can be dynamic, so always return true - // even when it may not be set; treat as empty - return "", true - } - // do our best to convert it to a string efficiently - switch val := raw.(type) { - case string: - return val, true - case fmt.Stringer: - return val.String(), true - default: - return fmt.Sprintf("%s", val), true - } + raw, _ := tbl[varName] + // variables can be dynamic, so always return true + // even when it may not be set; treat as empty then + return raw, true } } @@ -211,19 +200,19 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo } } - return "", false + return nil, false } repl.Map(httpVars) } -func getReqTLSReplacement(req *http.Request, key string) (string, bool) { +func getReqTLSReplacement(req *http.Request, key string) (interface{}, bool) { if req == nil || req.TLS == nil { - return "", false + return nil, false } if len(key) < len(reqTLSReplPrefix) { - return "", false + return nil, false } field := strings.ToLower(key[len(reqTLSReplPrefix):]) @@ -231,20 +220,20 @@ func getReqTLSReplacement(req *http.Request, key string) (string, bool) { if strings.HasPrefix(field, "client.") { cert := getTLSPeerCert(req.TLS) if cert == nil { - return "", false + return nil, false } switch field { case "client.fingerprint": return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true case "client.issuer": - return cert.Issuer.String(), true + return cert.Issuer, true case "client.serial": - return fmt.Sprintf("%x", cert.SerialNumber), true + return cert.SerialNumber, true case "client.subject": - return cert.Subject.String(), true + return cert.Subject, true default: - return "", false + return nil, false } } @@ -254,22 +243,15 @@ func getReqTLSReplacement(req *http.Request, key string) (string, bool) { case "cipher_suite": return tls.CipherSuiteName(req.TLS.CipherSuite), true case "resumed": - if req.TLS.DidResume { - return "true", true - } - return "false", true + return req.TLS.DidResume, true case "proto": return req.TLS.NegotiatedProtocol, true case "proto_mutual": - if req.TLS.NegotiatedProtocolIsMutual { - return "true", true - } - return "false", true + return req.TLS.NegotiatedProtocolIsMutual, true case "server_name": return req.TLS.ServerName, true - default: - return "", false } + return nil, false } // getTLSPeerCert retrieves the first peer certificate from a TLS session. diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 4ac50ac..6d0d441 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -24,7 +24,6 @@ import ( "net" "net/http" "regexp" - "strconv" "strings" "sync" "time" @@ -328,9 +327,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address) repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host) repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port) - repl.Set("http.reverse_proxy.upstream.requests", strconv.Itoa(upstream.Host.NumRequests())) - repl.Set("http.reverse_proxy.upstream.max_requests", strconv.Itoa(upstream.MaxRequests)) - repl.Set("http.reverse_proxy.upstream.fails", strconv.Itoa(upstream.Host.Fails())) + repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests()) + repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests) + repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails()) // mutate request headers according to this upstream; // because we're in a retry loop, we have to copy diff --git a/modules/caddyhttp/rewrite/rewrite.go b/modules/caddyhttp/rewrite/rewrite.go index ad05486..3ba63c4 100644 --- a/modules/caddyhttp/rewrite/rewrite.go +++ b/modules/caddyhttp/rewrite/rewrite.go @@ -15,8 +15,10 @@ package rewrite import ( + "fmt" "net/http" "net/url" + "strconv" "strings" "github.com/caddyserver/caddy/v2" @@ -208,11 +210,22 @@ func buildQueryString(qs string, repl *caddy.Replacer) string { // consume the component and write the result comp := qs[:end] - comp, _ = repl.ReplaceFunc(comp, func(name, val string) (string, error) { + comp, _ = repl.ReplaceFunc(comp, func(name string, val interface{}) (interface{}, error) { if name == "http.request.uri.query" && wroteVal { return val, nil // already escaped } - return url.QueryEscape(val), nil + var valStr string + switch v := val.(type) { + case string: + valStr = v + case fmt.Stringer: + valStr = v.String() + case int: + valStr = strconv.Itoa(v) + default: + valStr = fmt.Sprintf("%+v", v) + } + return url.QueryEscape(valStr), nil }) if end < len(qs) { end++ // consume delimiter diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index c7780b0..72a67a7 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -21,7 +21,6 @@ import ( "net" "net/http" "net/url" - "strconv" "strings" "time" @@ -166,9 +165,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer func() { latency := time.Since(start) - repl.Set("http.response.status", strconv.Itoa(wrec.Status())) - repl.Set("http.response.size", strconv.Itoa(wrec.Size())) - repl.Set("http.response.latency", latency.String()) + repl.Set("http.response.status", wrec.Status()) + repl.Set("http.response.size", wrec.Size()) + repl.Set("http.response.latency", latency) logger := accLog if s.Logs != nil && s.Logs.LoggerNames != nil { @@ -360,9 +359,9 @@ func (*HTTPErrorConfig) WithError(r *http.Request, err error) *http.Request { // add error values to the replacer repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) - repl.Set("http.error", err.Error()) + repl.Set("http.error", err) if handlerErr, ok := err.(HandlerError); ok { - repl.Set("http.error.status_code", strconv.Itoa(handlerErr.StatusCode)) + repl.Set("http.error.status_code", handlerErr.StatusCode) repl.Set("http.error.status_text", http.StatusText(handlerErr.StatusCode)) repl.Set("http.error.trace", handlerErr.Trace) repl.Set("http.error.id", handlerErr.ID) diff --git a/replacer.go b/replacer.go index 4ff578c..eac2080 100644 --- a/replacer.go +++ b/replacer.go @@ -27,7 +27,7 @@ import ( // NewReplacer returns a new Replacer. func NewReplacer() *Replacer { rep := &Replacer{ - static: make(map[string]string), + static: make(map[string]interface{}), } rep.providers = []ReplacerFunc{ globalDefaultReplacements, @@ -41,7 +41,7 @@ func NewReplacer() *Replacer { // use NewReplacer to make one. type Replacer struct { providers []ReplacerFunc - static map[string]string + static map[string]interface{} } // Map adds mapFunc to the list of value providers. @@ -51,19 +51,19 @@ func (r *Replacer) Map(mapFunc ReplacerFunc) { } // Set sets a custom variable to a static value. -func (r *Replacer) Set(variable, value string) { +func (r *Replacer) Set(variable string, value interface{}) { r.static[variable] = value } // Get gets a value from the replacer. It returns // the value and whether the variable was known. -func (r *Replacer) Get(variable string) (string, bool) { +func (r *Replacer) Get(variable string) (interface{}, bool) { for _, mapFunc := range r.providers { if val, ok := mapFunc(variable); ok { return val, true } } - return "", false + return nil, false } // Delete removes a variable with a static value @@ -73,9 +73,9 @@ func (r *Replacer) Delete(variable string) { } // fromStatic provides values from r.static. -func (r *Replacer) fromStatic(key string) (val string, ok bool) { - val, ok = r.static[key] - return +func (r *Replacer) fromStatic(key string) (interface{}, bool) { + val, ok := r.static[key] + return val, ok } // ReplaceOrErr is like ReplaceAll, but any placeholders @@ -102,10 +102,9 @@ func (r *Replacer) ReplaceAll(input, empty string) string { return out } -// ReplaceFunc calls ReplaceAll efficiently replaces placeholders in input with -// their values. All placeholders are replaced in the output -// whether they are recognized or not. Values that are empty -// string will be substituted with empty. +// ReplaceFunc is the same as ReplaceAll, but calls f for every +// replacement to be made, in case f wants to change or inspect +// the replacement. func (r *Replacer) ReplaceFunc(input string, f ReplacementFunc) (string, error) { return r.replace(input, "", true, false, false, f) } @@ -125,7 +124,7 @@ func (r *Replacer) replace(input, empty string, // iterate the input to find each placeholder var lastWriteCursor int - + scan: for i := 0; i < len(input); i++ { @@ -169,9 +168,8 @@ scan: return "", fmt.Errorf("unrecognized placeholder %s%s%s", string(phOpen), key, string(phClose)) } else if !treatUnknownAsEmpty { - // if treatUnknownAsEmpty is true, we'll - // handle an empty val later; so only - // continue otherwise + // if treatUnknownAsEmpty is true, we'll handle an empty + // val later; so only continue otherwise lastWriteCursor = i continue } @@ -186,9 +184,12 @@ scan: } } + // convert val to a string as efficiently as possible + valStr := toString(val) + // write the value; if it's empty, either return // an error or write a default value - if val == "" { + if valStr == "" { if errOnEmpty { return "", fmt.Errorf("evaluated placeholder %s%s%s is empty", string(phOpen), key, string(phClose)) @@ -196,7 +197,7 @@ scan: sb.WriteString(empty) } } else { - sb.WriteString(val) + sb.WriteString(valStr) } // advance cursor to end of placeholder @@ -210,14 +211,54 @@ scan: return sb.String(), nil } +func toString(val interface{}) string { + switch v := val.(type) { + case nil: + return "" + case string: + return v + case fmt.Stringer: + return v.String() + case byte: + return string(v) + case []byte: + return string(v) + case []rune: + return string(v) + case int: + return strconv.Itoa(v) + case int32: + return strconv.Itoa(int(v)) + case int64: + return strconv.Itoa(int(v)) + case uint: + return strconv.Itoa(int(v)) + case uint32: + return strconv.Itoa(int(v)) + case uint64: + return strconv.Itoa(int(v)) + case float32: + return strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + return strconv.FormatFloat(v, 'f', -1, 64) + case bool: + if v { + return "true" + } + return "false" + default: + return fmt.Sprintf("%+v", v) + } +} + // ReplacerFunc is a function that returns a replacement // for the given key along with true if the function is able // to service that key (even if the value is blank). If the // function does not recognize the key, false should be // returned. -type ReplacerFunc func(key string) (val string, ok bool) +type ReplacerFunc func(key string) (interface{}, bool) -func globalDefaultReplacements(key string) (string, bool) { +func globalDefaultReplacements(key string) (interface{}, bool) { // check environment variable const envPrefix = "env." if strings.HasPrefix(key, envPrefix) { @@ -241,7 +282,7 @@ func globalDefaultReplacements(key string) (string, bool) { return strconv.Itoa(nowFunc().Year()), true } - return "", false + return nil, false } // ReplacementFunc is a function that is called when a @@ -250,7 +291,7 @@ func globalDefaultReplacements(key string) (string, bool) { // will be the replacement, and returns the value that // will actually be the replacement, or an error. Note // that errors are sometimes ignored by replacers. -type ReplacementFunc func(variable, val string) (string, error) +type ReplacementFunc func(variable string, val interface{}) (interface{}, error) // nowFunc is a variable so tests can change it // in order to obtain a deterministic time. diff --git a/replacer_test.go b/replacer_test.go index a48917a..d6ac033 100644 --- a/replacer_test.go +++ b/replacer_test.go @@ -173,41 +173,12 @@ func TestReplacer(t *testing.T) { } } -func BenchmarkReplacer(b *testing.B) { - type testCase struct { - name, input, empty string - } - - rep := testReplacer() - - for _, bm := range []testCase{ - { - name: "no placeholder", - input: `simple string`, - }, - { - name: "placeholder", - input: `{"json": "object"}`, - }, - { - name: "escaped placeholder", - input: `\{"json": \{"nested": "{bar}"\}\}`, - }, - } { - b.Run(bm.name, func(b *testing.B) { - for i := 0; i < b.N; i++ { - rep.ReplaceAll(bm.input, bm.empty) - } - }) - } -} - func TestReplacerSet(t *testing.T) { rep := testReplacer() for _, tc := range []struct { variable string - value string + value interface{} }{ { variable: "test1", @@ -218,6 +189,10 @@ func TestReplacerSet(t *testing.T) { value: "123", }, { + variable: "numbers", + value: 123.456, + }, + { variable: "äöü", value: "öö_äü", }, @@ -252,7 +227,7 @@ func TestReplacerSet(t *testing.T) { // test if all keys are still there (by length) length := len(rep.static) - if len(rep.static) != 7 { + if len(rep.static) != 8 { t.Errorf("Expected length '%v' got '%v'", 7, length) } } @@ -261,7 +236,7 @@ func TestReplacerReplaceKnown(t *testing.T) { rep := Replacer{ providers: []ReplacerFunc{ // split our possible vars to two functions (to test if both functions are called) - func(key string) (val string, ok bool) { + func(key string) (val interface{}, ok bool) { switch key { case "test1": return "val1", true @@ -275,7 +250,7 @@ func TestReplacerReplaceKnown(t *testing.T) { return "NOOO", false } }, - func(key string) (val string, ok bool) { + func(key string) (val interface{}, ok bool) { switch key { case "1": return "test-123", true @@ -331,7 +306,7 @@ func TestReplacerReplaceKnown(t *testing.T) { func TestReplacerDelete(t *testing.T) { rep := Replacer{ - static: map[string]string{ + static: map[string]interface{}{ "key1": "val1", "key2": "val2", "key3": "val3", @@ -366,10 +341,10 @@ func TestReplacerMap(t *testing.T) { rep := testReplacer() for i, tc := range []ReplacerFunc{ - func(key string) (val string, ok bool) { + func(key string) (val interface{}, ok bool) { return "", false }, - func(key string) (val string, ok bool) { + func(key string) (val interface{}, ok bool) { return "", false }, } { @@ -434,12 +409,50 @@ func TestReplacerNew(t *testing.T) { } } } +} +func BenchmarkReplacer(b *testing.B) { + type testCase struct { + name, input, empty string + } + + rep := testReplacer() + rep.Set("str", "a string") + rep.Set("int", 123.456) + + for _, bm := range []testCase{ + { + name: "no placeholder", + input: `simple string`, + }, + { + name: "string replacement", + input: `str={str}`, + }, + { + name: "int replacement", + input: `int={int}`, + }, + { + name: "placeholder", + input: `{"json": "object"}`, + }, + { + name: "escaped placeholder", + input: `\{"json": \{"nested": "{bar}"\}\}`, + }, + } { + b.Run(bm.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + rep.ReplaceAll(bm.input, bm.empty) + } + }) + } } func testReplacer() Replacer { return Replacer{ providers: make([]ReplacerFunc, 0), - static: make(map[string]string), + static: make(map[string]interface{}), } } |