mirror of https://github.com/jackc/pgx.git
Implement pgx.Conn.Exec in terms of pgconn.PgConn.Exec
parent
23cbe89dfd
commit
12857ad05b
179
conn.go
179
conn.go
|
@ -2,6 +2,8 @@ package pgx
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -673,7 +675,7 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
|
|||
|
||||
// Listen establishes a PostgreSQL listen/notify to channel
|
||||
func (c *Conn) Listen(channel string) error {
|
||||
_, err := c.Exec("listen " + quoteIdentifier(channel))
|
||||
_, err := c.Exec(context.TODO(), "listen "+quoteIdentifier(channel))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -685,7 +687,7 @@ func (c *Conn) Listen(channel string) error {
|
|||
|
||||
// Unlisten unsubscribes from a listen channel
|
||||
func (c *Conn) Unlisten(channel string) error {
|
||||
_, err := c.Exec("unlisten " + quoteIdentifier(channel))
|
||||
_, err := c.Exec(context.TODO(), "unlisten "+quoteIdentifier(channel))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -837,12 +839,6 @@ func fatalWriteErr(bytesWritten int, err error) bool {
|
|||
return !(is && netErr.Timeout())
|
||||
}
|
||||
|
||||
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
|
||||
// arguments should be referenced positionally from the sql string as $1, $2, etc.
|
||||
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
return c.ExecEx(context.Background(), sql, nil, arguments...)
|
||||
}
|
||||
|
||||
// Processes messages that are not exclusive to one context such as
|
||||
// authentication or query response. The response to these messages is the same
|
||||
// regardless of when they occur. It also ignores messages that are only
|
||||
|
@ -1351,3 +1347,170 @@ func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) {
|
|||
func (c *Conn) LastStmtSent() bool {
|
||||
return c.lastStmtSent
|
||||
}
|
||||
|
||||
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
|
||||
// positionally from the sql string as $1, $2, etc.
|
||||
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
|
||||
c.lastStmtSent = false
|
||||
err := c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.lock(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer c.unlock()
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
commandTag, err := c.exec(ctx, sql, arguments...)
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
|
||||
}
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
if len(arguments) == 0 {
|
||||
c.lastStmtSent = true
|
||||
result, err := c.pgConn.Exec(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.CommandTag, nil
|
||||
} else {
|
||||
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ps := &PreparedStatement{
|
||||
Name: psd.Name,
|
||||
SQL: psd.SQL,
|
||||
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
|
||||
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
|
||||
}
|
||||
|
||||
for i := range ps.ParameterOIDs {
|
||||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||
}
|
||||
for i := range ps.FieldDescriptions {
|
||||
c.pgconnFieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
|
||||
}
|
||||
|
||||
arguments, err = convertDriverValuers(arguments)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramFormats := make([]int16, len(arguments))
|
||||
paramValues := make([][]byte, len(arguments))
|
||||
for i := range arguments {
|
||||
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], arguments[i])
|
||||
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], arguments[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
resultFormats := make([]int16, len(ps.FieldDescriptions))
|
||||
for i := range resultFormats {
|
||||
resultFormats[i] = ps.FieldDescriptions[i].FormatCode
|
||||
}
|
||||
|
||||
c.lastStmtSent = true
|
||||
result, err := c.pgConn.ExecPrepared(ctx, psd.Name, paramValues, paramFormats, resultFormats)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.CommandTag, nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) ([]byte, error) {
|
||||
if arg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch arg := arg.(type) {
|
||||
case pgtype.BinaryEncoder:
|
||||
return arg.EncodeBinary(ci, nil)
|
||||
case pgtype.TextEncoder:
|
||||
return arg.EncodeText(ci, nil)
|
||||
case string:
|
||||
return []byte(arg), nil
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(arg)
|
||||
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
if refVal.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
arg = refVal.Elem().Interface()
|
||||
return newencodePreparedStatementArgument(ci, oid, arg)
|
||||
}
|
||||
|
||||
if dt, ok := ci.DataTypeForOID(oid); ok {
|
||||
value := dt.Value
|
||||
err := value.Set(arg)
|
||||
if err != nil {
|
||||
{
|
||||
if arg, ok := arg.(driver.Valuer); ok {
|
||||
v, err := callValuerValue(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newencodePreparedStatementArgument(ci, oid, v)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return value.(pgtype.BinaryEncoder).EncodeBinary(ci, nil)
|
||||
}
|
||||
|
||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
||||
return newencodePreparedStatementArgument(ci, oid, strippedArg)
|
||||
}
|
||||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
||||
}
|
||||
|
||||
// pgconnFieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
|
||||
// FieldDescription.
|
||||
func (c *Conn) pgconnFieldDescriptionToPgxFieldDescription(src *pgconn.FieldDescription, dst *FieldDescription) {
|
||||
dst.Name = src.Name
|
||||
dst.Table = pgtype.OID(src.TableOID)
|
||||
dst.AttributeNumber = src.TableAttributeNumber
|
||||
dst.DataType = pgtype.OID(src.DataTypeOID)
|
||||
dst.DataTypeSize = src.DataTypeSize
|
||||
dst.Modifier = src.TypeModifier
|
||||
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok {
|
||||
dst.DataTypeName = dt.Name
|
||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||
dst.FormatCode = BinaryFormatCode
|
||||
} else {
|
||||
dst.FormatCode = TextFormatCode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -195,7 +195,7 @@ func (p *ConnPool) Release(conn *Conn) {
|
|||
}
|
||||
|
||||
if conn.pgConn.TxStatus != 'I' {
|
||||
conn.Exec("rollback")
|
||||
conn.Exec(context.TODO(), "rollback")
|
||||
}
|
||||
|
||||
if len(conn.channels) > 0 {
|
||||
|
@ -360,7 +360,7 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag pgconn
|
|||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.Exec(sql, arguments...)
|
||||
return c.Exec(context.TODO(), sql, arguments...)
|
||||
}
|
||||
|
||||
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
|
|
|
@ -345,7 +345,7 @@ func TestPoolReleaseWithTransactions(t *testing.T) {
|
|||
t.Fatalf("Unable to acquire connection: %v", err)
|
||||
}
|
||||
mustExec(t, conn, "begin")
|
||||
if _, err = conn.Exec("selct"); err == nil {
|
||||
if _, err = conn.Exec(context.Background(), "selct"); err == nil {
|
||||
t.Fatal("Did not receive expected error")
|
||||
}
|
||||
|
||||
|
@ -449,7 +449,7 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID()); err != nil {
|
||||
if _, err = c2.Exec(context.Background(), "select pg_terminate_backend($1)", c1.PID()); err != nil {
|
||||
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
|
||||
}
|
||||
|
||||
|
@ -695,7 +695,7 @@ func TestConnPoolBeginRetry(t *testing.T) {
|
|||
pool.Release(victimConn)
|
||||
|
||||
// Terminate connection that was released to pool
|
||||
if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID()); err != nil {
|
||||
if _, err = killerConn.Exec(context.Background(), "select pg_terminate_backend($1)", victimConn.PID()); err != nil {
|
||||
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
|
||||
}
|
||||
|
||||
|
|
14
conn_test.go
14
conn_test.go
|
@ -146,7 +146,7 @@ func TestExecFailure(t *testing.T) {
|
|||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if _, err := conn.Exec("selct;"); err == nil {
|
||||
if _, err := conn.Exec(context.Background(), "selct;"); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
|
@ -169,7 +169,7 @@ func TestExecFailureWithArguments(t *testing.T) {
|
|||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if _, err := conn.Exec("selct $1;", 1); err == nil {
|
||||
if _, err := conn.Exec(context.Background(), "selct $1;", 1); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
if conn.LastStmtSent() {
|
||||
|
@ -270,7 +270,7 @@ func TestExecFailureCloseBefore(t *testing.T) {
|
|||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
closeConn(t, conn)
|
||||
|
||||
if _, err := conn.Exec("select 1"); err == nil {
|
||||
if _, err := conn.Exec(context.Background(), "select 1"); err == nil {
|
||||
t.Fatal("Expected network error")
|
||||
}
|
||||
if conn.LastStmtSent() {
|
||||
|
@ -884,7 +884,7 @@ func TestFatalRxError(t *testing.T) {
|
|||
otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer otherConn.Close()
|
||||
|
||||
if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID()); err != nil {
|
||||
if _, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PID()); err != nil {
|
||||
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
|
||||
}
|
||||
|
||||
|
@ -907,7 +907,7 @@ func TestFatalTxError(t *testing.T) {
|
|||
otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer otherConn.Close()
|
||||
|
||||
_, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID())
|
||||
_, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PID())
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
|
||||
}
|
||||
|
@ -986,7 +986,7 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
|
|||
}
|
||||
defer rows.Close()
|
||||
|
||||
_, err = conn.Exec("create temporary table foo(spice timestamp[])")
|
||||
_, err = conn.Exec(context.Background(), "create temporary table foo(spice timestamp[])")
|
||||
if err != pgx.ErrConnBusy {
|
||||
t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err)
|
||||
}
|
||||
|
@ -1128,7 +1128,7 @@ func TestConnOnNotice(t *testing.T) {
|
|||
conn := mustConnect(t, connConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
_, err := conn.Exec(`do $$
|
||||
_, err := conn.Exec(context.Background(), `do $$
|
||||
begin
|
||||
raise notice 'hello, world';
|
||||
end$$;`)
|
||||
|
|
|
@ -55,7 +55,7 @@ func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) {
|
|||
|
||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) {
|
||||
var err error
|
||||
if commandTag, err = conn.Exec(sql, arguments...); err != nil {
|
||||
if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
|
||||
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
|
||||
}
|
||||
return
|
||||
|
|
|
@ -1083,7 +1083,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
|
|||
mustExec(t, conn, "create temporary table t(n numeric)")
|
||||
|
||||
var d *apd.Decimal
|
||||
commandTag, err := conn.Exec(`insert into t(n) values($1)`, d)
|
||||
commandTag, err := conn.Exec(context.Background(), `insert into t(n) values($1)`, d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -439,7 +439,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
|
|||
|
||||
// Create the replication slot, using the given name and output plugin.
|
||||
func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) {
|
||||
_, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
|
||||
_, err = rc.c.Exec(context.TODO(), fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -457,6 +457,6 @@ func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string
|
|||
|
||||
// Drop the replication slot for the given name
|
||||
func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) {
|
||||
_, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
|
||||
_, err = rc.c.Exec(context.TODO(), fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ func TestSimpleReplicationConnection(t *testing.T) {
|
|||
defer func() {
|
||||
// Ensure replication slot is destroyed, but don't check for errors as it
|
||||
// should have already been destroyed.
|
||||
conn.Exec("select pg_drop_replication_slot('pgx_test')")
|
||||
conn.Exec(context.Background(), "select pg_drop_replication_slot('pgx_test')")
|
||||
closeConn(t, conn)
|
||||
}()
|
||||
|
||||
|
@ -74,7 +74,7 @@ func TestSimpleReplicationConnection(t *testing.T) {
|
|||
}
|
||||
|
||||
// Do a simple change so we can get some wal data
|
||||
_, err = conn.Exec("create table if not exists replication_test (a integer)")
|
||||
_, err = conn.Exec(context.Background(), "create table if not exists replication_test (a integer)")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ func TestSimpleReplicationConnection(t *testing.T) {
|
|||
for i := 0; i < 5; i++ {
|
||||
var ct pgconn.CommandTag
|
||||
insertedTimes = append(insertedTimes, currentTime)
|
||||
ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime)
|
||||
ct, err = conn.Exec(context.Background(), "insert into replication_test(a) values($1)", currentTime)
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ func TestTransactionSuccessfulCommit(t *testing.T) {
|
|||
);
|
||||
`
|
||||
|
||||
if _, err := conn.Exec(createSql); err != nil {
|
||||
if _, err := conn.Exec(context.Background(), createSql); err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
|
@ -68,7 +68,7 @@ func TestTxCommitWhenTxBroken(t *testing.T) {
|
|||
);
|
||||
`
|
||||
|
||||
if _, err := conn.Exec(createSql); err != nil {
|
||||
if _, err := conn.Exec(context.Background(), createSql); err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
|
@ -160,7 +160,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) {
|
|||
);
|
||||
`
|
||||
|
||||
if _, err := conn.Exec(createSql); err != nil {
|
||||
if _, err := conn.Exec(context.Background(), createSql); err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
|
@ -227,7 +227,7 @@ func TestBeginExReadOnly(t *testing.T) {
|
|||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = conn.Exec("create table foo(id serial primary key)")
|
||||
_, err = conn.Exec(context.Background(), "create table foo(id serial primary key)")
|
||||
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" {
|
||||
t.Errorf("Expected error SQLSTATE 25006, but got %#v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue