Newer
Older
pokemon-go-trade / vendor / golang.org / x / net / websocket / websocket_test.go
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package websocket

import (
	"bytes"
	"crypto/rand"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"net/http/httptest"
	"net/url"
	"reflect"
	"runtime"
	"strings"
	"sync"
	"testing"
	"time"
)

var serverAddr string
var once sync.Once

func echoServer(ws *Conn) {
	defer ws.Close()
	io.Copy(ws, ws)
}

type Count struct {
	S string
	N int
}

func countServer(ws *Conn) {
	defer ws.Close()
	for {
		var count Count
		err := JSON.Receive(ws, &count)
		if err != nil {
			return
		}
		count.N++
		count.S = strings.Repeat(count.S, count.N)
		err = JSON.Send(ws, count)
		if err != nil {
			return
		}
	}
}

type testCtrlAndDataHandler struct {
	hybiFrameHandler
}

func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
	h.hybiFrameHandler.conn.wio.Lock()
	defer h.hybiFrameHandler.conn.wio.Unlock()
	w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
	if err != nil {
		return 0, err
	}
	n, err := w.Write(b)
	w.Close()
	return n, err
}

func ctrlAndDataServer(ws *Conn) {
	defer ws.Close()
	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
	ws.frameHandler = h

	go func() {
		for i := 0; ; i++ {
			var b []byte
			if i%2 != 0 { // with or without payload
				b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
			}
			if _, err := h.WritePing(b); err != nil {
				break
			}
			if _, err := h.WritePong(b); err != nil { // unsolicited pong
				break
			}
			time.Sleep(10 * time.Millisecond)
		}
	}()

	b := make([]byte, 128)
	for {
		n, err := ws.Read(b)
		if err != nil {
			break
		}
		if _, err := ws.Write(b[:n]); err != nil {
			break
		}
	}
}

func subProtocolHandshake(config *Config, req *http.Request) error {
	for _, proto := range config.Protocol {
		if proto == "chat" {
			config.Protocol = []string{proto}
			return nil
		}
	}
	return ErrBadWebSocketProtocol
}

func subProtoServer(ws *Conn) {
	for _, proto := range ws.Config().Protocol {
		io.WriteString(ws, proto)
	}
}

func startServer() {
	http.Handle("/echo", Handler(echoServer))
	http.Handle("/count", Handler(countServer))
	http.Handle("/ctrldata", Handler(ctrlAndDataServer))
	subproto := Server{
		Handshake: subProtocolHandshake,
		Handler:   Handler(subProtoServer),
	}
	http.Handle("/subproto", subproto)
	server := httptest.NewServer(nil)
	serverAddr = server.Listener.Addr().String()
	log.Print("Test WebSocket server listening on ", serverAddr)
}

func newConfig(t *testing.T, path string) *Config {
	config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
	return config
}

func TestEcho(t *testing.T) {
	once.Do(startServer)

	// websocket.Dial()
	client, err := net.Dial("tcp", serverAddr)
	if err != nil {
		t.Fatal("dialing", err)
	}
	conn, err := NewClient(newConfig(t, "/echo"), client)
	if err != nil {
		t.Errorf("WebSocket handshake error: %v", err)
		return
	}

	msg := []byte("hello, world\n")
	if _, err := conn.Write(msg); err != nil {
		t.Errorf("Write: %v", err)
	}
	var actual_msg = make([]byte, 512)
	n, err := conn.Read(actual_msg)
	if err != nil {
		t.Errorf("Read: %v", err)
	}
	actual_msg = actual_msg[0:n]
	if !bytes.Equal(msg, actual_msg) {
		t.Errorf("Echo: expected %q got %q", msg, actual_msg)
	}
	conn.Close()
}

