From eec82c9433368eaf9855ad6ee0a4f36418c6318a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 9 Jan 2022 00:35:49 -0600 Subject: [PATCH] Replace CID, OID, OIDValue, and XID with Uint32 --- pgtype/cid.go | 61 -------- pgtype/cid_test.go | 102 ------------- pgtype/oid.go | 81 ----------- pgtype/oid_value.go | 55 ------- pgtype/oid_value_test.go | 95 ------------ pgtype/pgtype.go | 6 +- pgtype/pguint32.go | 148 ------------------- pgtype/uint32.go | 303 +++++++++++++++++++++++++++++++++++++++ pgtype/uint32_test.go | 19 +++ pgtype/xid.go | 64 --------- pgtype/xid_test.go | 102 ------------- stdlib/sql.go | 24 +--- 12 files changed, 327 insertions(+), 733 deletions(-) delete mode 100644 pgtype/cid.go delete mode 100644 pgtype/cid_test.go delete mode 100644 pgtype/oid.go delete mode 100644 pgtype/oid_value.go delete mode 100644 pgtype/oid_value_test.go delete mode 100644 pgtype/pguint32.go create mode 100644 pgtype/uint32.go create mode 100644 pgtype/uint32_test.go delete mode 100644 pgtype/xid.go delete mode 100644 pgtype/xid_test.go diff --git a/pgtype/cid.go b/pgtype/cid.go deleted file mode 100644 index b944748c..00000000 --- a/pgtype/cid.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// CID is PostgreSQL's Command Identifier type. -// -// When one does -// -// select cmin, cmax, * from some_table; -// -// it is the data type of the cmin and cmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/c.h as CommandId -// in the PostgreSQL sources. -type CID pguint32 - -// Set converts from src to dst. Note that as CID is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *CID) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst CID) Get() interface{} { - return (pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as CID is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *CID) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} - -func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeText(ci, src) -} - -func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeBinary(ci, src) -} - -func (src CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeText(ci, buf) -} - -func (src CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *CID) Scan(src interface{}) error { - return (*pguint32)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src CID) Value() (driver.Value, error) { - return (pguint32)(src).Value() -} diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go deleted file mode 100644 index 3d3ad2a5..00000000 --- a/pgtype/cid_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestCIDTranscode(t *testing.T) { - pgTypeName := "cid" - values := []interface{}{ - &pgtype.CID{Uint: 42, Valid: true}, - &pgtype.CID{}, - } - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) -} - -func TestCIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.CID - }{ - {source: uint32(1), result: pgtype.CID{Uint: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.CID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestCIDAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.CID - dst interface{} - expected interface{} - }{ - {src: pgtype.CID{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.CID{}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.CID - dst interface{} - expected interface{} - }{ - {src: pgtype.CID{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.CID - dst interface{} - }{ - {src: pgtype.CID{}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/oid.go b/pgtype/oid.go deleted file mode 100644 index 31677e89..00000000 --- a/pgtype/oid.go +++ /dev/null @@ -1,81 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "strconv" - - "github.com/jackc/pgio" -) - -// OID (Object Identifier Type) is, according to -// https://www.postgresql.org/docs/current/static/datatype-oid.html, used -// internally by PostgreSQL as a primary key for various system tables. It is -// currently implemented as an unsigned four-byte integer. Its definition can be -// found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is -// so frequently required to be in a NOT NULL condition OID cannot be NULL. To -// allow for NULL OIDs use OIDValue. -type OID uint32 - -func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - return fmt.Errorf("cannot decode nil into OID") - } - - n, err := strconv.ParseUint(string(src), 10, 32) - if err != nil { - return err - } - - *dst = OID(n) - return nil -} - -func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - return fmt.Errorf("cannot decode nil into OID") - } - - if len(src) != 4 { - return fmt.Errorf("invalid length: %v", len(src)) - } - - n := binary.BigEndian.Uint32(src) - *dst = OID(n) - return nil -} - -func (src OID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return append(buf, strconv.FormatUint(uint64(src), 10)...), nil -} - -func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return pgio.AppendUint32(buf, uint32(src)), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *OID) Scan(src interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan NULL into %T", src) - } - - switch src := src.(type) { - case int64: - *dst = OID(src) - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src OID) Value() (driver.Value, error) { - return int64(src), nil -} diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go deleted file mode 100644 index 5dc9136c..00000000 --- a/pgtype/oid_value.go +++ /dev/null @@ -1,55 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// OIDValue (Object Identifier Type) is, according to -// https://www.postgresql.org/docs/current/static/datatype-OIDValue.html, used -// internally by PostgreSQL as a primary key for various system tables. It is -// currently implemented as an unsigned four-byte integer. Its definition can be -// found in src/include/postgres_ext.h in the PostgreSQL sources. -type OIDValue pguint32 - -// Set converts from src to dst. Note that as OIDValue is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *OIDValue) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst OIDValue) Get() interface{} { - return (pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as OIDValue is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *OIDValue) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} - -func (dst *OIDValue) DecodeText(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeText(ci, src) -} - -func (dst *OIDValue) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeBinary(ci, src) -} - -func (src OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeText(ci, buf) -} - -func (src OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *OIDValue) Scan(src interface{}) error { - return (*pguint32)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src OIDValue) Value() (driver.Value, error) { - return (pguint32)(src).Value() -} diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go deleted file mode 100644 index aecfc149..00000000 --- a/pgtype/oid_value_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestOIDValueTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ - &pgtype.OIDValue{Uint: 42, Valid: true}, - &pgtype.OIDValue{}, - }) -} - -func TestOIDValueSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.OIDValue - }{ - {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.OIDValue - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestOIDValueAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.OIDValue - dst interface{} - expected interface{} - }{ - {src: pgtype.OIDValue{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.OIDValue{}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.OIDValue - dst interface{} - expected interface{} - }{ - {src: pgtype.OIDValue{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.OIDValue - dst interface{} - }{ - {src: pgtype.OIDValue{}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index c9da1322..ef269365 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -286,7 +286,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) - ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) + ci.RegisterDataType(DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) @@ -309,7 +309,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) - ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) + ci.RegisterDataType(DataType{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) @@ -327,7 +327,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) ci.RegisterDataType(DataType{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) + ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { ci.RegisterDefaultPgType(value, name) diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go deleted file mode 100644 index e36ebb1f..00000000 --- a/pgtype/pguint32.go +++ /dev/null @@ -1,148 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "math" - "strconv" - - "github.com/jackc/pgio" -) - -// pguint32 is the core type that is used to implement PostgreSQL types such as -// CID and XID. -type pguint32 struct { - Uint uint32 - Valid bool -} - -// Set converts from src to dst. Note that as pguint32 is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *pguint32) Set(src interface{}) error { - switch value := src.(type) { - case int64: - if value < 0 { - return fmt.Errorf("%d is less than minimum value for pguint32", value) - } - if value > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for pguint32", value) - } - *dst = pguint32{Uint: uint32(value), Valid: true} - case uint32: - *dst = pguint32{Uint: value, Valid: true} - default: - return fmt.Errorf("cannot convert %v to pguint32", value) - } - - return nil -} - -func (dst pguint32) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Uint -} - -// AssignTo assigns from src to dst. Note that as pguint32 is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *pguint32) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *uint32: - if src.Valid { - *v = src.Uint - } else { - return fmt.Errorf("cannot assign %v into %T", src, dst) - } - case **uint32: - if src.Valid { - n := src.Uint - *v = &n - } else { - *v = nil - } - } - - return nil -} - -func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = pguint32{} - return nil - } - - n, err := strconv.ParseUint(string(src), 10, 32) - if err != nil { - return err - } - - *dst = pguint32{Uint: uint32(n), Valid: true} - return nil -} - -func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = pguint32{} - return nil - } - - if len(src) != 4 { - return fmt.Errorf("invalid length: %v", len(src)) - } - - n := binary.BigEndian.Uint32(src) - *dst = pguint32{Uint: n, Valid: true} - return nil -} - -func (src pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil -} - -func (src pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return pgio.AppendUint32(buf, src.Uint), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *pguint32) Scan(src interface{}) error { - if src == nil { - *dst = pguint32{} - return nil - } - - switch src := src.(type) { - case uint32: - *dst = pguint32{Uint: src, Valid: true} - return nil - case int64: - *dst = pguint32{Uint: uint32(src), Valid: true} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src pguint32) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - return int64(src.Uint), nil -} diff --git a/pgtype/uint32.go b/pgtype/uint32.go new file mode 100644 index 00000000..ccf39471 --- /dev/null +++ b/pgtype/uint32.go @@ -0,0 +1,303 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Uint32Scanner interface { + ScanUint32(v Uint32) error +} + +type Uint32Valuer interface { + Uint32Value() (Uint32, error) +} + +// Uint32 is the core type that is used to represent PostgreSQL types such as OID, CID, and XID. +type Uint32 struct { + Uint uint32 + Valid bool +} + +func (n *Uint32) ScanUint32(v Uint32) error { + *n = v + return nil +} + +func (n Uint32) Uint32Value() (Uint32, error) { + return n, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Uint32) Scan(src interface{}) error { + if src == nil { + *dst = Uint32{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + un, err := strconv.ParseUint(src, 10, 32) + if err != nil { + return err + } + n = int64(un) + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < 0 { + return fmt.Errorf("%d is less than the minimum value for Uint32", n) + } + if n > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for Uint32", n) + } + + *dst = Uint32{Uint: uint32(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Uint32) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Uint), nil +} + +type Uint32Codec struct{} + +func (Uint32Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Uint32Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Uint32Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case uint32: + return encodePlanUint32CodecBinaryUint32{} + case Uint32Valuer: + return encodePlanUint32CodecBinaryUint32Valuer{} + case Int64Valuer: + return encodePlanUint32CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case uint32: + return encodePlanUint32CodecTextUint32{} + case Int64Valuer: + return encodePlanUint32CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanUint32CodecBinaryUint32 struct{} + +func (encodePlanUint32CodecBinaryUint32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v := value.(uint32) + return pgio.AppendUint32(buf, v), nil +} + +type encodePlanUint32CodecBinaryUint32Valuer struct{} + +func (encodePlanUint32CodecBinaryUint32Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint32Valuer).Uint32Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return pgio.AppendUint32(buf, v.Uint), nil +} + +type encodePlanUint32CodecBinaryInt64Valuer struct{} + +func (encodePlanUint32CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int) + } + if v.Int > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int) + } + + return pgio.AppendUint32(buf, uint32(v.Int)), nil +} + +type encodePlanUint32CodecTextUint32 struct{} + +func (encodePlanUint32CodecTextUint32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v := value.(uint32) + return append(buf, strconv.FormatUint(uint64(v), 10)...), nil +} + +type encodePlanUint32CodecTextUint32Valuer struct{} + +func (encodePlanUint32CodecTextUint32Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint32Valuer).Uint32Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return append(buf, strconv.FormatUint(uint64(v.Uint), 10)...), nil +} + +type encodePlanUint32CodecTextInt64Valuer struct{} + +func (encodePlanUint32CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int) + } + if v.Int > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int) + } + + return append(buf, strconv.FormatInt(v.Int, 10)...), nil +} + +func (Uint32Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *uint32: + return scanPlanBinaryUint32ToUint32{} + case Uint32Scanner: + return scanPlanBinaryUint32ToUint32Scanner{} + } + case TextFormatCode: + switch target.(type) { + case *uint32: + return scanPlanTextAnyToUint32{} + case Uint32Scanner: + return scanPlanTextAnyToUint32Scanner{} + } + } + + return nil +} + +func (c Uint32Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n uint32 + err := codecScan(c, ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return int64(n), nil +} + +func (c Uint32Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var n uint32 + err := codecScan(c, ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryUint32ToUint32 struct{} + +func (scanPlanBinaryUint32ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + p := (dst).(*uint32) + *p = binary.BigEndian.Uint32(src) + + return nil +} + +type scanPlanBinaryUint32ToUint32Scanner struct{} + +func (scanPlanBinaryUint32ToUint32Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Uint32Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint32(Uint32{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + + return s.ScanUint32(Uint32{Uint: n, Valid: true}) +} + +type scanPlanTextAnyToUint32Scanner struct{} + +func (scanPlanTextAnyToUint32Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Uint32Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint32(Uint32{}) + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + return s.ScanUint32(Uint32{Uint: uint32(n), Valid: true}) +} diff --git a/pgtype/uint32_test.go b/pgtype/uint32_test.go new file mode 100644 index 00000000..8e58605d --- /dev/null +++ b/pgtype/uint32_test.go @@ -0,0 +1,19 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +func TestUint32Codec(t *testing.T) { + testPgxCodec(t, "oid", []PgxTranscodeTestCase{ + { + pgtype.Uint32{Uint: pgtype.TextOID, Valid: true}, + new(pgtype.Uint32), + isExpectedEq(pgtype.Uint32{Uint: pgtype.TextOID, Valid: true}), + }, + {pgtype.Uint32{}, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, + {nil, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, + }) +} diff --git a/pgtype/xid.go b/pgtype/xid.go deleted file mode 100644 index f6d6b22d..00000000 --- a/pgtype/xid.go +++ /dev/null @@ -1,64 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// XID is PostgreSQL's Transaction ID type. -// -// In later versions of PostgreSQL, it is the type used for the backend_xid -// and backend_xmin columns of the pg_stat_activity system view. -// -// Also, when one does -// -// select xmin, xmax, * from some_table; -// -// it is the data type of the xmin and xmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/postgres_ext.h as TransactionId -// in the PostgreSQL sources. -type XID pguint32 - -// Set converts from src to dst. Note that as XID is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *XID) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst XID) Get() interface{} { - return (pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as XID is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *XID) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} - -func (dst *XID) DecodeText(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeText(ci, src) -} - -func (dst *XID) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeBinary(ci, src) -} - -func (src XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeText(ci, buf) -} - -func (src XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *XID) Scan(src interface{}) error { - return (*pguint32)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src XID) Value() (driver.Value, error) { - return (pguint32)(src).Value() -} diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go deleted file mode 100644 index ee11fa41..00000000 --- a/pgtype/xid_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestXIDTranscode(t *testing.T) { - pgTypeName := "xid" - values := []interface{}{ - &pgtype.XID{Uint: 42, Valid: true}, - &pgtype.XID{}, - } - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) -} - -func TestXIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.XID - }{ - {source: uint32(1), result: pgtype.XID{Uint: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.XID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestXIDAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.XID - dst interface{} - expected interface{} - }{ - {src: pgtype.XID{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.XID{}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.XID - dst interface{} - expected interface{} - }{ - {src: pgtype.XID{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.XID - dst interface{} - }{ - {src: pgtype.XID{}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/stdlib/sql.go b/stdlib/sql.go index 39b7524a..cbb8544e 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -615,8 +615,8 @@ func (r *Rows) Next(dest []driver.Value) error { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } - case pgtype.CIDOID: - var d pgtype.CID + case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID: + var d pgtype.Uint32 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) @@ -690,16 +690,6 @@ func (r *Rows) Next(dest []driver.Value) error { } return d.Value() } - case pgtype.OIDOID: - var d pgtype.OIDValue - scanPlan := ci.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) - if err != nil { - return nil, err - } - return d.Value() - } case pgtype.TimestampOID: var d pgtype.Timestamp scanPlan := ci.PlanScan(dataTypeOID, format, &d) @@ -720,16 +710,6 @@ func (r *Rows) Next(dest []driver.Value) error { } return d.Value() } - case pgtype.XIDOID: - var d pgtype.XID - scanPlan := ci.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) - if err != nil { - return nil, err - } - return d.Value() - } default: var d string scanPlan := ci.PlanScan(dataTypeOID, format, &d)