From b7e56b003a29a11fc5ee1a04962d39f21536c4da Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 17:12:20 -0500 Subject: [PATCH] Replace lastStmtSent with pgconn support --- conn.go | 18 --------------- conn_test.go | 64 +++++++++++++++++++-------------------------------- go.mod | 5 ++-- go.sum | 6 +++++ query.go | 2 -- query_test.go | 22 ++++++++---------- stdlib/sql.go | 8 ++++--- 7 files changed, 47 insertions(+), 78 deletions(-) diff --git a/conn.go b/conn.go index bb094f05..a633b349 100644 --- a/conn.go +++ b/conn.go @@ -45,8 +45,6 @@ type Conn struct { causeOfDeath error - lastStmtSent bool - doneChan chan struct{} closedChan chan error @@ -392,17 +390,6 @@ func connInfoFromRows(rows Rows, err error) (map[string]pgtype.OID, error) { return nameOIDs, err } -// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to -// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets -// the value to false initially until the statement has been sent. This does -// NOT mean that the statement was successful or even received, it just means -// that a write was attempted and therefore it could have been executed. Calls -// to prepare a statement are ignored, only when the prepared statement is -// attempted to be executed will this return true. -func (c *Conn) LastStmtSent() bool { - return c.lastStmtSent -} - // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the // PostgreSQL connection than pgx exposes. // @@ -413,8 +400,6 @@ func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { - c.lastStmtSent = false - startTime := time.Now() commandTag, err := c.exec(ctx, sql, arguments...) @@ -462,13 +447,11 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) ( } } - c.lastStmtSent = true result := c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats).Read() return result.CommandTag, result.Err } if len(arguments) == 0 { - c.lastStmtSent = true results, err := c.pgConn.Exec(ctx, sql).ReadAll() if err != nil { return nil, err @@ -529,7 +512,6 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) ( } } - c.lastStmtSent = true result := c.pgConn.ExecPrepared(ctx, psd.Name, paramValues, paramFormats, resultFormats).Read() return result.CommandTag, result.Err } diff --git a/conn_test.go b/conn_test.go index ba4f038d..28f78744 100644 --- a/conn_test.go +++ b/conn_test.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" "github.com/stretchr/testify/require" + errors "golang.org/x/xerrors" ) func TestCrateDBConnect(t *testing.T) { @@ -122,18 +123,12 @@ func TestExecFailure(t *testing.T) { if _, err := conn.Exec(context.Background(), "selct;"); err == nil { t.Fatal("Expected SQL syntax error") } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } rows, _ := conn.Query(context.Background(), "select 1") rows.Close() if rows.Err() != nil { t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } } func TestExecFailureWithArguments(t *testing.T) { @@ -142,11 +137,12 @@ func TestExecFailureWithArguments(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - if _, err := conn.Exec(context.Background(), "selct $1;", 1); err == nil { + _, err := conn.Exec(context.Background(), "selct $1;", 1) + if err == nil { t.Fatal("Expected SQL syntax error") } - if conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return false") + if errors.Is(err, pgconn.ErrNoBytesSent) { + t.Error("Expected bytes to be sent to server") } } @@ -164,10 +160,10 @@ func TestExecContextWithoutCancelation(t *testing.T) { t.Fatal(err) } if string(commandTag) != "CREATE TABLE" { - t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + t.Fatalf("Unexpected results from Exec: %v", commandTag) } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") + if errors.Is(err, pgconn.ErrNoBytesSent) { + t.Error("Expected bytes to be sent to server") } } @@ -180,11 +176,12 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - if _, err := conn.Exec(ctx, "selct;"); err == nil { + _, err := conn.Exec(ctx, "selct;") + if err == nil { t.Fatal("Expected SQL syntax error") } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") + if errors.Is(err, pgconn.ErrNoBytesSent) { + t.Error("Expected bytes to be sent to server") } rows, _ := conn.Query(context.Background(), "select 1") @@ -192,8 +189,8 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) { if rows.Err() != nil { t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") + if errors.Is(err, pgconn.ErrNoBytesSent) { + t.Error("Expected bytes to be sent to server") } } @@ -206,26 +203,27 @@ func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - if _, err := conn.Exec(ctx, "selct $1;", 1); err == nil { + _, err := conn.Exec(ctx, "selct $1;", 1) + if err == nil { t.Fatal("Expected SQL syntax error") } - if conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return false") + if errors.Is(err, pgconn.ErrNoBytesSent) { + t.Error("Expected bytes to be sent to server") } } func TestExecFailureCloseBefore(t *testing.T) { - t.Skip("TODO: LastStmtSent needs to be ported / rewritten for pgconn") t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) closeConn(t, conn) - if _, err := conn.Exec(context.Background(), "select 1"); err == nil { + _, err := conn.Exec(context.Background(), "select 1") + if err == nil { t.Fatal("Expected network error") } - if conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return false") + if !errors.Is(err, pgconn.ErrNoBytesSent) { + t.Error("Expected no bytes to be sent to server") } } @@ -261,20 +259,6 @@ func TestExecExtendedProtocol(t *testing.T) { ensureConnValid(t, conn) } -func TestExecExFailureCloseBefore(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - closeConn(t, conn) - - if _, err := conn.Exec(context.Background(), "select 1", nil); err == nil { - t.Fatal("Expected network error") - } - if conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return false") - } -} - func TestPrepare(t *testing.T) { t.Parallel() @@ -478,7 +462,7 @@ func TestCatchSimultaneousConnectionQueries(t *testing.T) { defer rows1.Close() _, err = conn.Query(context.Background(), "select generate_series(1,$1)", 10) - if err != pgconn.ErrConnBusy { + if !errors.Is(err, pgconn.ErrConnBusy) { t.Fatalf("conn.Query should have failed with pgconn.ErrConnBusy, but it was %v", err) } } @@ -496,7 +480,7 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { defer rows.Close() _, err = conn.Exec(context.Background(), "create temporary table foo(spice timestamp[])") - if err != pgconn.ErrConnBusy { + if !errors.Is(err, pgconn.ErrConnBusy) { t.Fatalf("conn.Exec should have failed with pgconn.ErrConnBusy, but it was %v", err) } } diff --git a/go.mod b/go.mod index 4a50dea2..9d0f0632 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,10 @@ go 1.12 require ( github.com/cockroachdb/apd v1.1.0 - github.com/jackc/pgconn v0.0.0-20190420161109-39e6ff5766bd + github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3 github.com/jackc/pgio v1.0.0 github.com/jackc/pgproto3 v1.1.0 - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b github.com/kr/pretty v0.1.0 // indirect github.com/lib/pq v1.0.0 @@ -20,5 +20,6 @@ require ( go.uber.org/atomic v1.3.2 // indirect go.uber.org/multierr v1.1.0 // indirect go.uber.org/zap v1.9.1 + golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index 913ccbdf..d87851c7 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/jackc/pgconn v0.0.0-20190419211655-3710e52a9a12 h1:PzGjcOqGl6npHTDt8y github.com/jackc/pgconn v0.0.0-20190419211655-3710e52a9a12/go.mod h1:UsnoyBN75lNxOeZXUT70J9xAvZffv2fxrxCrIPIH/Rk= github.com/jackc/pgconn v0.0.0-20190420161109-39e6ff5766bd h1:eSKDWtHcm6H/vELPrs6fh7bch3wBc2vUvqVnHw17+5c= github.com/jackc/pgconn v0.0.0-20190420161109-39e6ff5766bd/go.mod h1:UsnoyBN75lNxOeZXUT70J9xAvZffv2fxrxCrIPIH/Rk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3 h1:ZFYpB74Kq8xE9gmfxCmXD6QxZ27ja+j3HwGFc+YurhQ= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -26,6 +28,8 @@ github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf h1:wI8d/uq9/RfZOe6bKOpC4Skd4VgkTIGZqxmHu6IQGb8= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b h1:cIcUpcEP55F/QuZWEtXyqHoWk+IV4TBiLjtBkeq/Q1c= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= @@ -70,5 +74,7 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqY golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/query.go b/query.go index 5f4c9420..7cf654ed 100644 --- a/query.go +++ b/query.go @@ -286,7 +286,6 @@ type QueryResultFormats []int16 // Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is // allowed to ignore the error returned from Query and handle it in Rows. func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { - c.lastStmtSent = false // rows = c.getRows(sql, args) var resultFormats QueryResultFormats @@ -369,7 +368,6 @@ optionLoop: } } - c.lastStmtSent = true rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats) return rows, rows.err diff --git a/query_test.go b/query_test.go index 0e6a6070..c366c04e 100644 --- a/query_test.go +++ b/query_test.go @@ -18,6 +18,7 @@ import ( satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" + errors "golang.org/x/xerrors" ) func TestConnQueryScan(t *testing.T) { @@ -285,8 +286,8 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) { if err != nil { t.Fatalf("conn.Query failed: %v", err) } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") + if errors.Is(err, pgconn.ErrNoBytesSent) { + t.Error("Expected bytes to be sent to server") } rows.Close() @@ -436,8 +437,8 @@ func TestQueryEncodeError(t *testing.T) { if err != nil { t.Errorf("conn.Query failure: %v", err) } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") + if errors.Is(err, pgconn.ErrNoBytesSent) { + t.Error("Expected bytes to be sent to server") } defer rows.Close() @@ -1158,9 +1159,6 @@ func TestQueryContextSuccess(t *testing.T) { if err != nil { t.Fatal(err) } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } var result, rowCount int for rows.Next() { @@ -1239,9 +1237,6 @@ func TestQueryRowContextSuccess(t *testing.T) { if result != 42 { t.Fatalf("Expected result 42, got %d", result) } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } ensureConnValid(t, conn) } @@ -1270,10 +1265,11 @@ func TestQueryCloseBefore(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) closeConn(t, conn) - if _, err := conn.Query(context.Background(), "select 1"); err == nil { + _, err := conn.Query(context.Background(), "select 1") + if err == nil { t.Fatal("Expected network error") } - if conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return false") + if !errors.Is(err, pgconn.ErrNoBytesSent) { + t.Error("Expected bytes to be sent to server") } } diff --git a/stdlib/sql.go b/stdlib/sql.go index 1bfbfcfc..93507cbd 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -80,8 +80,9 @@ import ( "sync" "time" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" + "github.com/jackc/pgconn" "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" ) @@ -226,8 +227,9 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam commandTag, err := c.conn.Exec(ctx, query, args...) // if we got a network error before we had a chance to send the query, retry - if err != nil && !c.conn.LastStmtSent() { - if _, is := err.(net.Error); is { + if err != nil { + var netErr net.Error + if is := errors.As(err, &netErr); is && errors.Is(err, pgconn.ErrNoBytesSent) { return nil, driver.ErrBadConn } }