diff --git a/conn.go b/conn.go index e9db6b19..056b1b77 100644 --- a/conn.go +++ b/conn.go @@ -2,6 +2,8 @@ package pgx import ( "context" + "database/sql/driver" + "fmt" "net" "reflect" "strings" @@ -673,7 +675,7 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { // Listen establishes a PostgreSQL listen/notify to channel func (c *Conn) Listen(channel string) error { - _, err := c.Exec("listen " + quoteIdentifier(channel)) + _, err := c.Exec(context.TODO(), "listen "+quoteIdentifier(channel)) if err != nil { return err } @@ -685,7 +687,7 @@ func (c *Conn) Listen(channel string) error { // Unlisten unsubscribes from a listen channel func (c *Conn) Unlisten(channel string) error { - _, err := c.Exec("unlisten " + quoteIdentifier(channel)) + _, err := c.Exec(context.TODO(), "unlisten "+quoteIdentifier(channel)) if err != nil { return err } @@ -837,12 +839,6 @@ func fatalWriteErr(bytesWritten int, err error) bool { return !(is && netErr.Timeout()) } -// Exec executes sql. sql can be either a prepared statement name or an SQL string. -// arguments should be referenced positionally from the sql string as $1, $2, etc. -func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { - return c.ExecEx(context.Background(), sql, nil, arguments...) -} - // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages is the same // regardless of when they occur. It also ignores messages that are only @@ -1351,3 +1347,170 @@ func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) { func (c *Conn) LastStmtSent() bool { return c.lastStmtSent } + +// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced +// positionally from the sql string as $1, $2, etc. +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { + c.lastStmtSent = false + err := c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + + if err := c.lock(); err != nil { + return nil, err + } + defer c.unlock() + + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + + startTime := time.Now() + + commandTag, err := c.exec(ctx, sql, arguments...) + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) + } + return commandTag, err + } + + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + } + + return commandTag, err +} + +func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { + if len(arguments) == 0 { + c.lastStmtSent = true + result, err := c.pgConn.Exec(ctx, sql) + if err != nil { + return nil, err + } + + return result.CommandTag, nil + } else { + psd, err := c.pgConn.Prepare(ctx, "", sql, nil) + if err != nil { + return nil, err + } + + ps := &PreparedStatement{ + Name: psd.Name, + SQL: psd.SQL, + ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), + FieldDescriptions: make([]FieldDescription, len(psd.Fields)), + } + + for i := range ps.ParameterOIDs { + ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) + } + for i := range ps.FieldDescriptions { + c.pgconnFieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i]) + } + + arguments, err = convertDriverValuers(arguments) + if err != nil { + return nil, err + } + + paramFormats := make([]int16, len(arguments)) + paramValues := make([][]byte, len(arguments)) + for i := range arguments { + paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], arguments[i]) + paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], arguments[i]) + if err != nil { + return nil, err + } + + } + + resultFormats := make([]int16, len(ps.FieldDescriptions)) + for i := range resultFormats { + resultFormats[i] = ps.FieldDescriptions[i].FormatCode + } + + c.lastStmtSent = true + result, err := c.pgConn.ExecPrepared(ctx, psd.Name, paramValues, paramFormats, resultFormats) + if err != nil { + return nil, err + } + + return result.CommandTag, nil + } + +} + +func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) ([]byte, error) { + if arg == nil { + return nil, nil + } + + switch arg := arg.(type) { + case pgtype.BinaryEncoder: + return arg.EncodeBinary(ci, nil) + case pgtype.TextEncoder: + return arg.EncodeText(ci, nil) + case string: + return []byte(arg), nil + } + + refVal := reflect.ValueOf(arg) + + if refVal.Kind() == reflect.Ptr { + if refVal.IsNil() { + return nil, nil + } + arg = refVal.Elem().Interface() + return newencodePreparedStatementArgument(ci, oid, arg) + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return newencodePreparedStatementArgument(ci, oid, v) + } + } + + return nil, err + } + + return value.(pgtype.BinaryEncoder).EncodeBinary(ci, nil) + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return newencodePreparedStatementArgument(ci, oid, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) +} + +// pgconnFieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a +// FieldDescription. +func (c *Conn) pgconnFieldDescriptionToPgxFieldDescription(src *pgconn.FieldDescription, dst *FieldDescription) { + dst.Name = src.Name + dst.Table = pgtype.OID(src.TableOID) + dst.AttributeNumber = src.TableAttributeNumber + dst.DataType = pgtype.OID(src.DataTypeOID) + dst.DataTypeSize = src.DataTypeSize + dst.Modifier = src.TypeModifier + + if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok { + dst.DataTypeName = dt.Name + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + dst.FormatCode = BinaryFormatCode + } else { + dst.FormatCode = TextFormatCode + } + } +} diff --git a/conn_pool.go b/conn_pool.go index eea043e7..f4943146 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -195,7 +195,7 @@ func (p *ConnPool) Release(conn *Conn) { } if conn.pgConn.TxStatus != 'I' { - conn.Exec("rollback") + conn.Exec(context.TODO(), "rollback") } if len(conn.channels) > 0 { @@ -360,7 +360,7 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag pgconn } defer p.Release(c) - return c.Exec(sql, arguments...) + return c.Exec(context.TODO(), sql, arguments...) } func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { diff --git a/conn_pool_test.go b/conn_pool_test.go index 168a3bc9..7a12bba6 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -345,7 +345,7 @@ func TestPoolReleaseWithTransactions(t *testing.T) { t.Fatalf("Unable to acquire connection: %v", err) } mustExec(t, conn, "begin") - if _, err = conn.Exec("selct"); err == nil { + if _, err = conn.Exec(context.Background(), "selct"); err == nil { t.Fatal("Did not receive expected error") } @@ -449,7 +449,7 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { } }() - if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID()); err != nil { + if _, err = c2.Exec(context.Background(), "select pg_terminate_backend($1)", c1.PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -695,7 +695,7 @@ func TestConnPoolBeginRetry(t *testing.T) { pool.Release(victimConn) // Terminate connection that was released to pool - if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID()); err != nil { + if _, err = killerConn.Exec(context.Background(), "select pg_terminate_backend($1)", victimConn.PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } diff --git a/conn_test.go b/conn_test.go index b49ae174..e3b71bb9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -146,7 +146,7 @@ func TestExecFailure(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - if _, err := conn.Exec("selct;"); err == nil { + if _, err := conn.Exec(context.Background(), "selct;"); err == nil { t.Fatal("Expected SQL syntax error") } if !conn.LastStmtSent() { @@ -169,7 +169,7 @@ func TestExecFailureWithArguments(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - if _, err := conn.Exec("selct $1;", 1); err == nil { + if _, err := conn.Exec(context.Background(), "selct $1;", 1); err == nil { t.Fatal("Expected SQL syntax error") } if conn.LastStmtSent() { @@ -270,7 +270,7 @@ func TestExecFailureCloseBefore(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) closeConn(t, conn) - if _, err := conn.Exec("select 1"); err == nil { + if _, err := conn.Exec(context.Background(), "select 1"); err == nil { t.Fatal("Expected network error") } if conn.LastStmtSent() { @@ -884,7 +884,7 @@ func TestFatalRxError(t *testing.T) { otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer otherConn.Close() - if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID()); err != nil { + if _, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -907,7 +907,7 @@ func TestFatalTxError(t *testing.T) { otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer otherConn.Close() - _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID()) + _, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PID()) if err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -986,7 +986,7 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { } defer rows.Close() - _, err = conn.Exec("create temporary table foo(spice timestamp[])") + _, err = conn.Exec(context.Background(), "create temporary table foo(spice timestamp[])") if err != pgx.ErrConnBusy { t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err) } @@ -1128,7 +1128,7 @@ func TestConnOnNotice(t *testing.T) { conn := mustConnect(t, connConfig) defer closeConn(t, conn) - _, err := conn.Exec(`do $$ + _, err := conn.Exec(context.Background(), `do $$ begin raise notice 'hello, world'; end$$;`) diff --git a/helper_test.go b/helper_test.go index 188876c4..b181ef31 100644 --- a/helper_test.go +++ b/helper_test.go @@ -55,7 +55,7 @@ func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) { func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) { var err error - if commandTag, err = conn.Exec(sql, arguments...); err != nil { + if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil { t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err) } return diff --git a/query_test.go b/query_test.go index 738e2ba6..048e82e9 100644 --- a/query_test.go +++ b/query_test.go @@ -1083,7 +1083,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes mustExec(t, conn, "create temporary table t(n numeric)") var d *apd.Decimal - commandTag, err := conn.Exec(`insert into t(n) values($1)`, d) + commandTag, err := conn.Exec(context.Background(), `insert into t(n) values($1)`, d) if err != nil { t.Fatal(err) } diff --git a/replication.go b/replication.go index cf4058ab..06768194 100644 --- a/replication.go +++ b/replication.go @@ -439,7 +439,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti // Create the replication slot, using the given name and output plugin. func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) { - _, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin)) + _, err = rc.c.Exec(context.TODO(), fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin)) return } @@ -457,6 +457,6 @@ func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string // Drop the replication slot for the given name func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) { - _, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) + _, err = rc.c.Exec(context.TODO(), fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) return } diff --git a/replication_test.go b/replication_test.go index 3ceac156..a98129cd 100644 --- a/replication_test.go +++ b/replication_test.go @@ -52,7 +52,7 @@ func TestSimpleReplicationConnection(t *testing.T) { defer func() { // Ensure replication slot is destroyed, but don't check for errors as it // should have already been destroyed. - conn.Exec("select pg_drop_replication_slot('pgx_test')") + conn.Exec(context.Background(), "select pg_drop_replication_slot('pgx_test')") closeConn(t, conn) }() @@ -74,7 +74,7 @@ func TestSimpleReplicationConnection(t *testing.T) { } // Do a simple change so we can get some wal data - _, err = conn.Exec("create table if not exists replication_test (a integer)") + _, err = conn.Exec(context.Background(), "create table if not exists replication_test (a integer)") if err != nil { t.Fatalf("Failed to create table: %v", err) } @@ -90,7 +90,7 @@ func TestSimpleReplicationConnection(t *testing.T) { for i := 0; i < 5; i++ { var ct pgconn.CommandTag insertedTimes = append(insertedTimes, currentTime) - ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime) + ct, err = conn.Exec(context.Background(), "insert into replication_test(a) values($1)", currentTime) if err != nil { t.Fatalf("Insert failed: %v", err) } diff --git a/tx_test.go b/tx_test.go index 3a091a16..c79f5132 100644 --- a/tx_test.go +++ b/tx_test.go @@ -26,7 +26,7 @@ func TestTransactionSuccessfulCommit(t *testing.T) { ); ` - if _, err := conn.Exec(createSql); err != nil { + if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } @@ -68,7 +68,7 @@ func TestTxCommitWhenTxBroken(t *testing.T) { ); ` - if _, err := conn.Exec(createSql); err != nil { + if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } @@ -160,7 +160,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) { ); ` - if _, err := conn.Exec(createSql); err != nil { + if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } @@ -227,7 +227,7 @@ func TestBeginExReadOnly(t *testing.T) { } defer tx.Rollback() - _, err = conn.Exec("create table foo(id serial primary key)") + _, err = conn.Exec(context.Background(), "create table foo(id serial primary key)") if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" { t.Errorf("Expected error SQLSTATE 25006, but got %#v", err) }