summaryrefslogtreecommitdiff
path: root/modules/caddyhttp
diff options
context:
space:
mode:
authorMatthew Holt <mholt@users.noreply.github.com>2019-10-16 15:18:02 -0600
committerMatthew Holt <mholt@users.noreply.github.com>2019-10-16 15:18:02 -0600
commita458544d9f6e6aaf72aeab0454acfa482880d3d6 (patch)
tree7fa6b1806809ede3f81e9949ec819864f294526a /modules/caddyhttp
parent2f91b44587fe487b6a05c1c7e56833247e4b8c79 (diff)
Minor enhancements/fixes to rewrite directive and template virt req's
Diffstat (limited to 'modules/caddyhttp')
-rw-r--r--modules/caddyhttp/rewrite/caddyfile.go3
-rw-r--r--modules/caddyhttp/templates/tplcontext.go20
2 files changed, 20 insertions, 3 deletions
diff --git a/modules/caddyhttp/rewrite/caddyfile.go b/modules/caddyhttp/rewrite/caddyfile.go
index cb65d43..a1fc874 100644
--- a/modules/caddyhttp/rewrite/caddyfile.go
+++ b/modules/caddyhttp/rewrite/caddyfile.go
@@ -31,6 +31,9 @@ func init() {
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
var rewr Rewrite
for h.Next() {
+ if !h.NextArg() {
+ return nil, h.ArgErr()
+ }
rewr.URI = h.Val()
if h.NextArg() {
return nil, h.ArgErr()
diff --git a/modules/caddyhttp/templates/tplcontext.go b/modules/caddyhttp/templates/tplcontext.go
index 40d1370..e3909b2 100644
--- a/modules/caddyhttp/templates/tplcontext.go
+++ b/modules/caddyhttp/templates/tplcontext.go
@@ -22,6 +22,7 @@ import (
"net"
"net/http"
"path"
+ "strconv"
"strings"
"sync"
@@ -79,8 +80,18 @@ func (c templateContext) Include(filename string, args ...interface{}) (template
// are NOT escaped, so you should only include trusted resources.
// If it is not trusted, be sure to use escaping functions yourself.
func (c templateContext) HTTPInclude(uri string) (template.HTML, error) {
- if c.Req.Header.Get(recursionPreventionHeader) == "1" {
- return "", fmt.Errorf("virtual request cycle")
+ // prevent virtual request loops by counting how many levels
+ // deep we are; and if we get too deep, return an error
+ recursionCount := 1
+ if numStr := c.Req.Header.Get(recursionPreventionHeader); numStr != "" {
+ num, err := strconv.Atoi(numStr)
+ if err != nil {
+ return "", fmt.Errorf("parsing %s: %v", recursionPreventionHeader, err)
+ }
+ if num >= 3 {
+ return "", fmt.Errorf("virtual request cycle")
+ }
+ recursionCount = num + 1
}
buf := bufPool.Get().(*bytes.Buffer)
@@ -91,7 +102,10 @@ func (c templateContext) HTTPInclude(uri string) (template.HTML, error) {
if err != nil {
return "", err
}
- virtReq.Header.Set(recursionPreventionHeader, "1")
+ virtReq.Host = c.Req.Host
+ virtReq.Header = c.Req.Header.Clone()
+ virtReq.Trailer = c.Req.Trailer.Clone()
+ virtReq.Header.Set(recursionPreventionHeader, strconv.Itoa(recursionCount))
vrw := &virtualResponseWriter{body: buf, header: make(http.Header)}
server := c.Req.Context().Value(caddyhttp.ServerCtxKey).(http.Handler)