Inital pass at converting stdlib

Multiple tests still failing
This commit is contained in:
Jack Christensen 2019-04-12 16:57:42 -05:00
parent 3901f3ef88
commit b77f901168
4 changed files with 108 additions and 592 deletions

View File

@ -47,7 +47,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
conn *pgx.Conn
)
if conn, err = pgx.Connect(c.ConnConfig); err != nil {
if conn, err = pgx.ConnectConfig(ctx, &c.ConnConfig); err != nil {
return nil, err
}

View File

@ -72,13 +72,13 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"net"
"reflect"
"strings"
"sync"
"time"
"github.com/pkg/errors"
@ -99,9 +99,7 @@ var ctxKeyFakeTx ctxKey = 0
var ErrNotPgx = errors.New("not pgx *sql.DB")
func init() {
pgxDriver = &Driver{
configs: make(map[int64]*DriverConfig),
}
pgxDriver = &Driver{}
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
sql.Register("pgx", pgxDriver)
@ -126,97 +124,25 @@ var (
fakeTxConns map[*pgx.Conn]*sql.Tx
)
type Driver struct {
configMutex sync.Mutex
configCount int64
configs map[int64]*DriverConfig
}
type Driver struct{}
func (d *Driver) Open(name string) (driver.Conn, error) {
var (
connConfig pgx.ConnConfig
afterConnect func(*pgx.Conn) error
)
if len(name) >= 9 && name[0] == 0 {
idBuf := []byte(name)[1:9]
id := int64(binary.BigEndian.Uint64(idBuf))
d.configMutex.Lock()
connConfig = d.configs[id].ConnConfig
afterConnect = d.configs[id].AfterConnect
d.configMutex.Unlock()
name = name[9:]
}
parsedConfig, err := pgx.ParseConnectionString(name)
if err != nil {
return nil, err
}
connConfig = connConfig.Merge(parsedConfig)
conn, err := pgx.Connect(connConfig)
connConfig, err := pgx.ParseConfig(name)
if err != nil {
return nil, err
}
if afterConnect != nil {
err = afterConnect(conn)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
defer cancel()
conn, err := pgx.ConnectConfig(ctx, connConfig)
if err != nil {
return nil, err
}
c := &Conn{conn: conn, driver: d, connConfig: connConfig}
c := &Conn{conn: conn, driver: d, connConfig: *connConfig}
return c, nil
}
type DriverConfig struct {
pgx.ConnConfig
AfterConnect func(*pgx.Conn) error // function to call on every new connection
driver *Driver
id int64
}
// ConnectionString encodes the DriverConfig into the original connection
// string. DriverConfig must be registered before calling ConnectionString.
func (c *DriverConfig) ConnectionString(original string) string {
if c.driver == nil {
panic("DriverConfig must be registered before calling ConnectionString")
}
buf := make([]byte, 9)
binary.BigEndian.PutUint64(buf[1:], uint64(c.id))
buf = append(buf, original...)
return string(buf)
}
func (d *Driver) registerDriverConfig(c *DriverConfig) {
d.configMutex.Lock()
c.driver = d
c.id = d.configCount
d.configs[d.configCount] = c
d.configCount++
d.configMutex.Unlock()
}
func (d *Driver) unregisterDriverConfig(c *DriverConfig) {
d.configMutex.Lock()
delete(d.configs, c.id)
d.configMutex.Unlock()
}
// RegisterDriverConfig registers a DriverConfig for use with Open.
func RegisterDriverConfig(c *DriverConfig) {
pgxDriver.registerDriverConfig(c)
}
// UnregisterDriverConfig removes a DriverConfig registration.
func UnregisterDriverConfig(c *DriverConfig) {
pgxDriver.unregisterDriverConfig(c)
}
type Conn struct {
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names
@ -247,7 +173,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
}
func (c *Conn) Close() error {
return c.conn.Close()
return c.conn.Close(context.Background())
}
func (c *Conn) Begin() (driver.Tx, error) {
@ -283,23 +209,12 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
pgxOpts.AccessMode = pgx.ReadOnly
}
return c.conn.BeginEx(ctx, &pgxOpts)
}
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
tx, err := c.conn.BeginEx(ctx, &pgxOpts)
if err != nil {
return nil, err
}
args := valueToInterface(argsV)
commandTag, err := c.conn.Exec(query, args...)
// if we got a network error before we had a chance to send the query, retry
if err != nil && !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return driver.RowsAffected(commandTag.RowsAffected()), err
return wrapTx{tx: tx}, nil
}
func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
@ -309,7 +224,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
args := namedValueToInterface(argsV)
commandTag, err := c.conn.ExecEx(ctx, query, nil, args...)
commandTag, err := c.conn.Exec(ctx, query, args...)
// if we got a network error before we had a chance to send the query, retry
if err != nil && !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
@ -319,44 +234,16 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
return driver.RowsAffected(commandTag.RowsAffected()), err
}
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
if !c.connConfig.PreferSimpleProtocol {
ps, err := c.conn.Prepare("", query)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPrepared("", argsV)
}
rows, err := c.conn.Query(query, valueToInterface(argsV)...)
if err != nil {
// if we got a network error before we had a chance to send the query, retry
if !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return nil, err
}
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
var rows pgx.Rows
if !c.connConfig.PreferSimpleProtocol {
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
c.conn.Deallocate("stdlibtemp")
ps, err := c.conn.PrepareEx(ctx, "stdlibtemp", query, nil)
if err != nil {
// since PrepareEx failed, we didn't actually get to send the values, so
// we can safely retry
@ -367,10 +254,10 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, "", argsV)
return c.queryPreparedContext(ctx, "stdlibtemp", argsV)
}
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
rows, err := c.conn.Query(ctx, query, namedValueToInterface(argsV)...)
if err != nil {
// if we got a network error before we had a chance to send the query, retry
if !c.conn.LastStmtSent() {
@ -393,7 +280,7 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er
args := valueToInterface(argsV)
rows, err := c.conn.Query(name, args...)
rows, err := c.conn.Query(context.Background(), name, args...)
if err != nil {
return nil, err
}
@ -408,12 +295,14 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr
args := namedValueToInterface(argsV)
rows, err := c.conn.QueryEx(ctx, name, nil, args...)
rows, err := c.conn.Query(ctx, name, args...)
if err != nil {
return nil, err
}
return &Rows{rows: rows}, nil
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) Ping(ctx context.Context) error {
@ -450,7 +339,7 @@ func (s *Stmt) NumInput() int {
}
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
return s.conn.Exec(s.ps.Name, argsV)
return nil, errors.New("Stmt.Exec deprecated and not implemented")
}
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
@ -458,7 +347,7 @@ func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driv
}
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
return s.conn.queryPrepared(s.ps.Name, argsV)
return nil, errors.New("Stmt.Query deprecated and not implemented")
}
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
@ -466,7 +355,7 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri
}
type Rows struct {
rows *pgx.Rows
rows pgx.Rows
values []interface{}
skipNext bool
skipNextMore bool
@ -605,6 +494,12 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
return args
}
type wrapTx struct{ tx *pgx.Tx }
func (wtx wrapTx) Commit() error { return wtx.tx.Commit(context.Background()) }
func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(context.Background()) }
type fakeTx struct{}
func (fakeTx) Commit() error { return nil }

