summaryrefslogtreecommitdiff
path: root/admin.go
diff options
context:
space:
mode:
authorMatthew Holt <mholt@users.noreply.github.com>2022-02-15 12:08:12 -0700
committerMatthew Holt <mholt@users.noreply.github.com>2022-02-15 12:08:12 -0700
commit40b54434f3cdb804ef10eee0ba5d8d6c390e93d4 (patch)
treebdc3ff917805a32ed97eeff88941893b71ca2766 /admin.go
parent1d0425b26f1fbb797d7bc10e3740dc031410d01f (diff)
admin: Enforce and refactor origin checking
Using URLs seems a little cleaner and more correct cf: https://caddy.community/t/protect-admin-endpoint/15114 (This used to work. Something must have changed recently.)
Diffstat (limited to 'admin.go')
-rw-r--r--admin.go85
1 files changed, 58 insertions, 27 deletions
diff --git a/admin.go b/admin.go
index 0a7b933..157ae95 100644
--- a/admin.go
+++ b/admin.go
@@ -42,6 +42,7 @@ import (
"github.com/caddyserver/certmagic"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
+ "go.uber.org/zap/zapcore"
)
// AdminConfig configures Caddy's API endpoint, which is used
@@ -192,6 +193,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin
} else {
muxWrap.enforceHost = !addr.isWildcardInterface()
muxWrap.allowedOrigins = admin.allowedOrigins(addr)
+ muxWrap.enforceOrigin = admin.EnforceOrigin
}
addRouteWithMetrics := func(pattern string, handlerLabel string, h http.Handler) {
@@ -252,7 +254,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin
// will be used as the default origin. If admin.Origins is
// empty, no origins will be allowed, effectively bricking the
// endpoint for non-unix-socket endpoints, but whatever.
-func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string {
+func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []*url.URL {
uniqueOrigins := make(map[string]struct{})
for _, o := range admin.Origins {
uniqueOrigins[o] = struct{}{}
@@ -276,8 +278,23 @@ func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string {
uniqueOrigins[addr.JoinHostPort(0)] = struct{}{}
}
}
- allowed := make([]string, 0, len(uniqueOrigins))
- for origin := range uniqueOrigins {
+ allowed := make([]*url.URL, 0, len(uniqueOrigins))
+ for originStr := range uniqueOrigins {
+ var origin *url.URL
+ if strings.Contains(originStr, "://") {
+ var err error
+ origin, err = url.Parse(originStr)
+ if err != nil {
+ continue
+ }
+ origin.Path = ""
+ origin.RawPath = ""
+ origin.Fragment = ""
+ origin.RawFragment = ""
+ origin.RawQuery = ""
+ } else {
+ origin = &url.URL{Host: originStr}
+ }
allowed = append(allowed, origin)
}
return allowed
@@ -358,7 +375,7 @@ func replaceLocalAdminServer(cfg *Config) error {
adminLogger.Info("admin endpoint started",
zap.String("address", addr.String()),
zap.Bool("enforce_origin", adminConfig.EnforceOrigin),
- zap.Strings("origins", handler.allowedOrigins))
+ zap.Array("origins", loggableURLArray(handler.allowedOrigins)))
if !handler.enforceHost {
adminLogger.Warn("admin endpoint on open interface; host checking disabled",
@@ -650,10 +667,10 @@ type AdminRoute struct {
type adminHandler struct {
mux *http.ServeMux
- // security for local/plaintext) endpoint, on by default
+ // security for local/plaintext endpoint
enforceOrigin bool
enforceHost bool
- allowedOrigins []string
+ allowedOrigins []*url.URL
// security for remote/encrypted endpoint
remoteControl *RemoteAdmin
@@ -779,8 +796,8 @@ func (h adminHandler) handleError(w http.ResponseWriter, r *http.Request, err er
// rebinding attacks.
func (h adminHandler) checkHost(r *http.Request) error {
var allowed bool
- for _, allowedHost := range h.allowedOrigins {
- if r.Host == allowedHost {
+ for _, allowedOrigin := range h.allowedOrigins {
+ if r.Host == allowedOrigin.Host {
allowed = true
break
}
@@ -799,43 +816,45 @@ func (h adminHandler) checkHost(r *http.Request) error {
// sites from issuing requests to our listener. It
// returns the origin that was obtained from r.
func (h adminHandler) checkOrigin(r *http.Request) (string, error) {
- origin := h.getOriginHost(r)
- if origin == "" {
- return origin, APIError{
+ originStr, origin := h.getOrigin(r)
+ if origin == nil {
+ return "", APIError{
HTTPStatus: http.StatusForbidden,
- Err: fmt.Errorf("missing required Origin header"),
+ Err: fmt.Errorf("required Origin header is missing or invalid"),
}
}
if !h.originAllowed(origin) {
- return origin, APIError{
+ return "", APIError{
HTTPStatus: http.StatusForbidden,
- Err: fmt.Errorf("client is not allowed to access from origin %s", origin),
+ Err: fmt.Errorf("client is not allowed to access from origin '%s'", originStr),
}
}
- return origin, nil
+ return origin.String(), nil
}
-func (h adminHandler) getOriginHost(r *http.Request) string {
+func (h adminHandler) getOrigin(r *http.Request) (string, *url.URL) {
origin := r.Header.Get("Origin")
if origin == "" {
origin = r.Header.Get("Referer")
}
originURL, err := url.Parse(origin)
- if err == nil && originURL.Host != "" {
- origin = originURL.Host
- }
- return origin
+ if err != nil {
+ return origin, nil
+ }
+ originURL.Path = ""
+ originURL.RawPath = ""
+ originURL.Fragment = ""
+ originURL.RawFragment = ""
+ originURL.RawQuery = ""
+ return origin, originURL
}
-func (h adminHandler) originAllowed(origin string) bool {
+func (h adminHandler) originAllowed(origin *url.URL) bool {
for _, allowedOrigin := range h.allowedOrigins {
- originCopy := origin
- if !strings.Contains(allowedOrigin, "://") {
- // no scheme specified, so allow both
- originCopy = strings.TrimPrefix(originCopy, "http://")
- originCopy = strings.TrimPrefix(originCopy, "https://")
+ if allowedOrigin.Scheme != "" && origin.Scheme != allowedOrigin.Scheme {
+ continue
}
- if originCopy == allowedOrigin {
+ if origin.Host == allowedOrigin.Host {
return true
}
}
@@ -1189,6 +1208,18 @@ func decodeBase64DERCert(certStr string) (*x509.Certificate, error) {
return x509.ParseCertificate(derBytes)
}
+type loggableURLArray []*url.URL
+
+func (ua loggableURLArray) MarshalLogArray(enc zapcore.ArrayEncoder) error {
+ if ua == nil {
+ return nil
+ }
+ for _, u := range ua {
+ enc.AppendString(u.String())
+ }
+ return nil
+}
+
var (
// DefaultAdminListen is the address for the local admin
// listener, if none is specified at startup.