package middleware import ( "compress/flate" "compress/gzip" "io" "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" "github.com/go-chi/chi" ) func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (*http.Response, string) { req, err := http.NewRequest(method, ts.URL+path, nil) if err != nil { t.Fatal(err) return nil, "" } if len(encodings) > 0 { encodingsString := strings.Join(encodings, ",") req.Header.Set("Accept-Encoding", encodingsString) } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) return nil, "" } respBody := decodeResponseBody(t, resp) defer resp.Body.Close() return resp, respBody } func decodeResponseBody(t *testing.T, resp *http.Response) string { var reader io.ReadCloser switch resp.Header.Get("Content-Encoding") { case "gzip": var err error reader, err = gzip.NewReader(resp.Body) if err != nil { t.Fatal(err) } case "deflate": reader = flate.NewReader(resp.Body) default: reader = resp.Body } respBody, err := ioutil.ReadAll(reader) if err != nil { t.Fatal(err) return "" } reader.Close() return string(respBody) } func TestOldAPI(t *testing.T) { r := chi.NewRouter() r.Use(Compress(5, "text/html")) r.Get("/", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.Write([]byte("textstring")) }) ts := httptest.NewServer(r) defer ts.Close() tests := []struct { name string acceptedEncodings []string expectedEncoding string extraCode func() }{ { name: "no expected encodings", acceptedEncodings: nil, expectedEncoding: "", }, { name: "gzip is only encoding", acceptedEncodings: []string{"gzip"}, expectedEncoding: "gzip", }, { name: "gzip is preferred over deflate", acceptedEncodings: []string{"gzip", "deflate"}, expectedEncoding: "gzip", }, { name: "deflate is used", acceptedEncodings: []string{"deflate"}, expectedEncoding: "deflate", }, { name: "deflate is preferred over gzip", acceptedEncodings: []string{"gzip, deflate"}, expectedEncoding: "deflate", extraCode: func() { SetEncoder("deflate", encoderDeflate) }, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { if tc.extraCode != nil { tc.extraCode() } resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", "/", tc.acceptedEncodings...) if respString != "textstring" { t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString) } if got := resp.Header.Get("Content-Encoding"); got != tc.expectedEncoding { t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got) } }) } } func TestCompressor(t *testing.T) { r := chi.NewRouter() compressor := NewCompressor(5, "text/html", "text/css") if len(compressor.encoders) != 0 || len(compressor.pooledEncoders) != 2 { t.Errorf("gzip and deflate should be pooled") } compressor.SetEncoder("nop", func(w io.Writer, _ int) io.Writer { return w }) if len(compressor.encoders) != 1 { t.Errorf("nop encoder should be stored in the encoders map") } r.Use(compressor.Handler()) r.Get("/gethtml", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.Write([]byte("textstring")) }) r.Get("/getcss", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.Write([]byte("textstring")) }) r.Get("/getplain", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.Write([]byte("textstring")) }) ts := httptest.NewServer(r) defer ts.Close() tests := []struct { name string path string acceptedEncodings []string expectedEncoding string }{ { name: "no expected encodings due to no accepted encodings", path: "/gethtml", acceptedEncodings: nil, expectedEncoding: "", }, { name: "no expected encodings due to content type", path: "/getplain", acceptedEncodings: nil, expectedEncoding: "", }, { name: "gzip is only encoding", path: "/gethtml", acceptedEncodings: []string{"gzip"}, expectedEncoding: "gzip", }, { name: "gzip is preferred over deflate", path: "/getcss", acceptedEncodings: []string{"gzip", "deflate"}, expectedEncoding: "gzip", }, { name: "deflate is used", path: "/getcss", acceptedEncodings: []string{"deflate"}, expectedEncoding: "deflate", }, { name: "nop is preferred", path: "/getcss", acceptedEncodings: []string{"nop, gzip, deflate"}, expectedEncoding: "nop", }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings...) if respString != "textstring" { t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString) } if got := resp.Header.Get("Content-Encoding"); got != tc.expectedEncoding { t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got) } }) } }