Use errors instead of golang.org/x/xerrors

pull/975/head
Jack Christensen 2021-03-25 09:55:12 -04:00
parent 80147fd7cc
commit a49f4bb135
15 changed files with 47 additions and 56 deletions

View File

@ -2,9 +2,9 @@ package pgx
import ( import (
"context" "context"
"errors"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
errors "golang.org/x/xerrors"
) )
type batchItem struct { type batchItem struct {

16
conn.go
View File

@ -2,12 +2,12 @@ package pgx
import ( import (
"context" "context"
"errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
errors "golang.org/x/xerrors"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgconn/stmtcache"
"github.com/jackc/pgproto3/v2" "github.com/jackc/pgproto3/v2"
@ -140,7 +140,7 @@ func ParseConfig(connString string) (*ConnConfig, error) {
delete(config.RuntimeParams, "statement_cache_capacity") delete(config.RuntimeParams, "statement_cache_capacity")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, errors.Errorf("cannot parse statement_cache_capacity: %w", err) return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err)
} }
statementCacheCapacity = int(n) statementCacheCapacity = int(n)
} }
@ -153,7 +153,7 @@ func ParseConfig(connString string) (*ConnConfig, error) {
case "describe": case "describe":
statementCacheMode = stmtcache.ModeDescribe statementCacheMode = stmtcache.ModeDescribe
default: default:
return nil, errors.Errorf("invalid statement_cache_mod: %s", s) return nil, fmt.Errorf("invalid statement_cache_mod: %s", s)
} }
} }
@ -169,7 +169,7 @@ func ParseConfig(connString string) (*ConnConfig, error) {
if b, err := strconv.ParseBool(s); err == nil { if b, err := strconv.ParseBool(s); err == nil {
preferSimpleProtocol = b preferSimpleProtocol = b
} else { } else {
return nil, errors.Errorf("invalid prefer_simple_protocol: %v", err) return nil, fmt.Errorf("invalid prefer_simple_protocol: %v", err)
} }
} }
@ -486,7 +486,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i
func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error { func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error {
if len(sd.ParamOIDs) != len(arguments) { if len(sd.ParamOIDs) != len(arguments) {
return errors.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments)) return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments))
} }
c.eqb.Reset() c.eqb.Reset()
@ -629,7 +629,7 @@ optionLoop:
} }
} }
if len(sd.ParamOIDs) != len(args) { if len(sd.ParamOIDs) != len(args) {
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
return rows, rows.err return rows, rows.err
} }
@ -791,7 +791,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
} }
if len(sd.ParamOIDs) != len(bi.arguments) { if len(sd.ParamOIDs) != len(bi.arguments) {
return &batchResults{ctx: ctx, conn: c, err: errors.Errorf("mismatched param and argument count")} return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
} }
args, err := convertDriverValuers(bi.arguments) args, err := convertDriverValuers(bi.arguments)

View File

@ -9,7 +9,6 @@ import (
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgio" "github.com/jackc/pgio"
errors "golang.org/x/xerrors"
) )
// CopyFromRows returns a CopyFromSource interface over the provided rows slice // CopyFromRows returns a CopyFromSource interface over the provided rows slice
@ -174,7 +173,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
return false, nil, err return false, nil, err
} }
if len(values) != len(ct.columnNames) { if len(values) != len(ct.columnNames) {
return false, nil, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
} }
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))

View File

