Replace lastStmtSent with pgconn support

pull/483/head
Jack Christensen 2019-04-20 17:12:20 -05:00
parent 35a0f64876
commit b7e56b003a
7 changed files with 47 additions and 78 deletions

18
conn.go
View File

@ -45,8 +45,6 @@ type Conn struct {
causeOfDeath error causeOfDeath error
lastStmtSent bool
doneChan chan struct{} doneChan chan struct{}
closedChan chan error closedChan chan error
@ -392,17 +390,6 @@ func connInfoFromRows(rows Rows, err error) (map[string]pgtype.OID, error) {
return nameOIDs, err return nameOIDs, err
} }
// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to
// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets
// the value to false initially until the statement has been sent. This does
// NOT mean that the statement was successful or even received, it just means
// that a write was attempted and therefore it could have been executed. Calls
// to prepare a statement are ignored, only when the prepared statement is
// attempted to be executed will this return true.
func (c *Conn) LastStmtSent() bool {
return c.lastStmtSent
}
// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
// PostgreSQL connection than pgx exposes. // PostgreSQL connection than pgx exposes.
// //
@ -413,8 +400,6 @@ func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn }
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc. // positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
c.lastStmtSent = false
startTime := time.Now() startTime := time.Now()
commandTag, err := c.exec(ctx, sql, arguments...) commandTag, err := c.exec(ctx, sql, arguments...)
@ -462,13 +447,11 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
} }
} }
c.lastStmtSent = true
result := c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats).Read() result := c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats).Read()
return result.CommandTag, result.Err return result.CommandTag, result.Err
} }
if len(arguments) == 0 { if len(arguments) == 0 {
c.lastStmtSent = true
results, err := c.pgConn.Exec(ctx, sql).ReadAll() results, err := c.pgConn.Exec(ctx, sql).ReadAll()
if err != nil { if err != nil {
return nil, err return nil, err
@ -529,7 +512,6 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
} }
} }
c.lastStmtSent = true
result := c.pgConn.ExecPrepared(ctx, psd.Name, paramValues, paramFormats, resultFormats).Read() result := c.pgConn.ExecPrepared(ctx, psd.Name, paramValues, paramFormats, resultFormats).Read()
return result.CommandTag, result.Err return result.CommandTag, result.Err
} }

View File

