mirror of https://github.com/jackc/pgx.git
Remove ExecEx
parent
12857ad05b
commit
89c3d8af5d
|
@ -395,7 +395,7 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
|
|||
|
||||
for src.Next() {
|
||||
values, _ := src.Values()
|
||||
if _, err = tx.Exec("insert_t", values...); err != nil {
|
||||
if _, err = tx.Exec(context.Background(), "insert_t", values...); err != nil {
|
||||
b.Fatalf("Exec unexpectedly failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -457,7 +457,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
|
|||
rowsThisInsert++
|
||||
|
||||
if rowsThisInsert == maxRowsPerInsert {
|
||||
_, err := tx.Exec(sqlBuf.String(), args...)
|
||||
_, err := tx.Exec(context.Background(), sqlBuf.String(), args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -468,7 +468,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
|
|||
}
|
||||
|
||||
if rowsThisInsert > 0 {
|
||||
_, err := tx.Exec(sqlBuf.String(), args...)
|
||||
_, err := tx.Exec(context.Background(), sqlBuf.String(), args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
135
conn.go
135
conn.go
|
@ -1080,139 +1080,10 @@ func (c *Conn) cancelQuery() {
|
|||
}
|
||||
|
||||
func (c *Conn) Ping(ctx context.Context) error {
|
||||
_, err := c.ExecEx(ctx, ";", nil)
|
||||
_, err := c.Exec(ctx, ";", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (pgconn.CommandTag, error) {
|
||||
c.lastStmtSent = false
|
||||
err := c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.lock(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer c.unlock()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
commandTag, err := c.execEx(ctx, sql, options, arguments...)
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
|
||||
}
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
||||
c.lastStmtSent = true
|
||||
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if options != nil && len(options.ParameterOIDs) > 0 {
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = appendSync(buf)
|
||||
|
||||
n, err := c.pgConn.Conn().Write(buf)
|
||||
c.lastStmtSent = true
|
||||
if err != nil && fatalWriteErr(n, err) {
|
||||
c.die(err)
|
||||
return nil, err
|
||||
}
|
||||
c.pendingReadyForQueryCount++
|
||||
} else {
|
||||
if len(arguments) > 0 {
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
ps, err = c.prepareEx("", sql, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
c.lastStmtSent = true
|
||||
err = c.sendPreparedQuery(ps, arguments...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
c.lastStmtSent = true
|
||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var softErr error
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return commandTag, softErr
|
||||
case *pgproto3.CommandComplete:
|
||||
commandTag = pgconn.CommandTag(msg.CommandTag)
|
||||
default:
|
||||
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
|
||||
softErr = e
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) {
|
||||
if len(arguments) != len(options.ParameterOIDs) {
|
||||
return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs))
|
||||
}
|
||||
|
||||
if len(options.ParameterOIDs) > 65535 {
|
||||
return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs))
|
||||
}
|
||||
|
||||
buf = appendParse(buf, "", sql, options.ParameterOIDs)
|
||||
buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf = appendExecute(buf, "", 0)
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (c *Conn) initContext(ctx context.Context) error {
|
||||
if c.ctxInProgress {
|
||||
return errors.New("ctx already in progress")
|
||||
|
@ -1399,6 +1270,10 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if len(psd.ParamOIDs) != len(arguments) {
|
||||
return nil, errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(arguments))
|
||||
}
|
||||
|
||||
ps := &PreparedStatement{
|
||||
Name: psd.Name,
|
||||
SQL: psd.SQL,
|
||||
|
|
14
conn_pool.go
14
conn_pool.go
|
@ -353,24 +353,14 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
|||
}
|
||||
|
||||
// Exec acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
func (p *ConnPool) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
var c *Conn
|
||||
if c, err = p.Acquire(); err != nil {
|
||||
return
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.Exec(context.TODO(), sql, arguments...)
|
||||
}
|
||||
|
||||
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
var c *Conn
|
||||
if c, err = p.Acquire(); err != nil {
|
||||
return
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.ExecEx(ctx, sql, options, arguments...)
|
||||
return c.Exec(ctx, sql, arguments...)
|
||||
}
|
||||
|
||||
// Query acquires a connection and delegates the call to that connection. When
|
||||
|
|
|
@ -801,7 +801,7 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) {
|
|||
t.Error("Select called onDataRow wrong number of times")
|
||||
}
|
||||
|
||||
_, err = pool.Exec("--;")
|
||||
_, err = pool.Exec(context.Background(), "--;")
|
||||
if err != nil {
|
||||
t.Fatalf("pool.Exec failed: %v", err)
|
||||
}
|
||||
|
@ -841,7 +841,7 @@ func TestConnPoolExec(t *testing.T) {
|
|||
pool := createConnPool(t, 2)
|
||||
defer pool.Close()
|
||||
|
||||
results, err := pool.Exec("create temporary table foo(id integer primary key);")
|
||||
results, err := pool.Exec(context.Background(), "create temporary table foo(id integer primary key);")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
||||
}
|
||||
|
@ -849,7 +849,7 @@ func TestConnPoolExec(t *testing.T) {
|
|||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
|
||||
results, err = pool.Exec("insert into foo(id) values($1)", 1)
|
||||
results, err = pool.Exec(context.Background(), "insert into foo(id) values($1)", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
||||
}
|
||||
|
@ -857,7 +857,7 @@ func TestConnPoolExec(t *testing.T) {
|
|||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
|
||||
results, err = pool.Exec("drop table foo;")
|
||||
results, err = pool.Exec(context.Background(), "drop table foo;")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
||||
}
|
||||
|
|
168
conn_test.go
168
conn_test.go
|
@ -177,7 +177,7 @@ func TestExecFailureWithArguments(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecExContextWithoutCancelation(t *testing.T) {
|
||||
func TestExecContextWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -186,7 +186,7 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
|
|||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil)
|
||||
commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
||||
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -207,7 +207,7 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
|||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil {
|
||||
if _, err := conn.Exec(ctx, "selct;"); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
|
@ -224,7 +224,7 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
||||
func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -233,7 +233,7 @@ func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
|||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
if _, err := conn.ExecEx(ctx, "selct $1;", nil, 1); err == nil {
|
||||
if _, err := conn.Exec(ctx, "selct $1;", 1); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
if conn.LastStmtSent() {
|
||||
|
@ -241,7 +241,7 @@ func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecExContextCancelationCancelsQuery(t *testing.T) {
|
||||
func TestExecContextCancelationCancelsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -253,7 +253,7 @@ func TestExecExContextCancelationCancelsQuery(t *testing.T) {
|
|||
cancelFunc()
|
||||
}()
|
||||
|
||||
_, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil)
|
||||
_, err := conn.Exec(ctx, "select pg_sleep(60)")
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("Expected context.Canceled err, got %v", err)
|
||||
}
|
||||
|
@ -278,7 +278,7 @@ func TestExecFailureCloseBefore(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecExExtendedProtocol(t *testing.T) {
|
||||
func TestExecExtendedProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -287,18 +287,17 @@ func TestExecExExtendedProtocol(t *testing.T) {
|
|||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
||||
commandTag, err := conn.Exec(ctx, "create temporary table foo(name varchar primary key);")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(commandTag) != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||
}
|
||||
|
||||
commandTag, err = conn.ExecEx(
|
||||
commandTag, err = conn.Exec(
|
||||
ctx,
|
||||
"insert into foo(name) values($1);",
|
||||
nil,
|
||||
"bar",
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -311,119 +310,42 @@ func TestExecExExtendedProtocol(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExecExSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
func TestExecSimpleProtocol(t *testing.T) {
|
||||
t.Skip("TODO when with simple protocol supported in connection")
|
||||
// t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
// conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
// defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
// ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
// defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(commandTag) != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
// commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// if string(commandTag) != "CREATE TABLE" {
|
||||
// t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
// }
|
||||
// if !conn.LastStmtSent() {
|
||||
// t.Error("Expected LastStmtSent to return true")
|
||||
// }
|
||||
|
||||
commandTag, err = conn.ExecEx(
|
||||
ctx,
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(commandTag) != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
||||
|
||||
commandTag, err := conn.ExecEx(
|
||||
context.Background(),
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.VarcharOID}},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(commandTag) != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
||||
|
||||
_, err := conn.ExecEx(
|
||||
context.Background(),
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
||||
|
||||
var s string
|
||||
err := conn.QueryRow("insert into foo(name) values('baz') returning name;").Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("Executing query failed: %v", err)
|
||||
}
|
||||
if s != "baz" {
|
||||
t.Errorf("Query did not return expected value: %v", s)
|
||||
}
|
||||
|
||||
_, err = conn.ExecEx(
|
||||
context.Background(),
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
// commandTag, err = conn.ExecEx(
|
||||
// ctx,
|
||||
// "insert into foo(name) values($1);",
|
||||
// &pgx.QueryExOptions{SimpleProtocol: true},
|
||||
// "bar'; drop table foo;--",
|
||||
// )
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// if string(commandTag) != "INSERT 0 1" {
|
||||
// t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
// }
|
||||
// if !conn.LastStmtSent() {
|
||||
// t.Error("Expected LastStmtSent to return true")
|
||||
// }
|
||||
}
|
||||
|
||||
func TestExecExFailureCloseBefore(t *testing.T) {
|
||||
|
@ -432,7 +354,7 @@ func TestExecExFailureCloseBefore(t *testing.T) {
|
|||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
closeConn(t, conn)
|
||||
|
||||
if _, err := conn.ExecEx(context.Background(), "select 1", nil); err == nil {
|
||||
if _, err := conn.Exec(context.Background(), "select 1", nil); err == nil {
|
||||
t.Fatal("Expected network error")
|
||||
}
|
||||
if conn.LastStmtSent() {
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
)
|
||||
|
||||
type execer interface {
|
||||
Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
|
||||
Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
|
||||
}
|
||||
type queryer interface {
|
||||
Query(sql string, args ...interface{}) (*pgx.Rows, error)
|
||||
|
@ -102,7 +102,7 @@ func TestStressConnPool(t *testing.T) {
|
|||
}
|
||||
|
||||
func setupStressDB(t *testing.T, pool *pgx.ConnPool) {
|
||||
_, err := pool.Exec(`
|
||||
_, err := pool.Exec(context.Background(), `
|
||||
drop table if exists widgets;
|
||||
create table widgets(
|
||||
id serial primary key,
|
||||
|
@ -121,7 +121,7 @@ func insertUnprepared(e execer, actionNum int) error {
|
|||
insert into widgets(name, description, creation_time)
|
||||
values($1, $2, $3)`
|
||||
|
||||
_, err := e.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
_, err := e.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -198,7 +198,7 @@ func queryErrorWhileReturningRows(q queryer, actionNum int) error {
|
|||
}
|
||||
|
||||
func notify(pool *pgx.ConnPool, actionNum int) error {
|
||||
_, err := pool.Exec("notify stress")
|
||||
_, err := pool.Exec(context.Background(), "notify stress")
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -254,7 +254,7 @@ func txInsertRollback(pool *pgx.ConnPool, actionNum int) error {
|
|||
insert into widgets(name, description, creation_time)
|
||||
values($1, $2, $3)`
|
||||
|
||||
_, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
_, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -272,7 +272,7 @@ func txInsertCommit(pool *pgx.ConnPool, actionNum int) error {
|
|||
insert into widgets(name, description, creation_time)
|
||||
values($1, $2, $3)`
|
||||
|
||||
_, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
_, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
|
@ -352,7 +352,7 @@ func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error {
|
|||
cancelFunc()
|
||||
}()
|
||||
|
||||
_, err := pool.ExecEx(ctx, "select pg_sleep(2)", nil)
|
||||
_, err := pool.Exec(ctx, "select pg_sleep(2)")
|
||||
if err != context.Canceled {
|
||||
return errors.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
|
|
19
tx.go
19
tx.go
|
@ -90,7 +90,7 @@ func (c *Conn) Begin() (*Tx, error) {
|
|||
// mode. Unlike database/sql, the context only affects the begin command. i.e.
|
||||
// there is no auto-rollback on context cancelation.
|
||||
func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
|
||||
_, err := c.ExecEx(ctx, txOptions.beginSQL(), nil)
|
||||
_, err := c.Exec(ctx, txOptions.beginSQL())
|
||||
if err != nil {
|
||||
// begin should never fail unless there is an underlying connection issue or
|
||||
// a context timeout. In either case, the connection is possibly broken.
|
||||
|
@ -123,7 +123,7 @@ func (tx *Tx) CommitEx(ctx context.Context) error {
|
|||
return ErrTxClosed
|
||||
}
|
||||
|
||||
commandTag, err := tx.conn.ExecEx(ctx, "commit", nil)
|
||||
commandTag, err := tx.conn.Exec(ctx, "commit")
|
||||
if err == nil && string(commandTag) == "COMMIT" {
|
||||
tx.status = TxStatusCommitSuccess
|
||||
} else if err == nil && string(commandTag) == "ROLLBACK" {
|
||||
|
@ -159,7 +159,7 @@ func (tx *Tx) RollbackEx(ctx context.Context) error {
|
|||
return ErrTxClosed
|
||||
}
|
||||
|
||||
_, tx.err = tx.conn.ExecEx(ctx, "rollback", nil)
|
||||
_, tx.err = tx.conn.Exec(ctx, "rollback")
|
||||
if tx.err == nil {
|
||||
tx.status = TxStatusRollbackSuccess
|
||||
} else {
|
||||
|
@ -176,17 +176,8 @@ func (tx *Tx) RollbackEx(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Exec delegates to the underlying *Conn
|
||||
func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
return tx.ExecEx(context.Background(), sql, nil, arguments...)
|
||||
}
|
||||
|
||||
// ExecEx delegates to the underlying *Conn
|
||||
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return nil, ErrTxClosed
|
||||
}
|
||||
|
||||
return tx.conn.ExecEx(ctx, sql, options, arguments...)
|
||||
func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
return tx.conn.Exec(ctx, sql, arguments...)
|
||||
}
|
||||
|
||||
// Prepare delegates to the underlying *Conn
|
||||
|
|
26
tx_test.go
26
tx_test.go
|
@ -35,7 +35,7 @@ func TestTransactionSuccessfulCommit(t *testing.T) {
|
|||
t.Fatalf("conn.Begin failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("insert into foo(id) values (1)")
|
||||
_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||
if err != nil {
|
||||
t.Fatalf("tx.Exec failed: %v", err)
|
||||
}
|
||||
|
@ -77,12 +77,12 @@ func TestTxCommitWhenTxBroken(t *testing.T) {
|
|||
t.Fatalf("conn.Begin failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec("insert into foo(id) values (1)"); err != nil {
|
||||
if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
|
||||
t.Fatalf("tx.Exec failed: %v", err)
|
||||
}
|
||||
|
||||
// Purposely break transaction
|
||||
if _, err := tx.Exec("syntax error"); err == nil {
|
||||
if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
|
||||
t.Fatal("Unexpected success")
|
||||
}
|
||||
|
||||
|
@ -107,12 +107,12 @@ func TestTxCommitSerializationFailure(t *testing.T) {
|
|||
pool := createConnPool(t, 5)
|
||||
defer pool.Close()
|
||||
|
||||
pool.Exec(`drop table if exists tx_serializable_sums`)
|
||||
_, err := pool.Exec(`create table tx_serializable_sums(num integer);`)
|
||||
pool.Exec(context.Background(), `drop table if exists tx_serializable_sums`)
|
||||
_, err := pool.Exec(context.Background(), `create table tx_serializable_sums(num integer);`)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create temporary table: %v", err)
|
||||
}
|
||||
defer pool.Exec(`drop table tx_serializable_sums`)
|
||||
defer pool.Exec(context.Background(), `drop table tx_serializable_sums`)
|
||||
|
||||
tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
|
||||
if err != nil {
|
||||
|
@ -126,12 +126,12 @@ func TestTxCommitSerializationFailure(t *testing.T) {
|
|||
}
|
||||
defer tx2.Rollback()
|
||||
|
||||
_, err = tx1.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
||||
_, err = tx1.Exec(context.Background(), `insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
||||
if err != nil {
|
||||
t.Fatalf("Exec failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = tx2.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
||||
_, err = tx2.Exec(context.Background(), `insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
||||
if err != nil {
|
||||
t.Fatalf("Exec failed: %v", err)
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) {
|
|||
t.Fatalf("conn.Begin failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("insert into foo(id) values (1)")
|
||||
_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||
if err != nil {
|
||||
t.Fatalf("tx.Exec failed: %v", err)
|
||||
}
|
||||
|
@ -373,12 +373,12 @@ func TestTxStatusErrorInTransactions(t *testing.T) {
|
|||
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("savepoint s")
|
||||
_, err = tx.Exec(context.Background(), "savepoint s")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("syntax error")
|
||||
_, err = tx.Exec(context.Background(), "syntax error")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error but did not get one")
|
||||
}
|
||||
|
@ -387,7 +387,7 @@ func TestTxStatusErrorInTransactions(t *testing.T) {
|
|||
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("rollback to s")
|
||||
_, err = tx.Exec(context.Background(), "rollback to s")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -417,7 +417,7 @@ func TestTxErr(t *testing.T) {
|
|||
}
|
||||
|
||||
// Purposely break transaction
|
||||
if _, err := tx.Exec("syntax error"); err == nil {
|
||||
if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
|
||||
t.Fatal("Unexpected success")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue