summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/caddyhttp/celmatcher.go32
-rw-r--r--modules/caddyhttp/replacer.go54
-rw-r--r--modules/caddyhttp/reverseproxy/reverseproxy.go7
-rw-r--r--modules/caddyhttp/rewrite/rewrite.go17
-rw-r--r--modules/caddyhttp/server.go11
-rw-r--r--replacer.go85
-rw-r--r--replacer_test.go87
7 files changed, 184 insertions, 109 deletions
diff --git a/modules/caddyhttp/celmatcher.go b/modules/caddyhttp/celmatcher.go
index a78bd9c..84565e4 100644
--- a/modules/caddyhttp/celmatcher.go
+++ b/modules/caddyhttp/celmatcher.go
@@ -86,7 +86,7 @@ func (m *MatchExpression) Provision(_ caddy.Context) error {
decls.NewFunction(placeholderFuncName,
decls.NewOverload(placeholderFuncName+"_httpRequest_string",
[]*exprpb.Type{httpRequestObjectType, decls.String},
- decls.String)),
+ decls.Any)),
),
cel.CustomTypeAdapter(celHTTPRequestTypeAdapter{}),
ext.Strings(),
@@ -210,7 +210,35 @@ func caddyPlaceholderFunc(lhs, rhs ref.Val) ref.Val {
repl := celReq.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
val, _ := repl.Get(string(phStr))
- return types.String(val)
+ // TODO: this is... kinda awful and underwhelming, how can we expand CEL's type system more easily?
+ switch v := val.(type) {
+ case string:
+ return types.String(v)
+ case fmt.Stringer:
+ return types.String(v.String())
+ case error:
+ return types.NewErr(v.Error())
+ case int:
+ return types.Int(v)
+ case int32:
+ return types.Int(v)
+ case int64:
+ return types.Int(v)
+ case uint:
+ return types.Int(v)
+ case uint32:
+ return types.Int(v)
+ case uint64:
+ return types.Int(v)
+ case float32:
+ return types.Double(v)
+ case float64:
+ return types.Double(v)
+ case bool:
+ return types.Bool(v)
+ default:
+ return types.String(fmt.Sprintf("%+v", v))
+ }
}
// Interface guards
diff --git a/modules/caddyhttp/replacer.go b/modules/caddyhttp/replacer.go
index c9c7522..f55bb0a 100644
--- a/modules/caddyhttp/replacer.go
+++ b/modules/caddyhttp/replacer.go
@@ -31,7 +31,7 @@ import (
)
func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) {
- httpVars := func(key string) (string, bool) {
+ httpVars := func(key string) (interface{}, bool) {
if req != nil {
// query string parameters
if strings.HasPrefix(key, reqURIQueryReplPrefix) {
@@ -62,7 +62,7 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
}
}
- // http.request.tls.
+ // http.request.tls.*
if strings.HasPrefix(key, reqTLSReplPrefix) {
return getReqTLSReplacement(req, key)
}
@@ -182,21 +182,10 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
if strings.HasPrefix(key, varsReplPrefix) {
varName := key[len(varsReplPrefix):]
tbl := req.Context().Value(VarsCtxKey).(map[string]interface{})
- raw, ok := tbl[varName]
- if !ok {
- // variables can be dynamic, so always return true
- // even when it may not be set; treat as empty
- return "", true
- }
- // do our best to convert it to a string efficiently
- switch val := raw.(type) {
- case string:
- return val, true
- case fmt.Stringer:
- return val.String(), true
- default:
- return fmt.Sprintf("%s", val), true
- }
+ raw, _ := tbl[varName]
+ // variables can be dynamic, so always return true
+ // even when it may not be set; treat as empty then
+ return raw, true
}
}
@@ -211,19 +200,19 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
}
}
- return "", false
+ return nil, false
}
repl.Map(httpVars)
}
-func getReqTLSReplacement(req *http.Request, key string) (string, bool) {
+func getReqTLSReplacement(req *http.Request, key string) (interface{}, bool) {
if req == nil || req.TLS == nil {
- return "", false
+ return nil, false
}
if len(key) < len(reqTLSReplPrefix) {
- return "", false
+ return nil, false
}
field := strings.ToLower(key[len(reqTLSReplPrefix):])
@@ -231,20 +220,20 @@ func getReqTLSReplacement(req *http.Request, key string) (string, bool) {
if strings.HasPrefix(field, "client.") {
cert := getTLSPeerCert(req.TLS)
if cert == nil {
- return "", false
+ return nil, false
}
switch field {
case "client.fingerprint":
return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true
case "client.issuer":
- return cert.Issuer.String(), true
+ return cert.Issuer, true
case "client.serial":
- return fmt.Sprintf("%x", cert.SerialNumber), true
+ return cert.SerialNumber, true
case "client.subject":
- return cert.Subject.String(), true
+ return cert.Subject, true
default:
- return "", false
+ return nil, false
}
}
@@ -254,22 +243,15 @@ func getReqTLSReplacement(req *http.Request, key string) (string, bool) {
case "cipher_suite":
return tls.CipherSuiteName(req.TLS.CipherSuite), true
case "resumed":
- if req.TLS.DidResume {
- return "true", true
- }
- return "false", true
+ return req.TLS.DidResume, true
case "proto":
return req.TLS.NegotiatedProtocol, true
case "proto_mutual":
- if req.TLS.NegotiatedProtocolIsMutual {
- return "true", true
- }
- return "false", true
+ return req.TLS.NegotiatedProtocolIsMutual, true
case "server_name":
return req.TLS.ServerName, true
- default:
- return "", false
}
+ return nil, false
}
// getTLSPeerCert retrieves the first peer certificate from a TLS session.
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go
index 4ac50ac..6d0d441 100644
--- a/modules/caddyhttp/reverseproxy/reverseproxy.go
+++ b/modules/caddyhttp/reverseproxy/reverseproxy.go
@@ -24,7 +24,6 @@ import (
"net"
"net/http"
"regexp"
- "strconv"
"strings"
"sync"
"time"
@@ -328,9 +327,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address)
repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host)
repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port)
- repl.Set("http.reverse_proxy.upstream.requests", strconv.Itoa(upstream.Host.NumRequests()))
- repl.Set("http.reverse_proxy.upstream.max_requests", strconv.Itoa(upstream.MaxRequests))
- repl.Set("http.reverse_proxy.upstream.fails", strconv.Itoa(upstream.Host.Fails()))
+ repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests())
+ repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests)
+ repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails())
// mutate request headers according to this upstream;
// because we're in a retry loop, we have to copy
diff --git a/modules/caddyhttp/rewrite/rewrite.go b/modules/caddyhttp/rewrite/rewrite.go
index ad05486..3ba63c4 100644
--- a/modules/caddyhttp/rewrite/rewrite.go
+++ b/modules/caddyhttp/rewrite/rewrite.go
@@ -15,8 +15,10 @@
package rewrite
import (
+ "fmt"
"net/http"
"net/url"
+ "strconv"
"strings"
"github.com/caddyserver/caddy/v2"
@@ -208,11 +210,22 @@ func buildQueryString(qs string, repl *caddy.Replacer) string {
// consume the component and write the result
comp := qs[:end]
- comp, _ = repl.ReplaceFunc(comp, func(name, val string) (string, error) {
+ comp, _ = repl.ReplaceFunc(comp, func(name string, val interface{}) (interface{}, error) {
if name == "http.request.uri.query" && wroteVal {
return val, nil // already escaped
}
- return url.QueryEscape(val), nil
+ var valStr string
+ switch v := val.(type) {
+ case string:
+ valStr = v
+ case fmt.Stringer:
+ valStr = v.String()
+ case int:
+ valStr = strconv.Itoa(v)
+ default:
+ valStr = fmt.Sprintf("%+v", v)
+ }
+ return url.QueryEscape(valStr), nil
})
if end < len(qs) {
end++ // consume delimiter
diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go
index c7780b0..72a67a7 100644
--- a/modules/caddyhttp/server.go
+++ b/modules/caddyhttp/server.go
@@ -21,7 +21,6 @@ import (
"net"
"net/http"
"net/url"
- "strconv"
"strings"
"time"
@@ -166,9 +165,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
latency := time.Since(start)
- repl.Set("http.response.status", strconv.Itoa(wrec.Status()))
- repl.Set("http.response.size", strconv.Itoa(wrec.Size()))
- repl.Set("http.response.latency", latency.String())
+ repl.Set("http.response.status", wrec.Status())
+ repl.Set("http.response.size", wrec.Size())
+ repl.Set("http.response.latency", latency)
logger := accLog
if s.Logs != nil && s.Logs.LoggerNames != nil {
@@ -360,9 +359,9 @@ func (*HTTPErrorConfig) WithError(r *http.Request, err error) *http.Request {
// add error values to the replacer
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
- repl.Set("http.error", err.Error())
+ repl.Set("http.error", err)
if handlerErr, ok := err.(HandlerError); ok {
- repl.Set("http.error.status_code", strconv.Itoa(handlerErr.StatusCode))
+ repl.Set("http.error.status_code", handlerErr.StatusCode)
repl.Set("http.error.status_text", http.StatusText(handlerErr.StatusCode))
repl.Set("http.error.trace", handlerErr.Trace)
repl.Set("http.error.id", handlerErr.ID)
diff --git a/replacer.go b/replacer.go
index 4ff578c..eac2080 100644
--- a/replacer.go
+++ b/replacer.go
@@ -27,7 +27,7 @@ import (
// NewReplacer returns a new Replacer.
func NewReplacer() *Replacer {
rep := &Replacer{
- static: make(map[string]string),
+ static: make(map[string]interface{}),
}
rep.providers = []ReplacerFunc{
globalDefaultReplacements,
@@ -41,7 +41,7 @@ func NewReplacer() *Replacer {
// use NewReplacer to make one.
type Replacer struct {
providers []ReplacerFunc
- static map[string]string
+ static map[string]interface{}
}
// Map adds mapFunc to the list of value providers.
@@ -51,19 +51,19 @@ func (r *Replacer) Map(mapFunc ReplacerFunc) {
}
// Set sets a custom variable to a static value.
-func (r *Replacer) Set(variable, value string) {
+func (r *Replacer) Set(variable string, value interface{}) {
r.static[variable] = value
}
// Get gets a value from the replacer. It returns
// the value and whether the variable was known.
-func (r *Replacer) Get(variable string) (string, bool) {
+func (r *Replacer) Get(variable string) (interface{}, bool) {
for _, mapFunc := range r.providers {
if val, ok := mapFunc(variable); ok {
return val, true
}
}
- return "", false
+ return nil, false
}
// Delete removes a variable with a static value
@@ -73,9 +73,9 @@ func (r *Replacer) Delete(variable string) {
}
// fromStatic provides values from r.static.
-func (r *Replacer) fromStatic(key string) (val string, ok bool) {
- val, ok = r.static[key]
- return
+func (r *Replacer) fromStatic(key string) (interface{}, bool) {
+ val, ok := r.static[key]
+ return val, ok
}
// ReplaceOrErr is like ReplaceAll, but any placeholders
@@ -102,10 +102,9 @@ func (r *Replacer) ReplaceAll(input, empty string) string {
return out
}
-// ReplaceFunc calls ReplaceAll efficiently replaces placeholders in input with
-// their values. All placeholders are replaced in the output
-// whether they are recognized or not. Values that are empty
-// string will be substituted with empty.
+// ReplaceFunc is the same as ReplaceAll, but calls f for every
+// replacement to be made, in case f wants to change or inspect
+// the replacement.
func (r *Replacer) ReplaceFunc(input string, f ReplacementFunc) (string, error) {
return r.replace(input, "", true, false, false, f)
}
@@ -125,7 +124,7 @@ func (r *Replacer) replace(input, empty string,
// iterate the input to find each placeholder
var lastWriteCursor int
-
+
scan:
for i := 0; i < len(input); i++ {
@@ -169,9 +168,8 @@ scan:
return "", fmt.Errorf("unrecognized placeholder %s%s%s",
string(phOpen), key, string(phClose))
} else if !treatUnknownAsEmpty {
- // if treatUnknownAsEmpty is true, we'll
- // handle an empty val later; so only
- // continue otherwise
+ // if treatUnknownAsEmpty is true, we'll handle an empty
+ // val later; so only continue otherwise
lastWriteCursor = i
continue
}
@@ -186,9 +184,12 @@ scan:
}
}
+ // convert val to a string as efficiently as possible
+ valStr := toString(val)
+
// write the value; if it's empty, either return
// an error or write a default value
- if val == "" {
+ if valStr == "" {
if errOnEmpty {
return "", fmt.Errorf("evaluated placeholder %s%s%s is empty",
string(phOpen), key, string(phClose))
@@ -196,7 +197,7 @@ scan:
sb.WriteString(empty)
}
} else {
- sb.WriteString(val)
+ sb.WriteString(valStr)
}
// advance cursor to end of placeholder
@@ -210,14 +211,54 @@ scan:
return sb.String(), nil
}
+func toString(val interface{}) string {
+ switch v := val.(type) {
+ case nil:
+ return ""
+ case string:
+ return v
+ case fmt.Stringer:
+ return v.String()
+ case byte:
+ return string(v)
+ case []byte:
+ return string(v)
+ case []rune:
+ return string(v)
+ case int:
+ return strconv.Itoa(v)
+ case int32:
+ return strconv.Itoa(int(v))
+ case int64:
+ return strconv.Itoa(int(v))
+ case uint:
+ return strconv.Itoa(int(v))
+ case uint32:
+ return strconv.Itoa(int(v))
+ case uint64:
+ return strconv.Itoa(int(v))
+ case float32:
+ return strconv.FormatFloat(float64(v), 'f', -1, 32)
+ case float64:
+ return strconv.FormatFloat(v, 'f', -1, 64)
+ case bool:
+ if v {
+ return "true"
+ }
+ return "false"
+ default:
+ return fmt.Sprintf("%+v", v)
+ }
+}
+
// ReplacerFunc is a function that returns a replacement
// for the given key along with true if the function is able
// to service that key (even if the value is blank). If the
// function does not recognize the key, false should be
// returned.
-type ReplacerFunc func(key string) (val string, ok bool)
+type ReplacerFunc func(key string) (interface{}, bool)
-func globalDefaultReplacements(key string) (string, bool) {
+func globalDefaultReplacements(key string) (interface{}, bool) {
// check environment variable
const envPrefix = "env."
if strings.HasPrefix(key, envPrefix) {
@@ -241,7 +282,7 @@ func globalDefaultReplacements(key string) (string, bool) {
return strconv.Itoa(nowFunc().Year()), true
}
- return "", false
+ return nil, false
}
// ReplacementFunc is a function that is called when a
@@ -250,7 +291,7 @@ func globalDefaultReplacements(key string) (string, bool) {
// will be the replacement, and returns the value that
// will actually be the replacement, or an error. Note
// that errors are sometimes ignored by replacers.
-type ReplacementFunc func(variable, val string) (string, error)
+type ReplacementFunc func(variable string, val interface{}) (interface{}, error)
// nowFunc is a variable so tests can change it
// in order to obtain a deterministic time.
diff --git a/replacer_test.go b/replacer_test.go
index a48917a..d6ac033 100644
--- a/replacer_test.go
+++ b/replacer_test.go
@@ -173,41 +173,12 @@ func TestReplacer(t *testing.T) {
}
}
-func BenchmarkReplacer(b *testing.B) {
- type testCase struct {
- name, input, empty string
- }
-
- rep := testReplacer()
-
- for _, bm := range []testCase{
- {
- name: "no placeholder",
- input: `simple string`,
- },
- {
- name: "placeholder",
- input: `{"json": "object"}`,
- },
- {
- name: "escaped placeholder",
- input: `\{"json": \{"nested": "{bar}"\}\}`,
- },
- } {
- b.Run(bm.name, func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- rep.ReplaceAll(bm.input, bm.empty)
- }
- })
- }
-}
-
func TestReplacerSet(t *testing.T) {
rep := testReplacer()
for _, tc := range []struct {
variable string
- value string
+ value interface{}
}{
{
variable: "test1",
@@ -218,6 +189,10 @@ func TestReplacerSet(t *testing.T) {
value: "123",
},
{
+ variable: "numbers",
+ value: 123.456,
+ },
+ {
variable: "äöü",
value: "öö_äü",
},
@@ -252,7 +227,7 @@ func TestReplacerSet(t *testing.T) {
// test if all keys are still there (by length)
length := len(rep.static)
- if len(rep.static) != 7 {
+ if len(rep.static) != 8 {
t.Errorf("Expected length '%v' got '%v'", 7, length)
}
}
@@ -261,7 +236,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
rep := Replacer{
providers: []ReplacerFunc{
// split our possible vars to two functions (to test if both functions are called)
- func(key string) (val string, ok bool) {
+ func(key string) (val interface{}, ok bool) {
switch key {
case "test1":
return "val1", true
@@ -275,7 +250,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
return "NOOO", false
}
},
- func(key string) (val string, ok bool) {
+ func(key string) (val interface{}, ok bool) {
switch key {
case "1":
return "test-123", true
@@ -331,7 +306,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
func TestReplacerDelete(t *testing.T) {
rep := Replacer{
- static: map[string]string{
+ static: map[string]interface{}{
"key1": "val1",
"key2": "val2",
"key3": "val3",
@@ -366,10 +341,10 @@ func TestReplacerMap(t *testing.T) {
rep := testReplacer()
for i, tc := range []ReplacerFunc{
- func(key string) (val string, ok bool) {
+ func(key string) (val interface{}, ok bool) {
return "", false
},
- func(key string) (val string, ok bool) {
+ func(key string) (val interface{}, ok bool) {
return "", false
},
} {
@@ -434,12 +409,50 @@ func TestReplacerNew(t *testing.T) {
}
}
}
+}
+func BenchmarkReplacer(b *testing.B) {
+ type testCase struct {
+ name, input, empty string
+ }
+
+ rep := testReplacer()
+ rep.Set("str", "a string")
+ rep.Set("int", 123.456)
+
+ for _, bm := range []testCase{
+ {
+ name: "no placeholder",
+ input: `simple string`,
+ },
+ {
+ name: "string replacement",
+ input: `str={str}`,
+ },
+ {
+ name: "int replacement",
+ input: `int={int}`,
+ },
+ {
+ name: "placeholder",
+ input: `{"json": "object"}`,
+ },
+ {
+ name: "escaped placeholder",
+ input: `\{"json": \{"nested": "{bar}"\}\}`,
+ },
+ } {
+ b.Run(bm.name, func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ rep.ReplaceAll(bm.input, bm.empty)
+ }
+ })
+ }
}
func testReplacer() Replacer {
return Replacer{
providers: make([]ReplacerFunc, 0),
- static: make(map[string]string),
+ static: make(map[string]interface{}),
}
}