diff options
Diffstat (limited to 'modules/caddyhttp')
-rw-r--r-- | modules/caddyhttp/app.go | 5 | ||||
-rw-r--r-- | modules/caddyhttp/server.go | 50 |
2 files changed, 54 insertions, 1 deletions
diff --git a/modules/caddyhttp/app.go b/modules/caddyhttp/app.go index 3edc5b2..0253521 100644 --- a/modules/caddyhttp/app.go +++ b/modules/caddyhttp/app.go @@ -178,7 +178,9 @@ func (app *App) Provision(ctx caddy.Context) error { } // prepare each server + oldContext := ctx.Context for srvName, srv := range app.Servers { + ctx.Context = context.WithValue(oldContext, ServerCtxKey, srv) srv.name = srvName srv.tlsApp = app.tlsApp srv.events = eventsAppIface.(*caddyevents.App) @@ -293,7 +295,7 @@ func (app *App) Provision(ctx caddy.Context) error { srv.IdleTimeout = defaultIdleTimeout } } - + ctx.Context = oldContext return nil } @@ -365,6 +367,7 @@ func (app *App) Start() error { // this TLS config is used by the std lib to choose the actual TLS config for connections // by looking through the connection policies to find the first one that matches tlsCfg := srv.TLSConnPolicies.TLSConfig(app.ctx) + srv.configureServer(srv.server) // enable H2C if configured if srv.protocol("h2c") { diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index ca5a594..d9fe077 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -176,6 +176,11 @@ type Server struct { shutdownAt time.Time shutdownAtMu *sync.RWMutex + + // registered callback functions + connStateFuncs []func(net.Conn, http.ConnState) + connContextFuncs []func(ctx context.Context, c net.Conn) context.Context + onShutdownFuncs []func() } // ServeHTTP is the entry point for all HTTP requests. @@ -513,6 +518,51 @@ func (s *Server) serveHTTP3(hostport string, tlsCfg *tls.Config) error { return nil } +// configureServer applies/binds the registered callback functions to the server. +func (s *Server) configureServer(server *http.Server) { + for _, f := range s.connStateFuncs { + if server.ConnState != nil { + baseConnStateFunc := server.ConnState + server.ConnState = func(conn net.Conn, state http.ConnState) { + baseConnStateFunc(conn, state) + f(conn, state) + } + } else { + server.ConnState = f + } + } + + for _, f := range s.connContextFuncs { + if server.ConnContext != nil { + baseConnContextFunc := server.ConnContext + server.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + return f(baseConnContextFunc(ctx, c), c) + } + } else { + server.ConnContext = f + } + } + + for _, f := range s.onShutdownFuncs { + server.RegisterOnShutdown(f) + } +} + +// RegisterConnState registers f to be invoked on s.ConnState. +func (s *Server) RegisterConnState(f func(net.Conn, http.ConnState)) { + s.connStateFuncs = append(s.connStateFuncs, f) +} + +// RegisterConnContext registers f to be invoked as part of s.ConnContext. +func (s *Server) RegisterConnContext(f func(ctx context.Context, c net.Conn) context.Context) { + s.connContextFuncs = append(s.connContextFuncs, f) +} + +// RegisterOnShutdown registers f to be invoked on server shutdown. +func (s *Server) RegisterOnShutdown(f func()) { + s.onShutdownFuncs = append(s.onShutdownFuncs, f) +} + // HTTPErrorConfig determines how to handle errors // from the HTTP handlers. type HTTPErrorConfig struct { |