From ee7c92ec9b57c671c9091ff993b1a24251020c25 Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Mon, 14 Nov 2022 11:38:02 -0500 Subject: reverseproxy: Mask the WS close message when we're the client (#5199) * reverseproxy: Mask the WS close message when we're the client * weakrand * Bump golangci-lint version so path ignores work on Windows * gofmt * ugh, gofmt everything, I guess --- .../caddyhttp/reverseproxy/selectionpolicies.go | 3 +- modules/caddyhttp/reverseproxy/streaming.go | 118 ++++++++++++++++++--- 2 files changed, 104 insertions(+), 17 deletions(-) (limited to 'modules/caddyhttp/reverseproxy') diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go index 5fc7136..2de830c 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -418,7 +418,8 @@ func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http } // UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax: -// lb_policy cookie [ []] +// +// lb_policy cookie [ []] // // By default name is `lb` func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 01d865d..834cb9e 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -20,12 +20,13 @@ package reverseproxy import ( "context" - "encoding/binary" "io" + weakrand "math/rand" "mime" "net/http" "sync" "time" + "unsafe" "go.uber.org/zap" "golang.org/x/net/http/httpguts" @@ -103,16 +104,19 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite // with the backend, are both closed in the event of a server shutdown. This // is done by registering them. We also try to gracefully close connections // we recognize as websockets. - gracefulClose := func(conn io.ReadWriteCloser) func() error { + // We need to make sure the client connection messages (i.e. to upstream) + // are masked, so we need to know whether the connection is considered the + // server or the client side of the proxy. + gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error { if isWebsocket(req) { return func() error { - return writeCloseControl(conn) + return writeCloseControl(conn, isClient) } } return nil } - deleteFrontConn := h.registerConnection(conn, gracefulClose(conn)) - deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn)) + deleteFrontConn := h.registerConnection(conn, gracefulClose(conn, false)) + deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn, true)) defer deleteFrontConn() defer deleteBackConn() @@ -248,27 +252,108 @@ func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func // writeCloseControl sends a best-effort Close control message to the given // WebSocket connection. Thanks to @pascaldekloe who provided inspiration // from his simple implementation of this I was able to learn from at: -// github.com/pascaldekloe/websocket. -func writeCloseControl(conn io.Writer) error { +// github.com/pascaldekloe/websocket. Further work for handling masking +// taken from github.com/gorilla/websocket. +func writeCloseControl(conn io.Writer, isClient bool) error { + // Sources: // https://github.com/pascaldekloe/websocket/blob/32050af67a5d/websocket.go#L119 + // https://github.com/gorilla/websocket/blob/v1.5.0/conn.go#L413 + // For now, we're not using a reason. We might later, though. + // The code handling the reason is left in var reason string // max 123 bytes (control frame payload limit is 125; status code takes 2) - const goingAway uint16 = 1001 - // TODO: we might need to ensure we are the exclusive writer by this point (io.Copy is stopped)? - var writeBuf [127]byte const closeMessage = 8 - const finalBit = 1 << 7 - writeBuf[0] = closeMessage | finalBit - writeBuf[1] = byte(len(reason) + 2) - binary.BigEndian.PutUint16(writeBuf[2:4], goingAway) - copy(writeBuf[4:], reason) + const finalBit = 1 << 7 // Frame header byte 0 bits from Section 5.2 of RFC 6455 + const maskBit = 1 << 7 // Frame header byte 1 bits from Section 5.2 of RFC 6455 + const goingAwayUpper uint8 = 1001 >> 8 + const goingAwayLower uint8 = 1001 & 0xff + + b0 := byte(closeMessage) | finalBit + b1 := byte(len(reason) + 2) + if isClient { + b1 |= maskBit + } + + buf := make([]byte, 0, 127) + buf = append(buf, b0, b1) + msgLength := 4 + len(reason) + + // Both branches below append the "going away" code and reason + appendMessage := func(buf []byte) []byte { + buf = append(buf, goingAwayUpper, goingAwayLower) + buf = append(buf, []byte(reason)...) + return buf + } + + // When we're the client, we need to mask the message as per + // https://www.rfc-editor.org/rfc/rfc6455#section-5.3 + if isClient { + key := newMaskKey() + buf = append(buf, key[:]...) + msgLength += len(key) + buf = appendMessage(buf) + maskBytes(key, 0, buf[2+len(key):]) + } else { + buf = appendMessage(buf) + } // simply best-effort, but return error for logging purposes - _, err := conn.Write(writeBuf[:4+len(reason)]) + // TODO: we might need to ensure we are the exclusive writer by this point (io.Copy is stopped)? + _, err := conn.Write(buf[:msgLength]) return err } +// Copied from https://github.com/gorilla/websocket/blob/v1.5.0/mask.go +func maskBytes(key [4]byte, pos int, b []byte) int { + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} + +// Copied from https://github.com/gorilla/websocket/blob/v1.5.0/conn.go#L184 +func newMaskKey() [4]byte { + n := weakrand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} +} + // isWebsocket returns true if r looks to be an upgrade request for WebSockets. // It is a fairly naive check. func isWebsocket(r *http.Request) bool { @@ -364,3 +449,4 @@ var streamingBufPool = sync.Pool{ } const defaultBufferSize = 32 * 1024 +const wordSize = int(unsafe.Sizeof(uintptr(0))) -- cgit v1.2.3