From d8d87a378f37d31cfe6502cc66ac3c95fc799489 Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Tue, 11 Apr 2023 01:05:02 +0800 Subject: caddyhttp: Serve http2 when listener wrapper doesn't return *tls.Conn (#4929) * Serve http2 when listener wrapper doesn't return *tls.Conn * close conn when h2server serveConn returns * merge from upstream * rebase from latest * run New and Closed ConnState hook for h2 conns * go fmt * fix lint * Add comments * reorder import --- modules/caddyhttp/app.go | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) (limited to 'modules/caddyhttp/app.go') diff --git a/modules/caddyhttp/app.go b/modules/caddyhttp/app.go index ceb62f4..53b5782 100644 --- a/modules/caddyhttp/app.go +++ b/modules/caddyhttp/app.go @@ -357,6 +357,14 @@ func (app *App) Start() error { MaxHeaderBytes: srv.MaxHeaderBytes, Handler: srv, ErrorLog: serverLogger, + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, ConnCtxKey, c) + }, + } + h2server := &http2.Server{ + NewWriteScheduler: func() http2.WriteScheduler { + return http2.NewPriorityWriteScheduler(nil) + }, } // disable HTTP/2, which we enabled by default during provisioning @@ -378,6 +386,9 @@ func (app *App) Start() error { } } } + } else { + //nolint:errcheck + http2.ConfigureServer(srv.server, h2server) } // this TLS config is used by the std lib to choose the actual TLS config for connections @@ -387,9 +398,6 @@ func (app *App) Start() error { // enable H2C if configured if srv.protocol("h2c") { - h2server := &http2.Server{ - IdleTimeout: time.Duration(srv.IdleTimeout), - } srv.server.Handler = h2c.NewHandler(srv, h2server) } @@ -456,6 +464,17 @@ func (app *App) Start() error { ln = srv.listenerWrappers[i].WrapListener(ln) } + // handle http2 if use tls listener wrapper + if useTLS { + http2lnWrapper := &http2Listener{ + Listener: ln, + server: srv.server, + h2server: h2server, + } + srv.h2listeners = append(srv.h2listeners, http2lnWrapper) + ln = http2lnWrapper + } + // if binding to port 0, the OS chooses a port for us; // but the user won't know the port unless we print it if !listenAddr.IsUnixNetwork() && listenAddr.StartPort == 0 && listenAddr.EndPort == 0 { @@ -585,12 +604,25 @@ func (app *App) Stop() error { zap.Strings("addresses", server.Listen)) } } + stopH2Listener := func(server *Server) { + defer finishedShutdown.Done() + startedShutdown.Done() + + for i, s := range server.h2listeners { + if err := s.Shutdown(ctx); err != nil { + app.logger.Error("http2 listener shutdown", + zap.Error(err), + zap.Int("index", i)) + } + } + } for _, server := range app.Servers { - startedShutdown.Add(2) - finishedShutdown.Add(2) + startedShutdown.Add(3) + finishedShutdown.Add(3) go stopServer(server) go stopH3Server(server) + go stopH2Listener(server) } // block until all the goroutines have been run by the scheduler; -- cgit v1.2.3