From d55c3b31ebb77df65cc052dbddc137cbe07b297e Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 11 Jun 2020 16:19:07 -0600 Subject: caddyhttp: Add client cert SAN placeholders --- modules/caddyhttp/replacer.go | 64 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 4 deletions(-) (limited to 'modules/caddyhttp/replacer.go') diff --git a/modules/caddyhttp/replacer.go b/modules/caddyhttp/replacer.go index 37b53d2..b3c7b5d 100644 --- a/modules/caddyhttp/replacer.go +++ b/modules/caddyhttp/replacer.go @@ -28,6 +28,7 @@ import ( "net" "net/http" "net/textproto" + "net/url" "path" "strconv" "strings" @@ -163,7 +164,7 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo if strings.HasPrefix(key, reqHostLabelsReplPrefix) { idxStr := key[len(reqHostLabelsReplPrefix):] idx, err := strconv.Atoi(idxStr) - if err != nil { + if err != nil || idx < 0 { return "", false } reqHost, _, err := net.SplitHostPort(req.Host) @@ -171,9 +172,6 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo reqHost = req.Host // OK; assume there was no port } hostLabels := strings.Split(reqHost, ".") - if idx < 0 { - return "", false - } if idx > len(hostLabels) { return "", true } @@ -245,6 +243,64 @@ func getReqTLSReplacement(req *http.Request, key string) (interface{}, bool) { return nil, false } + // subject alternate names (SANs) + if strings.HasPrefix(field, "client.san.") { + field = field[len("client.san."):] + var fieldName string + var fieldValue interface{} + switch { + case strings.HasPrefix(field, "dns_names"): + fieldName = "dns_names" + fieldValue = cert.DNSNames + case strings.HasPrefix(field, "emails"): + fieldName = "emails" + fieldValue = cert.EmailAddresses + case strings.HasPrefix(field, "ips"): + fieldName = "ips" + fieldValue = cert.IPAddresses + case strings.HasPrefix(field, "uris"): + fieldName = "uris" + fieldValue = cert.URIs + default: + return nil, false + } + field = field[len(fieldName):] + + // if no index was specified, return the whole list + if field == "" { + return fieldValue, true + } + if len(field) < 2 || field[0] != '.' { + return nil, false + } + field = field[1:] // trim '.' between field name and index + + // get the numeric index + idx, err := strconv.Atoi(field) + if err != nil || idx < 0 { + return nil, false + } + + // access the indexed element and return it + switch v := fieldValue.(type) { + case []string: + if idx >= len(v) { + return nil, true + } + return v[idx], true + case []net.IP: + if idx >= len(v) { + return nil, true + } + return v[idx], true + case []*url.URL: + if idx >= len(v) { + return nil, true + } + return v[idx], true + } + } + switch field { case "client.fingerprint": return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true -- cgit v1.2.3