package pg_test
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"math"
"net"
"reflect"
"testing"
"time"
"github.com/go-pg/pg"
"github.com/go-pg/pg/orm"
"github.com/go-pg/pg/types"
)
type JSONMap map[string]interface{}
func (m *JSONMap) Scan(b interface{}) error {
if b == nil {
*m = nil
return nil
}
return json.Unmarshal(b.([]byte), m)
}
func (m JSONMap) Value() (driver.Value, error) {
b, err := json.Marshal(m)
if err != nil {
return nil, err
}
return string(b), nil
}
type Struct struct {
Foo string
}
type conversionTest struct {
i int
src, dst, wanted interface{}
pgtype string
wanterr string
wantnil bool
wantzero bool
wantnothing bool
}
func unwrap(v interface{}) interface{} {
if arr, ok := v.(*types.Array); ok {
return arr.Value()
}
if hstore, ok := v.(*types.Hstore); ok {
return hstore.Value()
}
return v
}
func deref(vi interface{}) interface{} {
v := reflect.ValueOf(vi)
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.IsValid() {
return v.Interface()
}
return nil
}
func zero(v interface{}) interface{} {
return reflect.Zero(reflect.ValueOf(v).Elem().Type()).Interface()
}
func (test *conversionTest) String() string {
return fmt.Sprintf("#%d src=%#v dst=%#v", test.i, test.src, test.dst)
}
func (test *conversionTest) Assert(t *testing.T, err error) {
if test.wanterr != "" {
if err == nil || err.Error() != test.wanterr {
t.Fatalf("got error %v, wanted %q (%s)", err, test.wanterr, test)
}
return
}
if err != nil {
t.Fatalf("got error %q, wanted nil (%s)", err, test)
}
if test.wantnothing {
return
}
dst := reflect.Indirect(reflect.ValueOf(unwrap(test.dst))).Interface()
if test.wantnil {
dstValue := reflect.ValueOf(dst)
if !dstValue.IsValid() {
return
}
if dstValue.IsNil() {
return
}
t.Fatalf("got %#v, wanted nil (%s)", dst, test)
return
}
// Remove any intermediate pointers to compare values.
dst = deref(unwrap(dst))
src := deref(unwrap(test.src))
if test.wantzero {
dstValue := reflect.ValueOf(dst)
switch dstValue.Kind() {
case reflect.Slice, reflect.Map:
if dstValue.IsNil() {
t.Fatalf("got nil, wanted zero value")
}
if dstValue.Len() != 0 {
t.Fatalf("got %d items, wanted 0", dstValue.Len())
}
default:
zero := zero(test.dst)
if dst != zero {
t.Fatalf("%#v != %#v (%s)", dst, zero, test)
}
}
return
}
if dstTime, ok := dst.(time.Time); ok {
srcTime := src.(time.Time)
if dstTime.Unix() != srcTime.Unix() {
t.Fatalf("%#v != %#v", dstTime, srcTime)
}
return
}
if dstTimes, ok := dst.([]time.Time); ok {
srcTimes := src.([]time.Time)
for i, dstTime := range dstTimes {
srcTime := srcTimes[i]
if dstTime.Unix() != srcTime.Unix() {
t.Fatalf("%#v != %#v", dstTime, srcTime)
}
}
return
}
wanted := test.wanted
if wanted == nil {
wanted = src
}
if !reflect.DeepEqual(dst, wanted) {
t.Fatalf("%#v != %#v (%s)", dst, wanted, test)
}
}
func conversionTests() []conversionTest {
return []conversionTest{
{src: nil, dst: nil, wanterr: "pg: Scan(nil)"},
{src: nil, dst: new(uintptr), wanterr: "pg: Scan(unsupported uintptr)"},
{src: nil, dst: true, pgtype: "bool", wanterr: "pg: Scan(nonsettable bool)"},
{src: nil, dst: new(*bool), pgtype: "bool", wantnil: true},
{src: nil, dst: new(bool), pgtype: "bool", wantzero: true},
{src: true, dst: new(bool), pgtype: "bool"},
{src: true, dst: new(*bool), pgtype: "bool"},
{src: 1, dst: new(bool), wanted: true},
{src: nil, dst: "", pgtype: "text", wanterr: "pg: Scan(nonsettable string)"},
{src: nil, dst: new(string), pgtype: "text", wantzero: true},
{src: nil, dst: new(*string), pgtype: "text", wantnil: true},
{src: "hello world", dst: new(string), pgtype: "text"},
{src: "hello world", dst: new(*string), pgtype: "text"},
{src: "'\"\000", dst: new(string), wanted: `'"`, pgtype: "text"},
{src: nil, dst: []byte(nil), pgtype: "bytea", wanterr: "pg: Scan(nonsettable []uint8)"},
{src: nil, dst: new([]byte), pgtype: "bytea", wantnil: true},
{src: []byte("hello world\000"), dst: new([]byte), pgtype: "bytea"},
{src: []byte{}, dst: new([]byte), pgtype: "bytea", wantzero: true},
{src: nil, dst: int8(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable int8)"},
{src: nil, dst: new(int8), pgtype: "smallint", wantzero: true},
{src: int8(math.MaxInt8), dst: new(int8), pgtype: "smallint"},
{src: int8(math.MaxInt8), dst: new(*int8), pgtype: "smallint"},
{src: int8(math.MinInt8), dst: new(int8), pgtype: "smallint"},
{src: nil, dst: int16(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable int16)"},
{src: nil, dst: new(int16), pgtype: "smallint", wantzero: true},
{src: int16(math.MaxInt16), dst: new(int16), pgtype: "smallint"},
{src: int16(math.MaxInt16), dst: new(*int16), pgtype: "smallint"},
{src: int16(math.MinInt16), dst: new(int16), pgtype: "smallint"},
{src: nil, dst: int32(0), pgtype: "int", wanterr: "pg: Scan(nonsettable int32)"},
{src: nil, dst: new(int32), pgtype: "int", wantzero: true},
{src: int32(math.MaxInt32), dst: new(int32), pgtype: "int"},
{src: int32(math.MaxInt32), dst: new(*int32), pgtype: "int"},
{src: int32(math.MinInt32), dst: new(int32), pgtype: "int"},
{src: nil, dst: int64(0), pgtype: "bigint", wanterr: "pg: Scan(nonsettable int64)"},
{src: nil, dst: new(int64), pgtype: "bigint", wantzero: true},
{src: int64(math.MaxInt64), dst: new(int64), pgtype: "bigint"},
{src: int64(math.MaxInt64), dst: new(*int64), pgtype: "bigint"},
{src: int64(math.MinInt64), dst: new(int64), pgtype: "bigint"},
{src: nil, dst: int(0), pgtype: "bigint", wanterr: "pg: Scan(nonsettable int)"},
{src: nil, dst: new(int), pgtype: "bigint", wantzero: true},
{src: int64(math.MaxInt64), dst: new(int64), pgtype: "bigint"},
{src: int64(math.MaxInt64), dst: new(*int64), pgtype: "bigint"},
{src: int64(math.MinInt32), dst: new(int64), pgtype: "bigint"},
{src: nil, dst: uint8(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable uint8)"},
{src: nil, dst: new(uint8), pgtype: "smallint", wantzero: true},
{src: uint8(math.MaxUint8), dst: new(uint8), pgtype: "smallint"},
{src: uint8(math.MaxUint8), dst: new(*uint8), pgtype: "smallint"},
{src: nil, dst: uint16(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable uint16)"},
{src: nil, dst: new(uint16), pgtype: "smallint", wantzero: true},
{src: uint16(math.MaxUint16), dst: new(uint16), pgtype: "int"},
{src: uint16(math.MaxUint16), dst: new(*uint16), pgtype: "int"},
{src: nil, dst: uint32(0), pgtype: "bigint", wanterr: "pg: Scan(nonsettable uint32)"},
{src: nil, dst: new(uint32), pgtype: "bigint", wantzero: true},
{src: uint32(math.MaxUint32), dst: new(uint32), pgtype: "bigint"},
{src: uint32(math.MaxUint32), dst: new(*uint32), pgtype: "bigint"},
{src: nil, dst: uint64(0), pgtype: "bigint", wanterr: "pg: Scan(nonsettable uint64)"},
{src: nil, dst: new(uint64), pgtype: "bigint", wantzero: true},
{src: uint64(math.MaxUint64), dst: new(uint64)},
{src: uint64(math.MaxUint64), dst: new(*uint64)},
{src: uint64(math.MaxUint32), dst: new(uint64), pgtype: "bigint"},
{src: nil, dst: uint(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable uint)"},
{src: nil, dst: new(uint), pgtype: "bigint", wantzero: true},
{src: uint64(math.MaxUint64), dst: new(uint64)},
{src: uint64(math.MaxUint64), dst: new(*uint64)},
{src: uint64(math.MaxUint32), dst: new(uint64), pgtype: "bigint"},
{src: nil, dst: float32(0), pgtype: "decimal", wanterr: "pg: Scan(nonsettable float32)"},
{src: nil, dst: new(float32), pgtype: "decimal", wantzero: true},
{src: float32(math.MaxFloat32), dst: new(float32), pgtype: "decimal"},
{src: float32(math.MaxFloat32), dst: new(*float32), pgtype: "decimal"},
{src: float32(math.SmallestNonzeroFloat32), dst: new(float32), pgtype: "decimal"},
{src: nil, dst: float64(0), pgtype: "decimal", wanterr: "pg: Scan(nonsettable float64)"},
{src: nil, dst: new(float64), pgtype: "decimal", wantzero: true},
{src: float64(math.MaxFloat64), dst: new(float64), pgtype: "decimal"},
{src: float64(math.MaxFloat64), dst: new(*float64), pgtype: "decimal"},
{src: float64(math.SmallestNonzeroFloat64), dst: new(float64), pgtype: "decimal"},
{src: nil, dst: []int(nil), pgtype: "jsonb", wanterr: "pg: Scan(nonsettable []int)"},
{src: nil, dst: new([]int), pgtype: "jsonb", wantnil: true},
{src: []int(nil), dst: new([]int), pgtype: "jsonb", wantnil: true},
{src: []int{}, dst: new([]int), pgtype: "jsonb", wantzero: true},
{src: []int{1, 2, 3}, dst: new([]int), pgtype: "jsonb"},
{src: [3]int{1, 2, 3}, dst: new([3]int), pgtype: "jsonb"},
{src: nil, dst: pg.Array([]int(nil)), pgtype: "int[]", wanterr: "pg: Scan(nonsettable []int)"},
{src: pg.Array([]int(nil)), dst: pg.Array(new([]int)), pgtype: "int[]", wantnil: true},
{src: pg.Array([]int{}), dst: pg.Array(new([]int)), pgtype: "int[]"},
{src: pg.Array([]int{1, 2, 3}), dst: pg.Array(new([]int)), pgtype: "int[]"},
{src: pg.Array(&[3]int{1, 2, 3}), dst: pg.Array(new([3]int)), pgtype: "int[]"},
{src: nil, dst: pg.Array([]int64(nil)), pgtype: "bigint[]", wanterr: "pg: Scan(nonsettable []int64)"},
{src: nil, dst: pg.Array(new([]int64)), pgtype: "bigint[]", wantnil: true},
{src: pg.Array([]int64(nil)), dst: pg.Array(new([]int64)), pgtype: "bigint[]", wantnil: true},
{src: pg.Array([]int64{}), dst: pg.Array(new([]int64)), pgtype: "bigint[]"},
{src: pg.Array([]int64{1, 2, 3}), dst: pg.Array(new([]int64)), pgtype: "bigint[]"},
{src: nil, dst: pg.Array([]float64(nil)), pgtype: "decimal[]", wanterr: "pg: Scan(nonsettable []float64)"},
{src: nil, dst: pg.Array(new([]float64)), pgtype: "decimal[]", wantnil: true},
{src: pg.Array([]float64(nil)), dst: pg.Array(new([]float64)), pgtype: "decimal[]", wantnil: true},
{src: pg.Array([]float64{}), dst: pg.Array(new([]float64)), pgtype: "decimal[]"},
{src: pg.Array([]float64{1.1, 2.22, 3.333}), dst: pg.Array(new([]float64)), pgtype: "decimal[]"},
{src: pg.Array([]float64{math.NaN(), math.Inf(+1), math.Inf(-1)}), dst: pg.Array(new([]float64)), pgtype: "float[]", wantnothing: true},
{src: nil, dst: pg.Array([]string(nil)), pgtype: "text[]", wanterr: "pg: Scan(nonsettable []string)"},
{src: nil, dst: pg.Array(new([]string)), pgtype: "text[]", wantnil: true},
{src: pg.Array([]string(nil)), dst: pg.Array(new([]string)), pgtype: "text[]", wantnil: true},
{src: pg.Array([]string{}), dst: pg.Array(new([]string)), pgtype: "text[]"},
{src: pg.Array([]string{"one", "two", "three"}), dst: pg.Array(new([]string)), pgtype: "text[]"},
{src: pg.Array([]string{`'"{}`}), dst: pg.Array(new([]string)), pgtype: "text[]"},
{src: nil, dst: pg.Array([][]string(nil)), pgtype: "text[][]", wanterr: "pg: Scan(nonsettable [][]string)"},
{src: nil, dst: pg.Array(new([][]string)), pgtype: "text[][]", wantnil: true},
{src: pg.Array([][]string(nil)), dst: pg.Array(new([]string)), pgtype: "text[][]", wantnil: true},
{src: pg.Array([][]string{}), dst: pg.Array(new([][]string)), pgtype: "text[][]"},
{src: pg.Array([][]string{{"one", "two"}, {"three", "four"}}), dst: pg.Array(new([][]string)), pgtype: "text[][]"},
{src: pg.Array([][]string{{`'"\{}`}}), dst: pg.Array(new([][]string)), pgtype: "text[][]"},
{src: pg.Array([][]byte{[]byte(`'"\{}`)}), dst: pg.Array(new([][]byte)), pgtype: "bytea[]"},
{src: nil, dst: pg.Hstore(map[string]string(nil)), pgtype: "hstore", wanterr: "pg: Scan(nonsettable map[string]string)"},
{src: nil, dst: pg.Hstore(new(map[string]string)), pgtype: "hstore", wantnil: true},
{src: pg.Hstore(map[string]string(nil)), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore", wantnil: true},
{src: pg.Hstore(map[string]string{}), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore"},
{src: pg.Hstore(map[string]string{"foo": "bar"}), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore"},
{src: pg.Hstore(map[string]string{`'"\{}=>`: `'"\{}=>`}), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore"},
{src: nil, dst: sql.NullBool{}, pgtype: "bool", wanterr: "pg: Scan(nonsettable sql.NullBool)"},
{src: nil, dst: new(*sql.NullBool), pgtype: "bool", wantnil: true},
{src: nil, dst: new(sql.NullBool), pgtype: "bool", wanted: sql.NullBool{}},
{src: &sql.NullBool{}, dst: new(sql.NullBool), pgtype: "bool"},
{src: &sql.NullBool{Valid: true}, dst: new(sql.NullBool), pgtype: "bool"},
{src: &sql.NullBool{Valid: true, Bool: true}, dst: new(sql.NullBool), pgtype: "bool"},
{src: &sql.NullString{}, dst: new(sql.NullString), pgtype: "text"},
{src: &sql.NullString{Valid: true}, dst: new(sql.NullString), pgtype: "text"},
{src: &sql.NullString{Valid: true, String: "foo"}, dst: new(sql.NullString), pgtype: "text"},
{src: &sql.NullInt64{}, dst: new(sql.NullInt64), pgtype: "bigint"},
{src: &sql.NullInt64{Valid: true}, dst: new(sql.NullInt64), pgtype: "bigint"},
{src: &sql.NullInt64{Valid: true, Int64: math.MaxInt64}, dst: new(sql.NullInt64), pgtype: "bigint"},
{src: &sql.NullFloat64{}, dst: new(sql.NullFloat64), pgtype: "decimal"},
{src: &sql.NullFloat64{Valid: true}, dst: new(sql.NullFloat64), pgtype: "decimal"},
{src: &sql.NullFloat64{Valid: true, Float64: math.MaxFloat64}, dst: new(sql.NullFloat64), pgtype: "decimal"},
{src: nil, dst: customStrSlice{}, wanterr: "pg: Scan(nonsettable pg_test.customStrSlice)"},
{src: nil, dst: new(customStrSlice), wantnil: true},
{src: nil, dst: new(*customStrSlice), wantnil: true},
{src: customStrSlice{}, dst: new(customStrSlice), wantzero: true},
{src: customStrSlice{"one", "two"}, dst: new(customStrSlice)},
{src: nil, dst: time.Time{}, wanterr: "pg: Scan(nonsettable time.Time)"},
{src: nil, dst: new(time.Time), pgtype: "timestamptz", wantzero: true},
{src: nil, dst: new(*time.Time), pgtype: "timestamptz", wantnil: true},
{src: time.Now(), dst: new(time.Time), pgtype: "timestamptz"},
{src: time.Now(), dst: new(*time.Time), pgtype: "timestamptz"},
{src: time.Now().UTC(), dst: new(time.Time), pgtype: "timestamptz"},
{src: time.Time{}, dst: new(time.Time), pgtype: "timestamptz"},
{src: nil, dst: pg.Array([]time.Time(nil)), pgtype: "timestamptz[]", wanterr: "pg: Scan(nonsettable []time.Time)"},
{src: nil, dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]", wantnil: true},
{src: pg.Array([]time.Time(nil)), dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]", wantnil: true},
{src: pg.Array([]time.Time{}), dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]"},
{src: pg.Array([]time.Time{time.Now(), time.Now(), time.Now()}), dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]"},
{src: nil, dst: pg.Ints{}, wanterr: "pg: Scan(nonsettable pg.Ints)"},
{src: 1, dst: new(pg.Ints), wanted: pg.Ints{1}},
{src: nil, dst: pg.Strings{}, wanterr: "pg: Scan(nonsettable pg.Strings)"},
{src: "hello", dst: new(pg.Strings), wanted: pg.Strings{"hello"}},
{src: nil, dst: pg.IntSet{}, wanterr: "pg: Scan(nonsettable pg.IntSet)"},
{src: 1, dst: new(pg.IntSet), wanted: pg.IntSet{1: struct{}{}}},
{src: nil, dst: JSONMap{}, pgtype: "json", wanterr: "pg: Scan(nonsettable pg_test.JSONMap)"},
{src: nil, dst: new(JSONMap), pgtype: "json", wantnil: true},
{src: nil, dst: new(*JSONMap), pgtype: "json", wantnil: true},
{src: JSONMap{}, dst: new(JSONMap), pgtype: "json"},
{src: JSONMap{}, dst: new(*JSONMap), pgtype: "json"},
{src: JSONMap{"foo": "bar"}, dst: new(JSONMap), pgtype: "json"},
{src: `{"foo": "bar"}`, dst: new(JSONMap), pgtype: "json", wanted: JSONMap{"foo": "bar"}},
{src: nil, dst: Struct{}, pgtype: "json", wanterr: "pg: Scan(nonsettable pg_test.Struct)"},
{src: nil, dst: new(*Struct), pgtype: "json", wantnil: true},
{src: nil, dst: new(Struct), pgtype: "json", wantzero: true},
{src: Struct{}, dst: new(Struct), pgtype: "json"},
{src: Struct{Foo: "bar"}, dst: new(Struct), pgtype: "json"},
{src: `{"foo": "bar"}`, dst: new(Struct), wanted: Struct{Foo: "bar"}},
{src: nil, dst: new(net.IP), wanted: net.IP(nil), pgtype: "inet"},
{src: net.ParseIP("127.0.0.1"), dst: new(net.IP), pgtype: "inet"},
{src: net.ParseIP("::10.2.3.4"), dst: new(net.IP), pgtype: "inet"},
{src: net.ParseIP("::ffff:10.4.3.2"), dst: new(net.IP), pgtype: "inet"},
{src: nil, dst: (*net.IPNet)(nil), pgtype: "cidr", wanterr: "pg: Scan(nonsettable *net.IPNet)"},
{src: nil, dst: new(net.IPNet), wanted: net.IPNet{}, pgtype: "cidr"},
{src: nil, dst: mustParseCIDR("192.168.100.128/25"), wanted: net.IPNet{}, pgtype: "cidr"},
{src: mustParseCIDR("192.168.100.128/25"), dst: new(net.IPNet), pgtype: "cidr"},
{src: mustParseCIDR("2001:4f8:3:ba::/64"), dst: new(net.IPNet), pgtype: "cidr"},
{src: mustParseCIDR("2001:4f8:3:ba:2e0:81ff:fe22:d1f1/128"), dst: new(net.IPNet), pgtype: "cidr"},
}
}
func TestConversion(t *testing.T) {
db := pg.Connect(pgOptions())
defer db.Close()
for i, test := range conversionTests() {
test.i = i
var scanner orm.ColumnScanner
if v, ok := test.dst.(orm.ColumnScanner); ok {
scanner = v
} else {
scanner = pg.Scan(test.dst)
}
_, err := db.QueryOne(scanner, "SELECT (?) AS dst", test.src)
test.Assert(t, err)
}
for i, test := range conversionTests() {
test.i = i
var scanner orm.ColumnScanner
if v, ok := test.dst.(orm.ColumnScanner); ok {
scanner = v
} else {
scanner = pg.Scan(test.dst)
}
err := db.Model().ColumnExpr("(?) AS dst", test.src).Select(scanner)
test.Assert(t, err)
}
for i, test := range conversionTests() {
test.i = i
if test.pgtype == "" {
continue
}
stmt, err := db.Prepare(fmt.Sprintf("SELECT ($1::%s) AS dst", test.pgtype))
if err != nil {
t.Fatal(err)
}
var scanner orm.ColumnScanner
if v, ok := test.dst.(orm.ColumnScanner); ok {
scanner = v
} else {
scanner = pg.Scan(test.dst)
}
_, err = stmt.QueryOne(scanner, test.src)
test.Assert(t, err)
if err := stmt.Close(); err != nil {
t.Fatal(err)
}
}
}
func mustParseCIDR(s string) *net.IPNet {
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
return ipnet
}