Remove ExecEx

pull/483/head
Jack Christensen 2019-01-02 12:52:37 -06:00
parent 12857ad05b
commit 89c3d8af5d
8 changed files with 84 additions and 306 deletions

View File

@ -395,7 +395,7 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
for src.Next() { for src.Next() {
values, _ := src.Values() values, _ := src.Values()
if _, err = tx.Exec("insert_t", values...); err != nil { if _, err = tx.Exec(context.Background(), "insert_t", values...); err != nil {
b.Fatalf("Exec unexpectedly failed with: %v", err) b.Fatalf("Exec unexpectedly failed with: %v", err)
} }
} }
@ -457,7 +457,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
rowsThisInsert++ rowsThisInsert++
if rowsThisInsert == maxRowsPerInsert { if rowsThisInsert == maxRowsPerInsert {
_, err := tx.Exec(sqlBuf.String(), args...) _, err := tx.Exec(context.Background(), sqlBuf.String(), args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -468,7 +468,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
} }
if rowsThisInsert > 0 { if rowsThisInsert > 0 {
_, err := tx.Exec(sqlBuf.String(), args...) _, err := tx.Exec(context.Background(), sqlBuf.String(), args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }

135
conn.go
View File

@ -1080,139 +1080,10 @@ func (c *Conn) cancelQuery() {
} }
func (c *Conn) Ping(ctx context.Context) error { func (c *Conn) Ping(ctx context.Context) error {
_, err := c.ExecEx(ctx, ";", nil) _, err := c.Exec(ctx, ";", nil)
return err return err
} }
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (pgconn.CommandTag, error) {
c.lastStmtSent = false
err := c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
if err := c.lock(); err != nil {
return nil, err
}
defer c.unlock()
startTime := time.Now()
commandTag, err := c.execEx(ctx, sql, options, arguments...)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
}
return commandTag, err
}
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
}
return commandTag, err
}
func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
err = c.initContext(ctx)
if err != nil {
return nil, err
}
defer func() {
err = c.termContext(err)
}()
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
c.lastStmtSent = true
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
if err != nil {
return nil, err
}
} else if options != nil && len(options.ParameterOIDs) > 0 {
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments)
if err != nil {
return nil, err
}
buf = appendSync(buf)
n, err := c.pgConn.Conn().Write(buf)
c.lastStmtSent = true
if err != nil && fatalWriteErr(n, err) {
c.die(err)
return nil, err
}
c.pendingReadyForQueryCount++
} else {
if len(arguments) > 0 {
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
ps, err = c.prepareEx("", sql, nil)
if err != nil {
return nil, err
}
}
c.lastStmtSent = true
err = c.sendPreparedQuery(ps, arguments...)
if err != nil {
return nil, err
}
} else {
c.lastStmtSent = true
if err = c.sendQuery(sql, arguments...); err != nil {
return
}
}
}
var softErr error
for {
msg, err := c.rxMsg()
if err != nil {
return commandTag, err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return commandTag, softErr
case *pgproto3.CommandComplete:
commandTag = pgconn.CommandTag(msg.CommandTag)
default:
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
softErr = e
}
}
}
}
func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) {
if len(arguments) != len(options.ParameterOIDs) {
return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs))
}
if len(options.ParameterOIDs) > 65535 {
return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs))
}
buf = appendParse(buf, "", sql, options.ParameterOIDs)
buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, nil)
if err != nil {
return nil, err
}
buf = appendExecute(buf, "", 0)
return buf, nil
}
func (c *Conn) initContext(ctx context.Context) error { func (c *Conn) initContext(ctx context.Context) error {
if c.ctxInProgress { if c.ctxInProgress {
return errors.New("ctx already in progress") return errors.New("ctx already in progress")
@ -1399,6 +1270,10 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
return nil, err return nil, err
} }
if len(psd.ParamOIDs) != len(arguments) {
return nil, errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(arguments))
}
ps := &PreparedStatement{ ps := &PreparedStatement{
Name: psd.Name, Name: psd.Name,
SQL: psd.SQL, SQL: psd.SQL,

View File

@ -353,24 +353,14 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
} }
// Exec acquires a connection, delegates the call to that connection, and releases the connection // Exec acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { func (p *ConnPool) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
var c *Conn var c *Conn
if c, err = p.Acquire(); err != nil { if c, err = p.Acquire(); err != nil {
return return
} }
defer p.Release(c) defer p.Release(c)
return c.Exec(context.TODO(), sql, arguments...) return c.Exec(ctx, sql, arguments...)
}
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
var c *Conn
if c, err = p.Acquire(); err != nil {
return
}
defer p.Release(c)
return c.ExecEx(ctx, sql, options, arguments...)
} }
// Query acquires a connection and delegates the call to that connection. When // Query acquires a connection and delegates the call to that connection. When

