mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
Inital pass at converting stdlib
Multiple tests still failing
This commit is contained in:
parent
3901f3ef88
commit
b77f901168
@ -47,7 +47,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
conn *pgx.Conn
|
||||
)
|
||||
|
||||
if conn, err = pgx.Connect(c.ConnConfig); err != nil {
|
||||
if conn, err = pgx.ConnectConfig(ctx, &c.ConnConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
177
stdlib/sql.go
177
stdlib/sql.go
@ -72,13 +72,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@ -99,9 +99,7 @@ var ctxKeyFakeTx ctxKey = 0
|
||||
var ErrNotPgx = errors.New("not pgx *sql.DB")
|
||||
|
||||
func init() {
|
||||
pgxDriver = &Driver{
|
||||
configs: make(map[int64]*DriverConfig),
|
||||
}
|
||||
pgxDriver = &Driver{}
|
||||
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
|
||||
sql.Register("pgx", pgxDriver)
|
||||
|
||||
@ -126,97 +124,25 @@ var (
|
||||
fakeTxConns map[*pgx.Conn]*sql.Tx
|
||||
)
|
||||
|
||||
type Driver struct {
|
||||
configMutex sync.Mutex
|
||||
configCount int64
|
||||
configs map[int64]*DriverConfig
|
||||
}
|
||||
type Driver struct{}
|
||||
|
||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||
var (
|
||||
connConfig pgx.ConnConfig
|
||||
afterConnect func(*pgx.Conn) error
|
||||
)
|
||||
|
||||
if len(name) >= 9 && name[0] == 0 {
|
||||
idBuf := []byte(name)[1:9]
|
||||
id := int64(binary.BigEndian.Uint64(idBuf))
|
||||
d.configMutex.Lock()
|
||||
connConfig = d.configs[id].ConnConfig
|
||||
afterConnect = d.configs[id].AfterConnect
|
||||
d.configMutex.Unlock()
|
||||
name = name[9:]
|
||||
}
|
||||
|
||||
parsedConfig, err := pgx.ParseConnectionString(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connConfig = connConfig.Merge(parsedConfig)
|
||||
|
||||
conn, err := pgx.Connect(connConfig)
|
||||
connConfig, err := pgx.ParseConfig(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if afterConnect != nil {
|
||||
err = afterConnect(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
|
||||
defer cancel()
|
||||
conn, err := pgx.ConnectConfig(ctx, connConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &Conn{conn: conn, driver: d, connConfig: connConfig}
|
||||
c := &Conn{conn: conn, driver: d, connConfig: *connConfig}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type DriverConfig struct {
|
||||
pgx.ConnConfig
|
||||
AfterConnect func(*pgx.Conn) error // function to call on every new connection
|
||||
driver *Driver
|
||||
id int64
|
||||
}
|
||||
|
||||
// ConnectionString encodes the DriverConfig into the original connection
|
||||
// string. DriverConfig must be registered before calling ConnectionString.
|
||||
func (c *DriverConfig) ConnectionString(original string) string {
|
||||
if c.driver == nil {
|
||||
panic("DriverConfig must be registered before calling ConnectionString")
|
||||
}
|
||||
|
||||
buf := make([]byte, 9)
|
||||
binary.BigEndian.PutUint64(buf[1:], uint64(c.id))
|
||||
buf = append(buf, original...)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (d *Driver) registerDriverConfig(c *DriverConfig) {
|
||||
d.configMutex.Lock()
|
||||
|
||||
c.driver = d
|
||||
c.id = d.configCount
|
||||
d.configs[d.configCount] = c
|
||||
d.configCount++
|
||||
|
||||
d.configMutex.Unlock()
|
||||
}
|
||||
|
||||
func (d *Driver) unregisterDriverConfig(c *DriverConfig) {
|
||||
d.configMutex.Lock()
|
||||
delete(d.configs, c.id)
|
||||
d.configMutex.Unlock()
|
||||
}
|
||||
|
||||
// RegisterDriverConfig registers a DriverConfig for use with Open.
|
||||
func RegisterDriverConfig(c *DriverConfig) {
|
||||
pgxDriver.registerDriverConfig(c)
|
||||
}
|
||||
|
||||
// UnregisterDriverConfig removes a DriverConfig registration.
|
||||
func UnregisterDriverConfig(c *DriverConfig) {
|
||||
pgxDriver.unregisterDriverConfig(c)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
conn *pgx.Conn
|
||||
psCount int64 // Counter used for creating unique prepared statement names
|
||||
@ -247,7 +173,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.conn.Close()
|
||||
return c.conn.Close(context.Background())
|
||||
}
|
||||
|
||||
func (c *Conn) Begin() (driver.Tx, error) {
|
||||
@ -283,23 +209,12 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
pgxOpts.AccessMode = pgx.ReadOnly
|
||||
}
|
||||
|
||||
return c.conn.BeginEx(ctx, &pgxOpts)
|
||||
}
|
||||
|
||||
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
tx, err := c.conn.BeginEx(ctx, &pgxOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
args := valueToInterface(argsV)
|
||||
commandTag, err := c.conn.Exec(query, args...)
|
||||
// if we got a network error before we had a chance to send the query, retry
|
||||
if err != nil && !c.conn.LastStmtSent() {
|
||||
if _, is := err.(net.Error); is {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
}
|
||||
return driver.RowsAffected(commandTag.RowsAffected()), err
|
||||
return wrapTx{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
|
||||
@ -309,7 +224,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
|
||||
|
||||
args := namedValueToInterface(argsV)
|
||||
|
||||
commandTag, err := c.conn.ExecEx(ctx, query, nil, args...)
|
||||
commandTag, err := c.conn.Exec(ctx, query, args...)
|
||||
// if we got a network error before we had a chance to send the query, retry
|
||||
if err != nil && !c.conn.LastStmtSent() {
|
||||
if _, is := err.(net.Error); is {
|
||||
@ -319,44 +234,16 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
|
||||
return driver.RowsAffected(commandTag.RowsAffected()), err
|
||||
}
|
||||
|
||||
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
if !c.connConfig.PreferSimpleProtocol {
|
||||
ps, err := c.conn.Prepare("", query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
return c.queryPrepared("", argsV)
|
||||
}
|
||||
|
||||
rows, err := c.conn.Query(query, valueToInterface(argsV)...)
|
||||
if err != nil {
|
||||
// if we got a network error before we had a chance to send the query, retry
|
||||
if !c.conn.LastStmtSent() {
|
||||
if _, is := err.(net.Error); is {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
||||
more := rows.Next()
|
||||
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
var rows pgx.Rows
|
||||
|
||||
if !c.connConfig.PreferSimpleProtocol {
|
||||
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
|
||||
c.conn.Deallocate("stdlibtemp")
|
||||
ps, err := c.conn.PrepareEx(ctx, "stdlibtemp", query, nil)
|
||||
if err != nil {
|
||||
// since PrepareEx failed, we didn't actually get to send the values, so
|
||||
// we can safely retry
|
||||
@ -367,10 +254,10 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
return c.queryPreparedContext(ctx, "", argsV)
|
||||
return c.queryPreparedContext(ctx, "stdlibtemp", argsV)
|
||||
}
|
||||
|
||||
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
|
||||
rows, err := c.conn.Query(ctx, query, namedValueToInterface(argsV)...)
|
||||
if err != nil {
|
||||
// if we got a network error before we had a chance to send the query, retry
|
||||
if !c.conn.LastStmtSent() {
|
||||
@ -393,7 +280,7 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er
|
||||
|
||||
args := valueToInterface(argsV)
|
||||
|
||||
rows, err := c.conn.Query(name, args...)
|
||||
rows, err := c.conn.Query(context.Background(), name, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -408,12 +295,14 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr
|
||||
|
||||
args := namedValueToInterface(argsV)
|
||||
|
||||
rows, err := c.conn.QueryEx(ctx, name, nil, args...)
|
||||
rows, err := c.conn.Query(ctx, name, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Rows{rows: rows}, nil
|
||||
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
||||
more := rows.Next()
|
||||
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Ping(ctx context.Context) error {
|
||||
@ -450,7 +339,7 @@ func (s *Stmt) NumInput() int {
|
||||
}
|
||||
|
||||
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
||||
return s.conn.Exec(s.ps.Name, argsV)
|
||||
return nil, errors.New("Stmt.Exec deprecated and not implemented")
|
||||
}
|
||||
|
||||
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
|
||||
@ -458,7 +347,7 @@ func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driv
|
||||
}
|
||||
|
||||
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
||||
return s.conn.queryPrepared(s.ps.Name, argsV)
|
||||
return nil, errors.New("Stmt.Query deprecated and not implemented")
|
||||
}
|
||||
|
||||
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||
@ -466,7 +355,7 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri
|
||||
}
|
||||
|
||||
type Rows struct {
|
||||
rows *pgx.Rows
|
||||
rows pgx.Rows
|
||||
values []interface{}
|
||||
skipNext bool
|
||||
skipNextMore bool
|
||||
@ -605,6 +494,12 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
|
||||
return args
|
||||
}
|
||||
|
||||
type wrapTx struct{ tx *pgx.Tx }
|
||||
|
||||
func (wtx wrapTx) Commit() error { return wtx.tx.Commit(context.Background()) }
|
||||
|
||||
func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(context.Background()) }
|
||||
|
||||
type fakeTx struct{}
|
||||
|
||||
func (fakeTx) Commit() error { return nil }
|
||||
|
@ -6,12 +6,12 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgproto3"
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgmock"
|
||||
@ -127,33 +127,6 @@ func TestNormalLifeCycle(t *testing.T) {
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func TestOpenWithDriverConfigAfterConnect(t *testing.T) {
|
||||
driverConfig := stdlib.DriverConfig{
|
||||
AfterConnect: func(c *pgx.Conn) error {
|
||||
_, err := c.Exec("create temporary sequence pgx")
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
stdlib.RegisterDriverConfig(&driverConfig)
|
||||
defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||
|
||||
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
var n int64
|
||||
err = db.QueryRow("select nextval('pgx')").Scan(&n)
|
||||
if err != nil {
|
||||
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatalf("n => %d, want %d", n, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmtExec(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
@ -330,44 +303,6 @@ func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface
|
||||
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
|
||||
}
|
||||
|
||||
func TestConnQueryLog(t *testing.T) {
|
||||
logger := &testLogger{}
|
||||
|
||||
driverConfig := stdlib.DriverConfig{
|
||||
ConnConfig: pgx.ConnConfig{
|
||||
Host: "127.0.0.1",
|
||||
User: "pgx_md5",
|
||||
Password: "secret",
|
||||
Database: "pgx_test",
|
||||
Logger: logger,
|
||||
},
|
||||
}
|
||||
|
||||
stdlib.RegisterDriverConfig(&driverConfig)
|
||||
defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||
|
||||
db, err := sql.Open("pgx", driverConfig.ConnectionString(""))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
var n int64
|
||||
err = db.QueryRow("select 1").Scan(&n)
|
||||
if err != nil {
|
||||
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
l := logger.logs[len(logger.logs)-1]
|
||||
if l.msg != "Query" {
|
||||
t.Errorf("Expected to log Query, but got %v", l)
|
||||
}
|
||||
|
||||
if l.data["sql"] != "select 1" {
|
||||
t.Errorf("Expected to log Query with sql 'select 1', but got %v", l)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnQueryNull(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
@ -430,8 +365,8 @@ func TestConnQueryFailure(t *testing.T) {
|
||||
defer closeDB(t, db)
|
||||
|
||||
_, err := db.Query("select 'foo")
|
||||
if _, ok := err.(pgx.PgError); !ok {
|
||||
t.Fatalf("Expected db.Query to return pgx.PgError, but instead received: %v", err)
|
||||
if _, ok := err.(*pgconn.PgError); !ok {
|
||||
t.Fatalf("Expected db.Query to return pgconn.PgError, but instead received: %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, db)
|
||||
@ -723,7 +658,7 @@ func TestBeginTxContextCancel(t *testing.T) {
|
||||
|
||||
var n int
|
||||
err = db.QueryRow("select count(*) from t").Scan(&n)
|
||||
if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "42P01" {
|
||||
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
|
||||
t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
|
||||
}
|
||||
|
||||
@ -747,52 +682,6 @@ func acceptStandardPgxConn(backend *pgproto3.Backend) error {
|
||||
return typeScript.Run(backend)
|
||||
}
|
||||
|
||||
func TestBeginTxContextCancelWithDeadConn(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
|
||||
pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}),
|
||||
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
||||
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx failed: %v", err)
|
||||
}
|
||||
|
||||
cancelFn()
|
||||
|
||||
err = tx.Commit()
|
||||
if err != context.Canceled && err != sql.ErrTxDone {
|
||||
t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
|
||||
}
|
||||
|
||||
if err := <-errChan; err != nil {
|
||||
t.Fatalf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireConn(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
@ -807,7 +696,7 @@ func TestAcquireConn(t *testing.T) {
|
||||
}
|
||||
|
||||
var n int32
|
||||
err = conn.QueryRow("select 1").Scan(&n)
|
||||
err = conn.QueryRow(context.Background(), "select 1").Scan(&n)
|
||||
if err != nil {
|
||||
t.Errorf("%d. QueryRow failed: %v", i, err)
|
||||
}
|
||||
@ -843,46 +732,6 @@ func TestConnPingContextSuccess(t *testing.T) {
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func TestConnPingContextCancel(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Query{String: ";"}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
||||
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err = db.PingContext(ctx)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
if err := <-errChan; err != nil {
|
||||
t.Errorf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnPrepareContextSuccess(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
@ -896,48 +745,6 @@ func TestConnPrepareContextSuccess(t *testing.T) {
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func TestConnPrepareContextCancel(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}),
|
||||
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}),
|
||||
pgmock.ExpectMessage(&pgproto3.Sync{}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
||||
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = db.PrepareContext(ctx, "select now()")
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
if err := <-errChan; err != nil {
|
||||
t.Errorf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnExecContextSuccess(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
@ -950,46 +757,6 @@ func TestConnExecContextSuccess(t *testing.T) {
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func TestConnExecContextCancel(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
||||
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = db.ExecContext(ctx, "create temporary table exec_context_test(id serial primary key)")
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
if err := <-errChan; err != nil {
|
||||
t.Errorf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnExecContextFailureRetry(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
@ -1000,7 +767,7 @@ func TestConnExecContextFailureRetry(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
conn.Close(context.Background())
|
||||
stdlib.ReleaseConn(db, conn)
|
||||
}
|
||||
conn, err := db.Conn(context.Background())
|
||||
@ -1035,77 +802,6 @@ func TestConnQueryContextSuccess(t *testing.T) {
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func TestConnQueryContextCancel(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Parse{Query: "select * from generate_series(1,10) n"}),
|
||||
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S'}),
|
||||
pgmock.ExpectMessage(&pgproto3.Sync{}),
|
||||
|
||||
pgmock.SendMessage(&pgproto3.ParseComplete{}),
|
||||
pgmock.SendMessage(&pgproto3.ParameterDescription{}),
|
||||
pgmock.SendMessage(&pgproto3.RowDescription{
|
||||
Fields: []pgproto3.FieldDescription{
|
||||
{
|
||||
Name: "n",
|
||||
DataTypeOID: 23,
|
||||
DataTypeSize: 4,
|
||||
TypeModifier: -1,
|
||||
},
|
||||
},
|
||||
}),
|
||||
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
|
||||
pgmock.ExpectMessage(&pgproto3.Bind{ResultFormatCodes: []int16{1}}),
|
||||
pgmock.ExpectMessage(&pgproto3.Execute{}),
|
||||
pgmock.ExpectMessage(&pgproto3.Sync{}),
|
||||
|
||||
pgmock.SendMessage(&pgproto3.BindComplete{}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
||||
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
rows, err := db.QueryContext(ctx, "select * from generate_series(1,10) n")
|
||||
if err != nil {
|
||||
t.Fatalf("db.QueryContext failed: %v", err)
|
||||
}
|
||||
|
||||
cancelFn()
|
||||
|
||||
for rows.Next() {
|
||||
t.Fatalf("no rows should ever be received")
|
||||
}
|
||||
|
||||
if rows.Err() != context.Canceled {
|
||||
t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled)
|
||||
}
|
||||
|
||||
if err := <-errChan; err != nil {
|
||||
t.Errorf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnQueryContextFailureRetry(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
@ -1116,7 +812,7 @@ func TestConnQueryContextFailureRetry(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
conn.Close(context.Background())
|
||||
stdlib.ReleaseConn(db, conn)
|
||||
}
|
||||
conn, err := db.Conn(context.Background())
|
||||
@ -1233,83 +929,6 @@ func TestStmtQueryContextSuccess(t *testing.T) {
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func TestStmtQueryContextCancel(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select * from generate_series(1, $1::int4) n"}),
|
||||
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}),
|
||||
pgmock.ExpectMessage(&pgproto3.Sync{}),
|
||||
|
||||
pgmock.SendMessage(&pgproto3.ParseComplete{}),
|
||||
pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: []uint32{23}}),
|
||||
pgmock.SendMessage(&pgproto3.RowDescription{
|
||||
Fields: []pgproto3.FieldDescription{
|
||||
{
|
||||
Name: "n",
|
||||
DataTypeOID: 23,
|
||||
DataTypeSize: 4,
|
||||
TypeModifier: -1,
|
||||
},
|
||||
},
|
||||
}),
|
||||
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
|
||||
pgmock.ExpectMessage(&pgproto3.Bind{PreparedStatement: "pgx_0", ParameterFormatCodes: []int16{1}, Parameters: [][]uint8{{0x0, 0x0, 0x0, 0x2a}}, ResultFormatCodes: []int16{1}}),
|
||||
pgmock.ExpectMessage(&pgproto3.Execute{}),
|
||||
pgmock.ExpectMessage(&pgproto3.Sync{}),
|
||||
|
||||
pgmock.SendMessage(&pgproto3.BindComplete{}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
||||
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
// defer closeDB(t, db) // mock DB doesn't close correctly yet
|
||||
|
||||
stmt, err := db.Prepare("select * from generate_series(1, $1::int4) n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// defer stmt.Close()
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, 42)
|
||||
if err != nil {
|
||||
t.Fatalf("stmt.QueryContext failed: %v", err)
|
||||
}
|
||||
|
||||
cancelFn()
|
||||
|
||||
for rows.Next() {
|
||||
t.Fatalf("no rows should ever be received")
|
||||
}
|
||||
|
||||
if rows.Err() != context.Canceled {
|
||||
t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled)
|
||||
}
|
||||
|
||||
if err := <-errChan; err != nil {
|
||||
t.Errorf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowsColumnTypes(t *testing.T) {
|
||||
columnTypesTests := []struct {
|
||||
Name string
|
||||
@ -1436,84 +1055,86 @@ func TestRowsColumnTypes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSimpleQueryLifeCycle(t *testing.T) {
|
||||
driverConfig := stdlib.DriverConfig{
|
||||
ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
|
||||
}
|
||||
// TODO - need to use new method of establishing connection with pgx specific configuration
|
||||
|
||||
stdlib.RegisterDriverConfig(&driverConfig)
|
||||
defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||
// driverConfig := stdlib.DriverConfig{
|
||||
// ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
|
||||
// }
|
||||
|
||||
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
// stdlib.RegisterDriverConfig(&driverConfig)
|
||||
// defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||
|
||||
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
}
|
||||
// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
||||
// if err != nil {
|
||||
// t.Fatalf("sql.Open failed: %v", err)
|
||||
// }
|
||||
// defer closeDB(t, db)
|
||||
|
||||
rowCount := int64(0)
|
||||
// rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
|
||||
// if err != nil {
|
||||
// t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
var (
|
||||
s string
|
||||
n int64
|
||||
)
|
||||
// rowCount := int64(0)
|
||||
|
||||
if err := rows.Scan(&s, &n); err != nil {
|
||||
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
// for rows.Next() {
|
||||
// rowCount++
|
||||
// var (
|
||||
// s string
|
||||
// n int64
|
||||
// )
|
||||
|
||||
if s != "foo" {
|
||||
t.Errorf(`Expected "foo", received "%v"`, s)
|
||||
}
|
||||
// if err := rows.Scan(&s, &n); err != nil {
|
||||
// t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
if n != rowCount {
|
||||
t.Errorf("Expected %d, received %d", rowCount, n)
|
||||
}
|
||||
}
|
||||
// if s != "foo" {
|
||||
// t.Errorf(`Expected "foo", received "%v"`, s)
|
||||
// }
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
}
|
||||
// if n != rowCount {
|
||||
// t.Errorf("Expected %d, received %d", rowCount, n)
|
||||
// }
|
||||
// }
|
||||
|
||||
if rowCount != 10 {
|
||||
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
}
|
||||
// if err = rows.Err(); err != nil {
|
||||
// t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
// }
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
}
|
||||
// if rowCount != 10 {
|
||||
// t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
// }
|
||||
|
||||
rows, err = db.Query("select 1 where false")
|
||||
if err != nil {
|
||||
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
}
|
||||
// err = rows.Close()
|
||||
// if err != nil {
|
||||
// t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
rowCount = int64(0)
|
||||
// rows, err = db.Query("select 1 where false")
|
||||
// if err != nil {
|
||||
// t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
}
|
||||
// rowCount = int64(0)
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
}
|
||||
// for rows.Next() {
|
||||
// rowCount++
|
||||
// }
|
||||
|
||||
if rowCount != 0 {
|
||||
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
}
|
||||
// if err = rows.Err(); err != nil {
|
||||
// t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
// }
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
}
|
||||
// if rowCount != 0 {
|
||||
// t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
// }
|
||||
|
||||
ensureConnValid(t, db)
|
||||
// err = rows.Close()
|
||||
// if err != nil {
|
||||
// t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
// ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/409
|
||||
|
@ -11,10 +11,10 @@ import (
|
||||
)
|
||||
|
||||
func openDB(t *testing.T) *sql.DB {
|
||||
config, err := pgx.ParseConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
|
||||
config, err := pgx.ParseConfig("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
|
||||
if err != nil {
|
||||
t.Fatalf("pgx.ParseConnectionString failed: %v", err)
|
||||
}
|
||||
|
||||
return stdlib.OpenDB(config)
|
||||
return stdlib.OpenDB(*config)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user