package pg_test import ( "database/sql" "database/sql/driver" "encoding/json" "fmt" "math" "net" "reflect" "testing" "time" "github.com/go-pg/pg" "github.com/go-pg/pg/orm" "github.com/go-pg/pg/types" ) type JSONMap map[string]interface{} func (m *JSONMap) Scan(b interface{}) error { if b == nil { *m = nil return nil } return json.Unmarshal(b.([]byte), m) } func (m JSONMap) Value() (driver.Value, error) { b, err := json.Marshal(m) if err != nil { return nil, err } return string(b), nil } type Struct struct { Foo string } type conversionTest struct { i int src, dst, wanted interface{} pgtype string wanterr string wantnil bool wantzero bool wantnothing bool } func unwrap(v interface{}) interface{} { if arr, ok := v.(*types.Array); ok { return arr.Value() } if hstore, ok := v.(*types.Hstore); ok { return hstore.Value() } return v } func deref(vi interface{}) interface{} { v := reflect.ValueOf(vi) for v.Kind() == reflect.Ptr { v = v.Elem() } if v.IsValid() { return v.Interface() } return nil } func zero(v interface{}) interface{} { return reflect.Zero(reflect.ValueOf(v).Elem().Type()).Interface() } func (test *conversionTest) String() string { return fmt.Sprintf("#%d src=%#v dst=%#v", test.i, test.src, test.dst) } func (test *conversionTest) Assert(t *testing.T, err error) { if test.wanterr != "" { if err == nil || err.Error() != test.wanterr { t.Fatalf("got error %v, wanted %q (%s)", err, test.wanterr, test) } return } if err != nil { t.Fatalf("got error %q, wanted nil (%s)", err, test) } if test.wantnothing { return } dst := reflect.Indirect(reflect.ValueOf(unwrap(test.dst))).Interface() if test.wantnil { dstValue := reflect.ValueOf(dst) if !dstValue.IsValid() { return } if dstValue.IsNil() { return } t.Fatalf("got %#v, wanted nil (%s)", dst, test) return } // Remove any intermediate pointers to compare values. dst = deref(unwrap(dst)) src := deref(unwrap(test.src)) if test.wantzero { dstValue := reflect.ValueOf(dst) switch dstValue.Kind() { case reflect.Slice, reflect.Map: if dstValue.IsNil() { t.Fatalf("got nil, wanted zero value") } if dstValue.Len() != 0 { t.Fatalf("got %d items, wanted 0", dstValue.Len()) } default: zero := zero(test.dst) if dst != zero { t.Fatalf("%#v != %#v (%s)", dst, zero, test) } } return } if dstTime, ok := dst.(time.Time); ok { srcTime := src.(time.Time) if dstTime.Unix() != srcTime.Unix() { t.Fatalf("%#v != %#v", dstTime, srcTime) } return } if dstTimes, ok := dst.([]time.Time); ok { srcTimes := src.([]time.Time) for i, dstTime := range dstTimes { srcTime := srcTimes[i] if dstTime.Unix() != srcTime.Unix() { t.Fatalf("%#v != %#v", dstTime, srcTime) } } return } wanted := test.wanted if wanted == nil { wanted = src } if !reflect.DeepEqual(dst, wanted) { t.Fatalf("%#v != %#v (%s)", dst, wanted, test) } } func conversionTests() []conversionTest { return []conversionTest{ {src: nil, dst: nil, wanterr: "pg: Scan(nil)"}, {src: nil, dst: new(uintptr), wanterr: "pg: Scan(unsupported uintptr)"}, {src: nil, dst: true, pgtype: "bool", wanterr: "pg: Scan(nonsettable bool)"}, {src: nil, dst: new(*bool), pgtype: "bool", wantnil: true}, {src: nil, dst: new(bool), pgtype: "bool", wantzero: true}, {src: true, dst: new(bool), pgtype: "bool"}, {src: true, dst: new(*bool), pgtype: "bool"}, {src: 1, dst: new(bool), wanted: true}, {src: nil, dst: "", pgtype: "text", wanterr: "pg: Scan(nonsettable string)"}, {src: nil, dst: new(string), pgtype: "text", wantzero: true}, {src: nil, dst: new(*string), pgtype: "text", wantnil: true}, {src: "hello world", dst: new(string), pgtype: "text"}, {src: "hello world", dst: new(*string), pgtype: "text"}, {src: "'\"\000", dst: new(string), wanted: `'"`, pgtype: "text"}, {src: nil, dst: []byte(nil), pgtype: "bytea", wanterr: "pg: Scan(nonsettable []uint8)"}, {src: nil, dst: new([]byte), pgtype: "bytea", wantnil: true}, {src: []byte("hello world\000"), dst: new([]byte), pgtype: "bytea"}, {src: []byte{}, dst: new([]byte), pgtype: "bytea", wantzero: true}, {src: nil, dst: int8(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable int8)"}, {src: nil, dst: new(int8), pgtype: "smallint", wantzero: true}, {src: int8(math.MaxInt8), dst: new(int8), pgtype: "smallint"}, {src: int8(math.MaxInt8), dst: new(*int8), pgtype: "smallint"}, {src: int8(math.MinInt8), dst: new(int8), pgtype: "smallint"}, {src: nil, dst: int16(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable int16)"}, {src: nil, dst: new(int16), pgtype: "smallint", wantzero: true}, {src: int16(math.MaxInt16), dst: new(int16), pgtype: "smallint"}, {src: int16(math.MaxInt16), dst: new(*int16), pgtype: "smallint"}, {src: int16(math.MinInt16), dst: new(int16), pgtype: "smallint"}, {src: nil, dst: int32(0), pgtype: "int", wanterr: "pg: Scan(nonsettable int32)"}, {src: nil, dst: new(int32), pgtype: "int", wantzero: true}, {src: int32(math.MaxInt32), dst: new(int32), pgtype: "int"}, {src: int32(math.MaxInt32), dst: new(*int32), pgtype: "int"}, {src: int32(math.MinInt32), dst: new(int32), pgtype: "int"}, {src: nil, dst: int64(0), pgtype: "bigint", wanterr: "pg: Scan(nonsettable int64)"}, {src: nil, dst: new(int64), pgtype: "bigint", wantzero: true}, {src: int64(math.MaxInt64), dst: new(int64), pgtype: "bigint"}, {src: int64(math.MaxInt64), dst: new(*int64), pgtype: "bigint"}, {src: int64(math.MinInt64), dst: new(int64), pgtype: "bigint"}, {src: nil, dst: int(0), pgtype: "bigint", wanterr: "pg: Scan(nonsettable int)"}, {src: nil, dst: new(int), pgtype: "bigint", wantzero: true}, {src: int64(math.MaxInt64), dst: new(int64), pgtype: "bigint"}, {src: int64(math.MaxInt64), dst: new(*int64), pgtype: "bigint"}, {src: int64(math.MinInt32), dst: new(int64), pgtype: "bigint"}, {src: nil, dst: uint8(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable uint8)"}, {src: nil, dst: new(uint8), pgtype: "smallint", wantzero: true}, {src: uint8(math.MaxUint8), dst: new(uint8), pgtype: "smallint"}, {src: uint8(math.MaxUint8), dst: new(*uint8), pgtype: "smallint"}, {src: nil, dst: uint16(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable uint16)"}, {src: nil, dst: new(uint16), pgtype: "smallint", wantzero: true}, {src: uint16(math.MaxUint16), dst: new(uint16), pgtype: "int"}, {src: uint16(math.MaxUint16), dst: new(*uint16), pgtype: "int"}, {src: nil, dst: uint32(0), pgtype: "bigint", wanterr: "pg: Scan(nonsettable uint32)"}, {src: nil, dst: new(uint32), pgtype: "bigint", wantzero: true}, {src: uint32(math.MaxUint32), dst: new(uint32), pgtype: "bigint"}, {src: uint32(math.MaxUint32), dst: new(*uint32), pgtype: "bigint"}, {src: nil, dst: uint64(0), pgtype: "bigint", wanterr: "pg: Scan(nonsettable uint64)"}, {src: nil, dst: new(uint64), pgtype: "bigint", wantzero: true}, {src: uint64(math.MaxUint64), dst: new(uint64)}, {src: uint64(math.MaxUint64), dst: new(*uint64)}, {src: uint64(math.MaxUint32), dst: new(uint64), pgtype: "bigint"}, {src: nil, dst: uint(0), pgtype: "smallint", wanterr: "pg: Scan(nonsettable uint)"}, {src: nil, dst: new(uint), pgtype: "bigint", wantzero: true}, {src: uint64(math.MaxUint64), dst: new(uint64)}, {src: uint64(math.MaxUint64), dst: new(*uint64)}, {src: uint64(math.MaxUint32), dst: new(uint64), pgtype: "bigint"}, {src: nil, dst: float32(0), pgtype: "decimal", wanterr: "pg: Scan(nonsettable float32)"}, {src: nil, dst: new(float32), pgtype: "decimal", wantzero: true}, {src: float32(math.MaxFloat32), dst: new(float32), pgtype: "decimal"}, {src: float32(math.MaxFloat32), dst: new(*float32), pgtype: "decimal"}, {src: float32(math.SmallestNonzeroFloat32), dst: new(float32), pgtype: "decimal"}, {src: nil, dst: float64(0), pgtype: "decimal", wanterr: "pg: Scan(nonsettable float64)"}, {src: nil, dst: new(float64), pgtype: "decimal", wantzero: true}, {src: float64(math.MaxFloat64), dst: new(float64), pgtype: "decimal"}, {src: float64(math.MaxFloat64), dst: new(*float64), pgtype: "decimal"}, {src: float64(math.SmallestNonzeroFloat64), dst: new(float64), pgtype: "decimal"}, {src: nil, dst: []int(nil), pgtype: "jsonb", wanterr: "pg: Scan(nonsettable []int)"}, {src: nil, dst: new([]int), pgtype: "jsonb", wantnil: true}, {src: []int(nil), dst: new([]int), pgtype: "jsonb", wantnil: true}, {src: []int{}, dst: new([]int), pgtype: "jsonb", wantzero: true}, {src: []int{1, 2, 3}, dst: new([]int), pgtype: "jsonb"}, {src: [3]int{1, 2, 3}, dst: new([3]int), pgtype: "jsonb"}, {src: nil, dst: pg.Array([]int(nil)), pgtype: "int[]", wanterr: "pg: Scan(nonsettable []int)"}, {src: pg.Array([]int(nil)), dst: pg.Array(new([]int)), pgtype: "int[]", wantnil: true}, {src: pg.Array([]int{}), dst: pg.Array(new([]int)), pgtype: "int[]"}, {src: pg.Array([]int{1, 2, 3}), dst: pg.Array(new([]int)), pgtype: "int[]"}, {src: pg.Array(&[3]int{1, 2, 3}), dst: pg.Array(new([3]int)), pgtype: "int[]"}, {src: nil, dst: pg.Array([]int64(nil)), pgtype: "bigint[]", wanterr: "pg: Scan(nonsettable []int64)"}, {src: nil, dst: pg.Array(new([]int64)), pgtype: "bigint[]", wantnil: true}, {src: pg.Array([]int64(nil)), dst: pg.Array(new([]int64)), pgtype: "bigint[]", wantnil: true}, {src: pg.Array([]int64{}), dst: pg.Array(new([]int64)), pgtype: "bigint[]"}, {src: pg.Array([]int64{1, 2, 3}), dst: pg.Array(new([]int64)), pgtype: "bigint[]"}, {src: nil, dst: pg.Array([]float64(nil)), pgtype: "decimal[]", wanterr: "pg: Scan(nonsettable []float64)"}, {src: nil, dst: pg.Array(new([]float64)), pgtype: "decimal[]", wantnil: true}, {src: pg.Array([]float64(nil)), dst: pg.Array(new([]float64)), pgtype: "decimal[]", wantnil: true}, {src: pg.Array([]float64{}), dst: pg.Array(new([]float64)), pgtype: "decimal[]"}, {src: pg.Array([]float64{1.1, 2.22, 3.333}), dst: pg.Array(new([]float64)), pgtype: "decimal[]"}, {src: pg.Array([]float64{math.NaN(), math.Inf(+1), math.Inf(-1)}), dst: pg.Array(new([]float64)), pgtype: "float[]", wantnothing: true}, {src: nil, dst: pg.Array([]string(nil)), pgtype: "text[]", wanterr: "pg: Scan(nonsettable []string)"}, {src: nil, dst: pg.Array(new([]string)), pgtype: "text[]", wantnil: true}, {src: pg.Array([]string(nil)), dst: pg.Array(new([]string)), pgtype: "text[]", wantnil: true}, {src: pg.Array([]string{}), dst: pg.Array(new([]string)), pgtype: "text[]"}, {src: pg.Array([]string{"one", "two", "three"}), dst: pg.Array(new([]string)), pgtype: "text[]"}, {src: pg.Array([]string{`'"{}`}), dst: pg.Array(new([]string)), pgtype: "text[]"}, {src: nil, dst: pg.Array([][]string(nil)), pgtype: "text[][]", wanterr: "pg: Scan(nonsettable [][]string)"}, {src: nil, dst: pg.Array(new([][]string)), pgtype: "text[][]", wantnil: true}, {src: pg.Array([][]string(nil)), dst: pg.Array(new([]string)), pgtype: "text[][]", wantnil: true}, {src: pg.Array([][]string{}), dst: pg.Array(new([][]string)), pgtype: "text[][]"}, {src: pg.Array([][]string{{"one", "two"}, {"three", "four"}}), dst: pg.Array(new([][]string)), pgtype: "text[][]"}, {src: pg.Array([][]string{{`'"\{}`}}), dst: pg.Array(new([][]string)), pgtype: "text[][]"}, {src: pg.Array([][]byte{[]byte(`'"\{}`)}), dst: pg.Array(new([][]byte)), pgtype: "bytea[]"}, {src: nil, dst: pg.Hstore(map[string]string(nil)), pgtype: "hstore", wanterr: "pg: Scan(nonsettable map[string]string)"}, {src: nil, dst: pg.Hstore(new(map[string]string)), pgtype: "hstore", wantnil: true}, {src: pg.Hstore(map[string]string(nil)), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore", wantnil: true}, {src: pg.Hstore(map[string]string{}), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore"}, {src: pg.Hstore(map[string]string{"foo": "bar"}), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore"}, {src: pg.Hstore(map[string]string{`'"\{}=>`: `'"\{}=>`}), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore"}, {src: nil, dst: sql.NullBool{}, pgtype: "bool", wanterr: "pg: Scan(nonsettable sql.NullBool)"}, {src: nil, dst: new(*sql.NullBool), pgtype: "bool", wantnil: true}, {src: nil, dst: new(sql.NullBool), pgtype: "bool", wanted: sql.NullBool{}}, {src: &sql.NullBool{}, dst: new(sql.NullBool), pgtype: "bool"}, {src: &sql.NullBool{Valid: true}, dst: new(sql.NullBool), pgtype: "bool"}, {src: &sql.NullBool{Valid: true, Bool: true}, dst: new(sql.NullBool), pgtype: "bool"}, {src: &sql.NullString{}, dst: new(sql.NullString), pgtype: "text"}, {src: &sql.NullString{Valid: true}, dst: new(sql.NullString), pgtype: "text"}, {src: &sql.NullString{Valid: true, String: "foo"}, dst: new(sql.NullString), pgtype: "text"}, {src: &sql.NullInt64{}, dst: new(sql.NullInt64), pgtype: "bigint"}, {src: &sql.NullInt64{Valid: true}, dst: new(sql.NullInt64), pgtype: "bigint"}, {src: &sql.NullInt64{Valid: true, Int64: math.MaxInt64}, dst: new(sql.NullInt64), pgtype: "bigint"}, {src: &sql.NullFloat64{}, dst: new(sql.NullFloat64), pgtype: "decimal"}, {src: &sql.NullFloat64{Valid: true}, dst: new(sql.NullFloat64), pgtype: "decimal"}, {src: &sql.NullFloat64{Valid: true, Float64: math.MaxFloat64}, dst: new(sql.NullFloat64), pgtype: "decimal"}, {src: nil, dst: customStrSlice{}, wanterr: "pg: Scan(nonsettable pg_test.customStrSlice)"}, {src: nil, dst: new(customStrSlice), wantnil: true}, {src: nil, dst: new(*customStrSlice), wantnil: true}, {src: customStrSlice{}, dst: new(customStrSlice), wantzero: true}, {src: customStrSlice{"one", "two"}, dst: new(customStrSlice)}, {src: nil, dst: time.Time{}, wanterr: "pg: Scan(nonsettable time.Time)"}, {src: nil, dst: new(time.Time), pgtype: "timestamptz", wantzero: true}, {src: nil, dst: new(*time.Time), pgtype: "timestamptz", wantnil: true}, {src: time.Now(), dst: new(time.Time), pgtype: "timestamptz"}, {src: time.Now(), dst: new(*time.Time), pgtype: "timestamptz"}, {src: time.Now().UTC(), dst: new(time.Time), pgtype: "timestamptz"}, {src: time.Time{}, dst: new(time.Time), pgtype: "timestamptz"}, {src: nil, dst: pg.Array([]time.Time(nil)), pgtype: "timestamptz[]", wanterr: "pg: Scan(nonsettable []time.Time)"}, {src: nil, dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]", wantnil: true}, {src: pg.Array([]time.Time(nil)), dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]", wantnil: true}, {src: pg.Array([]time.Time{}), dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]"}, {src: pg.Array([]time.Time{time.Now(), time.Now(), time.Now()}), dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]"}, {src: nil, dst: pg.Ints{}, wanterr: "pg: Scan(nonsettable pg.Ints)"}, {src: 1, dst: new(pg.Ints), wanted: pg.Ints{1}}, {src: nil, dst: pg.Strings{}, wanterr: "pg: Scan(nonsettable pg.Strings)"}, {src: "hello", dst: new(pg.Strings), wanted: pg.Strings{"hello"}}, {src: nil, dst: pg.IntSet{}, wanterr: "pg: Scan(nonsettable pg.IntSet)"}, {src: 1, dst: new(pg.IntSet), wanted: pg.IntSet{1: struct{}{}}}, {src: nil, dst: JSONMap{}, pgtype: "json", wanterr: "pg: Scan(nonsettable pg_test.JSONMap)"}, {src: nil, dst: new(JSONMap), pgtype: "json", wantnil: true}, {src: nil, dst: new(*JSONMap), pgtype: "json", wantnil: true}, {src: JSONMap{}, dst: new(JSONMap), pgtype: "json"}, {src: JSONMap{}, dst: new(*JSONMap), pgtype: "json"}, {src: JSONMap{"foo": "bar"}, dst: new(JSONMap), pgtype: "json"}, {src: `{"foo": "bar"}`, dst: new(JSONMap), pgtype: "json", wanted: JSONMap{"foo": "bar"}}, {src: nil, dst: Struct{}, pgtype: "json", wanterr: "pg: Scan(nonsettable pg_test.Struct)"}, {src: nil, dst: new(*Struct), pgtype: "json", wantnil: true}, {src: nil, dst: new(Struct), pgtype: "json", wantzero: true}, {src: Struct{}, dst: new(Struct), pgtype: "json"}, {src: Struct{Foo: "bar"}, dst: new(Struct), pgtype: "json"}, {src: `{"foo": "bar"}`, dst: new(Struct), wanted: Struct{Foo: "bar"}}, {src: nil, dst: new(net.IP), wanted: net.IP(nil), pgtype: "inet"}, {src: net.ParseIP("127.0.0.1"), dst: new(net.IP), pgtype: "inet"}, {src: net.ParseIP("::10.2.3.4"), dst: new(net.IP), pgtype: "inet"}, {src: net.ParseIP("::ffff:10.4.3.2"), dst: new(net.IP), pgtype: "inet"}, {src: nil, dst: (*net.IPNet)(nil), pgtype: "cidr", wanterr: "pg: Scan(nonsettable *net.IPNet)"}, {src: nil, dst: new(net.IPNet), wanted: net.IPNet{}, pgtype: "cidr"}, {src: nil, dst: mustParseCIDR("192.168.100.128/25"), wanted: net.IPNet{}, pgtype: "cidr"}, {src: mustParseCIDR("192.168.100.128/25"), dst: new(net.IPNet), pgtype: "cidr"}, {src: mustParseCIDR("2001:4f8:3:ba::/64"), dst: new(net.IPNet), pgtype: "cidr"}, {src: mustParseCIDR("2001:4f8:3:ba:2e0:81ff:fe22:d1f1/128"), dst: new(net.IPNet), pgtype: "cidr"}, } } func TestConversion(t *testing.T) { db := pg.Connect(pgOptions()) defer db.Close() for i, test := range conversionTests() { test.i = i var scanner orm.ColumnScanner if v, ok := test.dst.(orm.ColumnScanner); ok { scanner = v } else { scanner = pg.Scan(test.dst) } _, err := db.QueryOne(scanner, "SELECT (?) AS dst", test.src) test.Assert(t, err) } for i, test := range conversionTests() { test.i = i var scanner orm.ColumnScanner if v, ok := test.dst.(orm.ColumnScanner); ok { scanner = v } else { scanner = pg.Scan(test.dst) } err := db.Model().ColumnExpr("(?) AS dst", test.src).Select(scanner) test.Assert(t, err) } for i, test := range conversionTests() { test.i = i if test.pgtype == "" { continue } stmt, err := db.Prepare(fmt.Sprintf("SELECT ($1::%s) AS dst", test.pgtype)) if err != nil { t.Fatal(err) } var scanner orm.ColumnScanner if v, ok := test.dst.(orm.ColumnScanner); ok { scanner = v } else { scanner = pg.Scan(test.dst) } _, err = stmt.QueryOne(scanner, test.src) test.Assert(t, err) if err := stmt.Close(); err != nil { t.Fatal(err) } } } func mustParseCIDR(s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { panic(err) } return ipnet }