summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/http2listener.go
blob: 51b356a7779c66756cb193ed644a7fac86736306 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package caddyhttp

import (
	"context"
	"crypto/tls"
	weakrand "math/rand"
	"net"
	"net/http"
	"sync/atomic"
	"time"

	"golang.org/x/net/http2"
)

// http2Listener wraps the listener to solve the following problems:
// 1. server h2 natively without using h2c hack when listener handles tls connection but
// don't return *tls.Conn
// 2. graceful shutdown. the shutdown logic is copied from stdlib http.Server, it's an extra maintenance burden but
// whatever, the shutdown logic maybe extracted to be used with h2c graceful shutdown. http2.Server supports graceful shutdown
// sending GO_AWAY frame to connected clients, but doesn't track connection status. It requires explicit call of http2.ConfigureServer
type http2Listener struct {
	cnt uint64
	net.Listener
	server   *http.Server
	h2server *http2.Server
}

type connectionStateConn interface {
	net.Conn
	ConnectionState() tls.ConnectionState
}

func (h *http2Listener) Accept() (net.Conn, error) {
	for {
		conn, err := h.Listener.Accept()
		if err != nil {
			return nil, err
		}

		if csc, ok := conn.(connectionStateConn); ok {
			// *tls.Conn will return empty string because it's only populated after handshake is complete
			if csc.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS {
				go h.serveHttp2(csc)
				continue
			}
		}

		return conn, nil
	}
}

func (h *http2Listener) serveHttp2(csc connectionStateConn) {
	atomic.AddUint64(&h.cnt, 1)
	h.runHook(csc, http.StateNew)
	defer func() {
		csc.Close()
		atomic.AddUint64(&h.cnt, ^uint64(0))
		h.runHook(csc, http.StateClosed)
	}()
	h.h2server.ServeConn(csc, &http2.ServeConnOpts{
		Context:    h.server.ConnContext(context.Background(), csc),
		BaseConfig: h.server,
		Handler:    h.server.Handler,
	})
}

const shutdownPollIntervalMax = 500 * time.Millisecond

func (h *http2Listener) Shutdown(ctx context.Context) error {
	pollIntervalBase := time.Millisecond
	nextPollInterval := func() time.Duration {
		// Add 10% jitter.
		//nolint:gosec
		interval := pollIntervalBase + time.Duration(weakrand.Intn(int(pollIntervalBase/10)))
		// Double and clamp for next time.
		pollIntervalBase *= 2
		if pollIntervalBase > shutdownPollIntervalMax {
			pollIntervalBase = shutdownPollIntervalMax
		}
		return interval
	}

	timer := time.NewTimer(nextPollInterval())
	defer timer.Stop()
	for {
		if atomic.LoadUint64(&h.cnt) == 0 {
			return nil
		}
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-timer.C:
			timer.Reset(nextPollInterval())
		}
	}
}

func (h *http2Listener) runHook(conn net.Conn, state http.ConnState) {
	if h.server.ConnState != nil {
		h.server.ConnState(conn, state)
	}
}