summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/templates
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/templates')
-rw-r--r--modules/caddyhttp/templates/tplcontext.go20
1 files changed, 17 insertions, 3 deletions
diff --git a/modules/caddyhttp/templates/tplcontext.go b/modules/caddyhttp/templates/tplcontext.go
index 40d1370..e3909b2 100644
--- a/modules/caddyhttp/templates/tplcontext.go
+++ b/modules/caddyhttp/templates/tplcontext.go
@@ -22,6 +22,7 @@ import (
"net"
"net/http"
"path"
+ "strconv"
"strings"
"sync"
@@ -79,8 +80,18 @@ func (c templateContext) Include(filename string, args ...interface{}) (template
// are NOT escaped, so you should only include trusted resources.
// If it is not trusted, be sure to use escaping functions yourself.
func (c templateContext) HTTPInclude(uri string) (template.HTML, error) {
- if c.Req.Header.Get(recursionPreventionHeader) == "1" {
- return "", fmt.Errorf("virtual request cycle")
+ // prevent virtual request loops by counting how many levels
+ // deep we are; and if we get too deep, return an error
+ recursionCount := 1
+ if numStr := c.Req.Header.Get(recursionPreventionHeader); numStr != "" {
+ num, err := strconv.Atoi(numStr)
+ if err != nil {
+ return "", fmt.Errorf("parsing %s: %v", recursionPreventionHeader, err)
+ }
+ if num >= 3 {
+ return "", fmt.Errorf("virtual request cycle")
+ }
+ recursionCount = num + 1
}
buf := bufPool.Get().(*bytes.Buffer)
@@ -91,7 +102,10 @@ func (c templateContext) HTTPInclude(uri string) (template.HTML, error) {
if err != nil {
return "", err
}
- virtReq.Header.Set(recursionPreventionHeader, "1")
+ virtReq.Host = c.Req.Host
+ virtReq.Header = c.Req.Header.Clone()
+ virtReq.Trailer = c.Req.Trailer.Clone()
+ virtReq.Header.Set(recursionPreventionHeader, strconv.Itoa(recursionCount))
vrw := &virtualResponseWriter{body: buf, header: make(http.Header)}
server := c.Req.Context().Value(caddyhttp.ServerCtxKey).(http.Handler)