diff --git a/stdlib/sql.go b/stdlib/sql.go index 41398879..f4c8fb8c 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -75,7 +75,6 @@ import ( "fmt" "io" "math" - "net" "reflect" "strings" "sync" @@ -227,8 +226,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam commandTag, err := c.conn.Exec(ctx, query, args...) // if we got a network error before we had a chance to send the query, retry if err != nil { - var netErr net.Error - if is := errors.As(err, &netErr); is && errors.Is(err, pgconn.ErrNoBytesSent) { + if pgconn.SafeToRetry(err) { return nil, driver.ErrBadConn } } @@ -245,7 +243,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na rows, err := c.conn.Query(ctx, query, args...) if err != nil { - if errors.Is(err, pgconn.ErrNoBytesSent) { + if pgconn.SafeToRetry(err) { return nil, driver.ErrBadConn } return nil, err diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 5e5039c3..f1ffbd01 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -895,8 +895,8 @@ func TestStmtExecContextCancel(t *testing.T) { defer cancel() _, err = stmt.ExecContext(ctx, 42) - if err != context.DeadlineExceeded { - t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + if !pgconn.Timeout(err) { + t.Errorf("expected timeout error, got %v", err) } ensureConnValid(t, db)