diff options
-rw-r--r-- | caddytest/integration/reverseproxy_test.go | 54 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/healthchecks.go | 3 |
2 files changed, 56 insertions, 1 deletions
diff --git a/caddytest/integration/reverseproxy_test.go b/caddytest/integration/reverseproxy_test.go index e838d86..e6aff87 100644 --- a/caddytest/integration/reverseproxy_test.go +++ b/caddytest/integration/reverseproxy_test.go @@ -436,3 +436,57 @@ func TestReverseProxyHealthCheckUnixSocket(t *testing.T) { tester.AssertGetResponse("http://localhost:9080/", 200, "Hello, World!") } + +func TestReverseProxyHealthCheckUnixSocketWithoutPort(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + tester := caddytest.NewTester(t) + f, err := ioutil.TempFile("", "*.sock") + if err != nil { + t.Errorf("failed to create TempFile: %s", err) + return + } + // a hack to get a file name within a valid path to use as socket + socketName := f.Name() + os.Remove(f.Name()) + + server := http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if strings.HasPrefix(req.URL.Path, "/health") { + w.Write([]byte("ok")) + return + } + w.Write([]byte("Hello, World!")) + }), + } + + unixListener, err := net.Listen("unix", socketName) + if err != nil { + t.Errorf("failed to listen on the socket: %s", err) + return + } + go server.Serve(unixListener) + t.Cleanup(func() { + server.Close() + }) + runtime.Gosched() // Allow other goroutines to run + + tester.InitServer(fmt.Sprintf(` + { + http_port 9080 + https_port 9443 + } + http://localhost:9080 { + reverse_proxy { + to unix/%s + + health_path /health + health_interval 2s + health_timeout 5s + } + } + `, socketName), "caddyfile") + + tester.AssertGetResponse("http://localhost:9080/", 200, "Hello, World!") +} diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 6f65866..8d5bd77 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -189,13 +189,14 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { return } hostAddr := addr.JoinHostPort(0) + dialAddr := hostAddr if addr.IsUnixNetwork() { // this will be used as the Host portion of a http.Request URL, and // paths to socket files would produce an error when creating URL, // so use a fake Host value instead; unix sockets are usually local hostAddr = "localhost" } - err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: hostAddr}, hostAddr, upstream.Host) + err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: dialAddr}, hostAddr, upstream.Host) if err != nil { h.HealthChecks.Active.logger.Error("active health check failed", zap.String("address", hostAddr), |