// Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package http2 import ( "errors" "fmt" "io" "io/ioutil" "net/http" "reflect" "strconv" "sync" "testing" "time" ) func TestServer_Push_Success(t *testing.T) { const ( mainBody = "<html>index page</html>" pushedBody = "<html>pushed page</html>" userAgent = "testagent" cookie = "testcookie" ) var stURL string checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error { if got, want := r.Method, wantMethod; got != want { return fmt.Errorf("promised Req.Method=%q, want %q", got, want) } if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) { return fmt.Errorf("promised Req.Header=%q, want %q", got, want) } if got, want := "https://"+r.Host, stURL; got != want { return fmt.Errorf("promised Req.Host=%q, want %q", got, want) } if r.Body == nil { return fmt.Errorf("nil Body") } if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 { return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err) } return nil } errc := make(chan error, 3) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { switch r.URL.RequestURI() { case "/": // Push "/pushed?get" as a GET request, using an absolute URL. opt := &http.PushOptions{ Header: http.Header{ "User-Agent": {userAgent}, }, } if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil { errc <- fmt.Errorf("error pushing /pushed?get: %v", err) return } // Push "/pushed?head" as a HEAD request, using a path. opt = &http.PushOptions{ Method: "HEAD", Header: http.Header{ "User-Agent": {userAgent}, "Cookie": {cookie}, }, } if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil { errc <- fmt.Errorf("error pushing /pushed?head: %v", err) return } w.Header().Set("Content-Type", "text/html") w.Header().Set("Content-Length", strconv.Itoa(len(mainBody))) w.WriteHeader(200) io.WriteString(w, mainBody) errc <- nil case "/pushed?get": wantH := http.Header{} wantH.Set("User-Agent", userAgent) if err := checkPromisedReq(r, "GET", wantH); err != nil { errc <- fmt.Errorf("/pushed?get: %v", err) return } w.Header().Set("Content-Type", "text/html") w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody))) w.WriteHeader(200) io.WriteString(w, pushedBody) errc <- nil case "/pushed?head": wantH := http.Header{} wantH.Set("User-Agent", userAgent) wantH.Set("Cookie", cookie) if err := checkPromisedReq(r, "HEAD", wantH); err != nil { errc <- fmt.Errorf("/pushed?head: %v", err) return } w.WriteHeader(204) errc <- nil default: errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI()) } }) stURL = st.ts.URL // Send one request, which should push two responses. st.greet() getSlash(st) for k := 0; k < 3; k++ { select { case <-time.After(2 * time.Second): t.Errorf("timeout waiting for handler %d to finish", k) case err := <-errc: if err != nil { t.Fatal(err) } } } checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error { pp, ok := f.(*PushPromiseFrame) if !ok { return fmt.Errorf("got a %T; want *PushPromiseFrame", f) } if !pp.HeadersEnded() { return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame") } if got, want := pp.PromiseID, promiseID; got != want { return fmt.Errorf("got PromiseID %v; want %v", got, want) } gotH := st.decodeHeader(pp.HeaderBlockFragment()) if !reflect.DeepEqual(gotH, wantH) { return fmt.Errorf("got promised headers %v; want %v", gotH, wantH) } return nil } checkHeaders := func(f Frame, wantH [][2]string) error { hf, ok := f.(*HeadersFrame) if !ok { return fmt.Errorf("got a %T; want *HeadersFrame", f) } gotH := st.decodeHeader(hf.HeaderBlockFragment()) if !reflect.DeepEqual(gotH, wantH) { return fmt.Errorf("got response headers %v; want %v", gotH, wantH) } return nil } checkData := func(f Frame, wantData string) error { df, ok := f.(*DataFrame) if !ok { return fmt.Errorf("got a %T; want *DataFrame", f) } if gotData := string(df.Data()); gotData != wantData { return fmt.Errorf("got response data %q; want %q", gotData, wantData) } return nil } // Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA // Stream 2 has HEADERS + DATA // Stream 4 has HEADERS expected := map[uint32][]func(Frame) error{ 1: { func(f Frame) error { return checkPushPromise(f, 2, [][2]string{ {":method", "GET"}, {":scheme", "https"}, {":authority", st.ts.Listener.Addr().String()}, {":path", "/pushed?get"}, {"user-agent", userAgent}, }) }, func(f Frame) error { return checkPushPromise(f, 4, [][2]string{ {":method", "HEAD"}, {":scheme", "https"}, {":authority", st.ts.Listener.Addr().String()}, {":path", "/pushed?head"}, {"cookie", cookie}, {"user-agent", userAgent}, }) }, func(f Frame) error { return checkHeaders(f, [][2]string{ {":status", "200"}, {"content-type", "text/html"}, {"content-length", strconv.Itoa(len(mainBody))}, }) }, func(f Frame) error { return checkData(f, mainBody) }, }, 2: { func(f Frame) error { return checkHeaders(f, [][2]string{ {":status", "200"}, {"content-type", "text/html"}, {"content-length", strconv.Itoa(len(pushedBody))}, }) }, func(f Frame) error { return checkData(f, pushedBody) }, }, 4: { func(f Frame) error { return checkHeaders(f, [][2]string{ {":status", "204"}, }) }, }, } consumed := map[uint32]int{} for k := 0; len(expected) > 0; k++ { f, err := st.readFrame() if err != nil { for id, left := range expected { t.Errorf("stream %d: missing %d frames", id, len(left)) } t.Fatalf("readFrame %d: %v", k, err) } id := f.Header().StreamID label := fmt.Sprintf("stream %d, frame %d", id, consumed[id]) if len(expected[id]) == 0 { t.Fatalf("%s: unexpected frame %#+v", label, f) } check := expected[id][0] expected[id] = expected[id][1:] if len(expected[id]) == 0 { delete(expected, id) } if err := check(f); err != nil { t.Fatalf("%s: %v", label, err) } consumed[id]++ } } func TestServer_Push_SuccessNoRace(t *testing.T) { // Regression test for issue #18326. Ensure the request handler can mutate // pushed request headers without racing with the PUSH_PROMISE write. errc := make(chan error, 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { switch r.URL.RequestURI() { case "/": opt := &http.PushOptions{ Header: http.Header{"User-Agent": {"testagent"}}, } if err := w.(http.Pusher).Push("/pushed", opt); err != nil { errc <- fmt.Errorf("error pushing: %v", err) return } w.WriteHeader(200) errc <- nil case "/pushed": // Update request header, ensure there is no race. r.Header.Set("User-Agent", "newagent") r.Header.Set("Cookie", "cookie") w.WriteHeader(200) errc <- nil default: errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI()) } }) // Send one request, which should push one response. st.greet() getSlash(st) for k := 0; k < 2; k++ { select { case <-time.After(2 * time.Second): t.Errorf("timeout waiting for handler %d to finish", k) case err := <-errc: if err != nil { t.Fatal(err) } } } } func TestServer_Push_RejectRecursivePush(t *testing.T) { // Expect two requests, but might get three if there's a bug and the second push succeeds. errc := make(chan error, 3) handler := func(w http.ResponseWriter, r *http.Request) error { baseURL := "https://" + r.Host switch r.URL.Path { case "/": if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil { return fmt.Errorf("first Push()=%v, want nil", err) } return nil case "/push1": if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want { return fmt.Errorf("Push()=%v, want %v", got, want) } return nil default: return fmt.Errorf("unexpected path: %q", r.URL.Path) } } st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { errc <- handler(w, r) }) defer st.Close() st.greet() getSlash(st) if err := <-errc; err != nil { t.Errorf("First request failed: %v", err) } if err := <-errc; err != nil { t.Errorf("Second request failed: %v", err) } } func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) { // Expect one request, but might get two if there's a bug and the push succeeds. errc := make(chan error, 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { errc <- doPush(w.(http.Pusher), r) }) defer st.Close() st.greet() if err := st.fr.WriteSettings(settings...); err != nil { st.t.Fatalf("WriteSettings: %v", err) } st.wantSettingsAck() getSlash(st) if err := <-errc; err != nil { t.Error(err) } // Should not get a PUSH_PROMISE frame. hf := st.wantHeaders() if !hf.StreamEnded() { t.Error("stream should end after headers") } } func TestServer_Push_RejectIfDisabled(t *testing.T) { testServer_Push_RejectSingleRequest(t, func(p http.Pusher, r *http.Request) error { if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { return fmt.Errorf("Push()=%v, want %v", got, want) } return nil }, Setting{SettingEnablePush, 0}) } func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) { testServer_Push_RejectSingleRequest(t, func(p http.Pusher, r *http.Request) error { if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want { return fmt.Errorf("Push()=%v, want %v", got, want) } return nil }, Setting{SettingMaxConcurrentStreams, 0}) } func TestServer_Push_RejectWrongScheme(t *testing.T) { testServer_Push_RejectSingleRequest(t, func(p http.Pusher, r *http.Request) error { if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil { return errors.New("Push() should have failed (push target URL is http)") } return nil }) } func TestServer_Push_RejectMissingHost(t *testing.T) { testServer_Push_RejectSingleRequest(t, func(p http.Pusher, r *http.Request) error { if err := p.Push("https:pushed", nil); err == nil { return errors.New("Push() should have failed (push target URL missing host)") } return nil }) } func TestServer_Push_RejectRelativePath(t *testing.T) { testServer_Push_RejectSingleRequest(t, func(p http.Pusher, r *http.Request) error { if err := p.Push("../test", nil); err == nil { return errors.New("Push() should have failed (push target is a relative path)") } return nil }) } func TestServer_Push_RejectForbiddenMethod(t *testing.T) { testServer_Push_RejectSingleRequest(t, func(p http.Pusher, r *http.Request) error { if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil { return errors.New("Push() should have failed (cannot promise a POST)") } return nil }) } func TestServer_Push_RejectForbiddenHeader(t *testing.T) { testServer_Push_RejectSingleRequest(t, func(p http.Pusher, r *http.Request) error { header := http.Header{ "Content-Length": {"10"}, "Content-Encoding": {"gzip"}, "Trailer": {"Foo"}, "Te": {"trailers"}, "Host": {"test.com"}, ":authority": {"test.com"}, } if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil { return errors.New("Push() should have failed (forbidden headers)") } return nil }) } func TestServer_Push_StateTransitions(t *testing.T) { const body = "foo" gotPromise := make(chan bool) finishedPush := make(chan bool) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { switch r.URL.RequestURI() { case "/": if err := w.(http.Pusher).Push("/pushed", nil); err != nil { t.Errorf("Push error: %v", err) } // Don't finish this request until the push finishes so we don't // nondeterministically interleave output frames with the push. <-finishedPush case "/pushed": <-gotPromise } w.Header().Set("Content-Type", "text/html") w.Header().Set("Content-Length", strconv.Itoa(len(body))) w.WriteHeader(200) io.WriteString(w, body) }) defer st.Close() st.greet() if st.stream(2) != nil { t.Fatal("stream 2 should be empty") } if got, want := st.streamState(2), stateIdle; got != want { t.Fatalf("streamState(2)=%v, want %v", got, want) } getSlash(st) // After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote. st.wantPushPromise() if got, want := st.streamState(2), stateHalfClosedRemote; got != want { t.Fatalf("streamState(2)=%v, want %v", got, want) } // We stall the HTTP handler for "/pushed" until the above check. If we don't // stall the handler, then the handler might write HEADERS and DATA and finish // the stream before we check st.streamState(2) -- should that happen, we'll // see stateClosed and fail the above check. close(gotPromise) st.wantHeaders() if df := st.wantData(); !df.StreamEnded() { t.Fatal("expected END_STREAM flag on DATA") } if got, want := st.streamState(2), stateClosed; got != want { t.Fatalf("streamState(2)=%v, want %v", got, want) } close(finishedPush) } func TestServer_Push_RejectAfterGoAway(t *testing.T) { var readyOnce sync.Once ready := make(chan struct{}) errc := make(chan error, 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { select { case <-ready: case <-time.After(5 * time.Second): errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed") } if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { errc <- fmt.Errorf("Push()=%v, want %v", got, want) } errc <- nil }) defer st.Close() st.greet() getSlash(st) // Send GOAWAY and wait for it to be processed. st.fr.WriteGoAway(1, ErrCodeNo, nil) go func() { for { select { case <-ready: return default: } st.sc.serveMsgCh <- func(loopNum int) { if !st.sc.pushEnabled { readyOnce.Do(func() { close(ready) }) } } } }() if err := <-errc; err != nil { t.Error(err) } }