Implement pgx.Conn.Exec in terms of pgconn.PgConn.Exec

pull/483/head
Jack Christensen 2019-01-02 12:32:36 -06:00
parent 23cbe89dfd
commit 12857ad05b
9 changed files with 194 additions and 31 deletions

179
conn.go
View File

@ -2,6 +2,8 @@ package pgx
import (
"context"
"database/sql/driver"
"fmt"
"net"
"reflect"
"strings"
@ -673,7 +675,7 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
// Listen establishes a PostgreSQL listen/notify to channel
func (c *Conn) Listen(channel string) error {
_, err := c.Exec("listen " + quoteIdentifier(channel))
_, err := c.Exec(context.TODO(), "listen "+quoteIdentifier(channel))
if err != nil {
return err
}
@ -685,7 +687,7 @@ func (c *Conn) Listen(channel string) error {
// Unlisten unsubscribes from a listen channel
func (c *Conn) Unlisten(channel string) error {
_, err := c.Exec("unlisten " + quoteIdentifier(channel))
_, err := c.Exec(context.TODO(), "unlisten "+quoteIdentifier(channel))
if err != nil {
return err
}
@ -837,12 +839,6 @@ func fatalWriteErr(bytesWritten int, err error) bool {
return !(is && netErr.Timeout())
}
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
// arguments should be referenced positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
return c.ExecEx(context.Background(), sql, nil, arguments...)
}
// Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages is the same
// regardless of when they occur. It also ignores messages that are only
@ -1351,3 +1347,170 @@ func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) {
func (c *Conn) LastStmtSent() bool {
return c.lastStmtSent
}
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
c.lastStmtSent = false
err := c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
if err := c.lock(); err != nil {
return nil, err
}
defer c.unlock()
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
startTime := time.Now()
commandTag, err := c.exec(ctx, sql, arguments...)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
}
return commandTag, err
}
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
}
return commandTag, err
}
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
if len(arguments) == 0 {
c.lastStmtSent = true
result, err := c.pgConn.Exec(ctx, sql)
if err != nil {
return nil, err
}
return result.CommandTag, nil
} else {
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
if err != nil {
return nil, err
}
ps := &PreparedStatement{
Name: psd.Name,
SQL: psd.SQL,
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
}
for i := range ps.ParameterOIDs {
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
c.pgconnFieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
}
arguments, err = convertDriverValuers(arguments)
if err != nil {
return nil, err
}
paramFormats := make([]int16, len(arguments))
paramValues := make([][]byte, len(arguments))
for i := range arguments {
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], arguments[i])
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], arguments[i])
if err != nil {
return nil, err
}
}
resultFormats := make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
resultFormats[i] = ps.FieldDescriptions[i].FormatCode
}
c.lastStmtSent = true
result, err := c.pgConn.ExecPrepared(ctx, psd.Name, paramValues, paramFormats, resultFormats)
if err != nil {
return nil, err
}
return result.CommandTag, nil
}
}
func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) ([]byte, error) {
if arg == nil {
return nil, nil
}
switch arg := arg.(type) {
case pgtype.BinaryEncoder:
return arg.EncodeBinary(ci, nil)
case pgtype.TextEncoder:
return arg.EncodeText(ci, nil)
case string:
return []byte(arg), nil
}
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
return nil, nil
}
arg = refVal.Elem().Interface()
return newencodePreparedStatementArgument(ci, oid, arg)
}
if dt, ok := ci.DataTypeForOID(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := callValuerValue(arg)
if err != nil {
return nil, err
}
return newencodePreparedStatementArgument(ci, oid, v)
}
}
return nil, err
}
return value.(pgtype.BinaryEncoder).EncodeBinary(ci, nil)
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return newencodePreparedStatementArgument(ci, oid, strippedArg)
}
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
// pgconnFieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
// FieldDescription.
func (c *Conn) pgconnFieldDescriptionToPgxFieldDescription(src *pgconn.FieldDescription, dst *FieldDescription) {
dst.Name = src.Name
dst.Table = pgtype.OID(src.TableOID)
dst.AttributeNumber = src.TableAttributeNumber
dst.DataType = pgtype.OID(src.DataTypeOID)
dst.DataTypeSize = src.DataTypeSize
dst.Modifier = src.TypeModifier
if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok {
dst.DataTypeName = dt.Name
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
dst.FormatCode = BinaryFormatCode
} else {
dst.FormatCode = TextFormatCode
}
}
}

View File

