Newer
Older
pokemon-go-trade / vendor / golang.org / x / crypto / openpgp / packet / packet_test.go
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package packet

import (
	"bytes"
	"encoding/hex"
	"fmt"
	"golang.org/x/crypto/openpgp/errors"
	"io"
	"io/ioutil"
	"testing"
)

func TestReadFull(t *testing.T) {
	var out [4]byte

	b := bytes.NewBufferString("foo")
	n, err := readFull(b, out[:3])
	if n != 3 || err != nil {
		t.Errorf("full read failed n:%d err:%s", n, err)
	}

	b = bytes.NewBufferString("foo")
	n, err = readFull(b, out[:4])
	if n != 3 || err != io.ErrUnexpectedEOF {
		t.Errorf("partial read failed n:%d err:%s", n, err)
	}

	b = bytes.NewBuffer(nil)
	n, err = readFull(b, out[:3])
	if n != 0 || err != io.ErrUnexpectedEOF {
		t.Errorf("empty read failed n:%d err:%s", n, err)
	}
}

func readerFromHex(s string) io.Reader {
	data, err := hex.DecodeString(s)
	if err != nil {
		panic("readerFromHex: bad input")
	}
	return bytes.NewBuffer(data)
}

var readLengthTests = []struct {
	hexInput  string
	length    int64
	isPartial bool
	err       error
}{
	{"", 0, false, io.ErrUnexpectedEOF},
	{"1f", 31, false, nil},
	{"c0", 0, false, io.ErrUnexpectedEOF},
	{"c101", 256 + 1 + 192, false, nil},
	{"e0", 1, true, nil},
	{"e1", 2, true, nil},
	{"e2", 4, true, nil},
	{"ff", 0, false, io.ErrUnexpectedEOF},
	{"ff00", 0, false, io.ErrUnexpectedEOF},
	{"ff0000", 0, false, io.ErrUnexpectedEOF},
	{"ff000000", 0, false, io.ErrUnexpectedEOF},
	{"ff00000000", 0, false, nil},
	{"ff01020304", 16909060, false, nil},
}

func TestReadLength(t *testing.T) {
	for i, test := range readLengthTests {
		length, isPartial, err := readLength(readerFromHex(test.hexInput))
		if test.err != nil {
			if err != test.err {
				t.Errorf("%d: expected different error got:%s want:%s", i, err, test.err)
			}
			continue
		}
		if err != nil {
			t.Errorf("%d: unexpected error: %s", i, err)
			continue
		}
		if length != test.length || isPartial != test.isPartial {
			t.Errorf("%d: bad result got:(%d,%t) want:(%d,%t)", i, length, isPartial, test.length, test.isPartial)
		}
	}
}

var partialLengthReaderTests = []struct {
	hexInput  string
	err       error
	hexOutput string
}{
	{"e0", io.ErrUnexpectedEOF, ""},
	{"e001", io.ErrUnexpectedEOF, ""},
	{"e0010102", nil, "0102"},
	{"ff00000000", nil, ""},
	{"e10102e1030400", nil, "01020304"},
	{"e101", io.ErrUnexpectedEOF, ""},
}

func TestPartialLengthReader(t *testing.T) {
	for i, test := range partialLengthReaderTests {
		r := &partialLengthReader{readerFromHex(test.hexInput), 0, true}
		out, err := ioutil.ReadAll(r)
		if test.err != nil {
			if err != test.err {
				t.Errorf("%d: expected different error got:%s want:%s", i, err, test.err)
			}
			continue
		}
		if err != nil {
			t.Errorf("%d: unexpected error: %s", i, err)
			continue
		}

		got := fmt.Sprintf("%x", out)
		if got != test.hexOutput {
			t.Errorf("%d: got:%s want:%s", i, test.hexOutput, got)
		}
	}
}

var readHeaderTests = []struct {
	hexInput        string
	structuralError bool
	unexpectedEOF   bool
	tag             int
	length          int64
	hexOutput       string
}{
	{"", false, false, 0, 0, ""},
	{"7f", true, false, 0, 0, ""},

	// Old format headers
	{"80", false, true, 0, 0, ""},
	{"8001", false, true, 0, 1, ""},
	{"800102", false, false, 0, 1, "02"},
	{"81000102", false, false, 0, 1, "02"},
	{"820000000102", false, false, 0, 1, "02"},
	{"860000000102", false, false, 1, 1, "02"},
	{"83010203", false, false, 0, -1, "010203"},

	// New format headers
	{"c0", false, true, 0, 0, ""},
	{"c000", false, false, 0, 0, ""},
	{"c00102", false, false, 0, 1, "02"},
	{"c0020203", false, false, 0, 2, "0203"},
	{"c00202", false, true, 0, 2, ""},
	{"c3020203", false, false, 3, 2, "0203"},
}

func TestReadHeader(t *testing.T) {
	for i, test := range readHeaderTests {
		tag, length, contents, err := readHeader(readerFromHex(test.hexInput))
		if test.structuralError {
			if _, ok := err.(errors.StructuralError); ok {
				continue
			}
			t.Errorf("%d: expected StructuralError, got:%s", i, err)
			continue
		}
		if err != nil {
			if len(test.hexInput) == 0 && err == io.EOF {
				continue
			}
			if !test.unexpectedEOF || err != io.ErrUnexpectedEOF {
				t.Errorf("%d: unexpected error from readHeader: %s", i, err)
			}
			continue
		}
		if int(tag) != test.tag || length != test.length {
			t.Errorf("%d: got:(%d,%d) want:(%d,%d)", i, int(tag), length, test.tag, test.length)
			continue
		}

		body, err := ioutil.ReadAll(contents)
		if err != nil {
			if !test.unexpectedEOF || err != io.ErrUnexpectedEOF {
				t.Errorf("%d: unexpected error from contents: %s", i, err)
			}
			continue
		}
		if test.unexpectedEOF {
			t.Errorf("%d: expected ErrUnexpectedEOF from contents but got no error", i)
			continue
		}
		got := fmt.Sprintf("%x", body)
		if got != test.hexOutput {
			t.Errorf("%d: got:%s want:%s", i, got, test.hexOutput)
		}
	}
}

func TestSerializeHeader(t *testing.T) {
	tag := packetTypePublicKey
	lengths := []int{0, 1, 2, 64, 192, 193, 8000, 8384, 8385, 10000}

	for _, length := range lengths {
		buf := bytes.NewBuffer(nil)
		serializeHeader(buf, tag, length)
		tag2, length2, _, err := readHeader(buf)
		if err != nil {
			t.Errorf("length %d, err: %s", length, err)
		}
		if tag2 != tag {
			t.Errorf("length %d, tag incorrect (got %d, want %d)", length, tag2, tag)
		}
		if int(length2) != length {
			t.Errorf("length %d, length incorrect (got %d)", length, length2)
		}
	}
}

func TestPartialLengths(t *testing.T) {
	buf := bytes.NewBuffer(nil)
	w := new(partialLengthWriter)
	w.w = noOpCloser{buf}

	const maxChunkSize = 64

	var b [maxChunkSize]byte
	var n uint8
	for l := 1; l <= maxChunkSize; l++ {
		for i := 0; i < l; i++ {
			b[i] = n
			n++
		}
		m, err := w.Write(b[:l])
		if m != l {
			t.Errorf("short write got: %d want: %d", m, l)
		}
		if err != nil {
			t.Errorf("error from write: %s", err)
		}
	}
	if err := w.Close(); err != nil {
		t.Fatal(err)
	}

	// The first packet should be at least 512 bytes.
	first, err := buf.ReadByte()
	if err != nil {
		t.Fatal(err)
	}
	if plen := 1 << (first & 0x1f); plen < 512 {
		t.Errorf("first packet too short: got %d want at least %d", plen, 512)
	}
	if err := buf.UnreadByte(); err != nil {
		t.Fatal(err)
	}

	want := (maxChunkSize * (maxChunkSize + 1)) / 2
	copyBuf := bytes.NewBuffer(nil)
	r := &partialLengthReader{buf, 0, true}
	m, err := io.Copy(copyBuf, r)
	if m != int64(want) {
		t.Errorf("short copy got: %d want: %d", m, want)
	}
	if err != nil {
		t.Errorf("error from copy: %s", err)
	}

	copyBytes := copyBuf.Bytes()
	for i := 0; i < want; i++ {
		if copyBytes[i] != uint8(i) {
			t.Errorf("bad pattern in copy at %d", i)
			break
		}
	}
}

func TestPartialLengthsShortWrite(t *testing.T) {
	buf := bytes.NewBuffer(nil)
	w := &partialLengthWriter{
		w: noOpCloser{buf},
	}
	data := bytes.Repeat([]byte("a"), 510)
	if _, err := w.Write(data); err != nil {
		t.Fatal(err)
	}
	if err := w.Close(); err != nil {
		t.Fatal(err)
	}
	copyBuf := bytes.NewBuffer(nil)
	r := &partialLengthReader{buf, 0, true}
	if _, err := io.Copy(copyBuf, r); err != nil {
		t.Fatal(err)
	}
	if !bytes.Equal(copyBuf.Bytes(), data) {
		t.Errorf("got %q want %q", buf.Bytes(), data)
	}
}