func TestAddr(t *testing.T) {
	once.Do(startServer)

	// websocket.Dial()
	client, err := net.Dial("tcp", serverAddr)
	if err != nil {
		t.Fatal("dialing", err)
	}
	conn, err := NewClient(newConfig(t, "/echo"), client)
	if err != nil {
		t.Errorf("WebSocket handshake error: %v", err)
		return
	}

	ra := conn.RemoteAddr().String()
	if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
		t.Errorf("Bad remote addr: %v", ra)
	}
	la := conn.LocalAddr().String()
	if !strings.HasPrefix(la, "http://") {
		t.Errorf("Bad local addr: %v", la)
	}
	conn.Close()
}

func TestCount(t *testing.T) {
	once.Do(startServer)

	// websocket.Dial()
	client, err := net.Dial("tcp", serverAddr)
	if err != nil {
		t.Fatal("dialing", err)
	}
	conn, err := NewClient(newConfig(t, "/count"), client)
	if err != nil {
		t.Errorf("WebSocket handshake error: %v", err)
		return
	}

	var count Count
	count.S = "hello"
	if err := JSON.Send(conn, count); err != nil {
		t.Errorf("Write: %v", err)
	}
	if err := JSON.Receive(conn, &count); err != nil {
		t.Errorf("Read: %v", err)
	}
	if count.N != 1 {
		t.Errorf("count: expected %d got %d", 1, count.N)
	}
	if count.S != "hello" {
		t.Errorf("count: expected %q got %q", "hello", count.S)
	}
	if err := JSON.Send(conn, count); err != nil {
		t.Errorf("Write: %v", err)
	}
	if err := JSON.Receive(conn, &count); err != nil {
		t.Errorf("Read: %v", err)
	}
	if count.N != 2 {
		t.Errorf("count: expected %d got %d", 2, count.N)
	}
	if count.S != "hellohello" {
		t.Errorf("count: expected %q got %q", "hellohello", count.S)
	}
	conn.Close()
}

func TestWithQuery(t *testing.T) {
	once.Do(startServer)

	client, err := net.Dial("tcp", serverAddr)
	if err != nil {
		t.Fatal("dialing", err)
	}

	config := newConfig(t, "/echo")
	config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
	if err != nil {
		t.Fatal("location url", err)
	}

	ws, err := NewClient(config, client)
	if err != nil {
		t.Errorf("WebSocket handshake: %v", err)
		return
	}
	ws.Close()
}

func testWithProtocol(t *testing.T, subproto []string) (string, error) {
	once.Do(startServer)

	client, err := net.Dial("tcp", serverAddr)
	if err != nil {
		t.Fatal("dialing", err)
	}

	config := newConfig(t, "/subproto")
	config.Protocol = subproto

	ws, err := NewClient(config, client)
	if err != nil {
		return "", err
	}
	msg := make([]byte, 16)
	n, err := ws.Read(msg)
	if err != nil {
		return "", err
	}
	ws.Close()
	return string(msg[:n]), nil
}

func TestWithProtocol(t *testing.T) {
	proto, err := testWithProtocol(t, []string{"chat"})
	if err != nil {
		t.Errorf("SubProto: unexpected error: %v", err)
	}
	if proto != "chat" {
		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
	}
}

func TestWithTwoProtocol(t *testing.T) {
	proto, err := testWithProtocol(t, []string{"test", "chat"})
	if err != nil {
		t.Errorf("SubProto: unexpected error: %v", err)
	}
	if proto != "chat" {
		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
	}
}

func TestWithBadProtocol(t *testing.T) {
	_, err := testWithProtocol(t, []string{"test"})
	if err != ErrBadStatus {
		t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
	}
}

func TestHTTP(t *testing.T) {
	once.Do(startServer)

	// If the client did not send a handshake that matches the protocol
	// specification, the server MUST return an HTTP response with an
	// appropriate error code (such as 400 Bad Request)
	resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
	if err != nil {
		t.Errorf("Get: error %#v", err)
		return
	}
	if resp == nil {
		t.Error("Get: resp is null")
		return
	}
	if resp.StatusCode != http.StatusBadRequest {
		t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
	}
}

