mirror of https://github.com/jackc/pgx.git
Partial conversion of pgx to use pgconn
parent
e3d431d0df
commit
d3a2c1c107
186
batch.go
186
batch.go
|
@ -6,6 +6,7 @@ import (
|
|||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type batchItem struct {
|
||||
|
@ -26,6 +27,8 @@ type Batch struct {
|
|||
ctx context.Context
|
||||
err error
|
||||
inTx bool
|
||||
|
||||
mrr *pgconn.MultiResultReader
|
||||
}
|
||||
|
||||
// BeginBatch returns a *Batch query for c.
|
||||
|
@ -56,10 +59,8 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
|
|||
})
|
||||
}
|
||||
|
||||
// Send sends all queued queries to the server at once.
|
||||
// If the batch is created from a conn Object then All queries are wrapped
|
||||
// in a transaction. The transaction can optionally be configured with
|
||||
// txOptions. The context is in effect until the Batch is closed.
|
||||
// Send sends all queued queries to the server at once. All queries are run in an implicit transaction unless explicit
|
||||
// transaction control statements are executed.
|
||||
//
|
||||
// Warning: Send writes all queued queries before reading any results. This can
|
||||
// cause a deadlock if an excessive number of queries are queued. It is highly
|
||||
|
@ -78,7 +79,7 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
|
|||
// able to finish sending the responses.
|
||||
//
|
||||
// See https://github.com/jackc/pgx/issues/374.
|
||||
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
||||
func (b *Batch) Send(ctx context.Context) error {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
|
@ -94,112 +95,62 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
buf := b.conn.wbuf
|
||||
if !b.inTx {
|
||||
buf = appendQuery(buf, txOptions.beginSQL())
|
||||
}
|
||||
|
||||
err = b.conn.initContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
for _, bi := range b.items {
|
||||
var psName string
|
||||
var psParameterOIDs []pgtype.OID
|
||||
var parameterOIDs []pgtype.OID
|
||||
ps := b.conn.preparedStatements[bi.query]
|
||||
|
||||
if ps, ok := b.conn.preparedStatements[bi.query]; ok {
|
||||
psName = ps.Name
|
||||
psParameterOIDs = ps.ParameterOIDs
|
||||
if ps != nil {
|
||||
parameterOIDs = ps.ParameterOIDs
|
||||
} else {
|
||||
psParameterOIDs = bi.parameterOIDs
|
||||
buf = appendParse(buf, "", bi.query, psParameterOIDs)
|
||||
parameterOIDs = bi.parameterOIDs
|
||||
}
|
||||
|
||||
var err error
|
||||
buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes)
|
||||
args, err := convertDriverValuers(bi.arguments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf = appendDescribe(buf, 'P', "")
|
||||
buf = appendExecute(buf, "", 0)
|
||||
}
|
||||
|
||||
buf = appendSync(buf)
|
||||
b.conn.pendingReadyForQueryCount++
|
||||
|
||||
if !b.inTx {
|
||||
buf = appendQuery(buf, "commit")
|
||||
b.conn.pendingReadyForQueryCount++
|
||||
}
|
||||
|
||||
n, err := b.conn.pgConn.Conn().Write(buf)
|
||||
if err != nil {
|
||||
if fatalWriteErr(n, err) {
|
||||
b.conn.die(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for !b.inTx {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
return nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
paramFormats := make([]int16, len(args))
|
||||
paramValues := make([][]byte, len(args))
|
||||
for i := range args {
|
||||
paramFormats[i] = chooseParameterFormatCode(b.conn.ConnInfo, parameterOIDs[i], args[i])
|
||||
paramValues[i], err = newencodePreparedStatementArgument(b.conn.ConnInfo, parameterOIDs[i], args[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if ps != nil {
|
||||
batch.ExecPrepared(ps.Name, paramValues, paramFormats, bi.resultFormatCodes)
|
||||
} else {
|
||||
oids := make([]uint32, len(parameterOIDs))
|
||||
for i := 0; i < len(parameterOIDs); i++ {
|
||||
oids[i] = uint32(parameterOIDs[i])
|
||||
}
|
||||
batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes)
|
||||
}
|
||||
}
|
||||
|
||||
b.mrr = b.conn.pgConn.ExecBatch(ctx, batch)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Exec.
|
||||
func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
|
||||
if b.err != nil {
|
||||
return "", b.err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.die(b.ctx.Err())
|
||||
return "", b.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
if !b.mrr.NextResult() {
|
||||
err := b.mrr.Close()
|
||||
if err == nil {
|
||||
err = errors.New("no result")
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
for {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
b.pendingCommandComplete = false
|
||||
return pgconn.CommandTag(msg.CommandTag), nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
return b.mrr.ResultReader().Close()
|
||||
}
|
||||
|
||||
// QueryResults reads the results from the next query in the batch as if the
|
||||
|
@ -207,38 +158,16 @@ func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
|
|||
func (b *Batch) QueryResults() (*Rows, error) {
|
||||
rows := b.conn.getRows("batch query", nil)
|
||||
|
||||
if b.err != nil {
|
||||
rows.fatal(b.err)
|
||||
return rows, b.err
|
||||
if !b.mrr.NextResult() {
|
||||
rows.err = b.mrr.Close()
|
||||
if rows.err == nil {
|
||||
rows.err = errors.New("no result")
|
||||
}
|
||||
rows.closed = true
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.die(b.ctx.Err())
|
||||
rows.fatal(b.err)
|
||||
return rows, b.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
fieldDescriptions, err := b.conn.readUntilRowDescription()
|
||||
if err != nil {
|
||||
b.die(err)
|
||||
rows.fatal(b.err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
rows.batch = b
|
||||
rows.fields = fieldDescriptions
|
||||
rows.resultReader = b.mrr.ResultReader()
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
|
@ -254,28 +183,7 @@ func (b *Batch) QueryRowResults() *Row {
|
|||
// operation may have made it impossible to resyncronize the connection with the
|
||||
// server. In this case the underlying connection will have been closed.
|
||||
func (b *Batch) Close() (err error) {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = b.conn.termContext(err)
|
||||
if b.conn != nil && b.connPool != nil {
|
||||
b.connPool.Release(b.conn)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := b.resultsRead; i < len(b.items); i++ {
|
||||
if _, err = b.ExecResults(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = b.conn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return b.mrr.Close()
|
||||
}
|
||||
|
||||
func (b *Batch) die(err error) {
|
||||
|
|
111
batch_test.go
111
batch_test.go
|
@ -17,10 +17,10 @@ func TestConnBeginBatch(t *testing.T) {
|
|||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
|
@ -50,7 +50,7 @@ func TestConnBeginBatch(t *testing.T) {
|
|||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
err := batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -71,6 +71,14 @@ func TestConnBeginBatch(t *testing.T) {
|
|||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||
}
|
||||
|
||||
ct, err = batch.ExecResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if ct.RowsAffected() != 1 {
|
||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||
}
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -173,7 +181,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
|
|||
)
|
||||
}
|
||||
|
||||
err = batch.Send(context.Background(), nil)
|
||||
err = batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -233,7 +241,7 @@ func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) {
|
|||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
err := batch.Send(ctx, nil)
|
||||
err := batch.Send(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -245,9 +253,7 @@ func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) {
|
|||
t.Errorf("err => %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) {
|
||||
|
@ -269,21 +275,26 @@ func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) {
|
|||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
err := batch.Send(ctx, nil)
|
||||
err := batch.Send(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cancelFn()
|
||||
|
||||
_, err = batch.QueryResults()
|
||||
if err != context.Canceled {
|
||||
t.Errorf("err => %v, want %v", err, context.Canceled)
|
||||
rows, err := batch.QueryResults()
|
||||
|
||||
if rows.Next() {
|
||||
t.Error("unexpected row")
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
if rows.Err() != context.Canceled {
|
||||
t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled)
|
||||
}
|
||||
|
||||
batch.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) {
|
||||
|
@ -305,7 +316,7 @@ func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) {
|
|||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
err := batch.Send(ctx, nil)
|
||||
err := batch.Send(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -317,9 +328,7 @@ func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) {
|
|||
t.Errorf("err => %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
|
||||
|
@ -340,7 +349,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
|
|||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
err := batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -411,7 +420,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
|
|||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
err := batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -440,9 +449,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
|
|||
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
||||
|
@ -458,7 +465,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
|||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
err := batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -474,9 +481,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
|||
t.Error("Expected error")
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
||||
|
@ -486,10 +491,10 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
|||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
|
@ -504,7 +509,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
|||
nil,
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
err := batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -535,10 +540,10 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
|||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
|
@ -553,7 +558,7 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
|||
nil,
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
err := batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -584,15 +589,15 @@ func TestTxBeginBatch(t *testing.T) {
|
|||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger1(
|
||||
id serial primary key,
|
||||
description varchar not null
|
||||
);`
|
||||
id serial primary key,
|
||||
description varchar not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
sql = `create temporary table ledger2(
|
||||
id int primary key,
|
||||
amount int not null
|
||||
);`
|
||||
id int primary key,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
tx, _ := conn.Begin()
|
||||
|
@ -603,7 +608,7 @@ func TestTxBeginBatch(t *testing.T) {
|
|||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
err := batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -627,7 +632,7 @@ func TestTxBeginBatch(t *testing.T) {
|
|||
nil,
|
||||
)
|
||||
|
||||
err = batch.Send(context.Background(), nil)
|
||||
err = batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -639,8 +644,8 @@ func TestTxBeginBatch(t *testing.T) {
|
|||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||
}
|
||||
|
||||
var amout int
|
||||
err = batch.QueryRowResults().Scan(&amout)
|
||||
var amount int
|
||||
err = batch.QueryRowResults().Scan(&amount)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
@ -669,9 +674,9 @@ func TestTxBeginBatchRollback(t *testing.T) {
|
|||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger1(
|
||||
id serial primary key,
|
||||
description varchar not null
|
||||
);`
|
||||
id serial primary key,
|
||||
description varchar not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
tx, _ := conn.Begin()
|
||||
|
@ -682,7 +687,7 @@ func TestTxBeginBatchRollback(t *testing.T) {
|
|||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
err := batch.Send(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -668,7 +668,7 @@ func BenchmarkMultipleQueriesBatch(b *testing.B) {
|
|||
)
|
||||
}
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
err := batch.Send(context.Background())
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
|
142
conn.go
142
conn.go
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
@ -69,7 +68,6 @@ type Conn struct {
|
|||
config *ConnConfig // config used when establishing this connection
|
||||
preparedStatements map[string]*PreparedStatement
|
||||
channels map[string]struct{}
|
||||
notifications []*Notification
|
||||
logger Logger
|
||||
logLevel int
|
||||
fp *fastpath
|
||||
|
@ -105,13 +103,6 @@ type PrepareExOptions struct {
|
|||
ParameterOIDs []pgtype.OID
|
||||
}
|
||||
|
||||
// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
|
||||
type Notification struct {
|
||||
PID uint32 // backend pid that sent the notification
|
||||
Channel string // channel from which notification was received
|
||||
Payload string
|
||||
}
|
||||
|
||||
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
|
||||
// multiple parts such as ["schema", "table"] or ["table", "column"].
|
||||
type Identifier []string
|
||||
|
@ -501,17 +492,6 @@ func (c *Conn) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExO
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ps, err = c.prepareEx(name, sql, opts)
|
||||
err = c.termContext(err)
|
||||
return ps, err
|
||||
}
|
||||
|
||||
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||
if name != "" {
|
||||
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
|
||||
return ps, nil
|
||||
|
@ -562,7 +542,9 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
|
||||
}
|
||||
|
||||
c.preparedStatements[name] = ps
|
||||
if name != "" {
|
||||
c.preparedStatements[name] = ps
|
||||
}
|
||||
|
||||
return ps, nil
|
||||
}
|
||||
|
@ -593,42 +575,8 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
|
|||
|
||||
delete(c.preparedStatements, name)
|
||||
|
||||
// close
|
||||
buf := c.wbuf
|
||||
buf = append(buf, 'C')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, 'S')
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
// flush
|
||||
buf = append(buf, 'H')
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
_, err = c.pgConn.Conn().Write(buf)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg.(type) {
|
||||
case *pgproto3.CloseComplete:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
|
||||
return err
|
||||
}
|
||||
|
||||
// Listen establishes a PostgreSQL listen/notify to channel
|
||||
|
@ -654,64 +602,10 @@ func (c *Conn) Unlisten(channel string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// WaitForNotification waits for a PostgreSQL notification.
|
||||
func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) {
|
||||
// Return already received notification immediately
|
||||
if len(c.notifications) > 0 {
|
||||
notification := c.notifications[0]
|
||||
c.notifications = c.notifications[1:]
|
||||
return notification, nil
|
||||
}
|
||||
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if err = c.lock(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
|
||||
err = unlockErr
|
||||
}
|
||||
}()
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(c.notifications) > 0 {
|
||||
notification := c.notifications[0]
|
||||
c.notifications = c.notifications[1:]
|
||||
return notification, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) IsAlive() bool {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
return c.status >= connStatusIdle
|
||||
return c.pgConn.IsAlive() && c.status >= connStatusIdle
|
||||
}
|
||||
|
||||
func (c *Conn) CauseOfDeath() error {
|
||||
|
@ -807,8 +701,6 @@ func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
|
|||
switch msg := msg.(type) {
|
||||
case *pgproto3.ErrorResponse:
|
||||
return c.rxErrorResponse(msg)
|
||||
case *pgproto3.NotificationResponse:
|
||||
c.rxNotificationResponse(msg)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
}
|
||||
|
@ -886,14 +778,6 @@ func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgty
|
|||
return parameters
|
||||
}
|
||||
|
||||
func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) {
|
||||
n := new(Notification)
|
||||
n.PID = msg.PID
|
||||
n.Channel = msg.Channel
|
||||
n.Payload = msg.Payload
|
||||
c.notifications = append(c.notifications, n)
|
||||
}
|
||||
|
||||
func (c *Conn) die(err error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
@ -1238,7 +1122,13 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||
|
||||
resultFormats := make([]int16, len(ps.FieldDescriptions))
|
||||
for i := range resultFormats {
|
||||
resultFormats[i] = ps.FieldDescriptions[i].FormatCode
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||
resultFormats[i] = BinaryFormatCode
|
||||
} else {
|
||||
resultFormats[i] = TextFormatCode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.lastStmtSent = true
|
||||
|
@ -1307,13 +1197,9 @@ func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.Field
|
|||
dst.DataType = pgtype.OID(src.DataTypeOID)
|
||||
dst.DataTypeSize = src.DataTypeSize
|
||||
dst.Modifier = src.TypeModifier
|
||||
dst.FormatCode = src.Format
|
||||
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok {
|
||||
dst.DataTypeName = dt.Name
|
||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||
dst.FormatCode = BinaryFormatCode
|
||||
} else {
|
||||
dst.FormatCode = TextFormatCode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
24
conn_pool.go
24
conn_pool.go
|
@ -2,7 +2,6 @@ package pgx
|
|||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -204,7 +203,6 @@ func (p *ConnPool) Release(conn *Conn) {
|
|||
}
|
||||
conn.channels = make(map[string]struct{})
|
||||
}
|
||||
conn.notifications = nil
|
||||
|
||||
p.cond.L.Lock()
|
||||
|
||||
|
@ -544,28 +542,6 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
|
|||
return c.CopyFrom(tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
||||
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.CopyFromReader(r, sql)
|
||||
}
|
||||
|
||||
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (pgconn.CommandTag, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.CopyToWriter(w, sql, args...)
|
||||
}
|
||||
|
||||
// BeginBatch acquires a connection and begins a batch on that connection. When
|
||||
// *Batch is finished, the connection is released automatically.
|
||||
func (p *ConnPool) BeginBatch() *Batch {
|
||||
|
|
|
@ -931,142 +931,146 @@ func TestConnPoolPrepareDeallocatePrepare(t *testing.T) {
|
|||
func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pool := createConnPool(t, 2)
|
||||
defer pool.Close()
|
||||
t.Skip("TODO")
|
||||
|
||||
testPreparedStatement := func(db queryRower, desc string) {
|
||||
var s string
|
||||
err := db.QueryRow("test", "hello").Scan(&s)
|
||||
if err != nil {
|
||||
t.Fatalf("%s. Executing prepared statement failed: %v", desc, err)
|
||||
}
|
||||
// pool := createConnPool(t, 2)
|
||||
// defer pool.Close()
|
||||
|
||||
if s != "hello" {
|
||||
t.Fatalf("%s. Prepared statement did not return expected value: %v", desc, s)
|
||||
}
|
||||
}
|
||||
// testPreparedStatement := func(db queryRower, desc string) {
|
||||
// var s string
|
||||
// err := db.QueryRow("test", "hello").Scan(&s)
|
||||
// if err != nil {
|
||||
// t.Fatalf("%s. Executing prepared statement failed: %v", desc, err)
|
||||
// }
|
||||
|
||||
newReleaseOnce := func(c *pgx.Conn) func() {
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() { pool.Release(c) })
|
||||
}
|
||||
}
|
||||
// if s != "hello" {
|
||||
// t.Fatalf("%s. Prepared statement did not return expected value: %v", desc, s)
|
||||
// }
|
||||
// }
|
||||
|
||||
c1, err := pool.Acquire()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to acquire connection: %v", err)
|
||||
}
|
||||
c1Release := newReleaseOnce(c1)
|
||||
defer c1Release()
|
||||
// newReleaseOnce := func(c *pgx.Conn) func() {
|
||||
// var once sync.Once
|
||||
// return func() {
|
||||
// once.Do(func() { pool.Release(c) })
|
||||
// }
|
||||
// }
|
||||
|
||||
_, err = pool.Prepare("test", "select $1::varchar")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to prepare statement: %v", err)
|
||||
}
|
||||
// c1, err := pool.Acquire()
|
||||
// if err != nil {
|
||||
// t.Fatalf("Unable to acquire connection: %v", err)
|
||||
// }
|
||||
// c1Release := newReleaseOnce(c1)
|
||||
// defer c1Release()
|
||||
|
||||
testPreparedStatement(pool, "pool")
|
||||
// _, err = pool.Prepare("test", "select $1::varchar")
|
||||
// if err != nil {
|
||||
// t.Fatalf("Unable to prepare statement: %v", err)
|
||||
// }
|
||||
|
||||
c1Release()
|
||||
// testPreparedStatement(pool, "pool")
|
||||
|
||||
c2, err := pool.Acquire()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to acquire connection: %v", err)
|
||||
}
|
||||
c2Release := newReleaseOnce(c2)
|
||||
defer c2Release()
|
||||
// c1Release()
|
||||
|
||||
// This conn will not be available and will be connection at this point
|
||||
c3, err := pool.Acquire()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to acquire connection: %v", err)
|
||||
}
|
||||
c3Release := newReleaseOnce(c3)
|
||||
defer c3Release()
|
||||
// c2, err := pool.Acquire()
|
||||
// if err != nil {
|
||||
// t.Fatalf("Unable to acquire connection: %v", err)
|
||||
// }
|
||||
// c2Release := newReleaseOnce(c2)
|
||||
// defer c2Release()
|
||||
|
||||
testPreparedStatement(c2, "c2")
|
||||
testPreparedStatement(c3, "c3")
|
||||
// // This conn will not be available and will be connection at this point
|
||||
// c3, err := pool.Acquire()
|
||||
// if err != nil {
|
||||
// t.Fatalf("Unable to acquire connection: %v", err)
|
||||
// }
|
||||
// c3Release := newReleaseOnce(c3)
|
||||
// defer c3Release()
|
||||
|
||||
c2Release()
|
||||
c3Release()
|
||||
// testPreparedStatement(c2, "c2")
|
||||
// testPreparedStatement(c3, "c3")
|
||||
|
||||
err = pool.Deallocate("test")
|
||||
if err != nil {
|
||||
t.Errorf("Deallocate failed: %v", err)
|
||||
}
|
||||
// c2Release()
|
||||
// c3Release()
|
||||
|
||||
var s string
|
||||
err = pool.QueryRow("test", "hello").Scan(&s)
|
||||
if err, ok := err.(*pgconn.PgError); !(ok && err.Code == "42601") {
|
||||
t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err)
|
||||
}
|
||||
// err = pool.Deallocate("test")
|
||||
// if err != nil {
|
||||
// t.Errorf("Deallocate failed: %v", err)
|
||||
// }
|
||||
|
||||
// var s string
|
||||
// err = pool.QueryRow("test", "hello").Scan(&s)
|
||||
// if err, ok := err.(*pgconn.PgError); !(ok && err.Code == "42601") {
|
||||
// t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err)
|
||||
// }
|
||||
}
|
||||
|
||||
func TestConnPoolBeginBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pool := createConnPool(t, 2)
|
||||
defer pool.Close()
|
||||
t.Skip("TODO")
|
||||
|
||||
batch := pool.BeginBatch()
|
||||
batch.Queue("select n from generate_series(0,5) n",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("select n from generate_series(0,5) n",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
// pool := createConnPool(t, 2)
|
||||
// defer pool.Close()
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// batch := pool.BeginBatch()
|
||||
// batch.Queue("select n from generate_series(0,5) n",
|
||||
// nil,
|
||||
// nil,
|
||||
// []int16{pgx.BinaryFormatCode},
|
||||
// )
|
||||
// batch.Queue("select n from generate_series(0,5) n",
|
||||
// nil,
|
||||
// nil,
|
||||
// []int16{pgx.BinaryFormatCode},
|
||||
// )
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
// err := batch.Send(context.Background(), nil)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n != i {
|
||||
t.Errorf("n => %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
// rows, err := batch.QueryResults()
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Error(rows.Err())
|
||||
}
|
||||
// for i := 0; rows.Next(); i++ {
|
||||
// var n int
|
||||
// if err := rows.Scan(&n); err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
// if n != i {
|
||||
// t.Errorf("n => %v, want %v", n, i)
|
||||
// }
|
||||
// }
|
||||
|
||||
rows, err = batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
// if rows.Err() != nil {
|
||||
// t.Error(rows.Err())
|
||||
// }
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n != i {
|
||||
t.Errorf("n => %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
// rows, err = batch.QueryResults()
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Error(rows.Err())
|
||||
}
|
||||
// for i := 0; rows.Next(); i++ {
|
||||
// var n int
|
||||
// if err := rows.Scan(&n); err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
// if n != i {
|
||||
// t.Errorf("n => %v, want %v", n, i)
|
||||
// }
|
||||
// }
|
||||
|
||||
err = batch.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// if rows.Err() != nil {
|
||||
// t.Error(rows.Err())
|
||||
// }
|
||||
|
||||
// err = batch.Close()
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
}
|
||||
|
||||
func TestConnPoolBeginEx(t *testing.T) {
|
||||
|
|
230
conn_test.go
230
conn_test.go
|
@ -551,234 +551,6 @@ func TestPrepareEx(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestListenNotify(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listener := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, listener)
|
||||
|
||||
if err := listener.Listen("chat"); err != nil {
|
||||
t.Fatalf("Unable to start listening: %v", err)
|
||||
}
|
||||
|
||||
notifier := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, notifier)
|
||||
|
||||
mustExec(t, notifier, "notify chat")
|
||||
|
||||
// when notification is waiting on the socket to be read
|
||||
notification, err := listener.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
if notification.Channel != "chat" {
|
||||
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||
}
|
||||
|
||||
// when notification has already been read during previous query
|
||||
mustExec(t, notifier, "notify chat")
|
||||
rows, _ := listener.Query("select 1")
|
||||
rows.Close()
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("Unexpected error on Query: %v", rows.Err())
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
cancelFn()
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
if notification.Channel != "chat" {
|
||||
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||
}
|
||||
|
||||
// when timeout occurs
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||
defer cancel()
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
|
||||
}
|
||||
if notification != nil {
|
||||
t.Errorf("WaitForNotification returned an unexpected notification: %v", notification)
|
||||
}
|
||||
|
||||
// listener can listen again after a timeout
|
||||
mustExec(t, notifier, "notify chat")
|
||||
notification, err = listener.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
if notification.Channel != "chat" {
|
||||
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlistenSpecificChannel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listener := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, listener)
|
||||
|
||||
if err := listener.Listen("unlisten_test"); err != nil {
|
||||
t.Fatalf("Unable to start listening: %v", err)
|
||||
}
|
||||
|
||||
notifier := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, notifier)
|
||||
|
||||
mustExec(t, notifier, "notify unlisten_test")
|
||||
|
||||
// when notification is waiting on the socket to be read
|
||||
notification, err := listener.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
if notification.Channel != "unlisten_test" {
|
||||
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||
}
|
||||
|
||||
err = listener.Unlisten("unlisten_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on Unlisten: %v", err)
|
||||
}
|
||||
|
||||
// when notification has already been read during previous query
|
||||
mustExec(t, notifier, "notify unlisten_test")
|
||||
rows, _ := listener.Query("select 1")
|
||||
rows.Close()
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("Unexpected error on Query: %v", rows.Err())
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listenerDone := make(chan bool)
|
||||
go func() {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
defer func() {
|
||||
listenerDone <- true
|
||||
}()
|
||||
|
||||
if err := conn.Listen("busysafe"); err != nil {
|
||||
t.Fatalf("Unable to start listening: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 5000; i++ {
|
||||
var sum int32
|
||||
var rowCount int32
|
||||
|
||||
rows, err := conn.Query("select generate_series(1,$1)", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: %v", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var n int32
|
||||
rows.Scan(&n)
|
||||
sum += n
|
||||
rowCount++
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("conn.Query failed: %v", err)
|
||||
}
|
||||
|
||||
if sum != 5050 {
|
||||
t.Fatalf("Wrong rows sum: %v", sum)
|
||||
}
|
||||
|
||||
if rowCount != 100 {
|
||||
t.Fatalf("Wrong number of rows: %v", rowCount)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
for i := 0; i < 100000; i++ {
|
||||
mustExec(t, conn, "notify busysafe, 'hello'")
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
}
|
||||
}()
|
||||
|
||||
<-listenerDone
|
||||
}
|
||||
|
||||
func TestListenNotifySelfNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if err := conn.Listen("self"); err != nil {
|
||||
t.Fatalf("Unable to start listening: %v", err)
|
||||
}
|
||||
|
||||
// Notify self and WaitForNotification immediately
|
||||
mustExec(t, conn, "notify self")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
notification, err := conn.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
if notification.Channel != "self" {
|
||||
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||
}
|
||||
|
||||
// Notify self and do something else before WaitForNotification
|
||||
mustExec(t, conn, "notify self")
|
||||
|
||||
rows, _ := conn.Query("select 1")
|
||||
rows.Close()
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("Unexpected error on Query: %v", rows.Err())
|
||||
}
|
||||
|
||||
ctx, cncl := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cncl()
|
||||
notification, err = conn.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
if notification.Channel != "self" {
|
||||
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenUnlistenSpecialCharacters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
chanName := "special characters !@#{$%^&*()}"
|
||||
if err := conn.Listen(chanName); err != nil {
|
||||
t.Fatalf("Unable to start listening: %v", err)
|
||||
}
|
||||
|
||||
if err := conn.Unlisten(chanName); err != nil {
|
||||
t.Fatalf("Unable to stop listening: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFatalRxError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -830,7 +602,7 @@ func TestFatalTxError(t *testing.T) {
|
|||
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
|
||||
}
|
||||
|
||||
_, err = conn.Query("select 1")
|
||||
err = conn.QueryRow("select 1").Scan(nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error but none occurred")
|
||||
}
|
||||
|
|
259
copy_from.go
259
copy_from.go
|
@ -2,12 +2,11 @@ package pgx
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
@ -58,39 +57,6 @@ type copyFrom struct {
|
|||
readerErrChan chan error
|
||||
}
|
||||
|
||||
func (ct *copyFrom) readUntilReadyForQuery() {
|
||||
for {
|
||||
msg, err := ct.conn.rxMsg()
|
||||
if err != nil {
|
||||
ct.readerErrChan <- err
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
ct.conn.rxReadyForQuery(msg)
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
case *pgproto3.CommandComplete:
|
||||
case *pgproto3.ErrorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
|
||||
default:
|
||||
err = ct.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *copyFrom) waitForReaderDone() error {
|
||||
var err error
|
||||
for err = range ct.readerErrChan {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ct *copyFrom) run() (int, error) {
|
||||
quotedTableName := ct.tableName.Sanitize()
|
||||
cbuf := &bytes.Buffer{}
|
||||
|
@ -107,163 +73,74 @@ func (ct *copyFrom) run() (int, error) {
|
|||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r, w := io.Pipe()
|
||||
|
||||
err = ct.conn.readUntilCopyInResponse()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
go func() {
|
||||
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
|
||||
buf := ct.conn.wbuf
|
||||
|
||||
panicked := true
|
||||
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
go ct.readUntilReadyForQuery()
|
||||
defer ct.waitForReaderDone()
|
||||
defer func() {
|
||||
if panicked {
|
||||
ct.conn.die(errors.New("panic while in copy from"))
|
||||
moreRows := true
|
||||
for moreRows {
|
||||
var err error
|
||||
moreRows, buf, err = ct.buildCopyBuf(buf, ps)
|
||||
if err != nil {
|
||||
w.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if ct.rowSrc.Err() != nil {
|
||||
w.CloseWithError(ct.rowSrc.Err())
|
||||
return
|
||||
}
|
||||
|
||||
if len(buf) > 0 {
|
||||
_, err = w.Write(buf)
|
||||
if err != nil {
|
||||
w.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
buf = buf[:0]
|
||||
}
|
||||
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
commandTag, err := ct.conn.pgConn.CopyFrom(context.TODO(), r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
|
||||
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
var sentCount int
|
||||
|
||||
moreRows := true
|
||||
for moreRows {
|
||||
select {
|
||||
case err = <-ct.readerErrChan:
|
||||
panicked = false
|
||||
return 0, err
|
||||
default:
|
||||
}
|
||||
|
||||
var addedRows int
|
||||
var err error
|
||||
moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps)
|
||||
if err != nil {
|
||||
panicked = false
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
}
|
||||
sentCount += addedRows
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err = ct.conn.pgConn.Conn().Write(buf)
|
||||
if err != nil {
|
||||
panicked = false
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Directly manipulate wbuf to reset to reuse the same buffer
|
||||
buf = buf[0:5]
|
||||
|
||||
}
|
||||
|
||||
if ct.rowSrc.Err() != nil {
|
||||
panicked = false
|
||||
ct.cancelCopyIn()
|
||||
return 0, ct.rowSrc.Err()
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
buf = append(buf, copyDone)
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
_, err = ct.conn.pgConn.Conn().Write(buf)
|
||||
if err != nil {
|
||||
panicked = false
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.waitForReaderDone()
|
||||
if err != nil {
|
||||
panicked = false
|
||||
return 0, err
|
||||
}
|
||||
|
||||
panicked = false
|
||||
return sentCount, nil
|
||||
return int(commandTag.RowsAffected()), err
|
||||
}
|
||||
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) {
|
||||
var rowCount int
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) {
|
||||
|
||||
for ct.rowSrc.Next() {
|
||||
values, err := ct.rowSrc.Values()
|
||||
if err != nil {
|
||||
return false, nil, 0, err
|
||||
return false, nil, err
|
||||
}
|
||||
if len(values) != len(ct.columnNames) {
|
||||
return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
return false, nil, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
|
||||
if err != nil {
|
||||
return false, nil, 0, err
|
||||
return false, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rowCount++
|
||||
|
||||
if len(buf) > 65536 {
|
||||
return true, buf, rowCount, nil
|
||||
return true, buf, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, buf, rowCount, nil
|
||||
}
|
||||
|
||||
func (c *Conn) readUntilCopyInResponse() error {
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyInResponse:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *copyFrom) cancelCopyIn() error {
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyFail)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, "client error: abort"...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err := ct.conn.pgConn.Conn().Write(buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return false, buf, nil
|
||||
}
|
||||
|
||||
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
|
||||
|
@ -283,57 +160,3 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
|
|||
|
||||
return ct.run()
|
||||
}
|
||||
|
||||
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
|
||||
func (c *Conn) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
|
||||
if err := c.sendSimpleQuery(sql); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.readUntilCopyInResponse(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf := c.wbuf
|
||||
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
for {
|
||||
n, err := r.Read(buf[5:cap(buf)])
|
||||
if err == io.EOF && n == 0 {
|
||||
break
|
||||
}
|
||||
buf = buf[0 : n+5]
|
||||
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||
|
||||
if _, err := c.pgConn.Conn().Write(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
buf = buf[:0]
|
||||
buf = append(buf, copyDone)
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
if _, err := c.pgConn.Conn().Write(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return "", err
|
||||
case *pgproto3.CommandComplete:
|
||||
return pgconn.CommandTag(msg.CommandTag), nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return "", c.rxErrorResponse(msg)
|
||||
default:
|
||||
return "", c.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -38,8 +33,8 @@ func TestConnCopyFromSmall(t *testing.T) {
|
|||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
inputReader := strings.NewReader("0\t1\t2\tabc\tefg\t2000-01-01\t" + tzedTime.Format(time.RFC3339Nano) + "\n" +
|
||||
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
|
||||
// inputReader := strings.NewReader("0\t1\t2\tabc\tefg\t2000-01-01\t" + tzedTime.Format(time.RFC3339Nano) + "\n" +
|
||||
// "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
|
||||
|
||||
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||
if err != nil {
|
||||
|
@ -73,36 +68,38 @@ func TestConnCopyFromSmall(t *testing.T) {
|
|||
|
||||
mustExec(t, conn, "truncate foo")
|
||||
|
||||
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFromReader: %v", err)
|
||||
}
|
||||
copyCount = int(res.RowsAffected())
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
// TODO
|
||||
|
||||
rows, err = conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
// res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for CopyFromReader: %v", err)
|
||||
// }
|
||||
// copyCount = int(res.RowsAffected())
|
||||
// if copyCount != len(inputRows) {
|
||||
// t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
// }
|
||||
|
||||
outputRows = make([][]interface{}, 0)
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
// rows, err = conn.Query("select * from foo")
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for Query: %v", err)
|
||||
// }
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
// outputRows = make([][]interface{}, 0)
|
||||
// for rows.Next() {
|
||||
// row, err := rows.Values()
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
// }
|
||||
// outputRows = append(outputRows, row)
|
||||
// }
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
// if rows.Err() != nil {
|
||||
// t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
// }
|
||||
|
||||
// if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
// t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
// }
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
@ -166,36 +163,38 @@ func TestConnCopyFromLarge(t *testing.T) {
|
|||
|
||||
mustExec(t, conn, "truncate foo")
|
||||
|
||||
res, err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFromReader: %v", err)
|
||||
}
|
||||
copyCount = int(res.RowsAffected())
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
// TODO
|
||||
|
||||
rows, err = conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
// res, err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin")
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for CopyFromReader: %v", err)
|
||||
// }
|
||||
// copyCount = int(res.RowsAffected())
|
||||
// if copyCount != len(inputRows) {
|
||||
// t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
// }
|
||||
|
||||
outputRows = make([][]interface{}, 0)
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
// rows, err = conn.Query("select * from foo")
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for Query: %v", err)
|
||||
// }
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
// outputRows = make([][]interface{}, 0)
|
||||
// for rows.Next() {
|
||||
// row, err := rows.Values()
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
// }
|
||||
// outputRows = append(outputRows, row)
|
||||
// }
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal")
|
||||
}
|
||||
// if rows.Err() != nil {
|
||||
// t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
// }
|
||||
|
||||
// if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
// t.Errorf("Input rows and output rows do not equal")
|
||||
// }
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
@ -221,7 +220,7 @@ func TestConnCopyFromJSON(t *testing.T) {
|
|||
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
|
||||
{nil, nil},
|
||||
}
|
||||
inputReader := strings.NewReader("{\"foo\":\"bar\"}\t{\"bar\":\"quz\"}\n\\N\t\\N\n")
|
||||
// inputReader := strings.NewReader("{\"foo\":\"bar\"}\t{\"bar\":\"quz\"}\n\\N\t\\N\n")
|
||||
|
||||
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
|
||||
if err != nil {
|
||||
|
@ -255,36 +254,38 @@ func TestConnCopyFromJSON(t *testing.T) {
|
|||
|
||||
mustExec(t, conn, "truncate foo")
|
||||
|
||||
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
copyCount = int(res.RowsAffected())
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
// TODO
|
||||
|
||||
rows, err = conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
// res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
// }
|
||||
// copyCount = int(res.RowsAffected())
|
||||
// if copyCount != len(inputRows) {
|
||||
// t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
// }
|
||||
|
||||
outputRows = make([][]interface{}, 0)
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
// rows, err = conn.Query("select * from foo")
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for Query: %v", err)
|
||||
// }
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
// outputRows = make([][]interface{}, 0)
|
||||
// for rows.Next() {
|
||||
// row, err := rows.Values()
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
// }
|
||||
// outputRows = append(outputRows, row)
|
||||
// }
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
// if rows.Err() != nil {
|
||||
// t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
// }
|
||||
|
||||
// if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
// t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
// }
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
@ -327,7 +328,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
|||
{int32(2), nil}, // this row should trigger a failure
|
||||
{int32(3), "def"},
|
||||
}
|
||||
inputReader := strings.NewReader("1\tabc\n2\t\\N\n3\tdef\n")
|
||||
// inputReader := strings.NewReader("1\tabc\n2\t\\N\n3\tdef\n")
|
||||
|
||||
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
|
||||
if err == nil {
|
||||
|
@ -364,39 +365,41 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
|||
|
||||
mustExec(t, conn, "truncate foo")
|
||||
|
||||
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFromReader return error, but it did not")
|
||||
}
|
||||
if _, ok := err.(*pgconn.PgError); !ok {
|
||||
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
copyCount = int(res.RowsAffected())
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
// TODO
|
||||
|
||||
rows, err = conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
// res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
// if err == nil {
|
||||
// t.Errorf("Expected CopyFromReader return error, but it did not")
|
||||
// }
|
||||
// if _, ok := err.(*pgconn.PgError); !ok {
|
||||
// t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
|
||||
// }
|
||||
// copyCount = int(res.RowsAffected())
|
||||
// if copyCount != 0 {
|
||||
// t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
|
||||
// }
|
||||
|
||||
outputRows = make([][]interface{}, 0)
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
// rows, err = conn.Query("select * from foo")
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for Query: %v", err)
|
||||
// }
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
// outputRows = make([][]interface{}, 0)
|
||||
// for rows.Next() {
|
||||
// row, err := rows.Values()
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
// }
|
||||
// outputRows = append(outputRows, row)
|
||||
// }
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
}
|
||||
// if rows.Err() != nil {
|
||||
// t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
// }
|
||||
|
||||
// if len(outputRows) != 0 {
|
||||
// t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
// }
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
@ -513,7 +516,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
|
|||
}
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
t.Errorf("Expected 0 rows, but got %v", len(outputRows))
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
|
@ -578,192 +581,3 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
|||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type nextPanicSource struct {
|
||||
}
|
||||
|
||||
func (cfs *nextPanicSource) Next() bool {
|
||||
panic("crash")
|
||||
}
|
||||
|
||||
func (cfs *nextPanicSource) Values() ([]interface{}, error) {
|
||||
return []interface{}{nil}, nil // should never get here
|
||||
}
|
||||
|
||||
func (cfs *nextPanicSource) Err() error {
|
||||
return nil // should never gets here
|
||||
}
|
||||
|
||||
func TestConnCopyFromCopyFromSourceNextPanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a bytea not null
|
||||
)`)
|
||||
|
||||
caughtPanic := false
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if x := recover(); x != nil {
|
||||
caughtPanic = true
|
||||
}
|
||||
}()
|
||||
|
||||
conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &nextPanicSource{})
|
||||
}()
|
||||
|
||||
if !caughtPanic {
|
||||
t.Error("expected panic but did not")
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("panic should have killed conn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnCopyFromReaderQueryError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
inputReader := strings.NewReader("")
|
||||
|
||||
res, err := conn.CopyFromReader(inputReader, "cropy foo from stdin")
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFromReader return error, but it did not")
|
||||
}
|
||||
|
||||
if _, ok := err.(*pgconn.PgError); !ok {
|
||||
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromReaderNoTableError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
inputReader := strings.NewReader("")
|
||||
|
||||
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFromReader return error, but it did not")
|
||||
}
|
||||
|
||||
if _, ok := err.(*pgconn.PgError); !ok {
|
||||
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromGzipReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int4,
|
||||
b varchar
|
||||
)`)
|
||||
|
||||
f, err := ioutil.TempFile("", "*")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for ioutil.TempFile: %v", err)
|
||||
}
|
||||
|
||||
gw := gzip.NewWriter(f)
|
||||
|
||||
inputRows := [][]interface{}{}
|
||||
for i := 0; i < 1000; i++ {
|
||||
val := strconv.Itoa(i * i)
|
||||
inputRows = append(inputRows, []interface{}{int32(i), val})
|
||||
_, err = gw.Write([]byte(fmt.Sprintf("%d,\"%s\"\n", i, val)))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for gw.Write: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = gw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for gw.Close: %v", err)
|
||||
}
|
||||
|
||||
_, err = f.Seek(0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for f.Seek: %v", err)
|
||||
}
|
||||
|
||||
gr, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for gzip.NewReader: %v", err)
|
||||
}
|
||||
|
||||
res, err := conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFromReader: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyFromReader to return 1000 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
err = gr.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for gr.Close: %v", err)
|
||||
}
|
||||
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for f.Close: %v", err)
|
||||
}
|
||||
|
||||
err = os.Remove(f.Name())
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for os.Remove: %v", err)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
|
64
copy_to.go
64
copy_to.go
|
@ -1,64 +0,0 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
func (c *Conn) readUntilCopyOutResponse() error {
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyOutResponse:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (pgconn.CommandTag, error) {
|
||||
if err := c.sendSimpleQuery(sql, args...); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.readUntilCopyOutResponse(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyDone:
|
||||
break
|
||||
case *pgproto3.CopyData:
|
||||
_, err := w.Write(msg.Data)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return "", err
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return "", nil
|
||||
case *pgproto3.CommandComplete:
|
||||
return pgconn.CommandTag(msg.CommandTag), nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return "", c.rxErrorResponse(msg)
|
||||
default:
|
||||
return "", c.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
116
copy_to_test.go
116
copy_to_test.go
|
@ -1,116 +0,0 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
)
|
||||
|
||||
func TestConnCopyToWriterSmall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g json
|
||||
)`)
|
||||
mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`)
|
||||
mustExec(t, conn, `insert into foo values (null, null, null, null, null, null, null)`)
|
||||
|
||||
inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
|
||||
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
|
||||
|
||||
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
||||
|
||||
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyToWriter: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 2 {
|
||||
t.Errorf("Expected CopyToWriter to return 2 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
|
||||
t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes()))
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToWriterLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g json,
|
||||
h bytea
|
||||
)`)
|
||||
inputBytes := make([]byte, 0)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`)
|
||||
inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
|
||||
}
|
||||
|
||||
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
||||
|
||||
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 1000 {
|
||||
t.Errorf("Expected CopyToWriter to return 1 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
|
||||
t.Errorf("Input rows and output rows do not equal")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToWriterQueryError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
outputWriter := bytes.NewBuffer(make([]byte, 0))
|
||||
|
||||
res, err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyToWriter return error, but it did not")
|
||||
}
|
||||
|
||||
if _, ok := err.(*pgconn.PgError); !ok {
|
||||
t.Errorf("Expected CopyToWriter return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyToWriter to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
|
@ -398,6 +398,12 @@ func (pgConn *PgConn) hardClose() error {
|
|||
return pgConn.conn.Close()
|
||||
}
|
||||
|
||||
// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of
|
||||
// underlying connection.
|
||||
func (pgConn *PgConn) IsAlive() bool {
|
||||
return !pgConn.closed
|
||||
}
|
||||
|
||||
// writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error.
|
||||
func (pgConn *PgConn) writeAll(buf []byte) error {
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
|
|
269
query.go
269
query.go
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/internal/sanitize"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
@ -56,6 +57,8 @@ type Rows struct {
|
|||
args []interface{}
|
||||
unlockConn bool
|
||||
closed bool
|
||||
|
||||
resultReader *pgconn.ResultReader
|
||||
}
|
||||
|
||||
func (rows *Rows) FieldDescriptions() []FieldDescription {
|
||||
|
@ -76,7 +79,12 @@ func (rows *Rows) Close() {
|
|||
|
||||
rows.closed = true
|
||||
|
||||
rows.err = rows.conn.termContext(rows.err)
|
||||
if rows.resultReader != nil {
|
||||
_, closeErr := rows.resultReader.Close()
|
||||
if rows.err == nil {
|
||||
rows.err = closeErr
|
||||
}
|
||||
}
|
||||
|
||||
if rows.err == nil {
|
||||
if rows.conn.shouldLog(LogLevelInfo) {
|
||||
|
@ -119,50 +127,21 @@ func (rows *Rows) Next() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
rows.rowCount++
|
||||
rows.columnIdx = 0
|
||||
|
||||
for {
|
||||
msg, err := rows.conn.rxMsg()
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return false
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.RowDescription:
|
||||
rows.fields = rows.conn.rxRowDescription(msg)
|
||||
for i := range rows.fields {
|
||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok {
|
||||
rows.fields[i].DataTypeName = dt.Name
|
||||
rows.fields[i].FormatCode = TextFormatCode
|
||||
} else {
|
||||
rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType))
|
||||
return false
|
||||
}
|
||||
}
|
||||
case *pgproto3.DataRow:
|
||||
if len(msg.Values) != len(rows.fields) {
|
||||
rows.fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values))))
|
||||
return false
|
||||
}
|
||||
|
||||
rows.values = msg.Values
|
||||
return true
|
||||
case *pgproto3.CommandComplete:
|
||||
if rows.batch != nil {
|
||||
rows.batch.pendingCommandComplete = false
|
||||
}
|
||||
rows.Close()
|
||||
return false
|
||||
|
||||
default:
|
||||
err = rows.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return false
|
||||
if rows.resultReader.NextRow() {
|
||||
if rows.fields == nil {
|
||||
rrFieldDescriptions := rows.resultReader.FieldDescriptions()
|
||||
rows.fields = make([]FieldDescription, len(rrFieldDescriptions))
|
||||
for i := range rrFieldDescriptions {
|
||||
rows.conn.pgproto3FieldDescriptionToPgxFieldDescription(&rrFieldDescriptions[i], &rows.fields[i])
|
||||
}
|
||||
}
|
||||
rows.rowCount++
|
||||
rows.columnIdx = 0
|
||||
rows.values = rows.resultReader.Values()
|
||||
return true
|
||||
} else {
|
||||
rows.Close()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,15 +160,6 @@ func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) {
|
|||
return buf, fd, true
|
||||
}
|
||||
|
||||
type scanArgError struct {
|
||||
col int
|
||||
err error
|
||||
}
|
||||
|
||||
func (e scanArgError) Error() string {
|
||||
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
|
||||
}
|
||||
|
||||
// Scan reads the values from the current row into dest values positionally.
|
||||
// dest can include pointers to core types, values implementing the Scanner
|
||||
// interface, []byte, and nil. []byte will skip the decoding process and directly
|
||||
|
@ -326,6 +296,15 @@ func (rows *Rows) Values() ([]interface{}, error) {
|
|||
return values, rows.Err()
|
||||
}
|
||||
|
||||
type scanArgError struct {
|
||||
col int
|
||||
err error
|
||||
}
|
||||
|
||||
func (e scanArgError) Error() string {
|
||||
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
|
||||
}
|
||||
|
||||
// Query executes sql with args. If there is an error the returned *Rows will
|
||||
// be returned in an error state. So it is allowed to ignore the error returned
|
||||
// from Query and handle it in *Rows.
|
||||
|
@ -369,7 +348,14 @@ type QueryExOptions struct {
|
|||
|
||||
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
|
||||
c.lastStmtSent = false
|
||||
rows = c.getRows(sql, args)
|
||||
// rows = c.getRows(sql, args)
|
||||
|
||||
rows = &Rows{
|
||||
conn: c,
|
||||
startTime: time.Now(),
|
||||
sql: sql,
|
||||
args: args,
|
||||
}
|
||||
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
|
@ -388,85 +374,128 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
}
|
||||
rows.unlockConn = true
|
||||
|
||||
err = c.initContext(ctx)
|
||||
// err = c.initContext(ctx)
|
||||
// if err != nil {
|
||||
// rows.fatal(err)
|
||||
// return rows, rows.err
|
||||
// }
|
||||
|
||||
// if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
||||
// c.lastStmtSent = true
|
||||
// err = c.sanitizeAndSendSimpleQuery(sql, args...)
|
||||
// if err != nil {
|
||||
// rows.fatal(err)
|
||||
// return rows, err
|
||||
// }
|
||||
|
||||
// return rows, nil
|
||||
// }
|
||||
|
||||
// if options != nil && len(options.ParameterOIDs) > 0 {
|
||||
|
||||
// buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args)
|
||||
// if err != nil {
|
||||
// rows.fatal(err)
|
||||
// return rows, err
|
||||
// }
|
||||
|
||||
// buf = appendSync(buf)
|
||||
|
||||
// n, err := c.pgConn.Conn().Write(buf)
|
||||
// c.lastStmtSent = true
|
||||
// if err != nil && fatalWriteErr(n, err) {
|
||||
// rows.fatal(err)
|
||||
// c.die(err)
|
||||
// return rows, err
|
||||
// }
|
||||
// c.pendingReadyForQueryCount++
|
||||
|
||||
// fieldDescriptions, err := c.readUntilRowDescription()
|
||||
// if err != nil {
|
||||
// rows.fatal(err)
|
||||
// return rows, err
|
||||
// }
|
||||
|
||||
// if len(options.ResultFormatCodes) == 0 {
|
||||
// for i := range fieldDescriptions {
|
||||
// fieldDescriptions[i].FormatCode = TextFormatCode
|
||||
// }
|
||||
// } else if len(options.ResultFormatCodes) == 1 {
|
||||
// fc := options.ResultFormatCodes[0]
|
||||
// for i := range fieldDescriptions {
|
||||
// fieldDescriptions[i].FormatCode = fc
|
||||
// }
|
||||
// } else {
|
||||
// for i := range options.ResultFormatCodes {
|
||||
// fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
|
||||
// }
|
||||
// }
|
||||
|
||||
// rows.sql = sql
|
||||
// rows.fields = fieldDescriptions
|
||||
// return rows, nil
|
||||
// }
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
if len(psd.ParamOIDs) != len(args) {
|
||||
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(args)))
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
ps = &PreparedStatement{
|
||||
Name: psd.Name,
|
||||
SQL: psd.SQL,
|
||||
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
|
||||
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
|
||||
}
|
||||
|
||||
for i := range ps.ParameterOIDs {
|
||||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||
}
|
||||
for i := range ps.FieldDescriptions {
|
||||
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
|
||||
}
|
||||
}
|
||||
rows.sql = ps.SQL
|
||||
|
||||
args, err = convertDriverValuers(args)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
||||
c.lastStmtSent = true
|
||||
err = c.sanitizeAndSendSimpleQuery(sql, args...)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
if options != nil && len(options.ParameterOIDs) > 0 {
|
||||
|
||||
buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
buf = appendSync(buf)
|
||||
|
||||
n, err := c.pgConn.Conn().Write(buf)
|
||||
c.lastStmtSent = true
|
||||
if err != nil && fatalWriteErr(n, err) {
|
||||
rows.fatal(err)
|
||||
c.die(err)
|
||||
return rows, err
|
||||
}
|
||||
c.pendingReadyForQueryCount++
|
||||
|
||||
fieldDescriptions, err := c.readUntilRowDescription()
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
if len(options.ResultFormatCodes) == 0 {
|
||||
for i := range fieldDescriptions {
|
||||
fieldDescriptions[i].FormatCode = TextFormatCode
|
||||
}
|
||||
} else if len(options.ResultFormatCodes) == 1 {
|
||||
fc := options.ResultFormatCodes[0]
|
||||
for i := range fieldDescriptions {
|
||||
fieldDescriptions[i].FormatCode = fc
|
||||
}
|
||||
} else {
|
||||
for i := range options.ResultFormatCodes {
|
||||
fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
|
||||
}
|
||||
}
|
||||
|
||||
rows.sql = sql
|
||||
rows.fields = fieldDescriptions
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
ps, err = c.prepareEx("", sql, nil)
|
||||
paramFormats := make([]int16, len(args))
|
||||
paramValues := make([][]byte, len(args))
|
||||
for i := range args {
|
||||
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], args[i])
|
||||
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], args[i])
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
resultFormats := make([]int16, len(ps.FieldDescriptions))
|
||||
for i := range resultFormats {
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||
resultFormats[i] = BinaryFormatCode
|
||||
} else {
|
||||
resultFormats[i] = TextFormatCode
|
||||
}
|
||||
}
|
||||
}
|
||||
rows.sql = ps.SQL
|
||||
rows.fields = ps.FieldDescriptions
|
||||
|
||||
c.lastStmtSent = true
|
||||
err = c.sendPreparedQuery(ps, args...)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
}
|
||||
rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats)
|
||||
|
||||
return rows, rows.err
|
||||
}
|
||||
|
|
|
@ -887,10 +887,10 @@ func TestQueryRowErrors(t *testing.T) {
|
|||
scanArgs []interface{}
|
||||
err string
|
||||
}{
|
||||
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
||||
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
||||
// {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
||||
// {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
||||
{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"},
|
||||
{"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"},
|
||||
// {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
|
594
stress_test.go
594
stress_test.go
|
@ -1,361 +1,361 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "math/rand"
|
||||
// "os"
|
||||
// "strconv"
|
||||
// "testing"
|
||||
// "time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
// "github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/fake"
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
)
|
||||
// "github.com/jackc/fake"
|
||||
// "github.com/jackc/pgx"
|
||||
// "github.com/jackc/pgx/pgconn"
|
||||
// )
|
||||
|
||||
type execer interface {
|
||||
Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
|
||||
}
|
||||
type queryer interface {
|
||||
Query(sql string, args ...interface{}) (*pgx.Rows, error)
|
||||
}
|
||||
type queryRower interface {
|
||||
QueryRow(sql string, args ...interface{}) *pgx.Row
|
||||
}
|
||||
// type execer interface {
|
||||
// Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
|
||||
// }
|
||||
// type queryer interface {
|
||||
// Query(sql string, args ...interface{}) (*pgx.Rows, error)
|
||||
// }
|
||||
// type queryRower interface {
|
||||
// QueryRow(sql string, args ...interface{}) *pgx.Row
|
||||
// }
|
||||
|
||||
func TestStressConnPool(t *testing.T) {
|
||||
t.Parallel()
|
||||
// func TestStressConnPool(t *testing.T) {
|
||||
// t.Parallel()
|
||||
|
||||
maxConnections := 8
|
||||
pool := createConnPool(t, maxConnections)
|
||||
defer pool.Close()
|
||||
// maxConnections := 8
|
||||
// pool := createConnPool(t, maxConnections)
|
||||
// defer pool.Close()
|
||||
|
||||
setupStressDB(t, pool)
|
||||
// setupStressDB(t, pool)
|
||||
|
||||
actions := []struct {
|
||||
name string
|
||||
fn func(*pgx.ConnPool, int) error
|
||||
}{
|
||||
{"insertUnprepared", func(p *pgx.ConnPool, n int) error { return insertUnprepared(p, n) }},
|
||||
{"queryRowWithoutParams", func(p *pgx.ConnPool, n int) error { return queryRowWithoutParams(p, n) }},
|
||||
{"query", func(p *pgx.ConnPool, n int) error { return queryCloseEarly(p, n) }},
|
||||
{"queryCloseEarly", func(p *pgx.ConnPool, n int) error { return query(p, n) }},
|
||||
{"queryErrorWhileReturningRows", func(p *pgx.ConnPool, n int) error { return queryErrorWhileReturningRows(p, n) }},
|
||||
{"txInsertRollback", txInsertRollback},
|
||||
{"txInsertCommit", txInsertCommit},
|
||||
{"txMultipleQueries", txMultipleQueries},
|
||||
{"notify", notify},
|
||||
{"listenAndPoolUnlistens", listenAndPoolUnlistens},
|
||||
{"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
|
||||
{"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
|
||||
{"canceledQueryExContext", canceledQueryExContext},
|
||||
{"canceledExecExContext", canceledExecExContext},
|
||||
}
|
||||
// actions := []struct {
|
||||
// name string
|
||||
// fn func(*pgx.ConnPool, int) error
|
||||
// }{
|
||||
// {"insertUnprepared", func(p *pgx.ConnPool, n int) error { return insertUnprepared(p, n) }},
|
||||
// {"queryRowWithoutParams", func(p *pgx.ConnPool, n int) error { return queryRowWithoutParams(p, n) }},
|
||||
// {"query", func(p *pgx.ConnPool, n int) error { return queryCloseEarly(p, n) }},
|
||||
// {"queryCloseEarly", func(p *pgx.ConnPool, n int) error { return query(p, n) }},
|
||||
// {"queryErrorWhileReturningRows", func(p *pgx.ConnPool, n int) error { return queryErrorWhileReturningRows(p, n) }},
|
||||
// {"txInsertRollback", txInsertRollback},
|
||||
// {"txInsertCommit", txInsertCommit},
|
||||
// {"txMultipleQueries", txMultipleQueries},
|
||||
// {"notify", notify},
|
||||
// {"listenAndPoolUnlistens", listenAndPoolUnlistens},
|
||||
// {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
|
||||
// {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
|
||||
// {"canceledQueryExContext", canceledQueryExContext},
|
||||
// {"canceledExecExContext", canceledExecExContext},
|
||||
// }
|
||||
|
||||
actionCount := 1000
|
||||
if s := os.Getenv("STRESS_FACTOR"); s != "" {
|
||||
stressFactor, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse STRESS_FACTOR: %v", s)
|
||||
}
|
||||
actionCount *= int(stressFactor)
|
||||
}
|
||||
// actionCount := 1000
|
||||
// if s := os.Getenv("STRESS_FACTOR"); s != "" {
|
||||
// stressFactor, err := strconv.ParseInt(s, 10, 64)
|
||||
// if err != nil {
|
||||
// t.Fatalf("failed to parse STRESS_FACTOR: %v", s)
|
||||
// }
|
||||
// actionCount *= int(stressFactor)
|
||||
// }
|
||||
|
||||
workerCount := 16
|
||||
// workerCount := 16
|
||||
|
||||
workChan := make(chan int)
|
||||
doneChan := make(chan struct{})
|
||||
errChan := make(chan error)
|
||||
// workChan := make(chan int)
|
||||
// doneChan := make(chan struct{})
|
||||
// errChan := make(chan error)
|
||||
|
||||
work := func() {
|
||||
for n := range workChan {
|
||||
action := actions[rand.Intn(len(actions))]
|
||||
err := action.fn(pool, n)
|
||||
if err != nil {
|
||||
errChan <- errors.Errorf("%s: %v", action.name, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
doneChan <- struct{}{}
|
||||
}
|
||||
// work := func() {
|
||||
// for n := range workChan {
|
||||
// action := actions[rand.Intn(len(actions))]
|
||||
// err := action.fn(pool, n)
|
||||
// if err != nil {
|
||||
// errChan <- errors.Errorf("%s: %v", action.name, err)
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// doneChan <- struct{}{}
|
||||
// }
|
||||
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go work()
|
||||
}
|
||||
// for i := 0; i < workerCount; i++ {
|
||||
// go work()
|
||||
// }
|
||||
|
||||
for i := 0; i < actionCount; i++ {
|
||||
select {
|
||||
case workChan <- i:
|
||||
case err := <-errChan:
|
||||
close(workChan)
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
close(workChan)
|
||||
// for i := 0; i < actionCount; i++ {
|
||||
// select {
|
||||
// case workChan <- i:
|
||||
// case err := <-errChan:
|
||||
// close(workChan)
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// }
|
||||
// close(workChan)
|
||||
|
||||
for i := 0; i < workerCount; i++ {
|
||||
<-doneChan
|
||||
}
|
||||
}
|
||||
// for i := 0; i < workerCount; i++ {
|
||||
// <-doneChan
|
||||
// }
|
||||
// }
|
||||
|
||||
func setupStressDB(t *testing.T, pool *pgx.ConnPool) {
|
||||
_, err := pool.Exec(context.Background(), `
|
||||
drop table if exists widgets;
|
||||
create table widgets(
|
||||
id serial primary key,
|
||||
name varchar not null,
|
||||
description text,
|
||||
creation_time timestamptz
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
// func setupStressDB(t *testing.T, pool *pgx.ConnPool) {
|
||||
// _, err := pool.Exec(context.Background(), `
|
||||
// drop table if exists widgets;
|
||||
// create table widgets(
|
||||
// id serial primary key,
|
||||
// name varchar not null,
|
||||
// description text,
|
||||
// creation_time timestamptz
|
||||
// );
|
||||
// `)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// }
|
||||
|
||||
func insertUnprepared(e execer, actionNum int) error {
|
||||
sql := `
|
||||
insert into widgets(name, description, creation_time)
|
||||
values($1, $2, $3)`
|
||||
// func insertUnprepared(e execer, actionNum int) error {
|
||||
// sql := `
|
||||
// insert into widgets(name, description, creation_time)
|
||||
// values($1, $2, $3)`
|
||||
|
||||
_, err := e.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
return err
|
||||
}
|
||||
// _, err := e.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
// return err
|
||||
// }
|
||||
|
||||
func queryRowWithoutParams(qr queryRower, actionNum int) error {
|
||||
var id int32
|
||||
var name, description string
|
||||
var creationTime time.Time
|
||||
// func queryRowWithoutParams(qr queryRower, actionNum int) error {
|
||||
// var id int32
|
||||
// var name, description string
|
||||
// var creationTime time.Time
|
||||
|
||||
sql := `select * from widgets order by random() limit 1`
|
||||
// sql := `select * from widgets order by random() limit 1`
|
||||
|
||||
err := qr.QueryRow(sql).Scan(&id, &name, &description, &creationTime)
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
// err := qr.QueryRow(sql).Scan(&id, &name, &description, &creationTime)
|
||||
// if err == pgx.ErrNoRows {
|
||||
// return nil
|
||||
// }
|
||||
// return err
|
||||
// }
|
||||
|
||||
func query(q queryer, actionNum int) error {
|
||||
sql := `select * from widgets order by random() limit $1`
|
||||
// func query(q queryer, actionNum int) error {
|
||||
// sql := `select * from widgets order by random() limit $1`
|
||||
|
||||
rows, err := q.Query(sql, 10)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
// rows, err := q.Query(sql, 10)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id int32
|
||||
var name, description string
|
||||
var creationTime time.Time
|
||||
rows.Scan(&id, &name, &description, &creationTime)
|
||||
}
|
||||
// for rows.Next() {
|
||||
// var id int32
|
||||
// var name, description string
|
||||
// var creationTime time.Time
|
||||
// rows.Scan(&id, &name, &description, &creationTime)
|
||||
// }
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
// return rows.Err()
|
||||
// }
|
||||
|
||||
func queryCloseEarly(q queryer, actionNum int) error {
|
||||
sql := `select * from generate_series(1,$1)`
|
||||
// func queryCloseEarly(q queryer, actionNum int) error {
|
||||
// sql := `select * from generate_series(1,$1)`
|
||||
|
||||
rows, err := q.Query(sql, 100)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
// rows, err := q.Query(sql, 100)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer rows.Close()
|
||||
|
||||
for i := 0; i < 10 && rows.Next(); i++ {
|
||||
var n int32
|
||||
rows.Scan(&n)
|
||||
}
|
||||
rows.Close()
|
||||
// for i := 0; i < 10 && rows.Next(); i++ {
|
||||
// var n int32
|
||||
// rows.Scan(&n)
|
||||
// }
|
||||
// rows.Close()
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
// return rows.Err()
|
||||
// }
|
||||
|
||||
func queryErrorWhileReturningRows(q queryer, actionNum int) error {
|
||||
// This query should divide by 0 within the first number of rows
|
||||
sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)`
|
||||
// func queryErrorWhileReturningRows(q queryer, actionNum int) error {
|
||||
// // This query should divide by 0 within the first number of rows
|
||||
// sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)`
|
||||
|
||||
rows, err := q.Query(sql)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
// rows, err := q.Query(sql)
|
||||
// if err != nil {
|
||||
// return nil
|
||||
// }
|
||||
// defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var n int32
|
||||
rows.Scan(&n)
|
||||
}
|
||||
// for rows.Next() {
|
||||
// var n int32
|
||||
// rows.Scan(&n)
|
||||
// }
|
||||
|
||||
if _, ok := rows.Err().(*pgconn.PgError); ok {
|
||||
return nil
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
// if _, ok := rows.Err().(*pgconn.PgError); ok {
|
||||
// return nil
|
||||
// }
|
||||
// return rows.Err()
|
||||
// }
|
||||
|
||||
func notify(pool *pgx.ConnPool, actionNum int) error {
|
||||
_, err := pool.Exec(context.Background(), "notify stress")
|
||||
return err
|
||||
}
|
||||
// func notify(pool *pgx.ConnPool, actionNum int) error {
|
||||
// _, err := pool.Exec(context.Background(), "notify stress")
|
||||
// return err
|
||||
// }
|
||||
|
||||
func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error {
|
||||
conn, err := pool.Acquire()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer pool.Release(conn)
|
||||
// func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error {
|
||||
// conn, err := pool.Acquire()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer pool.Release(conn)
|
||||
|
||||
err = conn.Listen("stress")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// err = conn.Listen("stress")
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err = conn.WaitForNotification(ctx)
|
||||
if err == context.DeadlineExceeded {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
// defer cancel()
|
||||
// _, err = conn.WaitForNotification(ctx)
|
||||
// if err == context.DeadlineExceeded {
|
||||
// return nil
|
||||
// }
|
||||
// return err
|
||||
// }
|
||||
|
||||
func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error {
|
||||
psName := fmt.Sprintf("poolPreparedStatement%d", actionNum)
|
||||
// func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error {
|
||||
// psName := fmt.Sprintf("poolPreparedStatement%d", actionNum)
|
||||
|
||||
_, err := pool.Prepare(psName, "select $1::text")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// _, err := pool.Prepare(psName, "select $1::text")
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
var s string
|
||||
err = pool.QueryRow(psName, "hello").Scan(&s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// var s string
|
||||
// err = pool.QueryRow(psName, "hello").Scan(&s)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
if s != "hello" {
|
||||
return errors.Errorf("Prepared statement did not return expected value: %v", s)
|
||||
}
|
||||
// if s != "hello" {
|
||||
// return errors.Errorf("Prepared statement did not return expected value: %v", s)
|
||||
// }
|
||||
|
||||
return pool.Deallocate(psName)
|
||||
}
|
||||
// return pool.Deallocate(psName)
|
||||
// }
|
||||
|
||||
func txInsertRollback(pool *pgx.ConnPool, actionNum int) error {
|
||||
tx, err := pool.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// func txInsertRollback(pool *pgx.ConnPool, actionNum int) error {
|
||||
// tx, err := pool.Begin()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
sql := `
|
||||
insert into widgets(name, description, creation_time)
|
||||
values($1, $2, $3)`
|
||||
// sql := `
|
||||
// insert into widgets(name, description, creation_time)
|
||||
// values($1, $2, $3)`
|
||||
|
||||
_, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// _, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
return tx.Rollback()
|
||||
}
|
||||
// return tx.Rollback()
|
||||
// }
|
||||
|
||||
func txInsertCommit(pool *pgx.ConnPool, actionNum int) error {
|
||||
tx, err := pool.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// func txInsertCommit(pool *pgx.ConnPool, actionNum int) error {
|
||||
// tx, err := pool.Begin()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
sql := `
|
||||
insert into widgets(name, description, creation_time)
|
||||
values($1, $2, $3)`
|
||||
// sql := `
|
||||
// insert into widgets(name, description, creation_time)
|
||||
// values($1, $2, $3)`
|
||||
|
||||
_, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
// _, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
// if err != nil {
|
||||
// tx.Rollback()
|
||||
// return err
|
||||
// }
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
// return tx.Commit()
|
||||
// }
|
||||
|
||||
func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
|
||||
tx, err := pool.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
|
||||
// tx, err := pool.Begin()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer tx.Rollback()
|
||||
|
||||
errExpectedTxDeath := errors.New("Expected tx death")
|
||||
// errExpectedTxDeath := errors.New("Expected tx death")
|
||||
|
||||
actions := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"insertUnprepared", func() error { return insertUnprepared(tx, actionNum) }},
|
||||
{"queryRowWithoutParams", func() error { return queryRowWithoutParams(tx, actionNum) }},
|
||||
{"query", func() error { return query(tx, actionNum) }},
|
||||
{"queryCloseEarly", func() error { return queryCloseEarly(tx, actionNum) }},
|
||||
{"queryErrorWhileReturningRows", func() error {
|
||||
err := queryErrorWhileReturningRows(tx, actionNum)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errExpectedTxDeath
|
||||
}},
|
||||
}
|
||||
// actions := []struct {
|
||||
// name string
|
||||
// fn func() error
|
||||
// }{
|
||||
// {"insertUnprepared", func() error { return insertUnprepared(tx, actionNum) }},
|
||||
// {"queryRowWithoutParams", func() error { return queryRowWithoutParams(tx, actionNum) }},
|
||||
// {"query", func() error { return query(tx, actionNum) }},
|
||||
// {"queryCloseEarly", func() error { return queryCloseEarly(tx, actionNum) }},
|
||||
// {"queryErrorWhileReturningRows", func() error {
|
||||
// err := queryErrorWhileReturningRows(tx, actionNum)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// return errExpectedTxDeath
|
||||
// }},
|
||||
// }
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
action := actions[rand.Intn(len(actions))]
|
||||
err := action.fn()
|
||||
if err == errExpectedTxDeath {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// for i := 0; i < 20; i++ {
|
||||
// action := actions[rand.Intn(len(actions))]
|
||||
// err := action.fn()
|
||||
// if err == errExpectedTxDeath {
|
||||
// return nil
|
||||
// } else if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
// return tx.Commit()
|
||||
// }
|
||||
|
||||
func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
// func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
// ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
// go func() {
|
||||
// time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||
// cancelFunc()
|
||||
// }()
|
||||
|
||||
rows, err := pool.QueryEx(ctx, "select pg_sleep(2)", nil)
|
||||
if err == context.Canceled {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return errors.Errorf("Only allowed error is context.Canceled, got %v", err)
|
||||
}
|
||||
// rows, err := pool.QueryEx(ctx, "select pg_sleep(2)", nil)
|
||||
// if err == context.Canceled {
|
||||
// return nil
|
||||
// } else if err != nil {
|
||||
// return errors.Errorf("Only allowed error is context.Canceled, got %v", err)
|
||||
// }
|
||||
|
||||
for rows.Next() {
|
||||
return errors.New("should never receive row")
|
||||
}
|
||||
// for rows.Next() {
|
||||
// return errors.New("should never receive row")
|
||||
// }
|
||||
|
||||
if rows.Err() != context.Canceled {
|
||||
return errors.Errorf("Expected context.Canceled error, got %v", rows.Err())
|
||||
}
|
||||
// if rows.Err() != context.Canceled {
|
||||
// return errors.Errorf("Expected context.Canceled error, got %v", rows.Err())
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
// func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
// ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
// go func() {
|
||||
// time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||
// cancelFunc()
|
||||
// }()
|
||||
|
||||
_, err := pool.Exec(ctx, "select pg_sleep(2)")
|
||||
if err != context.Canceled {
|
||||
return errors.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
// _, err := pool.Exec(ctx, "select pg_sleep(2)")
|
||||
// if err != context.Canceled {
|
||||
// return errors.Errorf("Expected context.Canceled error, got %v", err)
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
// return nil
|
||||
// }
|
||||
|
|
19
tx.go
19
tx.go
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
|
@ -231,24 +230,6 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
|
|||
return tx.conn.CopyFrom(tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
||||
// CopyFromReader delegates to the underlying *Conn
|
||||
func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag pgconn.CommandTag, err error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return "", ErrTxClosed
|
||||
}
|
||||
|
||||
return tx.conn.CopyFromReader(r, sql)
|
||||
}
|
||||
|
||||
// CopyToWriter delegates to the underlying *Conn
|
||||
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return "", ErrTxClosed
|
||||
}
|
||||
|
||||
return tx.conn.CopyToWriter(w, sql, args...)
|
||||
}
|
||||
|
||||
// Status returns the status of the transaction from the set of
|
||||
// pgx.TxStatus* constants.
|
||||
func (tx *Tx) Status() int8 {
|
||||
|
|
31
v4.md
31
v4.md
|
@ -28,10 +28,6 @@ Potential Changes:
|
|||
* Consider strongly typed query parameters in style of zerolog (chained functions instead of varargs)
|
||||
* Consider buffered query select where entire result set is received and parsed successfully or call returns error
|
||||
|
||||
Minor Potential Changes:
|
||||
|
||||
* Change PgError error implementation to pointer method
|
||||
|
||||
## Changes
|
||||
|
||||
* `pgconn.PgConn` now contains core PostgreSQL connection functionality.
|
||||
|
@ -43,7 +39,34 @@ Minor Potential Changes:
|
|||
* Connect method now takes context and connection string.
|
||||
* ConnectConfig takes context and config object.
|
||||
* `RuntimeParams` `pgx.Conn`. Server reported status can now be queried with the `ParameterStatus` method. The rename aligns with the PostgreSQL protocol and standard libpq naming. Access via a method instead of direct access to the map protects against outside modification.
|
||||
* LISTEN / NOTIFY functionality moved to pgconn.
|
||||
* COPY TO functionality moved to pgconn.
|
||||
* COPY FROM functionality moved to pgconn.
|
||||
|
||||
## New Features
|
||||
|
||||
* Specifying multiple hosts for connecting to HA systems.
|
||||
|
||||
|
||||
## Transaction idea
|
||||
|
||||
Problem: Using original connection or pool outside of tx object
|
||||
|
||||
|
||||
|
||||
tx = pool.Begin()
|
||||
tx.Query(...)
|
||||
pool.Query(...) // <- Possible to accidentally do stuff outside of tx
|
||||
|
||||
Solution: Common interface for basic queries and atomicity
|
||||
|
||||
var querier Querier
|
||||
querier = pool
|
||||
|
||||
querier = querier.Begin() <- tx implements querier
|
||||
querier.Query(...)
|
||||
querier = querier.Commit()
|
||||
|
||||
-- tx implements begin, commit, and rollback as save points
|
||||
-- conn implements begin as create tx (what about commit and rollback? No-op?)
|
||||
-- pool implements begin as?
|
||||
|
|
Loading…
Reference in New Issue