From e40bbecb16d196d2d700a9484e53c11b64dfe8d9 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 7 May 2019 09:56:13 -0600 Subject: Rough implementation of auto HTTP->HTTPS redirects Also added GracePeriod for server shutdowns --- modules/caddyhttp/caddyhttp.go | 169 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 158 insertions(+), 11 deletions(-) (limited to 'modules/caddyhttp/caddyhttp.go') diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go index 0731fea..e309053 100644 --- a/modules/caddyhttp/caddyhttp.go +++ b/modules/caddyhttp/caddyhttp.go @@ -30,13 +30,15 @@ func init() { } type httpModuleConfig struct { - Servers map[string]*httpServerConfig `json:"servers"` + HTTPPort int `json:"http_port"` + HTTPSPort int `json:"https_port"` + GracePeriod caddy2.Duration `json:"grace_period"` + Servers map[string]*httpServerConfig `json:"servers"` servers []*http.Server } func (hc *httpModuleConfig) Provision() error { - // TODO: Either prevent overlapping listeners on different servers, or combine them into one for _, srv := range hc.Servers { err := srv.Routes.setup() if err != nil { @@ -51,6 +53,27 @@ func (hc *httpModuleConfig) Provision() error { return nil } +func (hc *httpModuleConfig) Validate() error { + // each server must use distinct listener addresses + lnAddrs := make(map[string]string) + for srvName, srv := range hc.Servers { + for _, addr := range srv.Listen { + netw, expanded, err := parseListenAddr(addr) + if err != nil { + return fmt.Errorf("invalid listener address '%s': %v", addr, err) + } + for _, a := range expanded { + if sn, ok := lnAddrs[netw+a]; ok { + return fmt.Errorf("listener address repeated: %s (already claimed by server '%s')", a, sn) + } + lnAddrs[netw+a] = srvName + } + } + } + + return nil +} + func (hc *httpModuleConfig) Start(handle caddy2.Handle) error { err := hc.automaticHTTPS(handle) if err != nil { @@ -83,7 +106,12 @@ func (hc *httpModuleConfig) Start(handle caddy2.Handle) error { } // enable TLS - if len(srv.TLSConnPolicies) > 0 { + httpPort := hc.HTTPPort + if httpPort == 0 { + httpPort = DefaultHTTPPort + } + _, port, _ := net.SplitHostPort(addr) + if len(srv.TLSConnPolicies) > 0 && port != strconv.Itoa(httpPort) { tlsCfg, err := srv.TLSConnPolicies.TLSConfig(handle) if err != nil { return fmt.Errorf("%s/%s: making TLS configuration: %v", network, addr, err) @@ -100,9 +128,16 @@ func (hc *httpModuleConfig) Start(handle caddy2.Handle) error { return nil } +// Stop gracefully shuts down the HTTP server. func (hc *httpModuleConfig) Stop() error { + ctx := context.Background() + if hc.GracePeriod > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(hc.GracePeriod)) + defer cancel() + } for _, s := range hc.servers { - err := s.Shutdown(context.Background()) // TODO + err := s.Shutdown(ctx) if err != nil { return err } @@ -117,6 +152,9 @@ func (hc *httpModuleConfig) automaticHTTPS(handle caddy2.Handle) error { } tlsApp := tlsAppIface.(*caddytls.TLS) + lnAddrMap := make(map[string]struct{}) + var redirRoutes routeList + for srvName, srv := range hc.Servers { srv.tlsApp = tlsApp @@ -157,13 +195,93 @@ func (hc *httpModuleConfig) automaticHTTPS(handle caddy2.Handle) error { {ALPN: defaultALPN}, } - // TODO: create HTTP->HTTPS redirects + if srv.DisableAutoHTTPSRedir { + continue + } + + // create HTTP->HTTPS redirects + for _, addr := range srv.Listen { + netw, host, port, err := splitListenAddr(addr) + if err != nil { + return fmt.Errorf("%s: invalid listener address: %v", srvName, addr) + } + httpRedirLnAddr := joinListenAddr(netw, host, strconv.Itoa(hc.HTTPPort)) + lnAddrMap[httpRedirLnAddr] = struct{}{} + + if parts := strings.SplitN(port, "-", 2); len(parts) == 2 { + port = parts[0] + } + redirTo := "https://{request.host}" + + httpsPort := hc.HTTPSPort + if httpsPort == 0 { + httpsPort = DefaultHTTPSPort + } + if port != strconv.Itoa(httpsPort) { + redirTo += ":" + port + } + redirTo += "{request.uri}" + + redirRoutes = append(redirRoutes, serverRoute{ + matchers: []RouteMatcher{ + matchProtocol("http"), + matchHost(domains), + }, + responder: Static{ + StatusCode: http.StatusTemporaryRedirect, // TODO: use permanent redirect instead + Headers: http.Header{ + "Location": []string{redirTo}, + "Connection": []string{"close"}, + }, + Close: true, + }, + }) + } + } + } + + if len(lnAddrMap) > 0 { + var lnAddrs []string + mapLoop: + for addr := range lnAddrMap { + netw, addrs, err := parseListenAddr(addr) + if err != nil { + continue + } + for _, a := range addrs { + if hc.listenerTaken(netw, a) { + continue mapLoop + } + } + lnAddrs = append(lnAddrs, addr) + } + hc.Servers["auto_https_redirects"] = &httpServerConfig{ + Listen: lnAddrs, + Routes: redirRoutes, + DisableAutoHTTPS: true, } } return nil } +func (hc *httpModuleConfig) listenerTaken(network, address string) bool { + for _, srv := range hc.Servers { + for _, addr := range srv.Listen { + netw, addrs, err := parseListenAddr(addr) + if err != nil || netw != network { + continue + } + for _, a := range addrs { + if a == address { + return true + } + } + } + } + return false +} + var defaultALPN = []string{"h2", "http/1.1"} type httpServerConfig struct { @@ -204,6 +322,7 @@ func (s httpServerConfig) ServeHTTP(w http.ResponseWriter, r *http.Request) { // it can be accessed by error handlers c := context.WithValue(r.Context(), ErrorCtxKey, err) r = r.WithContext(c) + // TODO: add error values to Replacer if len(s.Errors.Routes) == 0 { // TODO: implement a default error handler? @@ -284,13 +403,11 @@ func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) error { var emptyHandler HandlerFunc = func(w http.ResponseWriter, r *http.Request) error { return nil } func parseListenAddr(a string) (network string, addrs []string, err error) { - network = "tcp" - if idx := strings.Index(a, "/"); idx >= 0 { - network = strings.ToLower(strings.TrimSpace(a[:idx])) - a = a[idx+1:] - } var host, port string - host, port, err = net.SplitHostPort(a) + network, host, port, err = splitListenAddr(a) + if network == "" { + network = "tcp" + } if err != nil { return } @@ -317,6 +434,27 @@ func parseListenAddr(a string) (network string, addrs []string, err error) { return } +func splitListenAddr(a string) (network, host, port string, err error) { + if idx := strings.Index(a, "/"); idx >= 0 { + network = strings.ToLower(strings.TrimSpace(a[:idx])) + a = a[idx+1:] + } + host, port, err = net.SplitHostPort(a) + return +} + +func joinListenAddr(network, host, port string) string { + var a string + if network != "" { + a = network + "/" + } + a += host + if port != "" { + a += ":" + port + } + return a +} + type middlewareResponseWriter struct { *ResponseWriterWrapper allowWrites bool @@ -336,7 +474,16 @@ func (mrw middlewareResponseWriter) Write(b []byte) (int, error) { return mrw.ResponseWriterWrapper.Write(b) } +// ReplacerCtxKey is the context key for the request's replacer. const ReplacerCtxKey caddy2.CtxKey = "replacer" +const ( + // DefaultHTTPPort is the default port for HTTP. + DefaultHTTPPort = 80 + + // DefaultHTTPSPort is the default port for HTTPS. + DefaultHTTPSPort = 443 +) + // Interface guards var _ HTTPInterfaces = middlewareResponseWriter{} -- cgit v1.2.3