func TestTrailingSpaces(t *testing.T) {
	// http://code.google.com/p/go/issues/detail?id=955
	// The last runs of this create keys with trailing spaces that should not be
	// generated by the client.
	once.Do(startServer)
	config := newConfig(t, "/echo")
	for i := 0; i < 30; i++ {
		// body
		ws, err := DialConfig(config)
		if err != nil {
			t.Errorf("Dial #%d failed: %v", i, err)
			break
		}
		ws.Close()
	}
}

func TestDialConfigBadVersion(t *testing.T) {
	once.Do(startServer)
	config := newConfig(t, "/echo")
	config.Version = 1234

	_, err := DialConfig(config)

	if dialerr, ok := err.(*DialError); ok {
		if dialerr.Err != ErrBadProtocolVersion {
			t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
		}
	}
}

func TestDialConfigWithDialer(t *testing.T) {
	once.Do(startServer)
	config := newConfig(t, "/echo")
	config.Dialer = &net.Dialer{
		Deadline: time.Now().Add(-time.Minute),
	}
	_, err := DialConfig(config)
	dialerr, ok := err.(*DialError)
	if !ok {
		t.Fatalf("DialError expected, got %#v", err)
	}
	neterr, ok := dialerr.Err.(*net.OpError)
	if !ok {
		t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
	}
	if !neterr.Timeout() {
		t.Fatalf("expected timeout error, got %#v", neterr)
	}
}

func TestSmallBuffer(t *testing.T) {
	// http://code.google.com/p/go/issues/detail?id=1145
	// Read should be able to handle reading a fragment of a frame.
	once.Do(startServer)

	// websocket.Dial()
	client, err := net.Dial("tcp", serverAddr)
	if err != nil {
		t.Fatal("dialing", err)
	}
	conn, err := NewClient(newConfig(t, "/echo"), client)
	if err != nil {
		t.Errorf("WebSocket handshake error: %v", err)
		return
	}

	msg := []byte("hello, world\n")
	if _, err := conn.Write(msg); err != nil {
		t.Errorf("Write: %v", err)
	}
	var small_msg = make([]byte, 8)
	n, err := conn.Read(small_msg)
	if err != nil {
		t.Errorf("Read: %v", err)
	}
	if !bytes.Equal(msg[:len(small_msg)], small_msg) {
		t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
	}
	var second_msg = make([]byte, len(msg))
	n, err = conn.Read(second_msg)
	if err != nil {
		t.Errorf("Read: %v", err)
	}
	second_msg = second_msg[0:n]
	if !bytes.Equal(msg[len(small_msg):], second_msg) {
		t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
	}
	conn.Close()
}

var parseAuthorityTests = []struct {
	in  *url.URL
	out string
}{
	{
		&url.URL{
			Scheme: "ws",
			Host:   "www.google.com",
		},
		"www.google.com:80",
	},
	{
		&url.URL{
			Scheme: "wss",
			Host:   "www.google.com",
		},
		"www.google.com:443",
	},
	{
		&url.URL{
			Scheme: "ws",
			Host:   "www.google.com:80",
		},
		"www.google.com:80",
	},
	{
		&url.URL{
			Scheme: "wss",
			Host:   "www.google.com:443",
		},
		"www.google.com:443",
	},
	// some invalid ones for parseAuthority. parseAuthority doesn't
	// concern itself with the scheme unless it actually knows about it
	{
		&url.URL{
			Scheme: "http",
			Host:   "www.google.com",
		},
		"www.google.com",
	},
	{
		&url.URL{
			Scheme: "http",
			Host:   "www.google.com:80",
		},
		"www.google.com:80",
	},
	{
		&url.URL{
			Scheme: "asdf",
			Host:   "127.0.0.1",
		},
		"127.0.0.1",
	},
	{
		&url.URL{
			Scheme: "asdf",
			Host:   "www.google.com",
		},
		"www.google.com",
	},
}

func TestParseAuthority(t *testing.T) {
	for _, tt := range parseAuthorityTests {
		out := parseAuthority(tt.in)
		if out != tt.out {
			t.Errorf("got %v; want %v", out, tt.out)
		}
	}
}

type closerConn struct {
	net.Conn
	closed int // count of the number of times Close was called
}