@ -2,6 +2,7 @@ package pgx_test
import ( import (
"context" "context"
"fmt"
"os" "os"
"reflect" "reflect"
"testing" "testing"
@ -10,7 +11,6 @@ import (
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
errors "golang.org/x/xerrors"
) )
func TestConnCopyFromSmall(t *testing.T) { func TestConnCopyFromSmall(t *testing.T) {
@ -316,7 +316,7 @@ func (cfs *clientFailSource) Next() bool {
func (cfs *clientFailSource) Values() ([]interface{}, error) { func (cfs *clientFailSource) Values() ([]interface{}, error) {
if cfs.count == 3 { if cfs.count == 3 {
cfs.err = errors.Errorf("client error") cfs.err = fmt.Errorf("client error")
return nil, cfs.err return nil, cfs.err
} }
return []interface{}{make([]byte, 100000)}, nil return []interface{}{make([]byte, 100000)}, nil
@ -559,7 +559,7 @@ func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
} }
func (cfs *clientFinalErrSource) Err() error { func (cfs *clientFinalErrSource) Err() error {
return errors.Errorf("final error") return fmt.Errorf("final error")
} }
func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {

View File

@ -9,7 +9,6 @@ import (
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
errors "golang.org/x/xerrors"
) )
var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`)
@ -21,7 +20,7 @@ type Point struct {
} }
func (dst *Point) Set(src interface{}) error { func (dst *Point) Set(src interface{}) error {
return errors.Errorf("cannot convert %v to Point", src) return fmt.Errorf("cannot convert %v to Point", src)
} }
func (dst *Point) Get() interface{} { func (dst *Point) Get() interface{} {
@ -36,7 +35,7 @@ func (dst *Point) Get() interface{} {
} }
func (src *Point) AssignTo(dst interface{}) error { func (src *Point) AssignTo(dst interface{}) error {
return errors.Errorf("cannot assign %v to %T", src, dst) return fmt.Errorf("cannot assign %v to %T", src, dst)
} }
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
@ -48,16 +47,16 @@ func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
s := string(src) s := string(src)
match := pointRegexp.FindStringSubmatch(s) match := pointRegexp.FindStringSubmatch(s)
if match == nil { if match == nil {
return errors.Errorf("Received invalid point: %v", s) return fmt.Errorf("Received invalid point: %v", s)
} }
x, err := strconv.ParseFloat(match[1], 64) x, err := strconv.ParseFloat(match[1], 64)
if err != nil { if err != nil {
return errors.Errorf("Received invalid point: %v", s) return fmt.Errorf("Received invalid point: %v", s)
} }
y, err := strconv.ParseFloat(match[2], 64) y, err := strconv.ParseFloat(match[2], 64)
if err != nil { if err != nil {
return errors.Errorf("Received invalid point: %v", s) return fmt.Errorf("Received invalid point: %v", s)
} }
*dst = Point{X: x, Y: y, Status: pgtype.Present} *dst = Point{X: x, Y: y, Status: pgtype.Present}

3
go.mod
View File

@ -1,6 +1,6 @@
module github.com/jackc/pgx/v4 module github.com/jackc/pgx/v4
go 1.12 go 1.13
require ( require (
github.com/Masterminds/semver/v3 v3.1.1 // indirect github.com/Masterminds/semver/v3 v3.1.1 // indirect
@ -17,6 +17,5 @@ require (
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.5.1 github.com/stretchr/testify v1.5.1
go.uber.org/zap v1.13.0 go.uber.org/zap v1.13.0
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec
) )

View File

@ -3,12 +3,11 @@ package sanitize
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
errors "golang.org/x/xerrors"
) )
// Part is either a string or an int. A string is raw SQL. An int is a // Part is either a string or an int. A string is raw SQL. An int is a
@ -31,7 +30,7 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) {
case int: case int:
argIdx := part - 1 argIdx := part - 1
if argIdx >= len(args) { if argIdx >= len(args) {
return "", errors.Errorf("insufficient arguments") return "", fmt.Errorf("insufficient arguments")
} }
arg := args[argIdx] arg := args[argIdx]
switch arg := arg.(type) { switch arg := arg.(type) {
@ -50,18 +49,18 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) {
case time.Time: case time.Time:
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
default: default:
return "", errors.Errorf("invalid arg type: %T", arg) return "", fmt.Errorf("invalid arg type: %T", arg)
} }
argUse[argIdx] = true argUse[argIdx] = true
default: default:
return "", errors.Errorf("invalid Part type: %T", part) return "", fmt.Errorf("invalid Part type: %T", part)
} }
buf.WriteString(str) buf.WriteString(str)
} }
for i, used := range argUse { for i, used := range argUse {
if !used { if !used {
return "", errors.Errorf("unused argument: %d", i) return "", fmt.Errorf("unused argument: %d", i)
} }
} }
return buf.String(), nil return buf.String(), nil

View File

@ -2,9 +2,8 @@ package pgx
import ( import (
"context" "context"
"errors"
"io" "io"
errors "golang.org/x/xerrors"
) )
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it // LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it

View File

@ -3,9 +3,8 @@ package pgx
import ( import (
"context" "context"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
errors "golang.org/x/xerrors"
) )
// The values for log levels are chosen such that the zero value means that no // The values for log levels are chosen such that the zero value means that no

View File

@ -2,6 +2,7 @@ package pgxpool
import ( import (
"context" "context"
"fmt"
"runtime" "runtime"
"strconv" "strconv"
"sync" "sync"
@ -10,7 +11,6 @@ import (
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"github.com/jackc/puddle" "github.com/jackc/puddle"
errors "golang.org/x/xerrors"
) )
var defaultMaxConns = int32(4) var defaultMaxConns = int32(4)
@ -266,10 +266,10 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_max_conns") delete(connConfig.Config.RuntimeParams, "pool_max_conns")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, errors.Errorf("cannot parse pool_max_conns: %w", err) return nil, fmt.Errorf("cannot parse pool_max_conns: %w", err)
} }
if n < 1 { if n < 1 {
return nil, errors.Errorf("pool_max_conns too small: %d", n) return nil, fmt.Errorf("pool_max_conns too small: %d", n)
} }
config.MaxConns = int32(n) config.MaxConns = int32(n)
} else { } else {
@ -283,7 +283,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_min_conns") delete(connConfig.Config.RuntimeParams, "pool_min_conns")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, errors.Errorf("cannot parse pool_min_conns: %w", err) return nil, fmt.Errorf("cannot parse pool_min_conns: %w", err)
} }
config.MinConns = int32(n) config.MinConns = int32(n)
} else { } else {
@ -294,7 +294,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime") delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
d, err := time.ParseDuration(s) d, err := time.ParseDuration(s)
if err != nil { if err != nil {
return nil, errors.Errorf("invalid pool_max_conn_lifetime: %w", err) return nil, fmt.Errorf("invalid pool_max_conn_lifetime: %w", err)
} }
config.MaxConnLifetime = d config.MaxConnLifetime = d
} else { } else {
@ -305,7 +305,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time") delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time")
d, err := time.ParseDuration(s) d, err := time.ParseDuration(s)
if err != nil { if err != nil {
return nil, errors.Errorf("invalid pool_max_conn_idle_time: %w", err) return nil, fmt.Errorf("invalid pool_max_conn_idle_time: %w", err)
} }
config.MaxConnIdleTime = d config.MaxConnIdleTime = d
} else { } else {
@ -316,7 +316,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_health_check_period") delete(connConfig.Config.RuntimeParams, "pool_health_check_period")
d, err := time.ParseDuration(s) d, err := time.ParseDuration(s)
if err != nil { if err != nil {
return nil, errors.Errorf("invalid pool_health_check_period: %w", err) return nil, fmt.Errorf("invalid pool_health_check_period: %w", err)
} }
config.HealthCheckPeriod = d config.HealthCheckPeriod = d
} else { } else {

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -22,7 +23,6 @@ import (
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
errors "golang.org/x/xerrors"
) )
func TestConnQueryScan(t *testing.T) { func TestConnQueryScan(t *testing.T) {

13
rows.go
View File

@ -2,11 +2,10 @@ package pgx
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
errors "golang.org/x/xerrors"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2" "github.com/jackc/pgproto3/v2"
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
@ -197,12 +196,12 @@ func (rows *connRows) Scan(dest ...interface{}) error {
values := rows.values values := rows.values
if len(fieldDescriptions) != len(values) { if len(fieldDescriptions) != len(values) {
err := errors.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
rows.fatal(err) rows.fatal(err)
return err return err
} }
if len(fieldDescriptions) != len(dest) { if len(fieldDescriptions) != len(dest) {
err := errors.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
rows.fatal(err) rows.fatal(err)
return err return err
} }
@ -308,7 +307,7 @@ func (rows *connRows) RawValues() [][]byte {
type ScanArgError struct { type ScanArgError struct {
ColumnIndex int ColumnIndex int
Err error Err error
} }
func (e ScanArgError) Error() string { func (e ScanArgError) Error() string {
@ -327,10 +326,10 @@ func (e ScanArgError) Unwrap() error {
// dest - the destination that values will be decoded into // dest - the destination that values will be decoded into
func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error { func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error {
if len(fieldDescriptions) != len(values) { if len(fieldDescriptions) != len(values) {
return errors.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
} }
if len(fieldDescriptions) != len(dest) { if len(fieldDescriptions) != len(dest) {
return errors.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
} }
for i, d := range dest { for i, d := range dest {

View File

@ -52,6 +52,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
@ -61,8 +62,6 @@ import (
"sync" "sync"
"time" "time"
errors "golang.org/x/xerrors"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
@ -308,7 +307,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
case sql.LevelSerializable: case sql.LevelSerializable:
pgxOpts.IsoLevel = pgx.Serializable pgxOpts.IsoLevel = pgx.Serializable
default: default:
return nil, errors.Errorf("unsupported isolation: %v", opts.Isolation) return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation)
} }
if opts.ReadOnly { if opts.ReadOnly {
@ -779,7 +778,7 @@ func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
fakeTxMutex.Unlock() fakeTxMutex.Unlock()
} else { } else {
fakeTxMutex.Unlock() fakeTxMutex.Unlock()
return errors.Errorf("can't release conn that is not acquired") return fmt.Errorf("can't release conn that is not acquired")
} }
return tx.Rollback() return tx.Rollback()

4
tx.go
View File

@ -3,11 +3,11 @@ package pgx
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
errors "golang.org/x/xerrors"
) )
type TxIsoLevel string type TxIsoLevel string
@ -246,7 +246,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error {
tx.closed = true tx.closed = true
if err != nil { if err != nil {
// A rollback failure leaves the connection in an undefined state // A rollback failure leaves the connection in an undefined state
tx.conn.die(errors.Errorf("rollback failed: %w", err)) tx.conn.die(fmt.Errorf("rollback failed: %w", err))
return err return err
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/jackc/pgio" "github.com/jackc/pgio"
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
errors "golang.org/x/xerrors"
) )
// PostgreSQL format codes // PostgreSQL format codes
@ -103,12 +102,12 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e
return int64(arg), nil return int64(arg), nil
case uint64: case uint64:
if arg > math.MaxInt64 { if arg > math.MaxInt64 {
return nil, errors.Errorf("arg too big for int64: %v", arg) return nil, fmt.Errorf("arg too big for int64: %v", arg)
} }
return int64(arg), nil return int64(arg), nil
case uint: case uint:
if uint64(arg) > math.MaxInt64 { if uint64(arg) > math.MaxInt64 {
return nil, errors.Errorf("arg too big for int64: %v", arg) return nil, fmt.Errorf("arg too big for int64: %v", arg)
} }
return int64(arg), nil return int64(arg), nil
} }