mirror of https://github.com/jackc/pgx.git
Remove ExecEx
parent
12857ad05b
commit
89c3d8af5d
|
@ -395,7 +395,7 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
|
||||||
|
|
||||||
for src.Next() {
|
for src.Next() {
|
||||||
values, _ := src.Values()
|
values, _ := src.Values()
|
||||||
if _, err = tx.Exec("insert_t", values...); err != nil {
|
if _, err = tx.Exec(context.Background(), "insert_t", values...); err != nil {
|
||||||
b.Fatalf("Exec unexpectedly failed with: %v", err)
|
b.Fatalf("Exec unexpectedly failed with: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -457,7 +457,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
|
||||||
rowsThisInsert++
|
rowsThisInsert++
|
||||||
|
|
||||||
if rowsThisInsert == maxRowsPerInsert {
|
if rowsThisInsert == maxRowsPerInsert {
|
||||||
_, err := tx.Exec(sqlBuf.String(), args...)
|
_, err := tx.Exec(context.Background(), sqlBuf.String(), args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -468,7 +468,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsThisInsert > 0 {
|
if rowsThisInsert > 0 {
|
||||||
_, err := tx.Exec(sqlBuf.String(), args...)
|
_, err := tx.Exec(context.Background(), sqlBuf.String(), args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
135
conn.go
135
conn.go
|
@ -1080,139 +1080,10 @@ func (c *Conn) cancelQuery() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Ping(ctx context.Context) error {
|
func (c *Conn) Ping(ctx context.Context) error {
|
||||||
_, err := c.ExecEx(ctx, ";", nil)
|
_, err := c.Exec(ctx, ";", nil)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (pgconn.CommandTag, error) {
|
|
||||||
c.lastStmtSent = false
|
|
||||||
err := c.waitForPreviousCancelQuery(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.lock(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer c.unlock()
|
|
||||||
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
commandTag, err := c.execEx(ctx, sql, options, arguments...)
|
|
||||||
if err != nil {
|
|
||||||
if c.shouldLog(LogLevelError) {
|
|
||||||
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
|
|
||||||
}
|
|
||||||
return commandTag, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.shouldLog(LogLevelInfo) {
|
|
||||||
endTime := time.Now()
|
|
||||||
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
|
|
||||||
}
|
|
||||||
|
|
||||||
return commandTag, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
|
||||||
err = c.initContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = c.termContext(err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
|
||||||
c.lastStmtSent = true
|
|
||||||
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else if options != nil && len(options.ParameterOIDs) > 0 {
|
|
||||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = appendSync(buf)
|
|
||||||
|
|
||||||
n, err := c.pgConn.Conn().Write(buf)
|
|
||||||
c.lastStmtSent = true
|
|
||||||
if err != nil && fatalWriteErr(n, err) {
|
|
||||||
c.die(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.pendingReadyForQueryCount++
|
|
||||||
} else {
|
|
||||||
if len(arguments) > 0 {
|
|
||||||
ps, ok := c.preparedStatements[sql]
|
|
||||||
if !ok {
|
|
||||||
var err error
|
|
||||||
ps, err = c.prepareEx("", sql, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.lastStmtSent = true
|
|
||||||
err = c.sendPreparedQuery(ps, arguments...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.lastStmtSent = true
|
|
||||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var softErr error
|
|
||||||
|
|
||||||
for {
|
|
||||||
msg, err := c.rxMsg()
|
|
||||||
if err != nil {
|
|
||||||
return commandTag, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case *pgproto3.ReadyForQuery:
|
|
||||||
c.rxReadyForQuery(msg)
|
|
||||||
return commandTag, softErr
|
|
||||||
case *pgproto3.CommandComplete:
|
|
||||||
commandTag = pgconn.CommandTag(msg.CommandTag)
|
|
||||||
default:
|
|
||||||
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
|
|
||||||
softErr = e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) {
|
|
||||||
if len(arguments) != len(options.ParameterOIDs) {
|
|
||||||
return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(options.ParameterOIDs) > 65535 {
|
|
||||||
return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs))
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = appendParse(buf, "", sql, options.ParameterOIDs)
|
|
||||||
buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
buf = appendExecute(buf, "", 0)
|
|
||||||
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) initContext(ctx context.Context) error {
|
func (c *Conn) initContext(ctx context.Context) error {
|
||||||
if c.ctxInProgress {
|
if c.ctxInProgress {
|
||||||
return errors.New("ctx already in progress")
|
return errors.New("ctx already in progress")
|
||||||
|
@ -1399,6 +1270,10 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(psd.ParamOIDs) != len(arguments) {
|
||||||
|
return nil, errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
ps := &PreparedStatement{
|
ps := &PreparedStatement{
|
||||||
Name: psd.Name,
|
Name: psd.Name,
|
||||||
SQL: psd.SQL,
|
SQL: psd.SQL,
|
||||||
|
|
14
conn_pool.go
14
conn_pool.go
|
@ -353,24 +353,14 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec acquires a connection, delegates the call to that connection, and releases the connection
|
// Exec acquires a connection, delegates the call to that connection, and releases the connection
|
||||||
func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
func (p *ConnPool) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||||
var c *Conn
|
var c *Conn
|
||||||
if c, err = p.Acquire(); err != nil {
|
if c, err = p.Acquire(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer p.Release(c)
|
defer p.Release(c)
|
||||||
|
|
||||||
return c.Exec(context.TODO(), sql, arguments...)
|
return c.Exec(ctx, sql, arguments...)
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
|
||||||
var c *Conn
|
|
||||||
if c, err = p.Acquire(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer p.Release(c)
|
|
||||||
|
|
||||||
return c.ExecEx(ctx, sql, options, arguments...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query acquires a connection and delegates the call to that connection. When
|
// Query acquires a connection and delegates the call to that connection. When
|
||||||
|
|
|
@ -801,7 +801,7 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) {
|
||||||
t.Error("Select called onDataRow wrong number of times")
|
t.Error("Select called onDataRow wrong number of times")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = pool.Exec("--;")
|
_, err = pool.Exec(context.Background(), "--;")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("pool.Exec failed: %v", err)
|
t.Fatalf("pool.Exec failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -841,7 +841,7 @@ func TestConnPoolExec(t *testing.T) {
|
||||||
pool := createConnPool(t, 2)
|
pool := createConnPool(t, 2)
|
||||||
defer pool.Close()
|
defer pool.Close()
|
||||||
|
|
||||||
results, err := pool.Exec("create temporary table foo(id integer primary key);")
|
results, err := pool.Exec(context.Background(), "create temporary table foo(id integer primary key);")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -849,7 +849,7 @@ func TestConnPoolExec(t *testing.T) {
|
||||||
t.Errorf("Unexpected results from Exec: %v", results)
|
t.Errorf("Unexpected results from Exec: %v", results)
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err = pool.Exec("insert into foo(id) values($1)", 1)
|
results, err = pool.Exec(context.Background(), "insert into foo(id) values($1)", 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -857,7 +857,7 @@ func TestConnPoolExec(t *testing.T) {
|
||||||
t.Errorf("Unexpected results from Exec: %v", results)
|
t.Errorf("Unexpected results from Exec: %v", results)
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err = pool.Exec("drop table foo;")
|
results, err = pool.Exec(context.Background(), "drop table foo;")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
||||||
}
|
}
|
||||||
|
|
168
conn_test.go
168
conn_test.go
|
@ -177,7 +177,7 @@ func TestExecFailureWithArguments(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExContextWithoutCancelation(t *testing.T) {
|
func TestExecContextWithoutCancelation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
@ -186,7 +186,7 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil)
|
commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
@ -207,7 +207,7 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil {
|
if _, err := conn.Exec(ctx, "selct;"); err == nil {
|
||||||
t.Fatal("Expected SQL syntax error")
|
t.Fatal("Expected SQL syntax error")
|
||||||
}
|
}
|
||||||
if !conn.LastStmtSent() {
|
if !conn.LastStmtSent() {
|
||||||
|
@ -224,7 +224,7 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
@ -233,7 +233,7 @@ func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
if _, err := conn.ExecEx(ctx, "selct $1;", nil, 1); err == nil {
|
if _, err := conn.Exec(ctx, "selct $1;", 1); err == nil {
|
||||||
t.Fatal("Expected SQL syntax error")
|
t.Fatal("Expected SQL syntax error")
|
||||||
}
|
}
|
||||||
if conn.LastStmtSent() {
|
if conn.LastStmtSent() {
|
||||||
|
@ -241,7 +241,7 @@ func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExContextCancelationCancelsQuery(t *testing.T) {
|
func TestExecContextCancelationCancelsQuery(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
@ -253,7 +253,7 @@ func TestExecExContextCancelationCancelsQuery(t *testing.T) {
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil)
|
_, err := conn.Exec(ctx, "select pg_sleep(60)")
|
||||||
if err != context.Canceled {
|
if err != context.Canceled {
|
||||||
t.Fatalf("Expected context.Canceled err, got %v", err)
|
t.Fatalf("Expected context.Canceled err, got %v", err)
|
||||||
}
|
}
|
||||||
|
@ -278,7 +278,7 @@ func TestExecFailureCloseBefore(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExExtendedProtocol(t *testing.T) {
|
func TestExecExtendedProtocol(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
@ -287,18 +287,17 @@ func TestExecExExtendedProtocol(t *testing.T) {
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
commandTag, err := conn.Exec(ctx, "create temporary table foo(name varchar primary key);")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if string(commandTag) != "CREATE TABLE" {
|
if string(commandTag) != "CREATE TABLE" {
|
||||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||||
}
|
}
|
||||||
|
|
||||||
commandTag, err = conn.ExecEx(
|
commandTag, err = conn.Exec(
|
||||||
ctx,
|
ctx,
|
||||||
"insert into foo(name) values($1);",
|
"insert into foo(name) values($1);",
|
||||||
nil,
|
|
||||||
"bar",
|
"bar",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -311,119 +310,42 @@ func TestExecExExtendedProtocol(t *testing.T) {
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExSimpleProtocol(t *testing.T) {
|
func TestExecSimpleProtocol(t *testing.T) {
|
||||||
t.Parallel()
|
t.Skip("TODO when with simple protocol supported in connection")
|
||||||
|
// t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
// conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
// defer closeConn(t, conn)
|
||||||
|
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
// ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
// defer cancelFunc()
|
||||||
|
|
||||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
// commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Fatal(err)
|
// t.Fatal(err)
|
||||||
}
|
// }
|
||||||
if string(commandTag) != "CREATE TABLE" {
|
// if string(commandTag) != "CREATE TABLE" {
|
||||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
// t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||||
}
|
// }
|
||||||
if !conn.LastStmtSent() {
|
// if !conn.LastStmtSent() {
|
||||||
t.Error("Expected LastStmtSent to return true")
|
// t.Error("Expected LastStmtSent to return true")
|
||||||
}
|
// }
|
||||||
|
|
||||||
commandTag, err = conn.ExecEx(
|
// commandTag, err = conn.ExecEx(
|
||||||
ctx,
|
// ctx,
|
||||||
"insert into foo(name) values($1);",
|
// "insert into foo(name) values($1);",
|
||||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
// &pgx.QueryExOptions{SimpleProtocol: true},
|
||||||
"bar'; drop table foo;--",
|
// "bar'; drop table foo;--",
|
||||||
)
|
// )
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Fatal(err)
|
// t.Fatal(err)
|
||||||
}
|
// }
|
||||||
if string(commandTag) != "INSERT 0 1" {
|
// if string(commandTag) != "INSERT 0 1" {
|
||||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
// t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||||
}
|
// }
|
||||||
if !conn.LastStmtSent() {
|
// if !conn.LastStmtSent() {
|
||||||
t.Error("Expected LastStmtSent to return true")
|
// t.Error("Expected LastStmtSent to return true")
|
||||||
}
|
// }
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
|
||||||
|
|
||||||
commandTag, err := conn.ExecEx(
|
|
||||||
context.Background(),
|
|
||||||
"insert into foo(name) values($1);",
|
|
||||||
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.VarcharOID}},
|
|
||||||
"bar'; drop table foo;--",
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if string(commandTag) != "INSERT 0 1" {
|
|
||||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
|
||||||
}
|
|
||||||
if !conn.LastStmtSent() {
|
|
||||||
t.Error("Expected LastStmtSent to return true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
|
||||||
|
|
||||||
_, err := conn.ExecEx(
|
|
||||||
context.Background(),
|
|
||||||
"insert into foo(name) values($1);",
|
|
||||||
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}},
|
|
||||||
"bar'; drop table foo;--",
|
|
||||||
)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error but got none")
|
|
||||||
}
|
|
||||||
if !conn.LastStmtSent() {
|
|
||||||
t.Error("Expected LastStmtSent to return true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
|
||||||
|
|
||||||
var s string
|
|
||||||
err := conn.QueryRow("insert into foo(name) values('baz') returning name;").Scan(&s)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Executing query failed: %v", err)
|
|
||||||
}
|
|
||||||
if s != "baz" {
|
|
||||||
t.Errorf("Query did not return expected value: %v", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.ExecEx(
|
|
||||||
context.Background(),
|
|
||||||
"insert into foo(name) values($1);",
|
|
||||||
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}},
|
|
||||||
"bar'; drop table foo;--",
|
|
||||||
)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error but got none")
|
|
||||||
}
|
|
||||||
if !conn.LastStmtSent() {
|
|
||||||
t.Error("Expected LastStmtSent to return true")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExFailureCloseBefore(t *testing.T) {
|
func TestExecExFailureCloseBefore(t *testing.T) {
|
||||||
|
@ -432,7 +354,7 @@ func TestExecExFailureCloseBefore(t *testing.T) {
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
closeConn(t, conn)
|
closeConn(t, conn)
|
||||||
|
|
||||||
if _, err := conn.ExecEx(context.Background(), "select 1", nil); err == nil {
|
if _, err := conn.Exec(context.Background(), "select 1", nil); err == nil {
|
||||||
t.Fatal("Expected network error")
|
t.Fatal("Expected network error")
|
||||||
}
|
}
|
||||||
if conn.LastStmtSent() {
|
if conn.LastStmtSent() {
|
||||||
|
|
|
@ -17,7 +17,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type execer interface {
|
type execer interface {
|
||||||
Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
|
Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
|
||||||
}
|
}
|
||||||
type queryer interface {
|
type queryer interface {
|
||||||
Query(sql string, args ...interface{}) (*pgx.Rows, error)
|
Query(sql string, args ...interface{}) (*pgx.Rows, error)
|
||||||
|
@ -102,7 +102,7 @@ func TestStressConnPool(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupStressDB(t *testing.T, pool *pgx.ConnPool) {
|
func setupStressDB(t *testing.T, pool *pgx.ConnPool) {
|
||||||
_, err := pool.Exec(`
|
_, err := pool.Exec(context.Background(), `
|
||||||
drop table if exists widgets;
|
drop table if exists widgets;
|
||||||
create table widgets(
|
create table widgets(
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
|
@ -121,7 +121,7 @@ func insertUnprepared(e execer, actionNum int) error {
|
||||||
insert into widgets(name, description, creation_time)
|
insert into widgets(name, description, creation_time)
|
||||||
values($1, $2, $3)`
|
values($1, $2, $3)`
|
||||||
|
|
||||||
_, err := e.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now())
|
_, err := e.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ func queryErrorWhileReturningRows(q queryer, actionNum int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func notify(pool *pgx.ConnPool, actionNum int) error {
|
func notify(pool *pgx.ConnPool, actionNum int) error {
|
||||||
_, err := pool.Exec("notify stress")
|
_, err := pool.Exec(context.Background(), "notify stress")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ func txInsertRollback(pool *pgx.ConnPool, actionNum int) error {
|
||||||
insert into widgets(name, description, creation_time)
|
insert into widgets(name, description, creation_time)
|
||||||
values($1, $2, $3)`
|
values($1, $2, $3)`
|
||||||
|
|
||||||
_, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now())
|
_, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -272,7 +272,7 @@ func txInsertCommit(pool *pgx.ConnPool, actionNum int) error {
|
||||||
insert into widgets(name, description, creation_time)
|
insert into widgets(name, description, creation_time)
|
||||||
values($1, $2, $3)`
|
values($1, $2, $3)`
|
||||||
|
|
||||||
_, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now())
|
_, err = tx.Exec(context.Background(), sql, fake.ProductName(), fake.Sentences(), time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
|
@ -352,7 +352,7 @@ func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error {
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := pool.ExecEx(ctx, "select pg_sleep(2)", nil)
|
_, err := pool.Exec(ctx, "select pg_sleep(2)")
|
||||||
if err != context.Canceled {
|
if err != context.Canceled {
|
||||||
return errors.Errorf("Expected context.Canceled error, got %v", err)
|
return errors.Errorf("Expected context.Canceled error, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
19
tx.go
19
tx.go
|
@ -90,7 +90,7 @@ func (c *Conn) Begin() (*Tx, error) {
|
||||||
// mode. Unlike database/sql, the context only affects the begin command. i.e.
|
// mode. Unlike database/sql, the context only affects the begin command. i.e.
|
||||||
// there is no auto-rollback on context cancelation.
|
// there is no auto-rollback on context cancelation.
|
||||||
func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
|
func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
|
||||||
_, err := c.ExecEx(ctx, txOptions.beginSQL(), nil)
|
_, err := c.Exec(ctx, txOptions.beginSQL())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// begin should never fail unless there is an underlying connection issue or
|
// begin should never fail unless there is an underlying connection issue or
|
||||||
// a context timeout. In either case, the connection is possibly broken.
|
// a context timeout. In either case, the connection is possibly broken.
|
||||||
|
@ -123,7 +123,7 @@ func (tx *Tx) CommitEx(ctx context.Context) error {
|
||||||
return ErrTxClosed
|
return ErrTxClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
commandTag, err := tx.conn.ExecEx(ctx, "commit", nil)
|
commandTag, err := tx.conn.Exec(ctx, "commit")
|
||||||
if err == nil && string(commandTag) == "COMMIT" {
|
if err == nil && string(commandTag) == "COMMIT" {
|
||||||
tx.status = TxStatusCommitSuccess
|
tx.status = TxStatusCommitSuccess
|
||||||
} else if err == nil && string(commandTag) == "ROLLBACK" {
|
} else if err == nil && string(commandTag) == "ROLLBACK" {
|
||||||
|
@ -159,7 +159,7 @@ func (tx *Tx) RollbackEx(ctx context.Context) error {
|
||||||
return ErrTxClosed
|
return ErrTxClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
_, tx.err = tx.conn.ExecEx(ctx, "rollback", nil)
|
_, tx.err = tx.conn.Exec(ctx, "rollback")
|
||||||
if tx.err == nil {
|
if tx.err == nil {
|
||||||
tx.status = TxStatusRollbackSuccess
|
tx.status = TxStatusRollbackSuccess
|
||||||
} else {
|
} else {
|
||||||
|
@ -176,17 +176,8 @@ func (tx *Tx) RollbackEx(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec delegates to the underlying *Conn
|
// Exec delegates to the underlying *Conn
|
||||||
func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||||
return tx.ExecEx(context.Background(), sql, nil, arguments...)
|
return tx.conn.Exec(ctx, sql, arguments...)
|
||||||
}
|
|
||||||
|
|
||||||
// ExecEx delegates to the underlying *Conn
|
|
||||||
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
|
||||||
if tx.status != TxStatusInProgress {
|
|
||||||
return nil, ErrTxClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.conn.ExecEx(ctx, sql, options, arguments...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare delegates to the underlying *Conn
|
// Prepare delegates to the underlying *Conn
|
||||||
|
|
26
tx_test.go
26
tx_test.go
|
@ -35,7 +35,7 @@ func TestTransactionSuccessfulCommit(t *testing.T) {
|
||||||
t.Fatalf("conn.Begin failed: %v", err)
|
t.Fatalf("conn.Begin failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("insert into foo(id) values (1)")
|
_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("tx.Exec failed: %v", err)
|
t.Fatalf("tx.Exec failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -77,12 +77,12 @@ func TestTxCommitWhenTxBroken(t *testing.T) {
|
||||||
t.Fatalf("conn.Begin failed: %v", err)
|
t.Fatalf("conn.Begin failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := tx.Exec("insert into foo(id) values (1)"); err != nil {
|
if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
|
||||||
t.Fatalf("tx.Exec failed: %v", err)
|
t.Fatalf("tx.Exec failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Purposely break transaction
|
// Purposely break transaction
|
||||||
if _, err := tx.Exec("syntax error"); err == nil {
|
if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
|
||||||
t.Fatal("Unexpected success")
|
t.Fatal("Unexpected success")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,12 +107,12 @@ func TestTxCommitSerializationFailure(t *testing.T) {
|
||||||
pool := createConnPool(t, 5)
|
pool := createConnPool(t, 5)
|
||||||
defer pool.Close()
|
defer pool.Close()
|
||||||
|
|
||||||
pool.Exec(`drop table if exists tx_serializable_sums`)
|
pool.Exec(context.Background(), `drop table if exists tx_serializable_sums`)
|
||||||
_, err := pool.Exec(`create table tx_serializable_sums(num integer);`)
|
_, err := pool.Exec(context.Background(), `create table tx_serializable_sums(num integer);`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to create temporary table: %v", err)
|
t.Fatalf("Unable to create temporary table: %v", err)
|
||||||
}
|
}
|
||||||
defer pool.Exec(`drop table tx_serializable_sums`)
|
defer pool.Exec(context.Background(), `drop table tx_serializable_sums`)
|
||||||
|
|
||||||
tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
|
tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -126,12 +126,12 @@ func TestTxCommitSerializationFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer tx2.Rollback()
|
defer tx2.Rollback()
|
||||||
|
|
||||||
_, err = tx1.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
_, err = tx1.Exec(context.Background(), `insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Exec failed: %v", err)
|
t.Fatalf("Exec failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx2.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
_, err = tx2.Exec(context.Background(), `insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Exec failed: %v", err)
|
t.Fatalf("Exec failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -169,7 +169,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) {
|
||||||
t.Fatalf("conn.Begin failed: %v", err)
|
t.Fatalf("conn.Begin failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("insert into foo(id) values (1)")
|
_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("tx.Exec failed: %v", err)
|
t.Fatalf("tx.Exec failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -373,12 +373,12 @@ func TestTxStatusErrorInTransactions(t *testing.T) {
|
||||||
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("savepoint s")
|
_, err = tx.Exec(context.Background(), "savepoint s")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("syntax error")
|
_, err = tx.Exec(context.Background(), "syntax error")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected an error but did not get one")
|
t.Fatal("expected an error but did not get one")
|
||||||
}
|
}
|
||||||
|
@ -387,7 +387,7 @@ func TestTxStatusErrorInTransactions(t *testing.T) {
|
||||||
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status)
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("rollback to s")
|
_, err = tx.Exec(context.Background(), "rollback to s")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -417,7 +417,7 @@ func TestTxErr(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Purposely break transaction
|
// Purposely break transaction
|
||||||
if _, err := tx.Exec("syntax error"); err == nil {
|
if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
|
||||||
t.Fatal("Unexpected success")
|
t.Fatal("Unexpected success")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue