summaryrefslogtreecommitdiff
path: root/modules/caddyhttp
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp')
-rw-r--r--modules/caddyhttp/app.go5
-rw-r--r--modules/caddyhttp/server.go50
2 files changed, 54 insertions, 1 deletions
diff --git a/modules/caddyhttp/app.go b/modules/caddyhttp/app.go
index 3edc5b2..0253521 100644
--- a/modules/caddyhttp/app.go
+++ b/modules/caddyhttp/app.go
@@ -178,7 +178,9 @@ func (app *App) Provision(ctx caddy.Context) error {
}
// prepare each server
+ oldContext := ctx.Context
for srvName, srv := range app.Servers {
+ ctx.Context = context.WithValue(oldContext, ServerCtxKey, srv)
srv.name = srvName
srv.tlsApp = app.tlsApp
srv.events = eventsAppIface.(*caddyevents.App)
@@ -293,7 +295,7 @@ func (app *App) Provision(ctx caddy.Context) error {
srv.IdleTimeout = defaultIdleTimeout
}
}
-
+ ctx.Context = oldContext
return nil
}
@@ -365,6 +367,7 @@ func (app *App) Start() error {
// this TLS config is used by the std lib to choose the actual TLS config for connections
// by looking through the connection policies to find the first one that matches
tlsCfg := srv.TLSConnPolicies.TLSConfig(app.ctx)
+ srv.configureServer(srv.server)
// enable H2C if configured
if srv.protocol("h2c") {
diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go
index ca5a594..d9fe077 100644
--- a/modules/caddyhttp/server.go
+++ b/modules/caddyhttp/server.go
@@ -176,6 +176,11 @@ type Server struct {
shutdownAt time.Time
shutdownAtMu *sync.RWMutex
+
+ // registered callback functions
+ connStateFuncs []func(net.Conn, http.ConnState)
+ connContextFuncs []func(ctx context.Context, c net.Conn) context.Context
+ onShutdownFuncs []func()
}
// ServeHTTP is the entry point for all HTTP requests.
@@ -513,6 +518,51 @@ func (s *Server) serveHTTP3(hostport string, tlsCfg *tls.Config) error {
return nil
}
+// configureServer applies/binds the registered callback functions to the server.
+func (s *Server) configureServer(server *http.Server) {
+ for _, f := range s.connStateFuncs {
+ if server.ConnState != nil {
+ baseConnStateFunc := server.ConnState
+ server.ConnState = func(conn net.Conn, state http.ConnState) {
+ baseConnStateFunc(conn, state)
+ f(conn, state)
+ }
+ } else {
+ server.ConnState = f
+ }
+ }
+
+ for _, f := range s.connContextFuncs {
+ if server.ConnContext != nil {
+ baseConnContextFunc := server.ConnContext
+ server.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
+ return f(baseConnContextFunc(ctx, c), c)
+ }
+ } else {
+ server.ConnContext = f
+ }
+ }
+
+ for _, f := range s.onShutdownFuncs {
+ server.RegisterOnShutdown(f)
+ }
+}
+
+// RegisterConnState registers f to be invoked on s.ConnState.
+func (s *Server) RegisterConnState(f func(net.Conn, http.ConnState)) {
+ s.connStateFuncs = append(s.connStateFuncs, f)
+}
+
+// RegisterConnContext registers f to be invoked as part of s.ConnContext.
+func (s *Server) RegisterConnContext(f func(ctx context.Context, c net.Conn) context.Context) {
+ s.connContextFuncs = append(s.connContextFuncs, f)
+}
+
+// RegisterOnShutdown registers f to be invoked on server shutdown.
+func (s *Server) RegisterOnShutdown(f func()) {
+ s.onShutdownFuncs = append(s.onShutdownFuncs, f)
+}
+
// HTTPErrorConfig determines how to handle errors
// from the HTTP handlers.
type HTTPErrorConfig struct {