package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi"
)
func TestStripSlashes(t *testing.T) {
r := chi.NewRouter()
// This middleware must be mounted at the top level of the router, not at the end-handler
// because then it'll be too late and will end up in a 404
r.Use(StripSlashes)
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root"))
})
r.Route("/accounts/{accountID}", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
accountID := chi.URLParam(r, "accountID")
w.Write([]byte(accountID))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, resp := testRequest(t, ts, "GET", "/", nil); resp != "root" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "//", nil); resp != "root" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "admin" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "admin" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/nothing-here", nil); resp != "nothing here" {
t.Fatalf(resp)
}
}
func TestStripSlashesInRoute(t *testing.T) {
r := chi.NewRouter()
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hi"))
})
r.Route("/accounts/{accountID}", func(r chi.Router) {
r.Use(StripSlashes)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("accounts index"))
})
r.Get("/query", func(w http.ResponseWriter, r *http.Request) {
accountID := chi.URLParam(r, "accountID")
w.Write([]byte(accountID))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, resp := testRequest(t, ts, "GET", "/hi", nil); resp != "hi" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/hi/", nil); resp != "nothing here" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "accounts index" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "accounts index" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin/query", nil); resp != "admin" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin/query/", nil); resp != "admin" {
t.Fatalf(resp)
}
}
func TestRedirectSlashes(t *testing.T) {
r := chi.NewRouter()
// This middleware must be mounted at the top level of the router, not at the end-handler
// because then it'll be too late and will end up in a 404
r.Use(RedirectSlashes)
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root"))
})
r.Route("/accounts/{accountID}", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
accountID := chi.URLParam(r, "accountID")
w.Write([]byte(accountID))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if resp, body := testRequest(t, ts, "GET", "/", nil); body != "root" && resp.StatusCode != 200 {
t.Fatalf(body)
}
// NOTE: the testRequest client will follow the redirection..
if resp, body := testRequest(t, ts, "GET", "//", nil); body != "root" && resp.StatusCode != 200 {
t.Fatalf(body)
}
if resp, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" && resp.StatusCode != 200 {
t.Fatalf(body)
}
// NOTE: the testRequest client will follow the redirection..
if resp, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" && resp.StatusCode != 200 {
t.Fatalf(body)
}
if resp, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" && resp.StatusCode != 200 {
t.Fatalf(body)
}
// Ensure redirect Location url is correct
{
resp, body := testRequestNoRedirect(t, ts, "GET", "/accounts/someuser/", nil)
if resp.StatusCode != 301 {
t.Fatalf(body)
}
if resp.Header.Get("Location") != "/accounts/someuser" {
t.Fatalf("invalid redirection, should be /accounts/someuser")
}
}
// Ensure query params are kept in tact upon redirecting a slash
{
resp, body := testRequestNoRedirect(t, ts, "GET", "/accounts/someuser/?a=1&b=2", nil)
if resp.StatusCode != 301 {
t.Fatalf(body)
}
if resp.Header.Get("Location") != "/accounts/someuser?a=1&b=2" {
t.Fatalf("invalid redirection, should be /accounts/someuser?a=1&b=2")
}
}
}