package pg_test import ( "bytes" "database/sql/driver" "net" "strings" "sync" "testing" "time" . "github.com/onsi/ginkgo" . "gopkg.in/check.v1" "github.com/go-pg/pg" ) func TestUnixSocket(t *testing.T) { opt := pgOptions() opt.Network = "unix" opt.Addr = "/var/run/postgresql/.s.PGSQL.5432" opt.TLSConfig = nil db := pg.Connect(opt) defer db.Close() _, err := db.Exec("SELECT 'test_unix_socket'") if err != nil { t.Fatal(err) } } func TestGocheck(t *testing.T) { TestingT(t) } var _ = Suite(&DBTest{}) type DBTest struct { db *pg.DB } func (t *DBTest) SetUpTest(c *C) { t.db = pg.Connect(pgOptions()) } func (t *DBTest) TearDownTest(c *C) { c.Assert(t.db.Close(), IsNil) } func (t *DBTest) TestQueryOneErrMultiRows(c *C) { _, err := t.db.QueryOne(pg.Discard, "SELECT generate_series(0, 1)") c.Assert(err, Equals, pg.ErrMultiRows) } func (t *DBTest) TestExecOne(c *C) { res, err := t.db.ExecOne("SELECT 'test_exec_one'") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 1) } func (t *DBTest) TestExecOneErrNoRows(c *C) { _, err := t.db.ExecOne("SELECT 1 WHERE 1 != 1") c.Assert(err, Equals, pg.ErrNoRows) } func (t *DBTest) TestExecOneErrMultiRows(c *C) { _, err := t.db.ExecOne("SELECT generate_series(0, 1)") c.Assert(err, Equals, pg.ErrMultiRows) } func (t *DBTest) TestScan(c *C) { var dst int _, err := t.db.QueryOne(pg.Scan(&dst), "SELECT 1") c.Assert(err, IsNil) c.Assert(dst, Equals, 1) } func (t *DBTest) TestExec(c *C) { res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, -1) res, err = t.db.Exec("INSERT INTO test VALUES (1)") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 1) } func (t *DBTest) TestStatementExec(c *C) { res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, -1) stmt, err := t.db.Prepare("INSERT INTO test VALUES($1)") c.Assert(err, IsNil) defer stmt.Close() res, err = stmt.Exec(1) c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 1) } func (t *DBTest) TestLargeWriteRead(c *C) { src := bytes.Repeat([]byte{0x1}, 1e6) var dst []byte _, err := t.db.QueryOne(pg.Scan(&dst), "SELECT ?", src) c.Assert(err, IsNil) c.Assert(dst, DeepEquals, src) } func (t *DBTest) TestIntegrityError(c *C) { _, err := t.db.Exec("DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END$$;") c.Assert(err.(pg.Error).IntegrityViolation(), Equals, true) } type customStrSlice []string func (s customStrSlice) Value() (driver.Value, error) { return strings.Join(s, "\n"), nil } func (s *customStrSlice) Scan(v interface{}) error { if v == nil { *s = nil return nil } b := v.([]byte) if len(b) == 0 { *s = []string{} return nil } *s = strings.Split(string(b), "\n") return nil } func (t *DBTest) TestScannerValueOnStruct(c *C) { src := customStrSlice{"foo", "bar"} dst := struct{ Dst customStrSlice }{} _, err := t.db.QueryOne(&dst, "SELECT ? AS dst", src) c.Assert(err, IsNil) c.Assert(dst.Dst, DeepEquals, src) } //------------------------------------------------------------------------------ type badConnError string func (e badConnError) Error() string { return string(e) } func (e badConnError) Timeout() bool { return false } func (e badConnError) Temporary() bool { return false } type badConn struct { net.TCPConn readDelay, writeDelay time.Duration readErr, writeErr error } var _ net.Conn = &badConn{} func (cn *badConn) Read([]byte) (int, error) { if cn.readDelay != 0 { time.Sleep(cn.readDelay) } if cn.readErr != nil { return 0, cn.readErr } return 0, badConnError("bad connection") } func (cn *badConn) Write([]byte) (int, error) { if cn.writeDelay != 0 { time.Sleep(cn.writeDelay) } if cn.writeErr != nil { return 0, cn.writeErr } return 0, badConnError("bad connection") } func performAsync(n int, cbs ...func(int)) *sync.WaitGroup { var wg sync.WaitGroup for _, cb := range cbs { for i := 0; i < n; i++ { wg.Add(1) go func(cb func(int), i int) { defer GinkgoRecover() defer wg.Done() cb(i) }(cb, i) } } return &wg } func perform(n int, cbs ...func(int)) { wg := performAsync(n, cbs...) wg.Wait() }