package caddyhttp import ( "context" "crypto/tls" weakrand "math/rand" "net" "net/http" "sync/atomic" "time" "golang.org/x/net/http2" ) // http2Listener wraps the listener to solve the following problems: // 1. server h2 natively without using h2c hack when listener handles tls connection but // don't return *tls.Conn // 2. graceful shutdown. the shutdown logic is copied from stdlib http.Server, it's an extra maintenance burden but // whatever, the shutdown logic maybe extracted to be used with h2c graceful shutdown. http2.Server supports graceful shutdown // sending GO_AWAY frame to connected clients, but doesn't track connection status. It requires explicit call of http2.ConfigureServer type http2Listener struct { cnt uint64 net.Listener server *http.Server h2server *http2.Server } type connectionStateConn interface { net.Conn ConnectionState() tls.ConnectionState } func (h *http2Listener) Accept() (net.Conn, error) { for { conn, err := h.Listener.Accept() if err != nil { return nil, err } if csc, ok := conn.(connectionStateConn); ok { // *tls.Conn will return empty string because it's only populated after handshake is complete if csc.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS { go h.serveHttp2(csc) continue } } return conn, nil } } func (h *http2Listener) serveHttp2(csc connectionStateConn) { atomic.AddUint64(&h.cnt, 1) h.runHook(csc, http.StateNew) defer func() { csc.Close() atomic.AddUint64(&h.cnt, ^uint64(0)) h.runHook(csc, http.StateClosed) }() h.h2server.ServeConn(csc, &http2.ServeConnOpts{ Context: h.server.ConnContext(context.Background(), csc), BaseConfig: h.server, Handler: h.server.Handler, }) } const shutdownPollIntervalMax = 500 * time.Millisecond func (h *http2Listener) Shutdown(ctx context.Context) error { pollIntervalBase := time.Millisecond nextPollInterval := func() time.Duration { // Add 10% jitter. //nolint:gosec interval := pollIntervalBase + time.Duration(weakrand.Intn(int(pollIntervalBase/10))) // Double and clamp for next time. pollIntervalBase *= 2 if pollIntervalBase > shutdownPollIntervalMax { pollIntervalBase = shutdownPollIntervalMax } return interval } timer := time.NewTimer(nextPollInterval()) defer timer.Stop() for { if atomic.LoadUint64(&h.cnt) == 0 { return nil } select { case <-ctx.Done(): return ctx.Err() case <-timer.C: timer.Reset(nextPollInterval()) } } } func (h *http2Listener) runHook(conn net.Conn, state http.ConnState) { if h.server.ConnState != nil { h.server.ConnState(conn, state) } }