// Copyright 2015 Matthew Holt and The Caddy Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package caddyhttp import ( "bytes" "context" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rsa" "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/asn1" "encoding/base64" "encoding/pem" "fmt" "io" "net" "net/http" "net/netip" "net/textproto" "net/url" "path" "strconv" "strings" "time" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddytls" "github.com/google/uuid" ) // NewTestReplacer creates a replacer for an http.Request // for use in tests that are not in this package func NewTestReplacer(req *http.Request) *caddy.Replacer { repl := caddy.NewReplacer() ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, repl) *req = *req.WithContext(ctx) addHTTPVarsToReplacer(repl, req, nil) return repl } func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) { SetVar(req.Context(), "start_time", time.Now()) SetVar(req.Context(), "uuid", new(requestID)) httpVars := func(key string) (any, bool) { if req != nil { // query string parameters if strings.HasPrefix(key, reqURIQueryReplPrefix) { vals := req.URL.Query()[key[len(reqURIQueryReplPrefix):]] // always return true, since the query param might // be present only in some requests return strings.Join(vals, ","), true } // request header fields if strings.HasPrefix(key, reqHeaderReplPrefix) { field := key[len(reqHeaderReplPrefix):] vals := req.Header[textproto.CanonicalMIMEHeaderKey(field)] // always return true, since the header field might // be present only in some requests return strings.Join(vals, ","), true } // cookies if strings.HasPrefix(key, reqCookieReplPrefix) { name := key[len(reqCookieReplPrefix):] for _, cookie := range req.Cookies() { if strings.EqualFold(name, cookie.Name) { // always return true, since the cookie might // be present only in some requests return cookie.Value, true } } } // http.request.tls.* if strings.HasPrefix(key, reqTLSReplPrefix) { return getReqTLSReplacement(req, key) } switch key { case "http.request.method": return req.Method, true case "http.request.scheme": if req.TLS != nil { return "https", true } return "http", true case "http.request.proto": return req.Proto, true case "http.request.host": host, _, err := net.SplitHostPort(req.Host) if err != nil { return req.Host, true // OK; there probably was no port } return host, true case "http.request.port": _, port, _ := net.SplitHostPort(req.Host) if portNum, err := strconv.Atoi(port); err == nil { return portNum, true } return port, true case "http.request.hostport": return req.Host, true case "http.request.remote": return req.RemoteAddr, true case "http.request.remote.host": host, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { return req.RemoteAddr, true } return host, true case "http.request.remote.port": _, port, _ := net.SplitHostPort(req.RemoteAddr) if portNum, err := strconv.Atoi(port); err == nil { return portNum, true } return port, true // current URI, including any internal rewrites case "http.request.uri": return req.URL.RequestURI(), true case "http.request.uri.path": return req.URL.Path, true case "http.request.uri.path.file": _, file := path.Split(req.URL.Path) return file, true case "http.request.uri.path.dir": dir, _ := path.Split(req.URL.Path) return dir, true case "http.request.uri.path.file.base": return strings.TrimSuffix(path.Base(req.URL.Path), path.Ext(req.URL.Path)), true case "http.request.uri.path.file.ext": return path.Ext(req.URL.Path), true case "http.request.uri.query": return req.URL.RawQuery, true case "http.request.duration": start := GetVar(req.Context(), "start_time").(time.Time) return time.Since(start), true case "http.request.duration_ms": start := GetVar(req.Context(), "start_time").(time.Time) return time.Since(start).Seconds() * 1e3, true // multiply seconds to preserve decimal (see #4666) case "http.request.uuid": id := GetVar(req.Context(), "uuid").(*requestID) return id.String(), true case "http.request.body": if req.Body == nil { return "", true } // normally net/http will close the body for us, but since we // are replacing it with a fake one, we have to ensure we close // the real body ourselves when we're done defer req.Body.Close() // read the request body into a buffer (can't pool because we // don't know its lifetime and would have to make a copy anyway) buf := new(bytes.Buffer) _, _ = io.Copy(buf, req.Body) // can't handle error, so just ignore it req.Body = io.NopCloser(buf) // replace real body with buffered data return buf.String(), true // original request, before any internal changes case "http.request.orig_method": or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request) return or.Method, true case "http.request.orig_uri": or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request) return or.RequestURI, true case "http.request.orig_uri.path": or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request) return or.URL.Path, true case "http.request.orig_uri.path.file": or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request) _, file := path.Split(or.URL.Path) return file, true case "http.request.orig_uri.path.dir": or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request) dir, _ := path.Split(or.URL.Path) return dir, true case "http.request.orig_uri.query": or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request) return or.URL.RawQuery, true } // remote IP range/prefix (e.g. keep top 24 bits of 1.2.3.4 => "1.2.3.0/24") // syntax: "/V4,V6" where V4 = IPv4 bits, and V6 = IPv6 bits; if no comma, then same bit length used for both // (EXPERIMENTAL) if strings.HasPrefix(key, "http.request.remote.host/") { host, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { host = req.RemoteAddr // assume no port, I guess? } addr, err := netip.ParseAddr(host) if err != nil { return host, true // not an IP address } // extract the bits from the end of the placeholder (start after "/") then split on "," bitsBoth := key[strings.Index(key, "/")+1:] ipv4BitsStr, ipv6BitsStr, cutOK := strings.Cut(bitsBoth, ",") bitsStr := ipv4BitsStr if addr.Is6() && cutOK { bitsStr = ipv6BitsStr } // convert to integer then compute prefix bits, err := strconv.Atoi(bitsStr) if err != nil { return "", true } prefix, err := addr.Prefix(bits) if err != nil { return "", true } return prefix.String(), true } // hostname labels if strings.HasPrefix(key, reqHostLabelsReplPrefix) { idxStr := key[len(reqHostLabelsReplPrefix):] idx, err := strconv.Atoi(idxStr) if err != nil || idx < 0 { return "", false } reqHost, _, err := net.SplitHostPort(req.Host) if err != nil { reqHost = req.Host // OK; assume there was no port } hostLabels := strings.Split(reqHost, ".") if idx >= len(hostLabels) { return "", true } return hostLabels[len(hostLabels)-idx-1], true } // path parts if strings.HasPrefix(key, reqURIPathReplPrefix) { idxStr := key[len(reqURIPathReplPrefix):] idx, err := strconv.Atoi(idxStr) if err != nil { return "", false } pathParts := strings.Split(req.URL.Path, "/") if len(pathParts) > 0 && pathParts[0] == "" { pathParts = pathParts[1:] } if idx < 0 { return "", false } if idx >= len(pathParts) { return "", true } return pathParts[idx], true } // orig uri path parts if strings.HasPrefix(key, reqOrigURIPathReplPrefix) { idxStr := key[len(reqOrigURIPathReplPrefix):] idx, err := strconv.Atoi(idxStr) if err != nil { return "", false } or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request) pathParts := strings.Split(or.URL.Path, "/") if len(pathParts) > 0 && pathParts[0] == "" { pathParts = pathParts[1:] } if idx < 0 { return "", false } if idx >= len(pathParts) { return "", true } return pathParts[idx], true } // middleware variables if strings.HasPrefix(key, varsReplPrefix) { varName := key[len(varsReplPrefix):] raw := GetVar(req.Context(), varName) // variables can be dynamic, so always return true // even when it may not be set; treat as empty then return raw, true } } if w != nil { // response header fields if strings.HasPrefix(key, respHeaderReplPrefix) { field := key[len(respHeaderReplPrefix):] vals := w.Header()[textproto.CanonicalMIMEHeaderKey(field)] // always return true, since the header field might // be present only in some responses return strings.Join(vals, ","), true } } switch { case key == "http.shutting_down": server := req.Context().Value(ServerCtxKey).(*Server) server.shutdownAtMu.RLock() defer server.shutdownAtMu.RUnlock() return !server.shutdownAt.IsZero(), true case key == "http.time_until_shutdown": server := req.Context().Value(ServerCtxKey).(*Server) server.shutdownAtMu.RLock() defer server.shutdownAtMu.RUnlock() if server.shutdownAt.IsZero() { return nil, true } return time.Until(server.shutdownAt), true } return nil, false } repl.Map(httpVars) } func getReqTLSReplacement(req *http.Request, key string) (any, bool) { if req == nil || req.TLS == nil { return nil, false } if len(key) < len(reqTLSReplPrefix) { return nil, false } field := strings.ToLower(key[len(reqTLSReplPrefix):]) if strings.HasPrefix(field, "client.") { cert := getTLSPeerCert(req.TLS) if cert == nil { return nil, false } // subject alternate names (SANs) if strings.HasPrefix(field, "client.san.") { field = field[len("client.san."):] var fieldName string var fieldValue any switch { case strings.HasPrefix(field, "dns_names"): fieldName = "dns_names" fieldValue = cert.DNSNames case strings.HasPrefix(field, "emails"): fieldName = "emails" fieldValue = cert.EmailAddresses case strings.HasPrefix(field, "ips"): fieldName = "ips" fieldValue = cert.IPAddresses case strings.HasPrefix(field, "uris"): fieldName = "uris" fieldValue = cert.URIs default: return nil, false } field = field[len(fieldName):] // if no index was specified, return the whole list if field == "" { return fieldValue, true } if len(field) < 2 || field[0] != '.' { return nil, false } field = field[1:] // trim '.' between field name and index // get the numeric index idx, err := strconv.Atoi(field) if err != nil || idx < 0 { return nil, false } // access the indexed element and return it switch v := fieldValue.(type) { case []string: if idx >= len(v) { return nil, true } return v[idx], true case []net.IP: if idx >= len(v) { return nil, true } return v[idx], true case []*url.URL: if idx >= len(v) { return nil, true } return v[idx], true } } switch field { case "client.fingerprint": return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true case "client.public_key", "client.public_key_sha256": if cert.PublicKey == nil { return nil, true } pubKeyBytes, err := marshalPublicKey(cert.PublicKey) if err != nil { return nil, true } if strings.HasSuffix(field, "_sha256") { return fmt.Sprintf("%x", sha256.Sum256(pubKeyBytes)), true } return fmt.Sprintf("%x", pubKeyBytes), true case "client.issuer": return cert.Issuer, true case "client.serial": return cert.SerialNumber, true case "client.subject": return cert.Subject, true case "client.certificate_pem": block := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw} return pem.EncodeToMemory(&block), true case "client.certificate_der_base64": return base64.StdEncoding.EncodeToString(cert.Raw), true default: return nil, false } } switch field { case "version": return caddytls.ProtocolName(req.TLS.Version), true case "cipher_suite": return tls.CipherSuiteName(req.TLS.CipherSuite), true case "resumed": return req.TLS.DidResume, true case "proto": return req.TLS.NegotiatedProtocol, true case "proto_mutual": // req.TLS.NegotiatedProtocolIsMutual is deprecated - it's always true. return true, true case "server_name": return req.TLS.ServerName, true } return nil, false } // marshalPublicKey returns the byte encoding of pubKey. func marshalPublicKey(pubKey any) ([]byte, error) { switch key := pubKey.(type) { case *rsa.PublicKey: return asn1.Marshal(key) case *ecdsa.PublicKey: return elliptic.Marshal(key.Curve, key.X, key.Y), nil case ed25519.PublicKey: return key, nil } return nil, fmt.Errorf("unrecognized public key type: %T", pubKey) } // getTLSPeerCert retrieves the first peer certificate from a TLS session. // Returns nil if no peer cert is in use. func getTLSPeerCert(cs *tls.ConnectionState) *x509.Certificate { if len(cs.PeerCertificates) == 0 { return nil } return cs.PeerCertificates[0] } type requestID struct { value string } // Lazy generates UUID string or return cached value if present func (rid *requestID) String() string { if rid.value == "" { if id, err := uuid.NewRandom(); err == nil { rid.value = id.String() } } return rid.value } const ( reqCookieReplPrefix = "http.request.cookie." reqHeaderReplPrefix = "http.request.header." reqHostLabelsReplPrefix = "http.request.host.labels." reqTLSReplPrefix = "http.request.tls." reqURIPathReplPrefix = "http.request.uri.path." reqURIQueryReplPrefix = "http.request.uri.query." respHeaderReplPrefix = "http.response.header." varsReplPrefix = "http.vars." reqOrigURIPathReplPrefix = "http.request.orig_uri.path." )