func (c *closerConn) Close() error {
	c.closed++
	return c.Conn.Close()
}

func TestClose(t *testing.T) {
	if runtime.GOOS == "plan9" {
		t.Skip("see golang.org/issue/11454")
	}

	once.Do(startServer)

	conn, err := net.Dial("tcp", serverAddr)
	if err != nil {
		t.Fatal("dialing", err)
	}

	cc := closerConn{Conn: conn}

	client, err := NewClient(newConfig(t, "/echo"), &cc)
	if err != nil {
		t.Fatalf("WebSocket handshake: %v", err)
	}

	// set the deadline to ten minutes ago, which will have expired by the time
	// client.Close sends the close status frame.
	conn.SetDeadline(time.Now().Add(-10 * time.Minute))

	if err := client.Close(); err == nil {
		t.Errorf("ws.Close(): expected error, got %v", err)
	}
	if cc.closed < 1 {
		t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
	}
}

var originTests = []struct {
	req    *http.Request
	origin *url.URL
}{
	{
		req: &http.Request{
			Header: http.Header{
				"Origin": []string{"http://www.example.com"},
			},
		},
		origin: &url.URL{
			Scheme: "http",
			Host:   "www.example.com",
		},
	},
	{
		req: &http.Request{},
	},
}

func TestOrigin(t *testing.T) {
	conf := newConfig(t, "/echo")
	conf.Version = ProtocolVersionHybi13
	for i, tt := range originTests {
		origin, err := Origin(conf, tt.req)
		if err != nil {
			t.Error(err)
			continue
		}
		if !reflect.DeepEqual(origin, tt.origin) {
			t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
			continue
		}
	}
}

func TestCtrlAndData(t *testing.T) {
	once.Do(startServer)

	c, err := net.Dial("tcp", serverAddr)
	if err != nil {
		t.Fatal(err)
	}
	ws, err := NewClient(newConfig(t, "/ctrldata"), c)
	if err != nil {
		t.Fatal(err)
	}
	defer ws.Close()

	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
	ws.frameHandler = h

	b := make([]byte, 128)
	for i := 0; i < 2; i++ {
		data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
		if _, err := ws.Write(data); err != nil {
			t.Fatalf("#%d: %v", i, err)
		}
		var ctrl []byte
		if i%2 != 0 { // with or without payload
			ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
		}
		if _, err := h.WritePing(ctrl); err != nil {
			t.Fatalf("#%d: %v", i, err)
		}
		n, err := ws.Read(b)
		if err != nil {
			t.Fatalf("#%d: %v", i, err)
		}
		if !bytes.Equal(b[:n], data) {
			t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
		}
	}
}

func TestCodec_ReceiveLimited(t *testing.T) {
	const limit = 2048
	var payloads [][]byte
	for _, size := range []int{
		1024,
		2048,
		4096, // receive of this message would be interrupted due to limit
		2048, // this one is to make sure next receive recovers discarding leftovers
	} {
		b := make([]byte, size)
		rand.Read(b)
		payloads = append(payloads, b)
	}
	handlerDone := make(chan struct{})
	limitedHandler := func(ws *Conn) {
		defer close(handlerDone)
		ws.MaxPayloadBytes = limit
		defer ws.Close()
		for i, p := range payloads {
			t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
			var recv []byte
			err := Message.Receive(ws, &recv)
			switch err {
			case nil:
			case ErrFrameTooLarge:
				if len(p) <= limit {
					t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
				}
				continue
			default:
				t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
			}
			if len(recv) > limit {
				t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
			}
			if !bytes.Equal(p, recv) {
				t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
			}
		}
	}
	server := httptest.NewServer(Handler(limitedHandler))
	defer server.CloseClientConnections()
	defer server.Close()
	addr := server.Listener.Addr().String()
	ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
	if err != nil {
		t.Fatal(err)
	}
	defer ws.Close()
	for i, p := range payloads {
		if err := Message.Send(ws, p); err != nil {
			t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
		}
	}
	<-handlerDone
}