Partial conversion of pgx to use pgconn

pull/483/head
Jack Christensen 2019-01-26 16:46:30 -06:00
parent e3d431d0df
commit d3a2c1c107
17 changed files with 877 additions and 1830 deletions

186
batch.go
View File

@ -6,6 +6,7 @@ import (
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
"github.com/pkg/errors"
)
type batchItem struct {
@ -26,6 +27,8 @@ type Batch struct {
ctx context.Context
err error
inTx bool
mrr *pgconn.MultiResultReader
}
// BeginBatch returns a *Batch query for c.
@ -56,10 +59,8 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
})
}
// Send sends all queued queries to the server at once.
// If the batch is created from a conn Object then All queries are wrapped
// in a transaction. The transaction can optionally be configured with
// txOptions. The context is in effect until the Batch is closed.
// Send sends all queued queries to the server at once. All queries are run in an implicit transaction unless explicit
// transaction control statements are executed.
//
// Warning: Send writes all queued queries before reading any results. This can
// cause a deadlock if an excessive number of queries are queued. It is highly
@ -78,7 +79,7 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
// able to finish sending the responses.
//
// See https://github.com/jackc/pgx/issues/374.
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
func (b *Batch) Send(ctx context.Context) error {
if b.err != nil {
return b.err
}
@ -94,112 +95,62 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
return err
}
buf := b.conn.wbuf
if !b.inTx {
buf = appendQuery(buf, txOptions.beginSQL())
}
err = b.conn.initContext(ctx)
if err != nil {
return err
}
batch := &pgconn.Batch{}
for _, bi := range b.items {
var psName string
var psParameterOIDs []pgtype.OID
var parameterOIDs []pgtype.OID
ps := b.conn.preparedStatements[bi.query]
if ps, ok := b.conn.preparedStatements[bi.query]; ok {
psName = ps.Name
psParameterOIDs = ps.ParameterOIDs
if ps != nil {
parameterOIDs = ps.ParameterOIDs
} else {
psParameterOIDs = bi.parameterOIDs
buf = appendParse(buf, "", bi.query, psParameterOIDs)
parameterOIDs = bi.parameterOIDs
}
var err error
buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes)
args, err := convertDriverValuers(bi.arguments)
if err != nil {
return err
}
buf = appendDescribe(buf, 'P', "")
buf = appendExecute(buf, "", 0)
}
buf = appendSync(buf)
b.conn.pendingReadyForQueryCount++
if !b.inTx {
buf = appendQuery(buf, "commit")
b.conn.pendingReadyForQueryCount++
}
n, err := b.conn.pgConn.Conn().Write(buf)
if err != nil {
if fatalWriteErr(n, err) {
b.conn.die(err)
}
return err
}
for !b.inTx {
msg, err := b.conn.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
return nil
default:
if err := b.conn.processContextFreeMsg(msg); err != nil {
paramFormats := make([]int16, len(args))
paramValues := make([][]byte, len(args))
for i := range args {
paramFormats[i] = chooseParameterFormatCode(b.conn.ConnInfo, parameterOIDs[i], args[i])
paramValues[i], err = newencodePreparedStatementArgument(b.conn.ConnInfo, parameterOIDs[i], args[i])
if err != nil {
return err
}
}
if ps != nil {
batch.ExecPrepared(ps.Name, paramValues, paramFormats, bi.resultFormatCodes)
} else {
oids := make([]uint32, len(parameterOIDs))
for i := 0; i < len(parameterOIDs); i++ {
oids[i] = uint32(parameterOIDs[i])
}
batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes)
}
}
b.mrr = b.conn.pgConn.ExecBatch(ctx, batch)
return nil
}
// ExecResults reads the results from the next query in the batch as if the
// query has been sent with Exec.
func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
if b.err != nil {
return "", b.err
}
select {
case <-b.ctx.Done():
b.die(b.ctx.Err())
return "", b.ctx.Err()
default:
}
if err := b.ensureCommandComplete(); err != nil {
b.die(err)
if !b.mrr.NextResult() {
err := b.mrr.Close()
if err == nil {
err = errors.New("no result")
}
return "", err
}
b.resultsRead++
b.pendingCommandComplete = true
for {
msg, err := b.conn.rxMsg()
if err != nil {
return "", err
}
switch msg := msg.(type) {
case *pgproto3.CommandComplete:
b.pendingCommandComplete = false
return pgconn.CommandTag(msg.CommandTag), nil
default:
if err := b.conn.processContextFreeMsg(msg); err != nil {
return "", err
}
}
}
return b.mrr.ResultReader().Close()
}
// QueryResults reads the results from the next query in the batch as if the
@ -207,38 +158,16 @@ func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
func (b *Batch) QueryResults() (*Rows, error) {
rows := b.conn.getRows("batch query", nil)
if b.err != nil {
rows.fatal(b.err)
return rows, b.err
if !b.mrr.NextResult() {
rows.err = b.mrr.Close()
if rows.err == nil {
rows.err = errors.New("no result")
}
rows.closed = true
return rows, rows.err
}
select {
case <-b.ctx.Done():
b.die(b.ctx.Err())
rows.fatal(b.err)
return rows, b.ctx.Err()
default:
}
if err := b.ensureCommandComplete(); err != nil {
b.die(err)
rows.fatal(err)
return rows, err
}
b.resultsRead++
b.pendingCommandComplete = true
fieldDescriptions, err := b.conn.readUntilRowDescription()
if err != nil {
b.die(err)
rows.fatal(b.err)
return rows, err
}
rows.batch = b
rows.fields = fieldDescriptions
rows.resultReader = b.mrr.ResultReader()
return rows, nil
}
@ -254,28 +183,7 @@ func (b *Batch) QueryRowResults() *Row {
// operation may have made it impossible to resyncronize the connection with the
// server. In this case the underlying connection will have been closed.
func (b *Batch) Close() (err error) {
if b.err != nil {
return b.err
}
defer func() {
err = b.conn.termContext(err)
if b.conn != nil && b.connPool != nil {
b.connPool.Release(b.conn)
}
}()
for i := b.resultsRead; i < len(b.items); i++ {
if _, err = b.ExecResults(); err != nil {
return err
}
}
if err = b.conn.ensureConnectionReadyForQuery(); err != nil {
return err
}
return nil
return b.mrr.Close()
}
func (b *Batch) die(err error) {

View File

@ -17,10 +17,10 @@ func TestConnBeginBatch(t *testing.T) {
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
@ -50,7 +50,7 @@ func TestConnBeginBatch(t *testing.T) {
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
@ -71,6 +71,14 @@ func TestConnBeginBatch(t *testing.T) {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
@ -173,7 +181,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
)
}
err = batch.Send(context.Background(), nil)
err = batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
@ -233,7 +241,7 @@ func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) {
ctx, cancelFn := context.WithCancel(context.Background())
err := batch.Send(ctx, nil)
err := batch.Send(ctx)
if err != nil {
t.Fatal(err)
}
@ -245,9 +253,7 @@ func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) {
t.Errorf("err => %v, want %v", err, context.Canceled)
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) {
@ -269,21 +275,26 @@ func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) {
ctx, cancelFn := context.WithCancel(context.Background())
err := batch.Send(ctx, nil)
err := batch.Send(ctx)
if err != nil {
t.Fatal(err)
}
cancelFn()
_, err = batch.QueryResults()
if err != context.Canceled {
t.Errorf("err => %v, want %v", err, context.Canceled)
rows, err := batch.QueryResults()
if rows.Next() {
t.Error("unexpected row")
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
if rows.Err() != context.Canceled {
t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled)
}
batch.Close()
ensureConnValid(t, conn)
}
func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) {
@ -305,7 +316,7 @@ func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) {
ctx, cancelFn := context.WithCancel(context.Background())
err := batch.Send(ctx, nil)
err := batch.Send(ctx)
if err != nil {
t.Fatal(err)
}
@ -317,9 +328,7 @@ func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) {
t.Errorf("err => %v, want %v", err, context.Canceled)
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
@ -340,7 +349,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
@ -411,7 +420,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
@ -440,9 +449,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
@ -458,7 +465,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
@ -474,9 +481,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
t.Error("Expected error")
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
@ -486,10 +491,10 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
@ -504,7 +509,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
nil,
)
err := batch.Send(context.Background(), nil)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
@ -535,10 +540,10 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
@ -553,7 +558,7 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
nil,
)
err := batch.Send(context.Background(), nil)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
@ -584,15 +589,15 @@ func TestTxBeginBatch(t *testing.T) {
defer closeConn(t, conn)
sql := `create temporary table ledger1(
id serial primary key,
description varchar not null
);`
id serial primary key,
description varchar not null
);`
mustExec(t, conn, sql)
sql = `create temporary table ledger2(
id int primary key,
amount int not null
);`
id int primary key,
amount int not null
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin()
@ -603,7 +608,7 @@ func TestTxBeginBatch(t *testing.T) {
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
@ -627,7 +632,7 @@ func TestTxBeginBatch(t *testing.T) {
nil,
)
err = batch.Send(context.Background(), nil)
err = batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
@ -639,8 +644,8 @@ func TestTxBeginBatch(t *testing.T) {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
var amout int
err = batch.QueryRowResults().Scan(&amout)
var amount int
err = batch.QueryRowResults().Scan(&amount)
if err != nil {
t.Error(err)
}
@ -669,9 +674,9 @@ func TestTxBeginBatchRollback(t *testing.T) {
defer closeConn(t, conn)
sql := `create temporary table ledger1(
id serial primary key,
description varchar not null
);`
id serial primary key,
description varchar not null
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin()
@ -682,7 +687,7 @@ func TestTxBeginBatchRollback(t *testing.T) {
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}

View File

@ -668,7 +668,7 @@ func BenchmarkMultipleQueriesBatch(b *testing.B) {
)
}
err := batch.Send(context.Background(), nil)
err := batch.Send(context.Background())
if err != nil {
b.Fatal(err)
}

142
conn.go
View File

@ -13,7 +13,6 @@ import (
"github.com/pkg/errors"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
@ -69,7 +68,6 @@ type Conn struct {
config *ConnConfig // config used when establishing this connection
preparedStatements map[string]*PreparedStatement
channels map[string]struct{}
notifications []*Notification
logger Logger
logLevel int
fp *fastpath
@ -105,13 +103,6 @@ type PrepareExOptions struct {
ParameterOIDs []pgtype.OID
}
// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
type Notification struct {
PID uint32 // backend pid that sent the notification
Channel string // channel from which notification was received
Payload string
}
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
// multiple parts such as ["schema", "table"] or ["table", "column"].
type Identifier []string
@ -501,17 +492,6 @@ func (c *Conn) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExO
return nil, err
}
err = c.initContext(ctx)
if err != nil {
return nil, err
}
ps, err = c.prepareEx(name, sql, opts)
err = c.termContext(err)
return ps, err
}
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
if name != "" {
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil
@ -562,7 +542,9 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
}
c.preparedStatements[name] = ps
if name != "" {
c.preparedStatements[name] = ps
}
return ps, nil
}
@ -593,42 +575,8 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
delete(c.preparedStatements, name)
// close
buf := c.wbuf
buf = append(buf, 'C')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 'S')
buf = append(buf, name...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// flush
buf = append(buf, 'H')
buf = pgio.AppendInt32(buf, 4)
_, err = c.pgConn.Conn().Write(buf)
if err != nil {
c.die(err)
return err
}
for {
msg, err := c.rxMsg()
if err != nil {
return err
}
switch msg.(type) {
case *pgproto3.CloseComplete:
return nil
default:
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
_, err = c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
return err
}
// Listen establishes a PostgreSQL listen/notify to channel
@ -654,64 +602,10 @@ func (c *Conn) Unlisten(channel string) error {
return nil
}
// WaitForNotification waits for a PostgreSQL notification.
func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) {
// Return already received notification immediately
if len(c.notifications) > 0 {
notification := c.notifications[0]
c.notifications = c.notifications[1:]
return notification, nil
}
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
err = c.initContext(ctx)
if err != nil {
return nil, err
}
defer func() {
err = c.termContext(err)
}()
if err = c.lock(); err != nil {
return nil, err
}
defer func() {
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
err = unlockErr
}
}()
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
for {
msg, err := c.rxMsg()
if err != nil {
return nil, err
}
err = c.processContextFreeMsg(msg)
if err != nil {
return nil, err
}
if len(c.notifications) > 0 {
notification := c.notifications[0]
c.notifications = c.notifications[1:]
return notification, nil
}
}
}
func (c *Conn) IsAlive() bool {
c.mux.Lock()
defer c.mux.Unlock()
return c.status >= connStatusIdle
return c.pgConn.IsAlive() && c.status >= connStatusIdle
}
func (c *Conn) CauseOfDeath() error {
@ -807,8 +701,6 @@ func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
return c.rxErrorResponse(msg)
case *pgproto3.NotificationResponse:
c.rxNotificationResponse(msg)
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
}
@ -886,14 +778,6 @@ func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgty
return parameters
}
func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) {
n := new(Notification)
n.PID = msg.PID
n.Channel = msg.Channel
n.Payload = msg.Payload
c.notifications = append(c.notifications, n)
}
func (c *Conn) die(err error) {
c.mux.Lock()
defer c.mux.Unlock()
@ -1238,7 +1122,13 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
resultFormats := make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
resultFormats[i] = ps.FieldDescriptions[i].FormatCode
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
} else {
resultFormats[i] = TextFormatCode
}
}
}
c.lastStmtSent = true
@ -1307,13 +1197,9 @@ func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.Field
dst.DataType = pgtype.OID(src.DataTypeOID)
dst.DataTypeSize = src.DataTypeSize
dst.Modifier = src.TypeModifier
dst.FormatCode = src.Format
if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok {
dst.DataTypeName = dt.Name
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
dst.FormatCode = BinaryFormatCode
} else {
dst.FormatCode = TextFormatCode
}
}
}

View File

@ -2,7 +2,6 @@ package pgx
import (
"context"
"io"
"sync"
"time"
@ -204,7 +203,6 @@ func (p *ConnPool) Release(conn *Conn) {
}
conn.channels = make(map[string]struct{})
}
conn.notifications = nil
p.cond.L.Lock()
@ -544,28 +542,6 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
return c.CopyFrom(tableName, columnNames, rowSrc)
}
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
c, err := p.Acquire()
if err != nil {
return "", err
}
defer p.Release(c)
return c.CopyFromReader(r, sql)
}
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (pgconn.CommandTag, error) {
c, err := p.Acquire()
if err != nil {
return "", err
}
defer p.Release(c)
return c.CopyToWriter(w, sql, args...)
}
// BeginBatch acquires a connection and begins a batch on that connection. When
// *Batch is finished, the connection is released automatically.
func (p *ConnPool) BeginBatch() *Batch {

View File

@ -931,142 +931,146 @@ func TestConnPoolPrepareDeallocatePrepare(t *testing.T) {
func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) {
t.Parallel()
pool := createConnPool(t, 2)
defer pool.Close()
t.Skip("TODO")
testPreparedStatement := func(db queryRower, desc string) {
var s string
err := db.QueryRow("test", "hello").Scan(&s)
if err != nil {
t.Fatalf("%s. Executing prepared statement failed: %v", desc, err)
}
// pool := createConnPool(t, 2)
// defer pool.Close()
if s != "hello" {
t.Fatalf("%s. Prepared statement did not return expected value: %v", desc, s)
}
}
// testPreparedStatement := func(db queryRower, desc string) {
// var s string
// err := db.QueryRow("test", "hello").Scan(&s)
// if err != nil {
// t.Fatalf("%s. Executing prepared statement failed: %v", desc, err)
// }
newReleaseOnce := func(c *pgx.Conn) func() {
var once sync.Once
return func() {
once.Do(func() { pool.Release(c) })
}
}
// if s != "hello" {
// t.Fatalf("%s. Prepared statement did not return expected value: %v", desc, s)
// }
// }
c1, err := pool.Acquire()
if err != nil {
t.Fatalf("Unable to acquire connection: %v", err)
}
c1Release := newReleaseOnce(c1)
defer c1Release()
// newReleaseOnce := func(c *pgx.Conn) func() {
// var once sync.Once
// return func() {
// once.Do(func() { pool.Release(c) })
// }
// }
_, err = pool.Prepare("test", "select $1::varchar")
if err != nil {
t.Fatalf("Unable to prepare statement: %v", err)
}
// c1, err := pool.Acquire()
// if err != nil {
// t.Fatalf("Unable to acquire connection: %v", err)
// }
// c1Release := newReleaseOnce(c1)
// defer c1Release()
testPreparedStatement(pool, "pool")
// _, err = pool.Prepare("test", "select $1::varchar")
// if err != nil {
// t.Fatalf("Unable to prepare statement: %v", err)
// }
c1Release()
// testPreparedStatement(pool, "pool")
c2, err := pool.Acquire()
if err != nil {
t.Fatalf("Unable to acquire connection: %v", err)
}
c2Release := newReleaseOnce(c2)
defer c2Release()
// c1Release()
// This conn will not be available and will be connection at this point
c3, err := pool.Acquire()
if err != nil {
t.Fatalf("Unable to acquire connection: %v", err)
}
c3Release := newReleaseOnce(c3)
defer c3Release()
// c2, err := pool.Acquire()
// if err != nil {
// t.Fatalf("Unable to acquire connection: %v", err)
// }
// c2Release := newReleaseOnce(c2)
// defer c2Release()
testPreparedStatement(c2, "c2")
testPreparedStatement(c3, "c3")
// // This conn will not be available and will be connection at this point
// c3, err := pool.Acquire()
// if err != nil {
// t.Fatalf("Unable to acquire connection: %v", err)
// }
// c3Release := newReleaseOnce(c3)
// defer c3Release()
c2Release()
c3Release()
// testPreparedStatement(c2, "c2")
// testPreparedStatement(c3, "c3")
err = pool.Deallocate("test")
if err != nil {
t.Errorf("Deallocate failed: %v", err)
}
// c2Release()
// c3Release()
var s string
err = pool.QueryRow("test", "hello").Scan(&s)
if err, ok := err.(*pgconn.PgError); !(ok && err.Code == "42601") {
t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err)
}
// err = pool.Deallocate("test")
// if err != nil {
// t.Errorf("Deallocate failed: %v", err)
// }
// var s string
// err = pool.QueryRow("test", "hello").Scan(&s)
// if err, ok := err.(*pgconn.PgError); !(ok && err.Code == "42601") {
// t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err)
// }
}
func TestConnPoolBeginBatch(t *testing.T) {
t.Parallel()
pool := createConnPool(t, 2)
defer pool.Close()
t.Skip("TODO")
batch := pool.BeginBatch()
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
// pool := createConnPool(t, 2)
// defer pool.Close()
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
// batch := pool.BeginBatch()
// batch.Queue("select n from generate_series(0,5) n",
// nil,
// nil,
// []int16{pgx.BinaryFormatCode},
// )
// batch.Queue("select n from generate_series(0,5) n",
// nil,
// nil,
// []int16{pgx.BinaryFormatCode},
// )
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
// err := batch.Send(context.Background(), nil)
// if err != nil {
// t.Fatal(err)
// }
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
// rows, err := batch.QueryResults()
// if err != nil {
// t.Error(err)
// }
if rows.Err() != nil {
t.Error(rows.Err())
}
// for i := 0; rows.Next(); i++ {
// var n int
// if err := rows.Scan(&n); err != nil {
// t.Error(err)
// }
// if n != i {
// t.Errorf("n => %v, want %v", n, i)
// }
// }
rows, err = batch.QueryResults()
if err != nil {
t.Error(err)
}
// if rows.Err() != nil {
// t.Error(rows.Err())
// }
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
// rows, err = batch.QueryResults()
// if err != nil {
// t.Error(err)
// }
if rows.Err() != nil {
t.Error(rows.Err())
}
// for i := 0; rows.Next(); i++ {
// var n int
// if err := rows.Scan(&n); err != nil {
// t.Error(err)
// }
// if n != i {
// t.Errorf("n => %v, want %v", n, i)
// }
// }
err = batch.Close()
if err != nil {
t.Fatal(err)
}
// if rows.Err() != nil {
// t.Error(rows.Err())
// }
// err = batch.Close()
// if err != nil {
// t.Fatal(err)
// }
}
func TestConnPoolBeginEx(t *testing.T) {

View File

@ -551,234 +551,6 @@ func TestPrepareEx(t *testing.T) {
}
}
func TestListenNotify(t *testing.T) {
t.Parallel()
listener := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, listener)
if err := listener.Listen("chat"); err != nil {
t.Fatalf("Unable to start listening: %v", err)
}
notifier := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, notifier)
mustExec(t, notifier, "notify chat")
// when notification is waiting on the socket to be read
notification, err := listener.WaitForNotification(context.Background())
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
if notification.Channel != "chat" {
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
}
// when notification has already been read during previous query
mustExec(t, notifier, "notify chat")
rows, _ := listener.Query("select 1")
rows.Close()
if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", rows.Err())
}
ctx, cancelFn := context.WithCancel(context.Background())
cancelFn()
notification, err = listener.WaitForNotification(ctx)
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
if notification.Channel != "chat" {
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
}
// when timeout occurs
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
notification, err = listener.WaitForNotification(ctx)
if err != context.DeadlineExceeded {
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
}
if notification != nil {
t.Errorf("WaitForNotification returned an unexpected notification: %v", notification)
}
// listener can listen again after a timeout
mustExec(t, notifier, "notify chat")
notification, err = listener.WaitForNotification(context.Background())
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
if notification.Channel != "chat" {
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
}
}
func TestUnlistenSpecificChannel(t *testing.T) {
t.Parallel()
listener := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, listener)
if err := listener.Listen("unlisten_test"); err != nil {
t.Fatalf("Unable to start listening: %v", err)
}
notifier := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, notifier)
mustExec(t, notifier, "notify unlisten_test")
// when notification is waiting on the socket to be read
notification, err := listener.WaitForNotification(context.Background())
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
if notification.Channel != "unlisten_test" {
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
}
err = listener.Unlisten("unlisten_test")
if err != nil {
t.Fatalf("Unexpected error on Unlisten: %v", err)
}
// when notification has already been read during previous query
mustExec(t, notifier, "notify unlisten_test")
rows, _ := listener.Query("select 1")
rows.Close()
if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", rows.Err())
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
notification, err = listener.WaitForNotification(ctx)
if err != context.DeadlineExceeded {
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
}
}
func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
t.Parallel()
listenerDone := make(chan bool)
go func() {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
defer func() {
listenerDone <- true
}()
if err := conn.Listen("busysafe"); err != nil {
t.Fatalf("Unable to start listening: %v", err)
}
for i := 0; i < 5000; i++ {
var sum int32
var rowCount int32
rows, err := conn.Query("select generate_series(1,$1)", 100)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
for rows.Next() {
var n int32
rows.Scan(&n)
sum += n
rowCount++
}
if rows.Err() != nil {
t.Fatalf("conn.Query failed: %v", err)
}
if sum != 5050 {
t.Fatalf("Wrong rows sum: %v", sum)
}
if rowCount != 100 {
t.Fatalf("Wrong number of rows: %v", rowCount)
}
time.Sleep(1 * time.Microsecond)
}
}()
go func() {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
for i := 0; i < 100000; i++ {
mustExec(t, conn, "notify busysafe, 'hello'")
time.Sleep(1 * time.Microsecond)
}
}()
<-listenerDone
}
func TestListenNotifySelfNotification(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
if err := conn.Listen("self"); err != nil {
t.Fatalf("Unable to start listening: %v", err)
}
// Notify self and WaitForNotification immediately
mustExec(t, conn, "notify self")
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
notification, err := conn.WaitForNotification(ctx)
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
if notification.Channel != "self" {
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
}
// Notify self and do something else before WaitForNotification
mustExec(t, conn, "notify self")
rows, _ := conn.Query("select 1")
rows.Close()
if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", rows.Err())
}
ctx, cncl := context.WithTimeout(context.Background(), time.Second)
defer cncl()
notification, err = conn.WaitForNotification(ctx)
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
if notification.Channel != "self" {
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
}
}
func TestListenUnlistenSpecialCharacters(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
chanName := "special characters !@#{$%^&*()}"
if err := conn.Listen(chanName); err != nil {
t.Fatalf("Unable to start listening: %v", err)
}
if err := conn.Unlisten(chanName); err != nil {
t.Fatalf("Unable to stop listening: %v", err)
}
}
func TestFatalRxError(t *testing.T) {
t.Parallel()
@ -830,7 +602,7 @@ func TestFatalTxError(t *testing.T) {
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
}
_, err = conn.Query("select 1")
err = conn.QueryRow("select 1").Scan(nil)
if err == nil {
t.Fatal("Expected error but none occurred")
}

View File

@ -2,12 +2,11 @@ package pgx
import (
"bytes"
"context"
"fmt"
"io"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/pkg/errors"
)
@ -58,39 +57,6 @@ type copyFrom struct {
readerErrChan chan error
}
func (ct *copyFrom) readUntilReadyForQuery() {
for {
msg, err := ct.conn.rxMsg()
if err != nil {
ct.readerErrChan <- err
close(ct.readerErrChan)
return
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
ct.conn.rxReadyForQuery(msg)
close(ct.readerErrChan)
return
case *pgproto3.CommandComplete:
case *pgproto3.ErrorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
default:
err = ct.conn.processContextFreeMsg(msg)
if err != nil {
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
}
}
}
}
func (ct *copyFrom) waitForReaderDone() error {
var err error
for err = range ct.readerErrChan {
}
return err
}
func (ct *copyFrom) run() (int, error) {
quotedTableName := ct.tableName.Sanitize()
cbuf := &bytes.Buffer{}
@ -107,163 +73,74 @@ func (ct *copyFrom) run() (int, error) {
return 0, err
}
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
if err != nil {
return 0, err
}
r, w := io.Pipe()
err = ct.conn.readUntilCopyInResponse()
if err != nil {
return 0, err
}
go func() {
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
buf := ct.conn.wbuf
panicked := true
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
go ct.readUntilReadyForQuery()
defer ct.waitForReaderDone()
defer func() {
if panicked {
ct.conn.die(errors.New("panic while in copy from"))
moreRows := true
for moreRows {
var err error
moreRows, buf, err = ct.buildCopyBuf(buf, ps)
if err != nil {
w.CloseWithError(err)
return
}
if ct.rowSrc.Err() != nil {
w.CloseWithError(ct.rowSrc.Err())
return
}
if len(buf) > 0 {
_, err = w.Write(buf)
if err != nil {
w.Close()
return
}
}
buf = buf[:0]
}
w.Close()
}()
buf := ct.conn.wbuf
buf = append(buf, copyData)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
commandTag, err := ct.conn.pgConn.CopyFrom(context.TODO(), r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
var sentCount int
moreRows := true
for moreRows {
select {
case err = <-ct.readerErrChan:
panicked = false
return 0, err
default:
}
var addedRows int
var err error
moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps)
if err != nil {
panicked = false
ct.cancelCopyIn()
return 0, err
}
sentCount += addedRows
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = ct.conn.pgConn.Conn().Write(buf)
if err != nil {
panicked = false
ct.conn.die(err)
return 0, err
}
// Directly manipulate wbuf to reset to reuse the same buffer
buf = buf[0:5]
}
if ct.rowSrc.Err() != nil {
panicked = false
ct.cancelCopyIn()
return 0, ct.rowSrc.Err()
}
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
buf = append(buf, copyDone)
buf = pgio.AppendInt32(buf, 4)
_, err = ct.conn.pgConn.Conn().Write(buf)
if err != nil {
panicked = false
ct.conn.die(err)
return 0, err
}
err = ct.waitForReaderDone()
if err != nil {
panicked = false
return 0, err
}
panicked = false
return sentCount, nil
return int(commandTag.RowsAffected()), err
}
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) {
var rowCount int
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) {
for ct.rowSrc.Next() {
values, err := ct.rowSrc.Values()
if err != nil {
return false, nil, 0, err
return false, nil, err
}
if len(values) != len(ct.columnNames) {
return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
return false, nil, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
return false, nil, 0, err
return false, nil, err
}
}
rowCount++
if len(buf) > 65536 {
return true, buf, rowCount, nil
return true, buf, nil
}
}
return false, buf, rowCount, nil
}
func (c *Conn) readUntilCopyInResponse() error {
for {
msg, err := c.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.CopyInResponse:
return nil
default:
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
}
func (ct *copyFrom) cancelCopyIn() error {
buf := ct.conn.wbuf
buf = append(buf, copyFail)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, "client error: abort"...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err := ct.conn.pgConn.Conn().Write(buf)
if err != nil {
ct.conn.die(err)
return err
}
return nil
return false, buf, nil
}
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
@ -283,57 +160,3 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
return ct.run()
}
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
func (c *Conn) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
if err := c.sendSimpleQuery(sql); err != nil {
return "", err
}
if err := c.readUntilCopyInResponse(); err != nil {
return "", err
}
buf := c.wbuf
buf = append(buf, copyData)
sp := len(buf)
for {
n, err := r.Read(buf[5:cap(buf)])
if err == io.EOF && n == 0 {
break
}
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
if _, err := c.pgConn.Conn().Write(buf); err != nil {
return "", err
}
}
buf = buf[:0]
buf = append(buf, copyDone)
buf = pgio.AppendInt32(buf, 4)
if _, err := c.pgConn.Conn().Write(buf); err != nil {
return "", err
}
for {
msg, err := c.rxMsg()
if err != nil {
return "", err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return "", err
case *pgproto3.CommandComplete:
return pgconn.CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse:
return "", c.rxErrorResponse(msg)
default:
return "", c.processContextFreeMsg(msg)
}
}
}

View File

@ -1,13 +1,8 @@
package pgx_test
import (
"compress/gzip"
"fmt"
"io/ioutil"
"os"
"reflect"
"strconv"
"strings"
"testing"
"time"
@ -38,8 +33,8 @@ func TestConnCopyFromSmall(t *testing.T) {
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
{nil, nil, nil, nil, nil, nil, nil},
}
inputReader := strings.NewReader("0\t1\t2\tabc\tefg\t2000-01-01\t" + tzedTime.Format(time.RFC3339Nano) + "\n" +
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
// inputReader := strings.NewReader("0\t1\t2\tabc\tefg\t2000-01-01\t" + tzedTime.Format(time.RFC3339Nano) + "\n" +
// "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
if err != nil {
@ -73,36 +68,38 @@ func TestConnCopyFromSmall(t *testing.T) {
mustExec(t, conn, "truncate foo")
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err)
}
copyCount = int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
}
// TODO
rows, err = conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
// res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
// if err != nil {
// t.Errorf("Unexpected error for CopyFromReader: %v", err)
// }
// copyCount = int(res.RowsAffected())
// if copyCount != len(inputRows) {
// t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
// }
outputRows = make([][]interface{}, 0)
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
// rows, err = conn.Query("select * from foo")
// if err != nil {
// t.Errorf("Unexpected error for Query: %v", err)
// }
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
// outputRows = make([][]interface{}, 0)
// for rows.Next() {
// row, err := rows.Values()
// if err != nil {
// t.Errorf("Unexpected error for rows.Values(): %v", err)
// }
// outputRows = append(outputRows, row)
// }
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
// if rows.Err() != nil {
// t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
// }
// if !reflect.DeepEqual(inputRows, outputRows) {
// t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
// }
ensureConnValid(t, conn)
}
@ -166,36 +163,38 @@ func TestConnCopyFromLarge(t *testing.T) {
mustExec(t, conn, "truncate foo")
res, err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin")
if err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err)
}
copyCount = int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
}
// TODO
rows, err = conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
// res, err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin")
// if err != nil {
// t.Errorf("Unexpected error for CopyFromReader: %v", err)
// }
// copyCount = int(res.RowsAffected())
// if copyCount != len(inputRows) {
// t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
// }
outputRows = make([][]interface{}, 0)
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
// rows, err = conn.Query("select * from foo")
// if err != nil {
// t.Errorf("Unexpected error for Query: %v", err)
// }
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
// outputRows = make([][]interface{}, 0)
// for rows.Next() {
// row, err := rows.Values()
// if err != nil {
// t.Errorf("Unexpected error for rows.Values(): %v", err)
// }
// outputRows = append(outputRows, row)
// }
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal")
}
// if rows.Err() != nil {
// t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
// }
// if !reflect.DeepEqual(inputRows, outputRows) {
// t.Errorf("Input rows and output rows do not equal")
// }
ensureConnValid(t, conn)
}
@ -221,7 +220,7 @@ func TestConnCopyFromJSON(t *testing.T) {
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
{nil, nil},
}
inputReader := strings.NewReader("{\"foo\":\"bar\"}\t{\"bar\":\"quz\"}\n\\N\t\\N\n")
// inputReader := strings.NewReader("{\"foo\":\"bar\"}\t{\"bar\":\"quz\"}\n\\N\t\\N\n")
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
if err != nil {
@ -255,36 +254,38 @@ func TestConnCopyFromJSON(t *testing.T) {
mustExec(t, conn, "truncate foo")
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
copyCount = int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
}
// TODO
rows, err = conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
// res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
// if err != nil {
// t.Errorf("Unexpected error for CopyFrom: %v", err)
// }
// copyCount = int(res.RowsAffected())
// if copyCount != len(inputRows) {
// t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
// }
outputRows = make([][]interface{}, 0)
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
// rows, err = conn.Query("select * from foo")
// if err != nil {
// t.Errorf("Unexpected error for Query: %v", err)
// }
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
// outputRows = make([][]interface{}, 0)
// for rows.Next() {
// row, err := rows.Values()
// if err != nil {
// t.Errorf("Unexpected error for rows.Values(): %v", err)
// }
// outputRows = append(outputRows, row)
// }
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
// if rows.Err() != nil {
// t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
// }
// if !reflect.DeepEqual(inputRows, outputRows) {
// t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
// }
ensureConnValid(t, conn)
}
@ -327,7 +328,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
{int32(2), nil}, // this row should trigger a failure
{int32(3), "def"},
}
inputReader := strings.NewReader("1\tabc\n2\t\\N\n3\tdef\n")
// inputReader := strings.NewReader("1\tabc\n2\t\\N\n3\tdef\n")
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
if err == nil {
@ -364,39 +365,41 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
mustExec(t, conn, "truncate foo")
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not")
}
if _, ok := err.(*pgconn.PgError); !ok {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
copyCount = int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
}
// TODO
rows, err = conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
// res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
// if err == nil {
// t.Errorf("Expected CopyFromReader return error, but it did not")
// }
// if _, ok := err.(*pgconn.PgError); !ok {
// t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
// }
// copyCount = int(res.RowsAffected())
// if copyCount != 0 {
// t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
// }
outputRows = make([][]interface{}, 0)
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
// rows, err = conn.Query("select * from foo")
// if err != nil {
// t.Errorf("Unexpected error for Query: %v", err)
// }
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
// outputRows = make([][]interface{}, 0)
// for rows.Next() {
// row, err := rows.Values()
// if err != nil {
// t.Errorf("Unexpected error for rows.Values(): %v", err)
// }
// outputRows = append(outputRows, row)
// }
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
}
// if rows.Err() != nil {
// t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
// }
// if len(outputRows) != 0 {
// t.Errorf("Expected 0 rows, but got %v", outputRows)
// }
ensureConnValid(t, conn)
}
@ -513,7 +516,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
}
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
t.Errorf("Expected 0 rows, but got %v", len(outputRows))
}
ensureConnValid(t, conn)
@ -578,192 +581,3 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
ensureConnValid(t, conn)
}
type nextPanicSource struct {
}
func (cfs *nextPanicSource) Next() bool {
panic("crash")
}
func (cfs *nextPanicSource) Values() ([]interface{}, error) {
return []interface{}{nil}, nil // should never get here
}
func (cfs *nextPanicSource) Err() error {
return nil // should never gets here
}
func TestConnCopyFromCopyFromSourceNextPanic(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a bytea not null
)`)
caughtPanic := false
func() {
defer func() {
if x := recover(); x != nil {
caughtPanic = true
}
}()
conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &nextPanicSource{})
}()
if !caughtPanic {
t.Error("expected panic but did not")
}
if conn.IsAlive() {
t.Error("panic should have killed conn")
}
}
func TestConnCopyFromReaderQueryError(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
inputReader := strings.NewReader("")
res, err := conn.CopyFromReader(inputReader, "cropy foo from stdin")
if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not")
}
if _, ok := err.(*pgconn.PgError); !ok {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn)
}
func TestConnCopyFromReaderNoTableError(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
inputReader := strings.NewReader("")
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not")
}
if _, ok := err.(*pgconn.PgError); !ok {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn)
}
func TestConnCopyFromGzipReader(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int4,
b varchar
)`)
f, err := ioutil.TempFile("", "*")
if err != nil {
t.Fatalf("Unexpected error for ioutil.TempFile: %v", err)
}
gw := gzip.NewWriter(f)
inputRows := [][]interface{}{}
for i := 0; i < 1000; i++ {
val := strconv.Itoa(i * i)
inputRows = append(inputRows, []interface{}{int32(i), val})
_, err = gw.Write([]byte(fmt.Sprintf("%d,\"%s\"\n", i, val)))
if err != nil {
t.Errorf("Unexpected error for gw.Write: %v", err)
}
}
err = gw.Close()
if err != nil {
t.Fatalf("Unexpected error for gw.Close: %v", err)
}
_, err = f.Seek(0, 0)
if err != nil {
t.Fatalf("Unexpected error for f.Seek: %v", err)
}
gr, err := gzip.NewReader(f)
if err != nil {
t.Fatalf("Unexpected error for gzip.NewReader: %v", err)
}
res, err := conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
if err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return 1000 copied rows, but got %d", copyCount)
}
err = gr.Close()
if err != nil {
t.Errorf("Unexpected error for gr.Close: %v", err)
}
err = f.Close()
if err != nil {
t.Errorf("Unexpected error for f.Close: %v", err)
}
err = os.Remove(f.Name())
if err != nil {
t.Errorf("Unexpected error for os.Remove: %v", err)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal")
}
ensureConnValid(t, conn)
}

View File

@ -1,64 +0,0 @@
package pgx
import (
"io"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgproto3"
)
func (c *Conn) readUntilCopyOutResponse() error {
for {
msg, err := c.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.CopyOutResponse:
return nil
default:
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
}
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (pgconn.CommandTag, error) {
if err := c.sendSimpleQuery(sql, args...); err != nil {
return "", err
}
if err := c.readUntilCopyOutResponse(); err != nil {
return "", err
}
for {
msg, err := c.rxMsg()
if err != nil {
return "", err
}
switch msg := msg.(type) {
case *pgproto3.CopyDone:
break
case *pgproto3.CopyData:
_, err := w.Write(msg.Data)
if err != nil {
c.die(err)
return "", err
}
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return "", nil
case *pgproto3.CommandComplete:
return pgconn.CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse:
return "", c.rxErrorResponse(msg)
default:
return "", c.processContextFreeMsg(msg)
}
}
}

View File

@ -1,116 +0,0 @@
package pgx_test
import (
"bytes"
"os"
"testing"
"github.com/jackc/pgx/pgconn"
)
func TestConnCopyToWriterSmall(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g json
)`)
mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`)
mustExec(t, conn, `insert into foo values (null, null, null, null, null, null, null)`)
inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
if err != nil {
t.Errorf("Unexpected error for CopyToWriter: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 2 {
t.Errorf("Expected CopyToWriter to return 2 copied rows, but got %d", copyCount)
}
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes()))
}
ensureConnValid(t, conn)
}
func TestConnCopyToWriterLarge(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g json,
h bytea
)`)
inputBytes := make([]byte, 0)
for i := 0; i < 1000; i++ {
mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`)
inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
}
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 1000 {
t.Errorf("Expected CopyToWriter to return 1 copied rows, but got %d", copyCount)
}
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
t.Errorf("Input rows and output rows do not equal")
}
ensureConnValid(t, conn)
}
func TestConnCopyToWriterQueryError(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
outputWriter := bytes.NewBuffer(make([]byte, 0))
res, err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
if err == nil {
t.Errorf("Expected CopyToWriter return error, but it did not")
}
if _, ok := err.(*pgconn.PgError); !ok {
t.Errorf("Expected CopyToWriter return pgx.PgError, but instead it returned: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyToWriter to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn)
}

View File

@ -398,6 +398,12 @@ func (pgConn *PgConn) hardClose() error {
return pgConn.conn.Close()
}
// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of
// underlying connection.
func (pgConn *PgConn) IsAlive() bool {
return !pgConn.closed
}
// writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error.
func (pgConn *PgConn) writeAll(buf []byte) error {
n, err := pgConn.conn.Write(buf)

269
query.go
View File

@ -10,6 +10,7 @@ import (
"github.com/pkg/errors"
"github.com/jackc/pgx/internal/sanitize"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
@ -56,6 +57,8 @@ type Rows struct {
args []interface{}
unlockConn bool
closed bool
resultReader *pgconn.ResultReader
}
func (rows *Rows) FieldDescriptions() []FieldDescription {
@ -76,7 +79,12 @@ func (rows *Rows) Close() {
rows.closed = true
rows.err = rows.conn.termContext(rows.err)
if rows.resultReader != nil {
_, closeErr := rows.resultReader.Close()
if rows.err == nil {
rows.err = closeErr
}
}
if rows.err == nil {
if rows.conn.shouldLog(LogLevelInfo) {
@ -119,50 +127,21 @@ func (rows *Rows) Next() bool {
return false
}
rows.rowCount++
rows.columnIdx = 0
for {
msg, err := rows.conn.rxMsg()
if err != nil {
rows.fatal(err)
return false
}
switch msg := msg.(type) {
case *pgproto3.RowDescription:
rows.fields = rows.conn.rxRowDescription(msg)
for i := range rows.fields {
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok {
rows.fields[i].DataTypeName = dt.Name
rows.fields[i].FormatCode = TextFormatCode
} else {
rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType))
return false
}
}
case *pgproto3.DataRow:
if len(msg.Values) != len(rows.fields) {
rows.fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values))))
return false
}
rows.values = msg.Values
return true
case *pgproto3.CommandComplete:
if rows.batch != nil {
rows.batch.pendingCommandComplete = false
}
rows.Close()
return false
default:
err = rows.conn.processContextFreeMsg(msg)
if err != nil {
rows.fatal(err)
return false
if rows.resultReader.NextRow() {
if rows.fields == nil {
rrFieldDescriptions := rows.resultReader.FieldDescriptions()
rows.fields = make([]FieldDescription, len(rrFieldDescriptions))
for i := range rrFieldDescriptions {
rows.conn.pgproto3FieldDescriptionToPgxFieldDescription(&rrFieldDescriptions[i], &rows.fields[i])
}
}
rows.rowCount++
rows.columnIdx = 0
rows.values = rows.resultReader.Values()
return true
} else {
rows.Close()
return false
}
}
@ -181,15 +160,6 @@ func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) {
return buf, fd, true
}
type scanArgError struct {
col int
err error
}
func (e scanArgError) Error() string {
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
}
// Scan reads the values from the current row into dest values positionally.
// dest can include pointers to core types, values implementing the Scanner
// interface, []byte, and nil. []byte will skip the decoding process and directly
@ -326,6 +296,15 @@ func (rows *Rows) Values() ([]interface{}, error) {
return values, rows.Err()
}
type scanArgError struct {
col int
err error
}
func (e scanArgError) Error() string {
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
}
// Query executes sql with args. If there is an error the returned *Rows will
// be returned in an error state. So it is allowed to ignore the error returned
// from Query and handle it in *Rows.
@ -369,7 +348,14 @@ type QueryExOptions struct {
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
c.lastStmtSent = false
rows = c.getRows(sql, args)
// rows = c.getRows(sql, args)
rows = &Rows{
conn: c,
startTime: time.Now(),
sql: sql,
args: args,
}
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
@ -388,85 +374,128 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
}
rows.unlockConn = true
err = c.initContext(ctx)
// err = c.initContext(ctx)
// if err != nil {
// rows.fatal(err)
// return rows, rows.err
// }
// if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
// c.lastStmtSent = true
// err = c.sanitizeAndSendSimpleQuery(sql, args...)
// if err != nil {
// rows.fatal(err)
// return rows, err
// }
// return rows, nil
// }
// if options != nil && len(options.ParameterOIDs) > 0 {
// buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args)
// if err != nil {
// rows.fatal(err)
// return rows, err
// }
// buf = appendSync(buf)
// n, err := c.pgConn.Conn().Write(buf)
// c.lastStmtSent = true
// if err != nil && fatalWriteErr(n, err) {
// rows.fatal(err)
// c.die(err)
// return rows, err
// }
// c.pendingReadyForQueryCount++
// fieldDescriptions, err := c.readUntilRowDescription()
// if err != nil {
// rows.fatal(err)
// return rows, err
// }
// if len(options.ResultFormatCodes) == 0 {
// for i := range fieldDescriptions {
// fieldDescriptions[i].FormatCode = TextFormatCode
// }
// } else if len(options.ResultFormatCodes) == 1 {
// fc := options.ResultFormatCodes[0]
// for i := range fieldDescriptions {
// fieldDescriptions[i].FormatCode = fc
// }
// } else {
// for i := range options.ResultFormatCodes {
// fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
// }
// }
// rows.sql = sql
// rows.fields = fieldDescriptions
// return rows, nil
// }
ps, ok := c.preparedStatements[sql]
if !ok {
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
if len(psd.ParamOIDs) != len(args) {
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(args)))
return rows, rows.err
}
ps = &PreparedStatement{
Name: psd.Name,
SQL: psd.SQL,
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
}
for i := range ps.ParameterOIDs {
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
}
}
rows.sql = ps.SQL
args, err = convertDriverValuers(args)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
c.lastStmtSent = true
err = c.sanitizeAndSendSimpleQuery(sql, args...)
if err != nil {
rows.fatal(err)
return rows, err
}
return rows, nil
}
if options != nil && len(options.ParameterOIDs) > 0 {
buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args)
if err != nil {
rows.fatal(err)
return rows, err
}
buf = appendSync(buf)
n, err := c.pgConn.Conn().Write(buf)
c.lastStmtSent = true
if err != nil && fatalWriteErr(n, err) {
rows.fatal(err)
c.die(err)
return rows, err
}
c.pendingReadyForQueryCount++
fieldDescriptions, err := c.readUntilRowDescription()
if err != nil {
rows.fatal(err)
return rows, err
}
if len(options.ResultFormatCodes) == 0 {
for i := range fieldDescriptions {
fieldDescriptions[i].FormatCode = TextFormatCode
}
} else if len(options.ResultFormatCodes) == 1 {
fc := options.ResultFormatCodes[0]
for i := range fieldDescriptions {
fieldDescriptions[i].FormatCode = fc
}
} else {
for i := range options.ResultFormatCodes {
fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
}
}
rows.sql = sql
rows.fields = fieldDescriptions
return rows, nil
}
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
ps, err = c.prepareEx("", sql, nil)
paramFormats := make([]int16, len(args))
paramValues := make([][]byte, len(args))
for i := range args {
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], args[i])
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], args[i])
if err != nil {
rows.fatal(err)
return rows, rows.err
}
}
resultFormats := make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
} else {
resultFormats[i] = TextFormatCode
}
}
}
rows.sql = ps.SQL
rows.fields = ps.FieldDescriptions
c.lastStmtSent = true
err = c.sendPreparedQuery(ps, args...)
if err != nil {
rows.fatal(err)
}
rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats)
return rows, rows.err
}

View File

@ -887,10 +887,10 @@ func TestQueryRowErrors(t *testing.T) {
scanArgs []interface{}
err string
}{
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
// {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
// {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"},
{"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"},
// {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"},
}
for i, tt := range tests {

View File

@ -1,361 +1,361 @@
package pgx_test
import (
"context"
"fmt"
"math/rand"
"os"
"strconv"
"testing"
"time"
// import (
// "context"
// "fmt"
// "math/rand"
// "os"
// "strconv"
// "testing"
// "time"
"github.com/pkg/errors"
// "github.com/pkg/errors"
"github.com/jackc/fake"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgconn"
)
// "github.com/jackc/fake"
// "github.com/jackc/pgx"
// "github.com/jackc/pgx/pgconn"
// )
type execer interface {
Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
}
type queryer interface {
Query(sql string, args ...interface{}) (*pgx.Rows, error)
}
type queryRower interface {
QueryRow(sql string, args ...interface{}) *pgx.Row
}
// type execer interface {
// Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
// }
// type queryer interface {
// Query(sql string, args ...interface{}) (*pgx.Rows, error)
// }
// type queryRower interface {
// QueryRow(sql string, args ...interface{}) *pgx.Row
// }
func TestStressConnPool(t *testing.T) {
t.Parallel()
// func TestStressConnPool(t *testing.T) {
// t.Parallel()
maxConnections := 8
pool := createConnPool(t, maxConnections)
defer pool.Close()
// maxConnections := 8
// pool := createConnPool(t, maxConnections)
// defer pool.Close()
setupStressDB(t, pool)
// setupStressDB(t, pool)
actions := []struct {
name string
fn func(*pgx.ConnPool, int) error
}{
{"insertUnprepared", func(p *pgx.ConnPool, n int) error { return insertUnprepared(p, n) }},
{"queryRowWithoutParams", func(p *pgx.ConnPool, n int) error { return queryRowWithoutParams(p, n) }},
{"query", func(p *pgx.ConnPool, n int) error { return queryCloseEarly(p, n) }},
{"queryCloseEarly", func(p *pgx.ConnPool, n int) error { return query(p, n) }},
{"queryErrorWhileReturningRows", func(p *pgx.ConnPool, n int) error { return queryErrorWhileReturningRows(p, n) }},
{"txInsertRollback", txInsertRollback},
{"txInsertCommit", txInsertCommit},
{"txMultipleQueries", txMultipleQueries},
{"notify", notify},
{"listenAndPoolUnlistens", listenAndPoolUnlistens},
{"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
{"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
{"canceledQueryExContext", canceledQueryExContext},
{"canceledExecExContext", canceledExecExContext},
}
// actions := []struct {
// name string
// fn func(*pgx.ConnPool, int) error
// }{
// {"insertUnprepared", func(p *pgx.ConnPool, n int) error { return insertUnprepared(p, n) }},
// {"queryRowWithoutParams", func(p *pgx.ConnPool, n int) error { return queryRowWithoutParams(p, n) }},
// {"query", func(p *pgx.ConnPool, n int) error { return queryCloseEarly(p, n) }},
// {"queryCloseEarly", func(p *pgx.ConnPool, n int) error { return query(p, n) }},
// {"queryErrorWhileReturningRows", func(p *pgx.ConnPool, n int) error { return queryErrorWhileReturningRows(p, n) }},
// {"txInsertRollback", txInsertRollback},
// {"txInsertCommit", txInsertCommit},
// {"txMultipleQueries", txMultipleQueries},
// {"notify", notify},
// {"listenAndPoolUnlistens", listenAndPoolUnlistens},
// {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
// {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
// {"canceledQueryExContext", canceledQueryExContext},
// {"canceledExecExContext", canceledExecExContext},
// }
actionCount := 1000
if s := os.Getenv("STRESS_FACTOR"); s != "" {
stressFactor, err := strconv.ParseInt(s, 10, 64)
if err != nil {
t.Fatalf("failed to parse STRESS_FACTOR: %v", s)
}
actionCount *= int(stressFactor)
}
// actionCount := 1000
// if s := os.Getenv("STRESS_FACTOR"); s != "" {
// stressFactor, err := strconv.ParseInt(s, 10, 64)
// if err != nil {
// t.Fatalf("failed to parse STRESS_FACTOR: %v", s)
// }
// actionCount *= int(stressFactor)
// }
workerCount := 16
// workerCount := 16
workChan := make(chan int)
doneChan := make(chan struct{})
errChan := make(chan error)
// workChan := make(chan int)
// doneChan := make(chan struct{})
// errChan := make(chan error)
work := func() {
for n := range workChan {
action := actions[rand.Intn(len(actions))]
err := action.fn(pool, n)
if err != nil {
errChan <- errors.Errorf("%s: %v", action.name, err)
break
}
}
doneChan <- struct{}{}
}
// work := func() {
// for n := range workChan {
// action := actions[rand.Intn(len(actions))]
// err := action.fn(pool, n)
// if err != nil {
// errChan <- errors.Errorf("%s: %v", action.name, err)
// break
// }
// }
// doneChan <- struct{}{}
// }
for i := 0; i < workerCount; i++ {
go work()
}
// for i := 0; i < workerCount; i++ {
// go work()
// }
for i := 0; i < actionCount; i++ {
select {
case workChan <- i:
case err := <-errChan:
close(workChan)
t.Fatal(err)
}
}
close(workChan)
// for i := 0; i < actionCount; i++ {
// select {
// case workChan <- i:
// case err := <-errChan:
// close(workChan)
// t.Fatal(err)
// }
// }
// close(workChan)
for i := 0; i < workerCount; i++ {
<-doneChan
}
}
// for i := 0; i < workerCount; i++ {
// <-doneChan
// }
// }
func setupStressDB(t *testing.T, pool *pgx.ConnPool) {
_, err := pool.Exec(context.Background(), `
drop table if exists widgets;
create table widgets(
id serial primary key,
name varchar not null,
description text,
creation_time timestamptz
);
`)
if err != nil {
t.Fatal(err)
}
}
// func setupStressDB(t *testing.T, pool *pgx.ConnPool) {
// _, err := pool.Exec(context.Background(), `
// drop table if exists widgets;
// create table widgets(
// id serial primary key,
// name varchar not null,
// description text,
// creation_time timestamptz
// );
// `)
// if err != nil {
// t.Fatal(err)
// }
// }
func insertUnprepared(e execer, actionNum int) error {
sql := `
insert into widgets(name, description, creation_time)
values($1, $2, $3)`
// func insertUnprepared(e execer, actionNum int) error {
// sql := `
// insert into widgets(name, description, creation_time)
// values($1, $2, $3)`
_, err := e.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
return err
}
// _, err := e.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
// return err
// }
func queryRowWithoutParams(qr queryRower, actionNum int) error {
var id int32
var name, description string
var creationTime time.Time
// func queryRowWithoutParams(qr queryRower, actionNum int) error {
// var id int32
// var name, description string
// var creationTime time.Time
sql := `select * from widgets order by random() limit 1`
// sql := `select * from widgets order by random() limit 1`
err := qr.QueryRow(sql).Scan(&id, &name, &description, &creationTime)
if err == pgx.ErrNoRows {
return nil
}
return err
}
// err := qr.QueryRow(sql).Scan(&id, &name, &description, &creationTime)
// if err == pgx.ErrNoRows {
// return nil
// }
// return err
// }
func query(q queryer, actionNum int) error {
sql := `select * from widgets order by random() limit $1`
// func query(q queryer, actionNum int) error {
// sql := `select * from widgets order by random() limit $1`
rows, err := q.Query(sql, 10)
if err != nil {
return err
}
defer rows.Close()
// rows, err := q.Query(sql, 10)
// if err != nil {
// return err
// }
// defer rows.Close()
for rows.Next() {
var id int32
var name, description string
var creationTime time.Time
rows.Scan(&id, &name, &description, &creationTime)
}
// for rows.Next() {
// var id int32
// var name, description string
// var creationTime time.Time
// rows.Scan(&id, &name, &description, &creationTime)
// }
return rows.Err()
}
// return rows.Err()
// }
func queryCloseEarly(q queryer, actionNum int) error {
sql := `select * from generate_series(1,$1)`
// func queryCloseEarly(q queryer, actionNum int) error {
// sql := `select * from generate_series(1,$1)`
rows, err := q.Query(sql, 100)
if err != nil {
return err
}
defer rows.Close()
// rows, err := q.Query(sql, 100)
// if err != nil {
// return err
// }
// defer rows.Close()
for i := 0; i < 10 && rows.Next(); i++ {
var n int32
rows.Scan(&n)
}
rows.Close()
// for i := 0; i < 10 && rows.Next(); i++ {
// var n int32
// rows.Scan(&n)
// }
// rows.Close()
return rows.Err()
}
// return rows.Err()
// }
func queryErrorWhileReturningRows(q queryer, actionNum int) error {
// This query should divide by 0 within the first number of rows
sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)`
// func queryErrorWhileReturningRows(q queryer, actionNum int) error {
// // This query should divide by 0 within the first number of rows
// sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)`
rows, err := q.Query(sql)
if err != nil {
return nil
}
defer rows.Close()
// rows, err := q.Query(sql)
// if err != nil {
// return nil
// }
// defer rows.Close()
for rows.Next() {
var n int32
rows.Scan(&n)
}
// for rows.Next() {
// var n int32
// rows.Scan(&n)
// }
if _, ok := rows.Err().(*pgconn.PgError); ok {
return nil
}
return rows.Err()
}
// if _, ok := rows.Err().(*pgconn.PgError); ok {
// return nil
// }
// return rows.Err()
// }
func notify(pool *pgx.ConnPool, actionNum int) error {
_, err := pool.Exec(context.Background(), "notify stress")
return err
}
// func notify(pool *pgx.ConnPool, actionNum int) error {
// _, err := pool.Exec(context.Background(), "notify stress")
// return err
// }
func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error {
conn, err := pool.Acquire()
if err != nil {
return err
}
defer pool.Release(conn)
// func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error {
// conn, err := pool.Acquire()
// if err != nil {
// return err
// }
// defer pool.Release(conn)
err = conn.Listen("stress")
if err != nil {
return err
}
// err = conn.Listen("stress")
// if err != nil {
// return err
// }
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = conn.WaitForNotification(ctx)
if err == context.DeadlineExceeded {
return nil
}
return err
}
// ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
// defer cancel()
// _, err = conn.WaitForNotification(ctx)
// if err == context.DeadlineExceeded {
// return nil
// }
// return err
// }
func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error {
psName := fmt.Sprintf("poolPreparedStatement%d", actionNum)
// func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error {
// psName := fmt.Sprintf("poolPreparedStatement%d", actionNum)
_, err := pool.Prepare(psName, "select $1::text")
if err != nil {
return err
}
// _, err := pool.Prepare(psName, "select $1::text")
// if err != nil {
// return err
// }
var s string
err = pool.QueryRow(psName, "hello").Scan(&s)
if err != nil {
return err
}
// var s string
// err = pool.QueryRow(psName, "hello").Scan(&s)
// if err != nil {
// return err
// }
if s != "hello" {
return errors.Errorf("Prepared statement did not return expected value: %v", s)
}
// if s != "hello" {
// return errors.Errorf("Prepared statement did not return expected value: %v", s)
// }
return pool.Deallocate(psName)
}
// return pool.Deallocate(psName)
// }
func txInsertRollback(pool *pgx.ConnPool, actionNum int) error {
tx, err := pool.Begin()
if err != nil {
return err
}
// func txInsertRollback(pool *pgx.ConnPool, actionNum int) error {
// tx, err := pool.Begin()
// if err != nil {
// return err
// }
sql := `
insert into widgets(name, description, creation_time)
values($1, $2, $3)`
// sql := `
// insert into widgets(name, description, creation_time)
// values($1, $2, $3)`
_, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
if err != nil {
return err
}
// _, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
// if err != nil {
// return err
// }
return tx.Rollback()
}
// return tx.Rollback()
// }
func txInsertCommit(pool *pgx.ConnPool, actionNum int) error {
tx, err := pool.Begin()
if err != nil {
return err
}
// func txInsertCommit(pool *pgx.ConnPool, actionNum int) error {
// tx, err := pool.Begin()
// if err != nil {
// return err
// }
sql := `
insert into widgets(name, description, creation_time)
values($1, $2, $3)`
// sql := `
// insert into widgets(name, description, creation_time)
// values($1, $2, $3)`
_, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
if err != nil {
tx.Rollback()
return err
}
// _, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
// if err != nil {
// tx.Rollback()
// return err
// }
return tx.Commit()
}
// return tx.Commit()
// }
func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
tx, err := pool.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
// tx, err := pool.Begin()
// if err != nil {
// return err
// }
// defer tx.Rollback()
errExpectedTxDeath := errors.New("Expected tx death")
// errExpectedTxDeath := errors.New("Expected tx death")
actions := []struct {
name string
fn func() error
}{
{"insertUnprepared", func() error { return insertUnprepared(tx, actionNum) }},
{"queryRowWithoutParams", func() error { return queryRowWithoutParams(tx, actionNum) }},
{"query", func() error { return query(tx, actionNum) }},
{"queryCloseEarly", func() error { return queryCloseEarly(tx, actionNum) }},
{"queryErrorWhileReturningRows", func() error {
err := queryErrorWhileReturningRows(tx, actionNum)
if err != nil {
return err
}
return errExpectedTxDeath
}},
}
// actions := []struct {
// name string
// fn func() error
// }{
// {"insertUnprepared", func() error { return insertUnprepared(tx, actionNum) }},
// {"queryRowWithoutParams", func() error { return queryRowWithoutParams(tx, actionNum) }},
// {"query", func() error { return query(tx, actionNum) }},
// {"queryCloseEarly", func() error { return queryCloseEarly(tx, actionNum) }},
// {"queryErrorWhileReturningRows", func() error {
// err := queryErrorWhileReturningRows(tx, actionNum)
// if err != nil {
// return err
// }
// return errExpectedTxDeath
// }},
// }
for i := 0; i < 20; i++ {
action := actions[rand.Intn(len(actions))]
err := action.fn()
if err == errExpectedTxDeath {
return nil
} else if err != nil {
return err
}
}
// for i := 0; i < 20; i++ {
// action := actions[rand.Intn(len(actions))]
// err := action.fn()
// if err == errExpectedTxDeath {
// return nil
// } else if err != nil {
// return err
// }
// }
return tx.Commit()
}
// return tx.Commit()
// }
func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error {
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
cancelFunc()
}()
// func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error {
// ctx, cancelFunc := context.WithCancel(context.Background())
// go func() {
// time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
// cancelFunc()
// }()
rows, err := pool.QueryEx(ctx, "select pg_sleep(2)", nil)
if err == context.Canceled {
return nil
} else if err != nil {
return errors.Errorf("Only allowed error is context.Canceled, got %v", err)
}
// rows, err := pool.QueryEx(ctx, "select pg_sleep(2)", nil)
// if err == context.Canceled {
// return nil
// } else if err != nil {
// return errors.Errorf("Only allowed error is context.Canceled, got %v", err)
// }
for rows.Next() {
return errors.New("should never receive row")
}
// for rows.Next() {
// return errors.New("should never receive row")
// }
if rows.Err() != context.Canceled {
return errors.Errorf("Expected context.Canceled error, got %v", rows.Err())
}
// if rows.Err() != context.Canceled {
// return errors.Errorf("Expected context.Canceled error, got %v", rows.Err())
// }
return nil
}
// return nil
// }
func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error {
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
cancelFunc()
}()
// func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error {
// ctx, cancelFunc := context.WithCancel(context.Background())
// go func() {
// time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
// cancelFunc()
// }()
_, err := pool.Exec(ctx, "select pg_sleep(2)")
if err != context.Canceled {
return errors.Errorf("Expected context.Canceled error, got %v", err)
}
// _, err := pool.Exec(ctx, "select pg_sleep(2)")
// if err != context.Canceled {
// return errors.Errorf("Expected context.Canceled error, got %v", err)
// }
return nil
}
// return nil
// }