@ -12,6 +12,7 @@ import (
"github.com/jackc/pgx" "github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
errors "golang.org/x/xerrors"
) )
func TestCrateDBConnect(t *testing.T) { func TestCrateDBConnect(t *testing.T) {
@ -122,18 +123,12 @@ func TestExecFailure(t *testing.T) {
if _, err := conn.Exec(context.Background(), "selct;"); err == nil { if _, err := conn.Exec(context.Background(), "selct;"); err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
rows, _ := conn.Query(context.Background(), "select 1") rows, _ := conn.Query(context.Background(), "select 1")
rows.Close() rows.Close()
if rows.Err() != nil { if rows.Err() != nil {
t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err())
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
func TestExecFailureWithArguments(t *testing.T) { func TestExecFailureWithArguments(t *testing.T) {
@ -142,11 +137,12 @@ func TestExecFailureWithArguments(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
if _, err := conn.Exec(context.Background(), "selct $1;", 1); err == nil { _, err := conn.Exec(context.Background(), "selct $1;", 1)
if err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
if conn.LastStmtSent() { if errors.Is(err, pgconn.ErrNoBytesSent) {
t.Error("Expected LastStmtSent to return false") t.Error("Expected bytes to be sent to server")
} }
} }
@ -164,10 +160,10 @@ func TestExecContextWithoutCancelation(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if string(commandTag) != "CREATE TABLE" { if string(commandTag) != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from Exec: %v", commandTag)
} }
if !conn.LastStmtSent() { if errors.Is(err, pgconn.ErrNoBytesSent) {
t.Error("Expected LastStmtSent to return true") t.Error("Expected bytes to be sent to server")
} }
} }
@ -180,11 +176,12 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
if _, err := conn.Exec(ctx, "selct;"); err == nil { _, err := conn.Exec(ctx, "selct;")
if err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
if !conn.LastStmtSent() { if errors.Is(err, pgconn.ErrNoBytesSent) {
t.Error("Expected LastStmtSent to return true") t.Error("Expected bytes to be sent to server")
} }
rows, _ := conn.Query(context.Background(), "select 1") rows, _ := conn.Query(context.Background(), "select 1")
@ -192,8 +189,8 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
if rows.Err() != nil { if rows.Err() != nil {
t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err())
} }
if !conn.LastStmtSent() { if errors.Is(err, pgconn.ErrNoBytesSent) {
t.Error("Expected LastStmtSent to return true") t.Error("Expected bytes to be sent to server")
} }
} }
@ -206,26 +203,27 @@ func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
if _, err := conn.Exec(ctx, "selct $1;", 1); err == nil { _, err := conn.Exec(ctx, "selct $1;", 1)
if err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
if conn.LastStmtSent() { if errors.Is(err, pgconn.ErrNoBytesSent) {
t.Error("Expected LastStmtSent to return false") t.Error("Expected bytes to be sent to server")
} }
} }
func TestExecFailureCloseBefore(t *testing.T) { func TestExecFailureCloseBefore(t *testing.T) {
t.Skip("TODO: LastStmtSent needs to be ported / rewritten for pgconn")
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
closeConn(t, conn) closeConn(t, conn)
if _, err := conn.Exec(context.Background(), "select 1"); err == nil { _, err := conn.Exec(context.Background(), "select 1")
if err == nil {
t.Fatal("Expected network error") t.Fatal("Expected network error")
} }
if conn.LastStmtSent() { if !errors.Is(err, pgconn.ErrNoBytesSent) {
t.Error("Expected LastStmtSent to return false") t.Error("Expected no bytes to be sent to server")
} }
} }
@ -261,20 +259,6 @@ func TestExecExtendedProtocol(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestExecExFailureCloseBefore(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
closeConn(t, conn)
if _, err := conn.Exec(context.Background(), "select 1", nil); err == nil {
t.Fatal("Expected network error")
}
if conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return false")
}
}
func TestPrepare(t *testing.T) { func TestPrepare(t *testing.T) {
t.Parallel() t.Parallel()
@ -478,7 +462,7 @@ func TestCatchSimultaneousConnectionQueries(t *testing.T) {
defer rows1.Close() defer rows1.Close()
_, err = conn.Query(context.Background(), "select generate_series(1,$1)", 10) _, err = conn.Query(context.Background(), "select generate_series(1,$1)", 10)
if err != pgconn.ErrConnBusy { if !errors.Is(err, pgconn.ErrConnBusy) {
t.Fatalf("conn.Query should have failed with pgconn.ErrConnBusy, but it was %v", err) t.Fatalf("conn.Query should have failed with pgconn.ErrConnBusy, but it was %v", err)
} }
} }
@ -496,7 +480,7 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
defer rows.Close() defer rows.Close()
_, err = conn.Exec(context.Background(), "create temporary table foo(spice timestamp[])") _, err = conn.Exec(context.Background(), "create temporary table foo(spice timestamp[])")
if err != pgconn.ErrConnBusy { if !errors.Is(err, pgconn.ErrConnBusy) {
t.Fatalf("conn.Exec should have failed with pgconn.ErrConnBusy, but it was %v", err) t.Fatalf("conn.Exec should have failed with pgconn.ErrConnBusy, but it was %v", err)
} }
} }

5
go.mod
View File

@ -4,10 +4,10 @@ go 1.12
require ( require (
github.com/cockroachdb/apd v1.1.0 github.com/cockroachdb/apd v1.1.0
github.com/jackc/pgconn v0.0.0-20190420161109-39e6ff5766bd github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3
github.com/jackc/pgio v1.0.0 github.com/jackc/pgio v1.0.0
github.com/jackc/pgproto3 v1.1.0 github.com/jackc/pgproto3 v1.1.0
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b
github.com/kr/pretty v0.1.0 // indirect github.com/kr/pretty v0.1.0 // indirect
github.com/lib/pq v1.0.0 github.com/lib/pq v1.0.0
@ -20,5 +20,6 @@ require (
go.uber.org/atomic v1.3.2 // indirect go.uber.org/atomic v1.3.2 // indirect
go.uber.org/multierr v1.1.0 // indirect go.uber.org/multierr v1.1.0 // indirect
go.uber.org/zap v1.9.1 go.uber.org/zap v1.9.1
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
) )

6
go.sum
View File

@ -16,6 +16,8 @@ github.com/jackc/pgconn v0.0.0-20190419211655-3710e52a9a12 h1:PzGjcOqGl6npHTDt8y
github.com/jackc/pgconn v0.0.0-20190419211655-3710e52a9a12/go.mod h1:UsnoyBN75lNxOeZXUT70J9xAvZffv2fxrxCrIPIH/Rk= github.com/jackc/pgconn v0.0.0-20190419211655-3710e52a9a12/go.mod h1:UsnoyBN75lNxOeZXUT70J9xAvZffv2fxrxCrIPIH/Rk=
github.com/jackc/pgconn v0.0.0-20190420161109-39e6ff5766bd h1:eSKDWtHcm6H/vELPrs6fh7bch3wBc2vUvqVnHw17+5c= github.com/jackc/pgconn v0.0.0-20190420161109-39e6ff5766bd h1:eSKDWtHcm6H/vELPrs6fh7bch3wBc2vUvqVnHw17+5c=
github.com/jackc/pgconn v0.0.0-20190420161109-39e6ff5766bd/go.mod h1:UsnoyBN75lNxOeZXUT70J9xAvZffv2fxrxCrIPIH/Rk= github.com/jackc/pgconn v0.0.0-20190420161109-39e6ff5766bd/go.mod h1:UsnoyBN75lNxOeZXUT70J9xAvZffv2fxrxCrIPIH/Rk=
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3 h1:ZFYpB74Kq8xE9gmfxCmXD6QxZ27ja+j3HwGFc+YurhQ=
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@ -26,6 +28,8 @@ github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf h1:wI8d/uq9/RfZOe6bKOpC4Skd4VgkTIGZqxmHu6IQGb8= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf h1:wI8d/uq9/RfZOe6bKOpC4Skd4VgkTIGZqxmHu6IQGb8=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b h1:cIcUpcEP55F/QuZWEtXyqHoWk+IV4TBiLjtBkeq/Q1c= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b h1:cIcUpcEP55F/QuZWEtXyqHoWk+IV4TBiLjtBkeq/Q1c=
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
@ -70,5 +74,7 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqY
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -286,7 +286,6 @@ type QueryResultFormats []int16
// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is // Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is
// allowed to ignore the error returned from Query and handle it in Rows. // allowed to ignore the error returned from Query and handle it in Rows.
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
c.lastStmtSent = false
// rows = c.getRows(sql, args) // rows = c.getRows(sql, args)
var resultFormats QueryResultFormats var resultFormats QueryResultFormats
@ -369,7 +368,6 @@ optionLoop:
} }
} }
c.lastStmtSent = true
rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats) rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats)
return rows, rows.err return rows, rows.err

