summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy/httptransport.go
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/reverseproxy/httptransport.go')
-rw-r--r--modules/caddyhttp/reverseproxy/httptransport.go95
1 files changed, 85 insertions, 10 deletions
diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go
index ec5d2f2..9290f7e 100644
--- a/modules/caddyhttp/reverseproxy/httptransport.go
+++ b/modules/caddyhttp/reverseproxy/httptransport.go
@@ -28,10 +28,13 @@ import (
"strings"
"time"
- "github.com/caddyserver/caddy/v2"
- "github.com/caddyserver/caddy/v2/modules/caddytls"
+ "github.com/mastercactapus/proxyprotocol"
"go.uber.org/zap"
"golang.org/x/net/http2"
+
+ "github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
+ "github.com/caddyserver/caddy/v2/modules/caddytls"
)
func init() {
@@ -64,6 +67,10 @@ type HTTPTransport struct {
// Maximum number of connections per host. Default: 0 (no limit)
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"`
+ // If non-empty, which PROXY protocol version to send when
+ // connecting to an upstream. Default: off.
+ ProxyProtocol string `json:"proxy_protocol,omitempty"`
+
// How long to wait before timing out trying to connect to
// an upstream. Default: `3s`.
DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
@@ -172,12 +179,19 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
}
}
- // Set up the dialer to pull the correct information from the context
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
- // the proper dialing information should be embedded into the request's context
+ // For unix socket upstreams, we need to recover the dial info from
+ // the request's context, because the Host on the request's URL
+ // will have been modified by directing the request, overwriting
+ // the unix socket filename.
+ // Also, we need to avoid overwriting the address at this point
+ // when not necessary, because http.ProxyFromEnvironment may have
+ // modified the address according to the user's env proxy config.
if dialInfo, ok := GetDialInfo(ctx); ok {
- network = dialInfo.Network
- address = dialInfo.Address
+ if strings.HasPrefix(dialInfo.Network, "unix") {
+ network = dialInfo.Network
+ address = dialInfo.Address
+ }
}
conn, err := dialer.DialContext(ctx, network, address)
@@ -188,8 +202,59 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
return nil, DialError{err}
}
- // if read/write timeouts are configured and this is a TCP connection, enforce the timeouts
- // by wrapping the connection with our own type
+ if h.ProxyProtocol != "" {
+ proxyProtocolInfo, ok := caddyhttp.GetVar(ctx, proxyProtocolInfoVarKey).(ProxyProtocolInfo)
+ if !ok {
+ return nil, fmt.Errorf("failed to get proxy protocol info from context")
+ }
+
+ // The src and dst have to be of the some address family. As we don't know the original
+ // dst address (it's kind of impossible to know) and this address is generelly of very
+ // little interest, we just set it to all zeros.
+ var destIP net.IP
+ switch {
+ case proxyProtocolInfo.AddrPort.Addr().Is4():
+ destIP = net.IPv4zero
+ case proxyProtocolInfo.AddrPort.Addr().Is6():
+ destIP = net.IPv6zero
+ default:
+ return nil, fmt.Errorf("unexpected remote addr type in proxy protocol info")
+ }
+
+ // TODO: We should probably migrate away from net.IP to use netip.Addr,
+ // but due to the upstream dependency, we can't do that yet.
+ switch h.ProxyProtocol {
+ case "v1":
+ header := proxyprotocol.HeaderV1{
+ SrcIP: net.IP(proxyProtocolInfo.AddrPort.Addr().AsSlice()),
+ SrcPort: int(proxyProtocolInfo.AddrPort.Port()),
+ DestIP: destIP,
+ DestPort: 0,
+ }
+ caddyCtx.Logger().Debug("sending proxy protocol header v1", zap.Any("header", header))
+ _, err = header.WriteTo(conn)
+ case "v2":
+ header := proxyprotocol.HeaderV2{
+ Command: proxyprotocol.CmdProxy,
+ Src: &net.TCPAddr{IP: net.IP(proxyProtocolInfo.AddrPort.Addr().AsSlice()), Port: int(proxyProtocolInfo.AddrPort.Port())},
+ Dest: &net.TCPAddr{IP: destIP, Port: 0},
+ }
+ caddyCtx.Logger().Debug("sending proxy protocol header v2", zap.Any("header", header))
+ _, err = header.WriteTo(conn)
+ default:
+ return nil, fmt.Errorf("unexpected proxy protocol version")
+ }
+
+ if err != nil {
+ // identify this error as one that occurred during
+ // dialing, which can be important when trying to
+ // decide whether to retry a request
+ return nil, DialError{err}
+ }
+ }
+
+ // if read/write timeouts are configured and this is a TCP connection,
+ // enforce the timeouts by wrapping the connection with our own type
if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) {
conn = &tcpRWTimeoutConn{
TCPConn: tcpConn,
@@ -203,6 +268,7 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
}
rt := &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
DialContext: dialContext,
MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
@@ -231,6 +297,14 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout)
}
+ // The proxy protocol header can only be sent once right after opening the connection.
+ // So single connection must not be used for multiple requests, which can potentially
+ // come from different clients.
+ if !rt.DisableKeepAlives && h.ProxyProtocol != "" {
+ caddyCtx.Logger().Warn("disabling keepalives, they are incompatible with using PROXY protocol")
+ rt.DisableKeepAlives = true
+ }
+
if h.Compression != nil {
rt.DisableCompression = !*h.Compression
}
@@ -452,10 +526,11 @@ func (t TLSConfig) MakeTLSClientConfig(ctx caddy.Context) (*tls.Config, error) {
return nil, fmt.Errorf("managing client certificate: %v", err)
}
cfg.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- certs := tlsApp.AllMatchingCertificates(t.ClientCertificateAutomate)
+ certs := caddytls.AllMatchingCertificates(t.ClientCertificateAutomate)
var err error
for _, cert := range certs {
- err = cri.SupportsCertificate(&cert.Certificate)
+ certCertificate := cert.Certificate // avoid taking address of iteration variable (gosec warning)
+ err = cri.SupportsCertificate(&certCertificate)
if err == nil {
return &cert.Certificate, nil
}