19
tx.go
View File

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/jackc/pgx/pgconn"
@ -231,24 +230,6 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
return tx.conn.CopyFrom(tableName, columnNames, rowSrc)
}
// CopyFromReader delegates to the underlying *Conn
func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag pgconn.CommandTag, err error) {
if tx.status != TxStatusInProgress {
return "", ErrTxClosed
}
return tx.conn.CopyFromReader(r, sql)
}
// CopyToWriter delegates to the underlying *Conn
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag pgconn.CommandTag, err error) {
if tx.status != TxStatusInProgress {
return "", ErrTxClosed
}
return tx.conn.CopyToWriter(w, sql, args...)
}
// Status returns the status of the transaction from the set of
// pgx.TxStatus* constants.
func (tx *Tx) Status() int8 {

31
v4.md
View File

@ -28,10 +28,6 @@ Potential Changes:
* Consider strongly typed query parameters in style of zerolog (chained functions instead of varargs)
* Consider buffered query select where entire result set is received and parsed successfully or call returns error
Minor Potential Changes:
* Change PgError error implementation to pointer method
## Changes
* `pgconn.PgConn` now contains core PostgreSQL connection functionality.
@ -43,7 +39,34 @@ Minor Potential Changes:
* Connect method now takes context and connection string.
* ConnectConfig takes context and config object.
* `RuntimeParams` `pgx.Conn`. Server reported status can now be queried with the `ParameterStatus` method. The rename aligns with the PostgreSQL protocol and standard libpq naming. Access via a method instead of direct access to the map protects against outside modification.
* LISTEN / NOTIFY functionality moved to pgconn.
* COPY TO functionality moved to pgconn.
* COPY FROM functionality moved to pgconn.
## New Features
* Specifying multiple hosts for connecting to HA systems.
## Transaction idea
Problem: Using original connection or pool outside of tx object
tx = pool.Begin()
tx.Query(...)
pool.Query(...) // <- Possible to accidentally do stuff outside of tx
Solution: Common interface for basic queries and atomicity
var querier Querier
querier = pool
querier = querier.Begin() <- tx implements querier
querier.Query(...)
querier = querier.Commit()
-- tx implements begin, commit, and rollback as save points
-- conn implements begin as create tx (what about commit and rollback? No-op?)
-- pool implements begin as?