summaryrefslogtreecommitdiff
path: root/caddytest/caddytest.go
diff options
context:
space:
mode:
authorMark Sargent <99003+sarge@users.noreply.github.com>2020-03-14 06:32:53 +1300
committerGitHub <noreply@github.com>2020-03-13 11:32:53 -0600
commitc4472363579e26485fdfc01f4fd79b86d44acd86 (patch)
tree9a59e09cd79f9d1c5efbca8723bfffc2d21afbf2 /caddytest/caddytest.go
parent5a19db5dc2db7c02d0f99630a07a64cacb7f7b44 (diff)
caddyhttp: Fix default SNI for default conn policy (#3141)
* add integration tests * removed SNI test * remove integration test condition * minor edit * fix sni when using static certificates Co-authored-by: Matt Holt <mholt@users.noreply.github.com>
Diffstat (limited to 'caddytest/caddytest.go')
-rw-r--r--caddytest/caddytest.go272
1 files changed, 272 insertions, 0 deletions
diff --git a/caddytest/caddytest.go b/caddytest/caddytest.go
new file mode 100644
index 0000000..04b65ba
--- /dev/null
+++ b/caddytest/caddytest.go
@@ -0,0 +1,272 @@
+package caddytest
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "path"
+ "regexp"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+// Defaults store any configuration required to make the tests run
+type Defaults struct {
+ // Port we expect caddy to listening on
+ AdminPort int
+ // Certificates we expect to be loaded before attempting to run the tests
+ Certifcates []string
+}
+
+// Default testing values
+var Default = Defaults{
+ AdminPort: 2019,
+ Certifcates: []string{"/caddy.localhost.crt", "/caddy.localhost.key"},
+}
+
+var (
+ matchKey = regexp.MustCompile(`(/[\w\d\.]+\.key)`)
+ matchCert = regexp.MustCompile(`(/[\w\d\.]+\.crt)`)
+)
+
+type configLoadError struct {
+ Response string
+}
+
+func (e configLoadError) Error() string { return e.Response }
+
+// InitServer this will configure the server with a configurion of a specific
+// type. The configType must be either "json" or the adapter type.
+func InitServer(t *testing.T, rawConfig string, configType string) {
+ if err := initServer(t, rawConfig, configType); errors.Is(err, &configLoadError{}) {
+ t.Logf("failed to load config: %s", err)
+ t.Fail()
+ }
+}
+
+// InitServer this will configure the server with a configurion of a specific
+// type. The configType must be either "json" or the adapter type.
+func initServer(t *testing.T, rawConfig string, configType string) error {
+
+ err := validateTestPrerequisites()
+ if err != nil {
+ t.Skipf("skipping tests as failed integration prerequisites. %s", err)
+ return nil
+ }
+
+ t.Cleanup(func() {
+ if t.Failed() {
+ res, err := http.Get(fmt.Sprintf("http://localhost:%d/config/", Default.AdminPort))
+ if err != nil {
+ t.Log("unable to read the current config")
+ }
+ defer res.Body.Close()
+ body, err := ioutil.ReadAll(res.Body)
+
+ var out bytes.Buffer
+ json.Indent(&out, body, "", " ")
+ t.Logf("----------- failed with config -----------\n%s", out.String())
+ }
+ })
+
+ rawConfig = prependCaddyFilePath(rawConfig)
+ client := &http.Client{
+ Timeout: time.Second * 2,
+ }
+ req, err := http.NewRequest("POST", fmt.Sprintf("http://localhost:%d/load", Default.AdminPort), strings.NewReader(rawConfig))
+ if err != nil {
+ t.Errorf("failed to create request. %s", err)
+ return err
+ }
+
+ if configType == "json" {
+ req.Header.Add("Content-Type", "application/json")
+ } else {
+ req.Header.Add("Content-Type", "text/"+configType)
+ }
+
+ res, err := client.Do(req)
+ if err != nil {
+ t.Errorf("unable to contact caddy server. %s", err)
+ return err
+ }
+ defer res.Body.Close()
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Errorf("unable to read response. %s", err)
+ return err
+ }
+
+ if res.StatusCode != 200 {
+ return configLoadError{Response: string(body)}
+ }
+
+ return nil
+}
+
+var hasValidated bool
+var arePrerequisitesValid bool
+
+func validateTestPrerequisites() error {
+
+ if hasValidated {
+ if !arePrerequisitesValid {
+ return errors.New("caddy integration prerequisites failed. see first error")
+ }
+ return nil
+ }
+
+ hasValidated = true
+ arePrerequisitesValid = false
+
+ // check certificates are found
+ for _, certName := range Default.Certifcates {
+ if _, err := os.Stat(getIntegrationDir() + certName); os.IsNotExist(err) {
+ return fmt.Errorf("caddy integration test certificates (%s) not found", certName)
+ }
+ }
+
+ // assert that caddy is running
+ client := &http.Client{
+ Timeout: time.Second * 2,
+ }
+ _, err := client.Get(fmt.Sprintf("http://localhost:%d/config/", Default.AdminPort))
+ if err != nil {
+ return errors.New("caddy integration test caddy server not running. Expected to be listening on localhost:2019")
+ }
+
+ arePrerequisitesValid = true
+ return nil
+}
+
+func getIntegrationDir() string {
+
+ _, filename, _, ok := runtime.Caller(1)
+ if !ok {
+ panic("unable to determine the current file path")
+ }
+
+ return path.Dir(filename)
+}
+
+// use the convention to replace /[certificatename].[crt|key] with the full path
+// this helps reduce the noise in test configurations and also allow this
+// to run in any path
+func prependCaddyFilePath(rawConfig string) string {
+ r := matchKey.ReplaceAllString(rawConfig, getIntegrationDir()+"$1")
+ r = matchCert.ReplaceAllString(r, getIntegrationDir()+"$1")
+ return r
+}
+
+// creates a testing transport that forces call dialing connections to happen locally
+func createTestingTransport() *http.Transport {
+
+ dialer := net.Dialer{
+ Timeout: 5 * time.Second,
+ KeepAlive: 5 * time.Second,
+ DualStack: true,
+ }
+
+ dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) {
+ parts := strings.Split(addr, ":")
+ destAddr := fmt.Sprintf("127.0.0.1:%s", parts[1])
+ log.Printf("caddytest: redirecting the dialer from %s to %s", addr, destAddr)
+ return dialer.DialContext(ctx, network, destAddr)
+ }
+
+ return &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: dialContext,
+ ForceAttemptHTTP2: true,
+ MaxIdleConns: 100,
+ IdleConnTimeout: 90 * time.Second,
+ TLSHandshakeTimeout: 5 * time.Second,
+ ExpectContinueTimeout: 1 * time.Second,
+ TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
+ }
+}
+
+// AssertLoadError will load a config and expect an error
+func AssertLoadError(t *testing.T, rawConfig string, configType string, expectedError string) {
+ err := initServer(t, rawConfig, configType)
+ if !strings.Contains(err.Error(), expectedError) {
+ t.Errorf("expected error \"%s\" but got \"%s\"", expectedError, err.Error())
+ }
+}
+
+// AssertGetResponse request a URI and assert the status code and the body contains a string
+func AssertGetResponse(t *testing.T, requestURI string, statusCode int, expectedBody string) (*http.Response, string) {
+ resp, body := AssertGetResponseBody(t, requestURI, statusCode)
+ if !strings.Contains(body, expectedBody) {
+ t.Errorf("expected response body \"%s\" but got \"%s\"", expectedBody, body)
+ }
+ return resp, string(body)
+}
+
+// AssertGetResponseBody request a URI and assert the status code matches
+func AssertGetResponseBody(t *testing.T, requestURI string, expectedStatusCode int) (*http.Response, string) {
+
+ client := &http.Client{
+ Transport: createTestingTransport(),
+ }
+
+ resp, err := client.Get(requestURI)
+ if err != nil {
+ t.Errorf("failed to call server %s", err)
+ return nil, ""
+ }
+
+ defer resp.Body.Close()
+
+ if expectedStatusCode != resp.StatusCode {
+ t.Errorf("expected status code: %d but got %d", expectedStatusCode, resp.StatusCode)
+ }
+
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("unable to read the response body %s", err)
+ return nil, ""
+ }
+
+ return resp, string(body)
+}
+
+// AssertRedirect makes a request and asserts the redirection happens
+func AssertRedirect(t *testing.T, requestURI string, expectedToLocation string, expectedStatusCode int) *http.Response {
+
+ redirectPolicyFunc := func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ }
+
+ client := &http.Client{
+ CheckRedirect: redirectPolicyFunc,
+ Transport: createTestingTransport(),
+ }
+
+ resp, err := client.Get(requestURI)
+ if err != nil {
+ t.Errorf("failed to call server %s", err)
+ return nil
+ }
+
+ if expectedStatusCode != resp.StatusCode {
+ t.Errorf("expected status code: %d but got %d", expectedStatusCode, resp.StatusCode)
+ }
+
+ loc, err := resp.Location()
+ if expectedToLocation != loc.String() {
+ t.Errorf("expected location: \"%s\" but got \"%s\"", expectedToLocation, loc.String())
+ }
+
+ return resp
+}