@ -195,7 +195,7 @@ func (p *ConnPool) Release(conn *Conn) {
}
if conn.pgConn.TxStatus != 'I' {
conn.Exec("rollback")
conn.Exec(context.TODO(), "rollback")
}
if len(conn.channels) > 0 {
@ -360,7 +360,7 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag pgconn
}
defer p.Release(c)
return c.Exec(sql, arguments...)
return c.Exec(context.TODO(), sql, arguments...)
}
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {

View File

@ -345,7 +345,7 @@ func TestPoolReleaseWithTransactions(t *testing.T) {
t.Fatalf("Unable to acquire connection: %v", err)
}
mustExec(t, conn, "begin")
if _, err = conn.Exec("selct"); err == nil {
if _, err = conn.Exec(context.Background(), "selct"); err == nil {
t.Fatal("Did not receive expected error")
}
@ -449,7 +449,7 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) {
}
}()
if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID()); err != nil {
if _, err = c2.Exec(context.Background(), "select pg_terminate_backend($1)", c1.PID()); err != nil {
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
}
@ -695,7 +695,7 @@ func TestConnPoolBeginRetry(t *testing.T) {
pool.Release(victimConn)
// Terminate connection that was released to pool
if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID()); err != nil {
if _, err = killerConn.Exec(context.Background(), "select pg_terminate_backend($1)", victimConn.PID()); err != nil {
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
}

View File

@ -146,7 +146,7 @@ func TestExecFailure(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
if _, err := conn.Exec("selct;"); err == nil {
if _, err := conn.Exec(context.Background(), "selct;"); err == nil {
t.Fatal("Expected SQL syntax error")
}
if !conn.LastStmtSent() {
@ -169,7 +169,7 @@ func TestExecFailureWithArguments(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
if _, err := conn.Exec("selct $1;", 1); err == nil {
if _, err := conn.Exec(context.Background(), "selct $1;", 1); err == nil {
t.Fatal("Expected SQL syntax error")
}
if conn.LastStmtSent() {
@ -270,7 +270,7 @@ func TestExecFailureCloseBefore(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
closeConn(t, conn)
if _, err := conn.Exec("select 1"); err == nil {
if _, err := conn.Exec(context.Background(), "select 1"); err == nil {
t.Fatal("Expected network error")
}
if conn.LastStmtSent() {
@ -884,7 +884,7 @@ func TestFatalRxError(t *testing.T) {
otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer otherConn.Close()
if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID()); err != nil {
if _, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PID()); err != nil {
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
}
@ -907,7 +907,7 @@ func TestFatalTxError(t *testing.T) {
otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer otherConn.Close()
_, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID())
_, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PID())
if err != nil {
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
}
@ -986,7 +986,7 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
}
defer rows.Close()
_, err = conn.Exec("create temporary table foo(spice timestamp[])")
_, err = conn.Exec(context.Background(), "create temporary table foo(spice timestamp[])")
if err != pgx.ErrConnBusy {
t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err)
}
@ -1128,7 +1128,7 @@ func TestConnOnNotice(t *testing.T) {
conn := mustConnect(t, connConfig)
defer closeConn(t, conn)
_, err := conn.Exec(`do $$
_, err := conn.Exec(context.Background(), `do $$
begin
raise notice 'hello, world';
end$$;`)

View File

@ -55,7 +55,7 @@ func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) {
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) {
var err error
if commandTag, err = conn.Exec(sql, arguments...); err != nil {
if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
}
return

View File

@ -1083,7 +1083,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
mustExec(t, conn, "create temporary table t(n numeric)")
var d *apd.Decimal
commandTag, err := conn.Exec(`insert into t(n) values($1)`, d)
commandTag, err := conn.Exec(context.Background(), `insert into t(n) values($1)`, d)
if err != nil {
t.Fatal(err)
}

View File

@ -439,7 +439,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
// Create the replication slot, using the given name and output plugin.
func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) {
_, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
_, err = rc.c.Exec(context.TODO(), fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
return
}
@ -457,6 +457,6 @@ func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string
// Drop the replication slot for the given name
func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) {
_, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
_, err = rc.c.Exec(context.TODO(), fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
return
}

View File

@ -52,7 +52,7 @@ func TestSimpleReplicationConnection(t *testing.T) {
defer func() {
// Ensure replication slot is destroyed, but don't check for errors as it
// should have already been destroyed.
conn.Exec("select pg_drop_replication_slot('pgx_test')")
conn.Exec(context.Background(), "select pg_drop_replication_slot('pgx_test')")
closeConn(t, conn)
}()
@ -74,7 +74,7 @@ func TestSimpleReplicationConnection(t *testing.T) {
}
// Do a simple change so we can get some wal data
_, err = conn.Exec("create table if not exists replication_test (a integer)")
_, err = conn.Exec(context.Background(), "create table if not exists replication_test (a integer)")
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
@ -90,7 +90,7 @@ func TestSimpleReplicationConnection(t *testing.T) {
for i := 0; i < 5; i++ {
var ct pgconn.CommandTag
insertedTimes = append(insertedTimes, currentTime)
ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime)
ct, err = conn.Exec(context.Background(), "insert into replication_test(a) values($1)", currentTime)
if err != nil {
t.Fatalf("Insert failed: %v", err)
}

View File

@ -26,7 +26,7 @@ func TestTransactionSuccessfulCommit(t *testing.T) {
);
`
if _, err := conn.Exec(createSql); err != nil {
if _, err := conn.Exec(context.Background(), createSql); err != nil {
t.Fatalf("Failed to create table: %v", err)
}
@ -68,7 +68,7 @@ func TestTxCommitWhenTxBroken(t *testing.T) {
);
`
if _, err := conn.Exec(createSql); err != nil {
if _, err := conn.Exec(context.Background(), createSql); err != nil {
t.Fatalf("Failed to create table: %v", err)
}
@ -160,7 +160,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) {
);
`
if _, err := conn.Exec(createSql); err != nil {
if _, err := conn.Exec(context.Background(), createSql); err != nil {
t.Fatalf("Failed to create table: %v", err)
}
@ -227,7 +227,7 @@ func TestBeginExReadOnly(t *testing.T) {
}
defer tx.Rollback()
_, err = conn.Exec("create table foo(id serial primary key)")
_, err = conn.Exec(context.Background(), "create table foo(id serial primary key)")
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" {
t.Errorf("Expected error SQLSTATE 25006, but got %#v", err)
}