package pg_test
import (
"bytes"
"database/sql/driver"
"net"
"strings"
"sync"
"testing"
"time"
. "github.com/onsi/ginkgo"
. "gopkg.in/check.v1"
"github.com/go-pg/pg"
)
func TestUnixSocket(t *testing.T) {
opt := pgOptions()
opt.Network = "unix"
opt.Addr = "/var/run/postgresql/.s.PGSQL.5432"
opt.TLSConfig = nil
db := pg.Connect(opt)
defer db.Close()
_, err := db.Exec("SELECT 'test_unix_socket'")
if err != nil {
t.Fatal(err)
}
}
func TestGocheck(t *testing.T) { TestingT(t) }
var _ = Suite(&DBTest{})
type DBTest struct {
db *pg.DB
}
func (t *DBTest) SetUpTest(c *C) {
t.db = pg.Connect(pgOptions())
}
func (t *DBTest) TearDownTest(c *C) {
c.Assert(t.db.Close(), IsNil)
}
func (t *DBTest) TestQueryOneErrMultiRows(c *C) {
_, err := t.db.QueryOne(pg.Discard, "SELECT generate_series(0, 1)")
c.Assert(err, Equals, pg.ErrMultiRows)
}
func (t *DBTest) TestExecOne(c *C) {
res, err := t.db.ExecOne("SELECT 'test_exec_one'")
c.Assert(err, IsNil)
c.Assert(res.RowsAffected(), Equals, 1)
}
func (t *DBTest) TestExecOneErrNoRows(c *C) {
_, err := t.db.ExecOne("SELECT 1 WHERE 1 != 1")
c.Assert(err, Equals, pg.ErrNoRows)
}
func (t *DBTest) TestExecOneErrMultiRows(c *C) {
_, err := t.db.ExecOne("SELECT generate_series(0, 1)")
c.Assert(err, Equals, pg.ErrMultiRows)
}
func (t *DBTest) TestScan(c *C) {
var dst int
_, err := t.db.QueryOne(pg.Scan(&dst), "SELECT 1")
c.Assert(err, IsNil)
c.Assert(dst, Equals, 1)
}
func (t *DBTest) TestExec(c *C) {
res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)")
c.Assert(err, IsNil)
c.Assert(res.RowsAffected(), Equals, -1)
res, err = t.db.Exec("INSERT INTO test VALUES (1)")
c.Assert(err, IsNil)
c.Assert(res.RowsAffected(), Equals, 1)
}
func (t *DBTest) TestStatementExec(c *C) {
res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)")
c.Assert(err, IsNil)
c.Assert(res.RowsAffected(), Equals, -1)
stmt, err := t.db.Prepare("INSERT INTO test VALUES($1)")
c.Assert(err, IsNil)
defer stmt.Close()
res, err = stmt.Exec(1)
c.Assert(err, IsNil)
c.Assert(res.RowsAffected(), Equals, 1)
}
func (t *DBTest) TestLargeWriteRead(c *C) {
src := bytes.Repeat([]byte{0x1}, 1e6)
var dst []byte
_, err := t.db.QueryOne(pg.Scan(&dst), "SELECT ?", src)
c.Assert(err, IsNil)
c.Assert(dst, DeepEquals, src)
}
func (t *DBTest) TestIntegrityError(c *C) {
_, err := t.db.Exec("DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END$$;")
c.Assert(err.(pg.Error).IntegrityViolation(), Equals, true)
}
type customStrSlice []string
func (s customStrSlice) Value() (driver.Value, error) {
return strings.Join(s, "\n"), nil
}
func (s *customStrSlice) Scan(v interface{}) error {
if v == nil {
*s = nil
return nil
}
b := v.([]byte)
if len(b) == 0 {
*s = []string{}
return nil
}
*s = strings.Split(string(b), "\n")
return nil
}
func (t *DBTest) TestScannerValueOnStruct(c *C) {
src := customStrSlice{"foo", "bar"}
dst := struct{ Dst customStrSlice }{}
_, err := t.db.QueryOne(&dst, "SELECT ? AS dst", src)
c.Assert(err, IsNil)
c.Assert(dst.Dst, DeepEquals, src)
}
//------------------------------------------------------------------------------
type badConnError string
func (e badConnError) Error() string { return string(e) }
func (e badConnError) Timeout() bool { return false }
func (e badConnError) Temporary() bool { return false }
type badConn struct {
net.TCPConn
readDelay, writeDelay time.Duration
readErr, writeErr error
}
var _ net.Conn = &badConn{}
func (cn *badConn) Read([]byte) (int, error) {
if cn.readDelay != 0 {
time.Sleep(cn.readDelay)
}
if cn.readErr != nil {
return 0, cn.readErr
}
return 0, badConnError("bad connection")
}
func (cn *badConn) Write([]byte) (int, error) {
if cn.writeDelay != 0 {
time.Sleep(cn.writeDelay)
}
if cn.writeErr != nil {
return 0, cn.writeErr
}
return 0, badConnError("bad connection")
}
func performAsync(n int, cbs ...func(int)) *sync.WaitGroup {
var wg sync.WaitGroup
for _, cb := range cbs {
for i := 0; i < n; i++ {
wg.Add(1)
go func(cb func(int), i int) {
defer GinkgoRecover()
defer wg.Done()
cb(i)
}(cb, i)
}
}
return &wg
}
func perform(n int, cbs ...func(int)) {
wg := performAsync(n, cbs...)
wg.Wait()
}