View File

@ -18,6 +18,7 @@ import (
satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" satori "github.com/jackc/pgx/pgtype/ext/satori-uuid"
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
errors "golang.org/x/xerrors"
) )
func TestConnQueryScan(t *testing.T) { func TestConnQueryScan(t *testing.T) {
@ -285,8 +286,8 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: %v", err) t.Fatalf("conn.Query failed: %v", err)
} }
if !conn.LastStmtSent() { if errors.Is(err, pgconn.ErrNoBytesSent) {
t.Error("Expected LastStmtSent to return true") t.Error("Expected bytes to be sent to server")
} }
rows.Close() rows.Close()
@ -436,8 +437,8 @@ func TestQueryEncodeError(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("conn.Query failure: %v", err) t.Errorf("conn.Query failure: %v", err)
} }
if !conn.LastStmtSent() { if errors.Is(err, pgconn.ErrNoBytesSent) {
t.Error("Expected LastStmtSent to return true") t.Error("Expected bytes to be sent to server")
} }
defer rows.Close() defer rows.Close()
@ -1158,9 +1159,6 @@ func TestQueryContextSuccess(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
var result, rowCount int var result, rowCount int
for rows.Next() { for rows.Next() {
@ -1239,9 +1237,6 @@ func TestQueryRowContextSuccess(t *testing.T) {
if result != 42 { if result != 42 {
t.Fatalf("Expected result 42, got %d", result) t.Fatalf("Expected result 42, got %d", result)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
@ -1270,10 +1265,11 @@ func TestQueryCloseBefore(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
closeConn(t, conn) closeConn(t, conn)
if _, err := conn.Query(context.Background(), "select 1"); err == nil { _, err := conn.Query(context.Background(), "select 1")
if err == nil {
t.Fatal("Expected network error") t.Fatal("Expected network error")
} }
if conn.LastStmtSent() { if !errors.Is(err, pgconn.ErrNoBytesSent) {
t.Error("Expected LastStmtSent to return false") t.Error("Expected bytes to be sent to server")
} }
} }

View File

@ -80,8 +80,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/pkg/errors" errors "golang.org/x/xerrors"
"github.com/jackc/pgconn"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@ -226,8 +227,9 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
commandTag, err := c.conn.Exec(ctx, query, args...) commandTag, err := c.conn.Exec(ctx, query, args...)
// if we got a network error before we had a chance to send the query, retry // if we got a network error before we had a chance to send the query, retry
if err != nil && !c.conn.LastStmtSent() { if err != nil {
if _, is := err.(net.Error); is { var netErr net.Error
if is := errors.As(err, &netErr); is && errors.Is(err, pgconn.ErrNoBytesSent) {
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
} }