View File

@ -6,12 +6,12 @@ import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"math"
"reflect"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgmock"
@ -127,33 +127,6 @@ func TestNormalLifeCycle(t *testing.T) {
ensureConnValid(t, db)
}
func TestOpenWithDriverConfigAfterConnect(t *testing.T) {
driverConfig := stdlib.DriverConfig{
AfterConnect: func(c *pgx.Conn) error {
_, err := c.Exec("create temporary sequence pgx")
return err
},
}
stdlib.RegisterDriverConfig(&driverConfig)
defer stdlib.UnregisterDriverConfig(&driverConfig)
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer closeDB(t, db)
var n int64
err = db.QueryRow("select nextval('pgx')").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
if n != 1 {
t.Fatalf("n => %d, want %d", n, 1)
}
}
func TestStmtExec(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
@ -330,44 +303,6 @@ func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
}
func TestConnQueryLog(t *testing.T) {
logger := &testLogger{}
driverConfig := stdlib.DriverConfig{
ConnConfig: pgx.ConnConfig{
Host: "127.0.0.1",
User: "pgx_md5",
Password: "secret",
Database: "pgx_test",
Logger: logger,
},
}
stdlib.RegisterDriverConfig(&driverConfig)
defer stdlib.UnregisterDriverConfig(&driverConfig)
db, err := sql.Open("pgx", driverConfig.ConnectionString(""))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer closeDB(t, db)
var n int64
err = db.QueryRow("select 1").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
l := logger.logs[len(logger.logs)-1]
if l.msg != "Query" {
t.Errorf("Expected to log Query, but got %v", l)
}
if l.data["sql"] != "select 1" {
t.Errorf("Expected to log Query with sql 'select 1', but got %v", l)
}
}
func TestConnQueryNull(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
@ -430,8 +365,8 @@ func TestConnQueryFailure(t *testing.T) {
defer closeDB(t, db)
_, err := db.Query("select 'foo")
if _, ok := err.(pgx.PgError); !ok {
t.Fatalf("Expected db.Query to return pgx.PgError, but instead received: %v", err)
if _, ok := err.(*pgconn.PgError); !ok {
t.Fatalf("Expected db.Query to return pgconn.PgError, but instead received: %v", err)
}
ensureConnValid(t, db)
@ -723,7 +658,7 @@ func TestBeginTxContextCancel(t *testing.T) {
var n int
err = db.QueryRow("select count(*) from t").Scan(&n)
if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "42P01" {
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
}
@ -747,52 +682,6 @@ func acceptStandardPgxConn(backend *pgproto3.Backend) error {
return typeScript.Run(backend)
}
func TestBeginTxContextCancelWithDeadConn(t *testing.T) {
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}),
)
server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
errChan := make(chan error)
go func() {
errChan <- server.ServeOne()
}()
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer closeDB(t, db)
ctx, cancelFn := context.WithCancel(context.Background())
tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatalf("BeginTx failed: %v", err)
}
cancelFn()
err = tx.Commit()
if err != context.Canceled && err != sql.ErrTxDone {
t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
}
if err := <-errChan; err != nil {
t.Fatalf("mock server err: %v", err)
}
}
func TestAcquireConn(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
@ -807,7 +696,7 @@ func TestAcquireConn(t *testing.T) {
}
var n int32
err = conn.QueryRow("select 1").Scan(&n)
err = conn.QueryRow(context.Background(), "select 1").Scan(&n)
if err != nil {
t.Errorf("%d. QueryRow failed: %v", i, err)
}
@ -843,46 +732,6 @@ func TestConnPingContextSuccess(t *testing.T) {
ensureConnValid(t, db)
}
func TestConnPingContextCancel(t *testing.T) {
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: ";"}),
pgmock.WaitForClose(),
)
server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
defer server.Close()
errChan := make(chan error, 1)
go func() {
errChan <- server.ServeOne()
}()
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer closeDB(t, db)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err = db.PingContext(ctx)
if err != context.DeadlineExceeded {
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
}
if err := <-errChan; err != nil {
t.Errorf("mock server err: %v", err)
}
}
func TestConnPrepareContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
@ -896,48 +745,6 @@ func TestConnPrepareContextSuccess(t *testing.T) {
ensureConnValid(t, db)
}
func TestConnPrepareContextCancel(t *testing.T) {
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}),
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}),
pgmock.ExpectMessage(&pgproto3.Sync{}),
pgmock.WaitForClose(),
)
server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
defer server.Close()
errChan := make(chan error)
go func() {
errChan <- server.ServeOne()
}()
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer closeDB(t, db)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = db.PrepareContext(ctx, "select now()")
if err != context.DeadlineExceeded {
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
}
if err := <-errChan; err != nil {
t.Errorf("mock server err: %v", err)
}
}
func TestConnExecContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
@ -950,46 +757,6 @@ func TestConnExecContextSuccess(t *testing.T) {
ensureConnValid(t, db)
}
func TestConnExecContextCancel(t *testing.T) {
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}),
pgmock.WaitForClose(),
)
server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
defer server.Close()
errChan := make(chan error)
go func() {
errChan <- server.ServeOne()
}()
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer closeDB(t, db)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = db.ExecContext(ctx, "create temporary table exec_context_test(id serial primary key)")
if err != context.DeadlineExceeded {
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
}
if err := <-errChan; err != nil {
t.Errorf("mock server err: %v", err)
}
}
func TestConnExecContextFailureRetry(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
@ -1000,7 +767,7 @@ func TestConnExecContextFailureRetry(t *testing.T) {
if err != nil {
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
}
conn.Close()
conn.Close(context.Background())
stdlib.ReleaseConn(db, conn)
}
conn, err := db.Conn(context.Background())
@ -1035,77 +802,6 @@ func TestConnQueryContextSuccess(t *testing.T) {
ensureConnValid(t, db)
}
func TestConnQueryContextCancel(t *testing.T) {
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Parse{Query: "select * from generate_series(1,10) n"}),
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S'}),
pgmock.ExpectMessage(&pgproto3.Sync{}),
pgmock.SendMessage(&pgproto3.ParseComplete{}),
pgmock.SendMessage(&pgproto3.ParameterDescription{}),
pgmock.SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
{
Name: "n",
DataTypeOID: 23,
DataTypeSize: 4,
TypeModifier: -1,
},
},
}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
pgmock.ExpectMessage(&pgproto3.Bind{ResultFormatCodes: []int16{1}}),
pgmock.ExpectMessage(&pgproto3.Execute{}),
pgmock.ExpectMessage(&pgproto3.Sync{}),
pgmock.SendMessage(&pgproto3.BindComplete{}),
pgmock.WaitForClose(),
)
server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
defer server.Close()
errChan := make(chan error)
go func() {
errChan <- server.ServeOne()
}()
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer db.Close()
ctx, cancelFn := context.WithCancel(context.Background())
rows, err := db.QueryContext(ctx, "select * from generate_series(1,10) n")
if err != nil {
t.Fatalf("db.QueryContext failed: %v", err)
}
cancelFn()
for rows.Next() {
t.Fatalf("no rows should ever be received")
}
if rows.Err() != context.Canceled {
t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled)
}
if err := <-errChan; err != nil {
t.Errorf("mock server err: %v", err)
}
}
func TestConnQueryContextFailureRetry(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
@ -1116,7 +812,7 @@ func TestConnQueryContextFailureRetry(t *testing.T) {
if err != nil {
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
}
conn.Close()
conn.Close(context.Background())
stdlib.ReleaseConn(db, conn)
}
conn, err := db.Conn(context.Background())
@ -1233,83 +929,6 @@ func TestStmtQueryContextSuccess(t *testing.T) {
ensureConnValid(t, db)
}
func TestStmtQueryContextCancel(t *testing.T) {
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select * from generate_series(1, $1::int4) n"}),
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}),
pgmock.ExpectMessage(&pgproto3.Sync{}),
pgmock.SendMessage(&pgproto3.ParseComplete{}),
pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: []uint32{23}}),
pgmock.SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
{
Name: "n",
DataTypeOID: 23,
DataTypeSize: 4,
TypeModifier: -1,
},
},
}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
pgmock.ExpectMessage(&pgproto3.Bind{PreparedStatement: "pgx_0", ParameterFormatCodes: []int16{1}, Parameters: [][]uint8{{0x0, 0x0, 0x0, 0x2a}}, ResultFormatCodes: []int16{1}}),
pgmock.ExpectMessage(&pgproto3.Execute{}),
pgmock.ExpectMessage(&pgproto3.Sync{}),
pgmock.SendMessage(&pgproto3.BindComplete{}),
pgmock.WaitForClose(),
)
server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
defer server.Close()
errChan := make(chan error)
go func() {
errChan <- server.ServeOne()
}()
db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
// defer closeDB(t, db) // mock DB doesn't close correctly yet
stmt, err := db.Prepare("select * from generate_series(1, $1::int4) n")
if err != nil {
t.Fatal(err)
}
// defer stmt.Close()
ctx, cancelFn := context.WithCancel(context.Background())
rows, err := stmt.QueryContext(ctx, 42)
if err != nil {
t.Fatalf("stmt.QueryContext failed: %v", err)
}
cancelFn()
for rows.Next() {
t.Fatalf("no rows should ever be received")
}
if rows.Err() != context.Canceled {
t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled)
}
if err := <-errChan; err != nil {
t.Errorf("mock server err: %v", err)
}
}
func TestRowsColumnTypes(t *testing.T) {
columnTypesTests := []struct {
Name string
@ -1436,84 +1055,86 @@ func TestRowsColumnTypes(t *testing.T) {
}
func TestSimpleQueryLifeCycle(t *testing.T) {
driverConfig := stdlib.DriverConfig{
ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
}
// TODO - need to use new method of establishing connection with pgx specific configuration
stdlib.RegisterDriverConfig(&driverConfig)
defer stdlib.UnregisterDriverConfig(&driverConfig)
// driverConfig := stdlib.DriverConfig{
// ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
// }
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer closeDB(t, db)
// stdlib.RegisterDriverConfig(&driverConfig)
// defer stdlib.UnregisterDriverConfig(&driverConfig)
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
// if err != nil {
// t.Fatalf("sql.Open failed: %v", err)
// }
// defer closeDB(t, db)
rowCount := int64(0)
// rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
// if err != nil {
// t.Fatalf("stmt.Query unexpectedly failed: %v", err)
// }
for rows.Next() {
rowCount++
var (
s string
n int64
)
// rowCount := int64(0)
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
// for rows.Next() {
// rowCount++
// var (
// s string
// n int64
// )
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
// if err := rows.Scan(&s, &n); err != nil {
// t.Fatalf("rows.Scan unexpectedly failed: %v", err)
// }
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
// if s != "foo" {
// t.Errorf(`Expected "foo", received "%v"`, s)
// }
if err = rows.Err(); err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
// if n != rowCount {
// t.Errorf("Expected %d, received %d", rowCount, n)
// }
// }
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
// if err = rows.Err(); err != nil {
// t.Fatalf("rows.Err unexpectedly is: %v", err)
// }
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
// if rowCount != 10 {
// t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
// }
rows, err = db.Query("select 1 where false")
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
// err = rows.Close()
// if err != nil {
// t.Fatalf("rows.Close unexpectedly failed: %v", err)
// }
rowCount = int64(0)
// rows, err = db.Query("select 1 where false")
// if err != nil {
// t.Fatalf("stmt.Query unexpectedly failed: %v", err)
// }
for rows.Next() {
rowCount++
}
// rowCount = int64(0)
if err = rows.Err(); err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
// for rows.Next() {
// rowCount++
// }
if rowCount != 0 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
// if err = rows.Err(); err != nil {
// t.Fatalf("rows.Err unexpectedly is: %v", err)
// }
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
// if rowCount != 0 {
// t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
// }
ensureConnValid(t, db)
// err = rows.Close()
// if err != nil {
// t.Fatalf("rows.Close unexpectedly failed: %v", err)
// }
// ensureConnValid(t, db)
}
// https://github.com/jackc/pgx/issues/409

View File

@ -11,10 +11,10 @@ import (
)
func openDB(t *testing.T) *sql.DB {
config, err := pgx.ParseConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
config, err := pgx.ParseConfig("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
if err != nil {
t.Fatalf("pgx.ParseConnectionString failed: %v", err)
}
return stdlib.OpenDB(config)
return stdlib.OpenDB(*config)
}