summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/caddy/main.go (renamed from cmd/caddy2/main.go)1
-rw-r--r--modules/caddyhttp/caddyhttp.go2
-rw-r--r--modules/caddyhttp/encode/encode.go12
-rw-r--r--modules/caddyhttp/routes.go10
-rw-r--r--modules/caddyhttp/server.go13
-rw-r--r--modules/caddyhttp/templates/templates.go146
-rw-r--r--modules/caddyhttp/templates/tplcontext.go413
-rw-r--r--modules/caddyhttp/templates/tplcontext_test.go420
-rw-r--r--modules/caddytls/matchers.go2
9 files changed, 1006 insertions, 13 deletions
diff --git a/cmd/caddy2/main.go b/cmd/caddy/main.go
index c2320ef..463b1b9 100644
--- a/cmd/caddy2/main.go
+++ b/cmd/caddy/main.go
@@ -16,6 +16,7 @@ import (
_ "github.com/caddyserver/caddy/modules/caddyhttp/requestbody"
_ "github.com/caddyserver/caddy/modules/caddyhttp/reverseproxy"
_ "github.com/caddyserver/caddy/modules/caddyhttp/rewrite"
+ _ "github.com/caddyserver/caddy/modules/caddyhttp/templates"
_ "github.com/caddyserver/caddy/modules/caddytls"
_ "github.com/caddyserver/caddy/modules/caddytls/standardstek"
)
diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go
index 50c0316..ffd7d0f 100644
--- a/modules/caddyhttp/caddyhttp.go
+++ b/modules/caddyhttp/caddyhttp.go
@@ -33,7 +33,7 @@ func init() {
type App struct {
HTTPPort int `json:"http_port,omitempty"`
HTTPSPort int `json:"https_port,omitempty"`
- GracePeriod caddy.Duration `json:"grace_period,omitempty"`
+ GracePeriod caddy.Duration `json:"grace_period,omitempty"`
Servers map[string]*Server `json:"servers,omitempty"`
servers []*http.Server
diff --git a/modules/caddyhttp/encode/encode.go b/modules/caddyhttp/encode/encode.go
index cf658e3..e20667f 100644
--- a/modules/caddyhttp/encode/encode.go
+++ b/modules/caddyhttp/encode/encode.go
@@ -148,7 +148,7 @@ func (rw *responseWriter) Write(p []byte) (int, error) {
return n, err
}
-// init should be called once we know we are writing an encoded response.
+// init should be called before we write a response, if rw.buf is not nil.
func (rw *responseWriter) init() {
if rw.Header().Get("Content-Encoding") == "" && rw.buf.Len() >= rw.config.MinLength {
rw.w = rw.config.writerPools[rw.encodingName].Get().(Encoder)
@@ -164,7 +164,13 @@ func (rw *responseWriter) init() {
// deallocates any active resources.
func (rw *responseWriter) Close() error {
var err error
- if rw.buf != nil {
+ // only attempt to write the remaining buffered response
+ // if there are any bytes left to write; otherwise, if
+ // the handler above us returned an error without writing
+ // anything, we'd write to the response when we instead
+ // should simply let the error propagate back down; this
+ // is why the check for rw.buf.Len() > 0 is crucial
+ if rw.buf != nil && rw.buf.Len() > 0 {
rw.init()
p := rw.buf.Bytes()
defer func() {
@@ -280,7 +286,7 @@ const defaultMinLength = 512
// Interface guards
var (
- _ caddy.Provisioner = (*Encode)(nil)
+ _ caddy.Provisioner = (*Encode)(nil)
_ caddyhttp.MiddlewareHandler = (*Encode)(nil)
_ caddyhttp.HTTPInterfaces = (*responseWriter)(nil)
)
diff --git a/modules/caddyhttp/routes.go b/modules/caddyhttp/routes.go
index 596149d..8033b91 100644
--- a/modules/caddyhttp/routes.go
+++ b/modules/caddyhttp/routes.go
@@ -186,17 +186,21 @@ type middlewareResponseWriter struct {
func (mrw middlewareResponseWriter) WriteHeader(statusCode int) {
if !mrw.allowWrites {
- panic("WriteHeader: middleware cannot write to the response")
+ // technically, this is not true: middleware can write headers,
+ // but only after the responder handler has returned; either the
+ // responder did nothing with the response (sad face), or the
+ // middleware wrapped the response and deferred the write
+ panic("WriteHeader: middleware cannot write response headers")
}
mrw.ResponseWriterWrapper.WriteHeader(statusCode)
}
func (mrw middlewareResponseWriter) Write(b []byte) (int, error) {
if !mrw.allowWrites {
- panic("Write: middleware cannot write to the response")
+ panic("Write: middleware cannot write to the response before responder")
}
return mrw.ResponseWriterWrapper.Write(b)
}
// Interface guard
-var _ HTTPInterfaces = middlewareResponseWriter{}
+var _ HTTPInterfaces = (*middlewareResponseWriter)(nil)
diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go
index 3de82f6..61b5631 100644
--- a/modules/caddyhttp/server.go
+++ b/modules/caddyhttp/server.go
@@ -15,10 +15,10 @@ import (
// Server is an HTTP server.
type Server struct {
Listen []string `json:"listen,omitempty"`
- ReadTimeout caddy.Duration `json:"read_timeout,omitempty"`
- ReadHeaderTimeout caddy.Duration `json:"read_header_timeout,omitempty"`
- WriteTimeout caddy.Duration `json:"write_timeout,omitempty"`
- IdleTimeout caddy.Duration `json:"idle_timeout,omitempty"`
+ ReadTimeout caddy.Duration `json:"read_timeout,omitempty"`
+ ReadHeaderTimeout caddy.Duration `json:"read_header_timeout,omitempty"`
+ WriteTimeout caddy.Duration `json:"write_timeout,omitempty"`
+ IdleTimeout caddy.Duration `json:"idle_timeout,omitempty"`
MaxHeaderBytes int `json:"max_header_bytes,omitempty"`
Routes RouteList `json:"routes,omitempty"`
Errors *httpErrorConfig `json:"errors,omitempty"`
@@ -40,6 +40,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// set up the context for the request
repl := caddy.NewReplacer()
ctx := context.WithValue(r.Context(), caddy.ReplacerCtxKey, repl)
+ ctx = context.WithValue(ctx, ServerCtxKey, s)
ctx = context.WithValue(ctx, TableCtxKey, make(map[string]interface{})) // TODO: Implement this
r = r.WithContext(ctx)
@@ -126,5 +127,7 @@ type httpErrorConfig struct {
// the logging configuration first.
}
-// TableCtxKey is the context key for the request's variable table.
+const ServerCtxKey caddy.CtxKey = "server"
+
+// TableCtxKey is the context key for the request's variable table. TODO: implement this
const TableCtxKey caddy.CtxKey = "table"
diff --git a/modules/caddyhttp/templates/templates.go b/modules/caddyhttp/templates/templates.go
new file mode 100644
index 0000000..56c3b66
--- /dev/null
+++ b/modules/caddyhttp/templates/templates.go
@@ -0,0 +1,146 @@
+package templates
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "strconv"
+ "text/template"
+
+ "github.com/caddyserver/caddy"
+ "github.com/caddyserver/caddy/modules/caddyhttp"
+)
+
+func init() {
+ caddy.RegisterModule(caddy.Module{
+ Name: "http.middleware.templates",
+ New: func() interface{} { return new(Templates) },
+ })
+}
+
+// Templates is a middleware which execute response bodies as templates.
+type Templates struct {
+ FileRoot string `json:"file_root,omitempty"`
+ Delimiters []string `json:"delimiters,omitempty"`
+}
+
+// Validate ensures t has a valid configuration.
+func (t *Templates) Validate() error {
+ if len(t.Delimiters) != 0 && len(t.Delimiters) != 2 {
+ return fmt.Errorf("delimiters must consist of exactly two elements: opening and closing")
+ }
+ return nil
+}
+
+func (t *Templates) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
+ buf := bufPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ defer bufPool.Put(buf)
+
+ wb := &responseBuffer{
+ ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{ResponseWriter: w},
+ buf: buf,
+ }
+
+ err := next.ServeHTTP(wb, r)
+ if err != nil {
+ return err
+ }
+
+ err = t.executeTemplate(wb, r)
+ if err != nil {
+ return err
+ }
+
+ w.Header().Set("Content-Length", strconv.Itoa(wb.buf.Len()))
+ w.Header().Del("Accept-Ranges") // we don't know ranges for dynamically-created content
+ w.Header().Del("Etag") // don't know a way to quickly generate etag for dynamic content
+ w.Header().Del("Last-Modified") // useless for dynamic content since it's always changing
+
+ w.WriteHeader(wb.statusCode)
+ io.Copy(w, wb.buf)
+
+ return nil
+}
+
+// executeTemplate executes the template contianed
+// in wb.buf and replaces it with the results.
+func (t *Templates) executeTemplate(wb *responseBuffer, r *http.Request) error {
+ tpl := template.New(r.URL.Path)
+
+ if len(t.Delimiters) == 2 {
+ tpl.Delims(t.Delimiters[0], t.Delimiters[1])
+ }
+
+ parsedTpl, err := tpl.Parse(wb.buf.String())
+ if err != nil {
+ return caddyhttp.Error(http.StatusInternalServerError, err)
+ }
+
+ var fs http.FileSystem
+ if t.FileRoot != "" {
+ fs = http.Dir(t.FileRoot)
+ }
+ ctx := &templateContext{
+ Root: fs,
+ Req: r,
+ RespHeader: tplWrappedHeader{wb.Header()},
+ }
+
+ wb.buf.Reset() // reuse buffer for output
+ err = parsedTpl.Execute(wb.buf, ctx)
+ if err != nil {
+ return caddyhttp.Error(http.StatusInternalServerError, err)
+ }
+
+ return nil
+}
+
+// responseBuffer buffers the response so that it can be
+// executed as a template.
+type responseBuffer struct {
+ *caddyhttp.ResponseWriterWrapper
+ wroteHeader bool
+ statusCode int
+ buf *bytes.Buffer
+}
+
+func (rb *responseBuffer) WriteHeader(statusCode int) {
+ if rb.wroteHeader {
+ return
+ }
+ rb.statusCode = statusCode
+ rb.wroteHeader = true
+}
+
+func (rb *responseBuffer) Write(data []byte) (int, error) {
+ rb.WriteHeader(http.StatusOK)
+ return rb.buf.Write(data)
+}
+
+// virtualResponseWriter is used in virtualized HTTP requests.
+type virtualResponseWriter struct {
+ status int
+ header http.Header
+ body *bytes.Buffer
+}
+
+func (vrw *virtualResponseWriter) Header() http.Header {
+ return vrw.header
+}
+
+func (vrw *virtualResponseWriter) WriteHeader(statusCode int) {
+ vrw.status = statusCode
+}
+
+func (vrw *virtualResponseWriter) Write(data []byte) (int, error) {
+ return vrw.body.Write(data)
+}
+
+// Interface guards
+var (
+ _ caddy.Validator = (*Templates)(nil)
+ _ caddyhttp.MiddlewareHandler = (*Templates)(nil)
+ _ caddyhttp.HTTPInterfaces = (*responseBuffer)(nil)
+)
diff --git a/modules/caddyhttp/templates/tplcontext.go b/modules/caddyhttp/templates/tplcontext.go
new file mode 100644
index 0000000..7123c42
--- /dev/null
+++ b/modules/caddyhttp/templates/tplcontext.go
@@ -0,0 +1,413 @@
+package templates
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "log"
+ weakrand "math/rand"
+ "net"
+ "net/http"
+ "path"
+ "strings"
+ "sync"
+ "text/template"
+ "time"
+
+ "os"
+
+ "github.com/caddyserver/caddy/modules/caddyhttp"
+ "gopkg.in/russross/blackfriday.v2"
+)
+
+// templateContext is the templateContext with which HTTP templates are executed.
+type templateContext struct {
+ Root http.FileSystem
+ Req *http.Request
+ Args []interface{} // defined by arguments to .Include
+ RespHeader tplWrappedHeader
+ server http.Handler
+}
+
+// Include returns the contents of filename relative to the site root.
+func (c templateContext) Include(filename string, args ...interface{}) (string, error) {
+ if c.Root == nil {
+ return "", fmt.Errorf("root file system not specified")
+ }
+
+ file, err := c.Root.Open(filename)
+ if err != nil {
+ return "", err
+ }
+ defer file.Close()
+
+ bodyBuf := bufPool.Get().(*bytes.Buffer)
+ bodyBuf.Reset()
+ defer bufPool.Put(bodyBuf)
+
+ _, err = io.Copy(bodyBuf, file)
+ if err != nil {
+ return "", err
+ }
+
+ c.Args = args
+
+ return c.executeTemplate(filename, bodyBuf.Bytes())
+}
+
+// HTTPInclude returns the body of a virtual (lightweight) request
+// to the given URI on the same server.
+func (c templateContext) HTTPInclude(uri string) (string, error) {
+ if c.Req.Header.Get(recursionPreventionHeader) == "1" {
+ return "", fmt.Errorf("virtual include cycle")
+ }
+
+ buf := bufPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ defer bufPool.Put(buf)
+
+ virtReq, err := http.NewRequest("GET", uri, nil)
+ if err != nil {
+ return "", err
+ }
+ virtReq.Header.Set(recursionPreventionHeader, "1")
+
+ vrw := &virtualResponseWriter{body: buf, header: make(http.Header)}
+ server := c.Req.Context().Value(caddyhttp.ServerCtxKey).(http.Handler)
+
+ server.ServeHTTP(vrw, virtReq)
+ if vrw.status >= 400 {
+ return "", fmt.Errorf("http %d", vrw.status)
+ }
+
+ return c.executeTemplate(uri, buf.Bytes())
+}
+
+func (c templateContext) executeTemplate(tplName string, body []byte) (string, error) {
+ tpl, err := template.New(tplName).Parse(string(body))
+ if err != nil {
+ return "", err
+ }
+
+ buf := bufPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ defer bufPool.Put(buf)
+
+ err = tpl.Execute(buf, c)
+ if err != nil {
+ return "", err
+ }
+
+ return buf.String(), nil
+}
+
+// Now returns the current timestamp.
+func (c templateContext) Now() time.Time {
+ return time.Now()
+}
+
+// Cookie gets the value of a cookie with name name.
+func (c templateContext) Cookie(name string) string {
+ cookies := c.Req.Cookies()
+ for _, cookie := range cookies {
+ if cookie.Name == name {
+ return cookie.Value
+ }
+ }
+ return ""
+}
+
+// ReqHeader gets the value of a request header with field name.
+func (c templateContext) ReqHeader(name string) string {
+ return c.Req.Header.Get(name)
+}
+
+// Hostname gets the (remote) hostname of the client making the request.
+func (c templateContext) Hostname() string {
+ ip := c.IP()
+
+ hostnameList, err := net.LookupAddr(ip)
+ if err != nil || len(hostnameList) == 0 {
+ return c.Req.RemoteAddr
+ }
+
+ return hostnameList[0]
+}
+
+// Env gets a map of the environment variables.
+func (c templateContext) Env() map[string]string {
+ osEnv := os.Environ()
+ envVars := make(map[string]string, len(osEnv))
+ for _, env := range osEnv {
+ data := strings.SplitN(env, "=", 2)
+ if len(data) == 2 && len(data[0]) > 0 {
+ envVars[data[0]] = data[1]
+ }
+ }
+ return envVars
+}
+
+// IP gets the (remote) IP address of the client making the request.
+func (c templateContext) IP() string {
+ ip, _, err := net.SplitHostPort(c.Req.RemoteAddr)
+ if err != nil {
+ return c.Req.RemoteAddr
+ }
+ return ip
+}
+
+// Host returns the hostname portion of the Host header
+// from the HTTP request.
+func (c templateContext) Host() (string, error) {
+ host, _, err := net.SplitHostPort(c.Req.Host)
+ if err != nil {
+ if !strings.Contains(c.Req.Host, ":") {
+ // common with sites served on the default port 80
+ return c.Req.Host, nil
+ }
+ return "", err
+ }
+ return host, nil
+}
+
+// Truncate truncates the input string to the given length.
+// If length is negative, it returns that many characters
+// starting from the end of the string. If the absolute value
+// of length is greater than len(input), the whole input is
+// returned.
+func (c templateContext) Truncate(input string, length int) string {
+ if length < 0 && len(input)+length > 0 {
+ return input[len(input)+length:]
+ }
+ if length >= 0 && len(input) > length {
+ return input[:length]
+ }
+ return input
+}
+
+// StripHTML returns s without HTML tags. It is fairly naive
+// but works with most valid HTML inputs.
+func (c templateContext) StripHTML(s string) string {
+ var buf bytes.Buffer
+ var inTag, inQuotes bool
+ var tagStart int
+ for i, ch := range s {
+ if inTag {
+ if ch == '>' && !inQuotes {
+ inTag = false
+ } else if ch == '<' && !inQuotes {
+ // false start
+ buf.WriteString(s[tagStart:i])
+ tagStart = i
+ } else if ch == '"' {
+ inQuotes = !inQuotes
+ }
+ continue
+ }
+ if ch == '<' {
+ inTag = true
+ tagStart = i
+ continue
+ }
+ buf.WriteRune(ch)
+ }
+ if inTag {
+ // false start
+ buf.WriteString(s[tagStart:])
+ }
+ return buf.String()
+}
+
+// Markdown renders the markdown body as HTML.
+func (c templateContext) Markdown(body string) string {
+ return string(blackfriday.Run([]byte(body)))
+}
+
+// Ext returns the suffix beginning at the final dot in the final
+// slash-separated element of the pathStr (or in other words, the
+// file extension).
+func (c templateContext) Ext(pathStr string) string {
+ return path.Ext(pathStr)
+}
+
+// StripExt returns the input string without the extension,
+// which is the suffix starting with the final '.' character
+// but not before the final path separator ('/') character.
+// If there is no extension, the whole input is returned.
+func (c templateContext) StripExt(path string) string {
+ for i := len(path) - 1; i >= 0 && path[i] != '/'; i-- {
+ if path[i] == '.' {
+ return path[:i]
+ }
+ }
+ return path
+}
+
+// Replace replaces instances of find in input with replacement.
+func (c templateContext) Replace(input, find, replacement string) string {
+ return strings.Replace(input, find, replacement, -1)
+}
+
+// HasPrefix returns true if s starts with prefix.
+func (c templateContext) HasPrefix(s, prefix string) bool {
+ return strings.HasPrefix(s, prefix)
+}
+
+// ToLower will convert the given string to lower case.
+func (c templateContext) ToLower(s string) string {
+ return strings.ToLower(s)
+}
+
+// ToUpper will convert the given string to upper case.
+func (c templateContext) ToUpper(s string) string {
+ return strings.ToUpper(s)
+}
+
+// Split is a pass-through to strings.Split. It will split
+// the first argument at each instance of the separator and
+// return a slice of strings.
+func (c templateContext) Split(s string, sep string) []string {
+ return strings.Split(s, sep)
+}
+
+// Join is a pass-through to strings.Join. It will join the
+// first argument slice with the separator in the second
+// argument and return the result.
+func (c templateContext) Join(a []string, sep string) string {
+ return strings.Join(a, sep)
+}
+
+// Slice will convert the given arguments into a slice.
+func (c templateContext) Slice(elems ...interface{}) []interface{} {
+ return elems
+}
+
+// Dict will convert the arguments into a dictionary (map). It expects
+// alternating keys and values of string types. This is useful since you
+// cannot express map literals directly in Go templates.
+func (c templateContext) Dict(values ...interface{}) (map[string]interface{}, error) {
+ if len(values)%2 != 0 {
+ return nil, fmt.Errorf("expected even number of arguments")
+ }
+ dict := make(map[string]interface{}, len(values)/2)
+ for i := 0; i < len(values); i += 2 {
+ key, ok := values[i].(string)
+ if !ok {
+ return nil, fmt.Errorf("argument %d: map keys must be strings", i)
+ }
+ dict[key] = values[i+1]
+ }
+ return dict, nil
+}
+
+// ListFiles reads and returns a slice of names from the given
+// directory relative to the root of c.
+func (c templateContext) ListFiles(name string) ([]string, error) {
+ if c.Root == nil {
+ return nil, fmt.Errorf("root file system not specified")
+ }
+
+ dir, err := c.Root.Open(path.Clean(name))
+ if err != nil {
+ return nil, err
+ }
+ defer dir.Close()
+
+ stat, err := dir.Stat()
+ if err != nil {
+ return nil, err
+ }
+ if !stat.IsDir() {
+ return nil, fmt.Errorf("%v is not a directory", name)
+ }
+
+ dirInfo, err := dir.Readdir(0)
+ if err != nil {
+ return nil, err
+ }
+
+ names := make([]string, len(dirInfo))
+ for i, fileInfo := range dirInfo {
+ names[i] = fileInfo.Name()
+ }
+
+ return names, nil
+}
+
+// RandomString generates a random string of random length given
+// length bounds. Thanks to http://stackoverflow.com/a/35615565/1048862
+// for the clever technique that is fairly fast, secure, and maintains
+// proper distributions over the dictionary.
+func (c templateContext) RandomString(minLen, maxLen int) string {
+ const (
+ letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+ letterIdxBits = 6 // 6 bits to represent 64 possibilities (indexes)
+ letterIdxMask = 1<<letterIdxBits - 1 // all 1-bits, as many as letterIdxBits
+ )
+
+ if minLen < 0 || maxLen < 0 || maxLen < minLen {
+ return ""
+ }
+
+ n := weakrand.Intn(maxLen-minLen+1) + minLen // choose actual length
+
+ // secureRandomBytes returns a number of bytes using crypto/rand.
+ secureRandomBytes := func(numBytes int) []byte {
+ randomBytes := make([]byte, numBytes)
+ if _, err := rand.Read(randomBytes); err != nil {
+ // TODO: what to do with the logs (throughout whole file) (could return as error? might get rendered though...)
+ log.Println("[ERROR] failed to read bytes: ", err)
+ }
+ return randomBytes
+ }
+
+ result := make([]byte, n)
+ bufferSize := int(float64(n) * 1.3)
+ for i, j, randomBytes := 0, 0, []byte{}; i < n; j++ {
+ if j%bufferSize == 0 {
+ randomBytes = secureRandomBytes(bufferSize)
+ }
+ if idx := int(randomBytes[j%n] & letterIdxMask); idx < len(letterBytes) {
+ result[i] = letterBytes[idx]
+ i++
+ }
+ }
+
+ return string(result)
+}
+
+// tplWrappedHeader wraps niladic functions so that they
+// can be used in templates. (Template functions must
+// return a value.)
+type tplWrappedHeader struct{ http.Header }
+
+// Add adds a header field value, appending val to
+// existing values for that field. It returns an
+// empty string.
+func (h tplWrappedHeader) Add(field, val string) string {
+ h.Header.Add(field, val)
+ return ""
+}
+
+// Set sets a header field value, overwriting any
+// other values for that field. It returns an
+// empty string.
+func (h tplWrappedHeader) Set(field, val string) string {
+ h.Header.Set(field, val)
+ return ""
+}
+
+// Del deletes a header field. It returns an empty string.
+func (h tplWrappedHeader) Del(field string) string {
+ h.Header.Del(field)
+ return ""
+}
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ return new(bytes.Buffer)
+ },
+}
+
+const recursionPreventionHeader = "Caddy-Templates-Include"
diff --git a/modules/caddyhttp/templates/tplcontext_test.go b/modules/caddyhttp/templates/tplcontext_test.go
new file mode 100644
index 0000000..af4ad4e
--- /dev/null
+++ b/modules/caddyhttp/templates/tplcontext_test.go
@@ -0,0 +1,420 @@
+// Copyright 2015 Light Code Labs, LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package templates
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "path/filepath"
+ "reflect"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestMarkdown(t *testing.T) {
+ context := getContextOrFail(t)
+
+ for i, test := range []struct {
+ body string
+ expect string
+ }{
+ {
+ body: "- str1\n- str2\n",
+ expect: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
+ },
+ } {
+ result := context.Markdown(test.body)
+ if result != test.expect {
+ t.Errorf("Test %d: expected '%s' but got '%s'", i, test.expect, result)
+ }
+ }
+}
+
+func TestCookie(t *testing.T) {
+ for i, test := range []struct {
+ cookie *http.Cookie
+ cookieName string
+ expect string
+ }{
+ {
+ // happy path
+ cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
+ cookieName: "cookieName",
+ expect: "cookieValue",
+ },
+ {
+ // try to get a non-existing cookie
+ cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
+ cookieName: "notExisting",
+ expect: "",
+ },
+ {
+ // partial name match
+ cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"},
+ cookieName: "cook",
+ expect: "",
+ },
+ {
+ // cookie with optional fields
+ cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120},
+ cookieName: "cookie",
+ expect: "cookieValue",
+ },
+ } {
+ context := getContextOrFail(t)
+ context.Req.AddCookie(test.cookie)
+ actual := context.Cookie(test.cookieName)
+ if actual != test.expect {
+ t.Errorf("Test %d: Expected cookie value '%s' but got '%s' for cookie with name '%s'",
+ i, test.expect, actual, test.cookieName)
+ }
+ }
+}
+
+func TestCookieMultipleCookies(t *testing.T) {
+ context := getContextOrFail(t)
+
+ cookieNameBase, cookieValueBase := "cookieName", "cookieValue"
+
+ for i := 0; i < 10; i++ {
+ context.Req.AddCookie(&http.Cookie{
+ Name: fmt.Sprintf("%s%d", cookieNameBase, i),
+ Value: fmt.Sprintf("%s%d", cookieValueBase, i),
+ })
+ }
+
+ for i := 0; i < 10; i++ {
+ expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
+ actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
+ if actualCookieVal != expectedCookieVal {
+ t.Errorf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
+ }
+ }
+}
+
+func TestEnv(t *testing.T) {
+ context := getContextOrFail(t)
+
+ name := "ENV_TEST_NAME"
+ testValue := "TEST_VALUE"
+ os.Setenv(name, testValue)
+
+ notExisting := "ENV_TEST_NOT_EXISTING"
+ os.Unsetenv(notExisting)
+
+ invalidName := "ENV_TEST_INVALID_NAME"
+ os.Setenv("="+invalidName, testValue)
+
+ env := context.Env()
+ if value := env[name]; value != testValue {
+ t.Errorf("Expected env-variable %s value '%s', found '%s'",
+ name, testValue, value)
+ }
+
+ if value, ok := env[notExisting]; ok {
+ t.Errorf("Expected empty env-variable %s, found '%s'",
+ notExisting, value)
+ }
+
+ for k, v := range env {
+ if strings.Contains(k, invalidName) {
+ t.Errorf("Expected invalid name not to be included in Env %s, found in key '%s'", invalidName, k)
+ }
+ if strings.Contains(v, invalidName) {
+ t.Errorf("Expected invalid name not be be included in Env %s, found in value '%s'", invalidName, v)
+ }
+ }
+
+ os.Unsetenv("=" + invalidName)
+}
+
+func TestIP(t *testing.T) {
+ context := getContextOrFail(t)
+ for i, test := range []struct {
+ inputRemoteAddr string
+ expect string
+ }{
+ {"1.1.1.1:1111", "1.1.1.1"},
+ {"1.1.1.1", "1.1.1.1"},
+ {"[::1]:11", "::1"},
+ {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
+ {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
+ } {
+ context.Req.RemoteAddr = test.inputRemoteAddr
+ if actual := context.IP(); actual != test.expect {
+ t.Errorf("Test %d: Expected %s but got %s", i, test.expect, actual)
+ }
+ }
+}
+
+func TestTruncate(t *testing.T) {
+ context := getContextOrFail(t)
+
+ for i, test := range []struct {
+ input string
+ length int
+ expect string
+ }{
+ {
+ input: "string",
+ length: 1,
+ expect: "s",
+ },
+ {
+ input: "string",
+ length: 6,
+ expect: "string",
+ },
+ {
+ input: "string",
+ length: 10,
+ expect: "string",
+ },
+ {
+ input: "string",
+ length: 0,
+ expect: "",
+ },
+ {
+ input: "string",
+ length: -5,
+ expect: "tring",
+ },
+ {
+ input: "string",
+ length: -6,
+ expect: "string",
+ },
+ {
+ input: "string",
+ length: -7,
+ expect: "string",
+ },
+ } {
+ actual := context.Truncate(test.input, test.length)
+ if actual != test.expect {
+ t.Errorf("Test %d: Expected '%s' but got '%s'", i, test.expect, actual)
+ }
+ }
+}
+
+func TestStripHTML(t *testing.T) {
+ context := getContextOrFail(t)
+
+ for i, test := range []struct {
+ input string
+ expect string
+ }{
+ {
+ // no tags
+ input: `h1`,
+ expect: `h1`,
+ },
+ {
+ // happy path
+ input: `<h1>h1</h1>`,
+ expect: `h1`,
+ },
+ {
+ // tag in quotes
+ input: `<h1">">h1</h1>`,
+ expect: `h1`,
+ },
+ {
+ // multiple tags
+ input: `<h1><b>h1</b></h1>`,
+ expect: `h1`,
+ },
+ {
+ // tags not closed
+ input: `<h1`,
+ expect: `<h1`,
+ },
+ {
+ // false start
+ input: `<h1<b>hi`,
+ expect: `<h1hi`,
+ },
+ } {
+ actual := context.StripHTML(test.input)
+ if actual != test.expect {
+ t.Errorf("Test %d: Expected %s, found %s. Input was StripHTML(%s)", i, test.expect, actual, test.input)
+ }
+ }
+}
+
+func TestStripExt(t *testing.T) {
+ context := getContextOrFail(t)
+ tests := []struct {
+ input string
+ expect string
+ }{
+ {
+ input: "",
+ expect: "",
+ },
+ {
+ input: "file.ext",
+ expect: "file",
+ },
+ {
+ input: "file",
+ expect: "file",
+ },
+ {
+ input: "/file",
+ expect: "/file",
+ },
+ {
+ input: "/file.ext",
+ expect: "/file",
+ },
+ {
+ input: "/dir.ext/",
+ expect: "/dir.ext/",
+ },
+ {
+ input: "/dir.ext/file.ext",
+ expect: "/dir.ext/file",
+ },
+ }
+
+ for i, test := range tests {
+ actual := context.StripExt(test.input)
+ if actual != test.expect {
+ t.Errorf("Test %d: Expected %s but got %s", i, test.expect, actual)
+ }
+ }
+}
+
+func TestFileListing(t *testing.T) {
+ for i, test := range []struct {
+ fileNames []string
+ inputBase string
+ shouldErr bool
+ verifyErr func(error) bool
+ }{
+ {
+ // directory and files exist
+ fileNames: []string{"file1", "file2"},
+ shouldErr: false,
+ },
+ {
+ // directory exists, no files
+ fileNames: []string{},
+ shouldErr: false,
+ },
+ {
+ // file or directory does not exist
+ fileNames: nil,
+ inputBase: "doesNotExist",
+ shouldErr: true,
+ verifyErr: os.IsNotExist,
+ },
+ {
+ // directory and files exist, but path to a file
+ fileNames: []string{"file1", "file2"},
+ inputBase: "file1",
+ shouldErr: true,
+ verifyErr: func(err error) bool {
+ return strings.HasSuffix(err.Error(), "is not a directory")
+ },
+ },
+ {
+ // try to escape Context Root
+ fileNames: nil,
+ inputBase: filepath.Join("..", "..", "..", "..", "..", "etc"),
+ shouldErr: true,
+ verifyErr: os.IsNotExist,
+ },
+ } {
+ context := getContextOrFail(t)
+ var dirPath string
+ var err error
+
+ // create files for test case
+ if test.fileNames != nil {
+ dirPath, err = ioutil.TempDir(fmt.Sprintf("%s", context.Root), "caddy_ctxtest")
+ if err != nil {
+ t.Fatalf("Test %d: Expected no error creating directory, got: '%s'", i, err.Error())
+ }
+ for _, name := range test.fileNames {
+ absFilePath := filepath.Join(dirPath, name)
+ if err = ioutil.WriteFile(absFilePath, []byte(""), os.ModePerm); err != nil {
+ os.RemoveAll(dirPath)
+ t.Fatalf("Test %d: Expected no error creating file, got: '%s'", i, err.Error())
+ }
+ }
+ }
+
+ // perform test
+ input := filepath.ToSlash(filepath.Join(filepath.Base(dirPath), test.inputBase))
+ actual, err := context.ListFiles(input)
+ if err != nil {
+ if !test.shouldErr {
+ t.Errorf("Test %d: Expected no error, got: '%s'", i, err)
+ } else if !test.verifyErr(err) {
+ t.Errorf("Test %d: Could not verify error content, got: '%s'", i, err)
+ }
+ } else if test.shouldErr {
+ t.Errorf("Test %d: Expected error but had none", i)
+ } else {
+ numFiles := len(test.fileNames)
+ // reflect.DeepEqual does not consider two empty slices to be equal
+ if numFiles == 0 && len(actual) != 0 {
+ t.Errorf("Test %d: Expected files %v, got: %v",
+ i, test.fileNames, actual)
+ } else {
+ sort.Strings(actual)
+ if numFiles > 0 && !reflect.DeepEqual(test.fileNames, actual) {
+ t.Errorf("Test %d: Expected files %v, got: %v",
+ i, test.fileNames, actual)
+ }
+ }
+ }
+
+ if dirPath != "" {
+ if err := os.RemoveAll(dirPath); err != nil && !os.IsNotExist(err) {
+ t.Fatalf("Test %d: Expected no error removing temporary test directory, got: %v", i, err)
+ }
+ }
+ }
+}
+
+func getContextOrFail(t *testing.T) templateContext {
+ context, err := initTestContext()
+ if err != nil {
+ t.Fatalf("failed to prepare test context: %v", err)
+ }
+ return context
+}
+
+func initTestContext() (templateContext, error) {
+ body := bytes.NewBufferString("request body")
+ request, err := http.NewRequest("GET", "https://example.com/foo/bar", body)
+ if err != nil {
+ return templateContext{}, err
+ }
+ return templateContext{
+ Root: http.Dir(os.TempDir()),
+ Req: request,
+ RespHeader: tplWrappedHeader{make(http.Header)},
+ }, nil
+}
diff --git a/modules/caddytls/matchers.go b/modules/caddytls/matchers.go
index a0d26bb..1f26222 100644
--- a/modules/caddytls/matchers.go
+++ b/modules/caddytls/matchers.go
@@ -28,4 +28,4 @@ func (m MatchServerName) Match(hello *tls.ClientHelloInfo) bool {
}
// Interface guard
-var _ ConnectionMatcher = MatchServerName{}
+var _ ConnectionMatcher = (*MatchServerName)(nil)