View File

@ -801,7 +801,7 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) {
t.Error("Select called onDataRow wrong number of times") t.Error("Select called onDataRow wrong number of times")
} }
_, err = pool.Exec("--;") _, err = pool.Exec(context.Background(), "--;")
if err != nil { if err != nil {
t.Fatalf("pool.Exec failed: %v", err) t.Fatalf("pool.Exec failed: %v", err)
} }
@ -841,7 +841,7 @@ func TestConnPoolExec(t *testing.T) {
pool := createConnPool(t, 2) pool := createConnPool(t, 2)
defer pool.Close() defer pool.Close()
results, err := pool.Exec("create temporary table foo(id integer primary key);") results, err := pool.Exec(context.Background(), "create temporary table foo(id integer primary key);")
if err != nil { if err != nil {
t.Fatalf("Unexpected error from pool.Exec: %v", err) t.Fatalf("Unexpected error from pool.Exec: %v", err)
} }
@ -849,7 +849,7 @@ func TestConnPoolExec(t *testing.T) {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
results, err = pool.Exec("insert into foo(id) values($1)", 1) results, err = pool.Exec(context.Background(), "insert into foo(id) values($1)", 1)
if err != nil { if err != nil {
t.Fatalf("Unexpected error from pool.Exec: %v", err) t.Fatalf("Unexpected error from pool.Exec: %v", err)
} }
@ -857,7 +857,7 @@ func TestConnPoolExec(t *testing.T) {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
results, err = pool.Exec("drop table foo;") results, err = pool.Exec(context.Background(), "drop table foo;")
if err != nil { if err != nil {
t.Fatalf("Unexpected error from pool.Exec: %v", err) t.Fatalf("Unexpected error from pool.Exec: %v", err)
} }

View File

@ -177,7 +177,7 @@ func TestExecFailureWithArguments(t *testing.T) {
} }
} }
func TestExecExContextWithoutCancelation(t *testing.T) { func TestExecContextWithoutCancelation(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@ -186,7 +186,7 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil) commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -198,7 +198,7 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
} }
} }
func TestExecExContextFailureWithoutCancelation(t *testing.T) { func TestExecContextFailureWithoutCancelation(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@ -207,7 +207,7 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil { if _, err := conn.Exec(ctx, "selct;"); err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
if !conn.LastStmtSent() { if !conn.LastStmtSent() {
@ -224,7 +224,7 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) {
} }
} }
func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) { func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@ -233,7 +233,7 @@ func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
if _, err := conn.ExecEx(ctx, "selct $1;", nil, 1); err == nil { if _, err := conn.Exec(ctx, "selct $1;", 1); err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
if conn.LastStmtSent() { if conn.LastStmtSent() {
@ -241,7 +241,7 @@ func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
} }
} }
func TestExecExContextCancelationCancelsQuery(t *testing.T) { func TestExecContextCancelationCancelsQuery(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@ -253,7 +253,7 @@ func TestExecExContextCancelationCancelsQuery(t *testing.T) {
cancelFunc() cancelFunc()
}() }()
_, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil) _, err := conn.Exec(ctx, "select pg_sleep(60)")
if err != context.Canceled { if err != context.Canceled {
t.Fatalf("Expected context.Canceled err, got %v", err) t.Fatalf("Expected context.Canceled err, got %v", err)
} }
@ -278,7 +278,7 @@ func TestExecFailureCloseBefore(t *testing.T) {
} }
} }
func TestExecExExtendedProtocol(t *testing.T) { func TestExecExtendedProtocol(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@ -287,18 +287,17 @@ func TestExecExExtendedProtocol(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil) commandTag, err := conn.Exec(ctx, "create temporary table foo(name varchar primary key);")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(commandTag) != "CREATE TABLE" { if string(commandTag) != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from Exec: %v", commandTag)
} }
commandTag, err = conn.ExecEx( commandTag, err = conn.Exec(
ctx, ctx,
"insert into foo(name) values($1);", "insert into foo(name) values($1);",
nil,
"bar", "bar",
) )
if err != nil { if err != nil {
@ -311,119 +310,42 @@ func TestExecExExtendedProtocol(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestExecExSimpleProtocol(t *testing.T) { func TestExecSimpleProtocol(t *testing.T) {
t.Parallel() t.Skip("TODO when with simple protocol supported in connection")
// t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) // conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) // defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background()) // ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() // defer cancelFunc()
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil) // commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
if err != nil { // if err != nil {
t.Fatal(err) // t.Fatal(err)
} // }
if string(commandTag) != "CREATE TABLE" { // if string(commandTag) != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) // t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} // }
if !conn.LastStmtSent() { // if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true") // t.Error("Expected LastStmtSent to return true")
} // }
commandTag, err = conn.ExecEx( // commandTag, err = conn.ExecEx(
ctx, // ctx,
"insert into foo(name) values($1);", // "insert into foo(name) values($1);",
&pgx.QueryExOptions{SimpleProtocol: true}, // &pgx.QueryExOptions{SimpleProtocol: true},
"bar'; drop table foo;--", // "bar'; drop table foo;--",
) // )
if err != nil { // if err != nil {
t.Fatal(err) // t.Fatal(err)
} // }
if string(commandTag) != "INSERT 0 1" { // if string(commandTag) != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) // t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} // }
if !conn.LastStmtSent() { // if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true") // t.Error("Expected LastStmtSent to return true")
} // }
}
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
commandTag, err := conn.ExecEx(
context.Background(),
"insert into foo(name) values($1);",
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.VarcharOID}},
"bar'; drop table foo;--",
)
if err != nil {
t.Fatal(err)
}
if string(commandTag) != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
_, err := conn.ExecEx(
context.Background(),
"insert into foo(name) values($1);",
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}},
"bar'; drop table foo;--",
)
if err == nil {
t.Fatal("expected error but got none")
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
var s string
err := conn.QueryRow("insert into foo(name) values('baz') returning name;").Scan(&s)
if err != nil {
t.Errorf("Executing query failed: %v", err)
}
if s != "baz" {
t.Errorf("Query did not return expected value: %v", s)
}
_, err = conn.ExecEx(
context.Background(),
"insert into foo(name) values($1);",
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}},
"bar'; drop table foo;--",
)
if err == nil {
t.Fatal("expected error but got none")
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
func TestExecExFailureCloseBefore(t *testing.T) { func TestExecExFailureCloseBefore(t *testing.T) {
@ -432,7 +354,7 @@ func TestExecExFailureCloseBefore(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
closeConn(t, conn) closeConn(t, conn)
if _, err := conn.ExecEx(context.Background(), "select 1", nil); err == nil { if _, err := conn.Exec(context.Background(), "select 1", nil); err == nil {
t.Fatal("Expected network error") t.Fatal("Expected network error")
} }
if conn.LastStmtSent() { if conn.LastStmtSent() {

View File

@ -17,7 +17,7 @@ import (
) )
type execer interface { type execer interface {
Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
} }
type queryer interface { type queryer interface {
Query(sql string, args ...interface{}) (*pgx.Rows, error) Query(sql string, args ...interface{}) (*pgx.Rows, error)
@ -102,7 +102,7 @@ func TestStressConnPool(t *testing.T) {
} }
func setupStressDB(t *testing.T, pool *pgx.ConnPool) { func setupStressDB(t *testing.T, pool *pgx.ConnPool) {
_, err := pool.Exec(` _, err := pool.Exec(context.Background(), `
drop table if exists widgets; drop table if exists widgets;
create table widgets( create table widgets(
id serial primary key, id serial primary key,
@ -121,7 +121,7 @@ func insertUnprepared(e execer, actionNum int) error {
insert into widgets(name, description, creation_time) insert into widgets(name, description, creation_time)
values($1, $2, $3)` values($1, $2, $3)`
_, err := e.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now()) _, err := e.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
return err return err
} }
@ -198,7 +198,7 @@ func queryErrorWhileReturningRows(q queryer, actionNum int) error {
} }
func notify(pool *pgx.ConnPool, actionNum int) error { func notify(pool *pgx.ConnPool, actionNum int) error {
_, err := pool.Exec("notify stress") _, err := pool.Exec(context.Background(), "notify stress")
return err return err
} }
@ -254,7 +254,7 @@ func txInsertRollback(pool *pgx.ConnPool, actionNum int) error {
insert into widgets(name, description, creation_time) insert into widgets(name, description, creation_time)
values($1, $2, $3)` values($1, $2, $3)`
_, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now()) _, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
if err != nil { if err != nil {
return err return err
} }
@ -272,7 +272,7 @@ func txInsertCommit(pool *pgx.ConnPool, actionNum int) error {
insert into widgets(name, description, creation_time) insert into widgets(name, description, creation_time)
values($1, $2, $3)` values($1, $2, $3)`
_, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now()) _, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -352,7 +352,7 @@ func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error {
cancelFunc() cancelFunc()
}() }()
_, err := pool.ExecEx(ctx, "select pg_sleep(2)", nil) _, err := pool.Exec(ctx, "select pg_sleep(2)")
if err != context.Canceled { if err != context.Canceled {
return errors.Errorf("Expected context.Canceled error, got %v", err) return errors.Errorf("Expected context.Canceled error, got %v", err)
} }

19
tx.go
View File

@ -90,7 +90,7 @@ func (c *Conn) Begin() (*Tx, error) {
// mode. Unlike database/sql, the context only affects the begin command. i.e. // mode. Unlike database/sql, the context only affects the begin command. i.e.
// there is no auto-rollback on context cancelation. // there is no auto-rollback on context cancelation.
func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
_, err := c.ExecEx(ctx, txOptions.beginSQL(), nil) _, err := c.Exec(ctx, txOptions.beginSQL())
if err != nil { if err != nil {
// begin should never fail unless there is an underlying connection issue or // begin should never fail unless there is an underlying connection issue or
// a context timeout. In either case, the connection is possibly broken. // a context timeout. In either case, the connection is possibly broken.
@ -123,7 +123,7 @@ func (tx *Tx) CommitEx(ctx context.Context) error {
return ErrTxClosed return ErrTxClosed
} }
commandTag, err := tx.conn.ExecEx(ctx, "commit", nil) commandTag, err := tx.conn.Exec(ctx, "commit")
if err == nil && string(commandTag) == "COMMIT" { if err == nil && string(commandTag) == "COMMIT" {
tx.status = TxStatusCommitSuccess tx.status = TxStatusCommitSuccess
} else if err == nil && string(commandTag) == "ROLLBACK" { } else if err == nil && string(commandTag) == "ROLLBACK" {
@ -159,7 +159,7 @@ func (tx *Tx) RollbackEx(ctx context.Context) error {
return ErrTxClosed return ErrTxClosed
} }
_, tx.err = tx.conn.ExecEx(ctx, "rollback", nil) _, tx.err = tx.conn.Exec(ctx, "rollback")
if tx.err == nil { if tx.err == nil {
tx.status = TxStatusRollbackSuccess tx.status = TxStatusRollbackSuccess
} else { } else {
@ -176,17 +176,8 @@ func (tx *Tx) RollbackEx(ctx context.Context) error {
} }
// Exec delegates to the underlying *Conn // Exec delegates to the underlying *Conn
func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
return tx.ExecEx(context.Background(), sql, nil, arguments...) return tx.conn.Exec(ctx, sql, arguments...)
}
// ExecEx delegates to the underlying *Conn
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
if tx.status != TxStatusInProgress {
return nil, ErrTxClosed
}
return tx.conn.ExecEx(ctx, sql, options, arguments...)
} }
// Prepare delegates to the underlying *Conn // Prepare delegates to the underlying *Conn

View File

@ -35,7 +35,7 @@ func TestTransactionSuccessfulCommit(t *testing.T) {
t.Fatalf("conn.Begin failed: %v", err) t.Fatalf("conn.Begin failed: %v", err)
} }
_, err = tx.Exec("insert into foo(id) values (1)") _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
if err != nil { if err != nil {
t.Fatalf("tx.Exec failed: %v", err) t.Fatalf("tx.Exec failed: %v", err)
} }
@ -77,12 +77,12 @@ func TestTxCommitWhenTxBroken(t *testing.T) {
t.Fatalf("conn.Begin failed: %v", err) t.Fatalf("conn.Begin failed: %v", err)
} }
if _, err := tx.Exec("insert into foo(id) values (1)"); err != nil { if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
t.Fatalf("tx.Exec failed: %v", err) t.Fatalf("tx.Exec failed: %v", err)
} }
// Purposely break transaction // Purposely break transaction
if _, err := tx.Exec("syntax error"); err == nil { if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
t.Fatal("Unexpected success") t.Fatal("Unexpected success")
} }
@ -107,12 +107,12 @@ func TestTxCommitSerializationFailure(t *testing.T) {
pool := createConnPool(t, 5) pool := createConnPool(t, 5)
defer pool.Close() defer pool.Close()
pool.Exec(`drop table if exists tx_serializable_sums`) pool.Exec(context.Background(), `drop table if exists tx_serializable_sums`)
_, err := pool.Exec(`create table tx_serializable_sums(num integer);`) _, err := pool.Exec(context.Background(), `create table tx_serializable_sums(num integer);`)
if err != nil { if err != nil {
t.Fatalf("Unable to create temporary table: %v", err) t.Fatalf("Unable to create temporary table: %v", err)
} }
defer pool.Exec(`drop table tx_serializable_sums`) defer pool.Exec(context.Background(), `drop table tx_serializable_sums`)
tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
if err != nil { if err != nil {
@ -126,12 +126,12 @@ func TestTxCommitSerializationFailure(t *testing.T) {
} }
defer tx2.Rollback() defer tx2.Rollback()
_, err = tx1.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`) _, err = tx1.Exec(context.Background(), `insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
if err != nil { if err != nil {
t.Fatalf("Exec failed: %v", err) t.Fatalf("Exec failed: %v", err)
} }
_, err = tx2.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`) _, err = tx2.Exec(context.Background(), `insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
if err != nil { if err != nil {
t.Fatalf("Exec failed: %v", err) t.Fatalf("Exec failed: %v", err)
} }
@ -169,7 +169,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) {
t.Fatalf("conn.Begin failed: %v", err) t.Fatalf("conn.Begin failed: %v", err)
} }
_, err = tx.Exec("insert into foo(id) values (1)") _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
if err != nil { if err != nil {
t.Fatalf("tx.Exec failed: %v", err) t.Fatalf("tx.Exec failed: %v", err)
} }
@ -373,12 +373,12 @@ func TestTxStatusErrorInTransactions(t *testing.T) {
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status) t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
} }
_, err = tx.Exec("savepoint s") _, err = tx.Exec(context.Background(), "savepoint s")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = tx.Exec("syntax error") _, err = tx.Exec(context.Background(), "syntax error")
if err == nil { if err == nil {
t.Fatal("expected an error but did not get one") t.Fatal("expected an error but did not get one")
} }
@ -387,7 +387,7 @@ func TestTxStatusErrorInTransactions(t *testing.T) {
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status) t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status)
} }
_, err = tx.Exec("rollback to s") _, err = tx.Exec(context.Background(), "rollback to s")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -417,7 +417,7 @@ func TestTxErr(t *testing.T) {
} }
// Purposely break transaction // Purposely break transaction
if _, err := tx.Exec("syntax error"); err == nil { if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
t.Fatal("Unexpected success") t.Fatal("Unexpected success")
} }