From bcf4931a7e4c54c639f8b3d805d5dffb9aba6df7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 17:56:07 -0600 Subject: [PATCH] Convert "char" to Codec --- pgtype/pgtype.go | 2 +- pgtype/qchar.go | 208 +++++++++++++++++++++---------------------- pgtype/qchar_test.go | 145 +++--------------------------- 3 files changed, 113 insertions(+), 242 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5072d061..476450a2 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -304,7 +304,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) ci.RegisterDataType(DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) - ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) + ci.RegisterDataType(DataType{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) ci.RegisterDataType(DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) diff --git a/pgtype/qchar.go b/pgtype/qchar.go index e56bf142..28c91110 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -1,145 +1,141 @@ package pgtype import ( + "database/sql/driver" "fmt" "math" - "strconv" ) -// QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// QCharCodec is for PostgreSQL's special 8-bit-only "char" type more akin to the C // language's char type, or Go's byte type. (Note that the name in PostgreSQL // itself is "char", in double-quotes, and not char.) It gets used a lot in // PostgreSQL's system tables to hold a single ASCII character value (eg // pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL // standard type char. -// -// Not all possible values of QChar are representable in the text format. -// Therefore, QChar does not implement TextEncoder and TextDecoder. In -// addition, database/sql Scanner and database/sql/driver Value are not -// implemented. -type QChar struct { - Int int8 - Valid bool +type QCharCodec struct{} + +func (QCharCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst *QChar) Set(src interface{}) error { - if src == nil { - *dst = QChar{} - return nil - } +func (QCharCodec) PreferredFormat() int16 { + return BinaryFormatCode +} - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) +func (QCharCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case byte: + return encodePlanQcharCodecByte{} + case rune: + return encodePlanQcharCodecRune{} } } - switch value := src.(type) { - case int8: - *dst = QChar{Int: value, Valid: true} - case uint8: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case int16: - if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case uint16: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case int32: - if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case uint32: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case int64: - if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case uint64: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case int: - if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case uint: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case string: - num, err := strconv.ParseInt(value, 10, 8) - if err != nil { - return err - } - *dst = QChar{Int: int8(num), Valid: true} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to QChar", value) - } - return nil } -func (dst QChar) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanQcharCodecByte struct{} + +func (encodePlanQcharCodecByte) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b := value.(byte) + buf = append(buf, b) + return buf, nil +} + +type encodePlanQcharCodecRune struct{} + +func (encodePlanQcharCodecRune) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + r := value.(rune) + if r > math.MaxUint8 { + return nil, fmt.Errorf(`%v cannot be encoded to "char"`, r) } - return dst.Int + b := byte(r) + buf = append(buf, b) + return buf, nil } -func (src *QChar) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Valid, dst) +func (QCharCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *byte: + return scanPlanQcharCodecByte{} + case *rune: + return scanPlanQcharCodecRune{} + } + } + + return nil } -func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanQcharCodecByte struct{} + +func (scanPlanQcharCodecByte) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - *dst = QChar{} - return nil + return fmt.Errorf("cannot scan null into %T", dst) } - if len(src) != 1 { + if len(src) > 1 { return fmt.Errorf(`invalid length for "char": %v`, len(src)) } - *dst = QChar{Int: int8(src[0]), Valid: true} + b := dst.(*byte) + // In the text format the zero value is returned as a zero byte value instead of 0 + if len(src) == 0 { + *b = 0 + } else { + *b = src[0] + } + return nil } -func (src QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { +type scanPlanQcharCodecRune struct{} + +func (scanPlanQcharCodecRune) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) > 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) + } + + r := dst.(*rune) + // In the text format the zero value is returned as a zero byte value instead of 0 + if len(src) == 0 { + *r = 0 + } else { + *r = rune(src[0]) + } + + return nil +} + +func (c QCharCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { return nil, nil } - return append(buf, byte(src.Int)), nil + var r rune + err := codecScan(c, ci, oid, format, src, &r) + if err != nil { + return nil, err + } + return string(r), nil +} + +func (c QCharCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var r rune + err := codecScan(c, ci, oid, format, src, &r) + if err != nil { + return nil, err + } + return r, nil } diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index cb9b6786..ec555eb2 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -2,142 +2,17 @@ package pgtype_test import ( "math" - "reflect" "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestQCharTranscode(t *testing.T) { - testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ - &pgtype.QChar{Int: math.MinInt8, Valid: true}, - &pgtype.QChar{Int: -1, Valid: true}, - &pgtype.QChar{Int: 0, Valid: true}, - &pgtype.QChar{Int: 1, Valid: true}, - &pgtype.QChar{Int: math.MaxInt8, Valid: true}, - &pgtype.QChar{Int: 0}, - }, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestQCharSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.QChar - }{ - {source: int8(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: int16(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: int32(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: int64(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: int8(-1), result: pgtype.QChar{Int: -1, Valid: true}}, - {source: int16(-1), result: pgtype.QChar{Int: -1, Valid: true}}, - {source: int32(-1), result: pgtype.QChar{Int: -1, Valid: true}}, - {source: int64(-1), result: pgtype.QChar{Int: -1, Valid: true}}, - {source: uint8(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: uint16(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: uint32(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: uint64(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: "1", result: pgtype.QChar{Int: 1, Valid: true}}, - {source: _int8(1), result: pgtype.QChar{Int: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.QChar - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestQCharAssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.QChar - dst interface{} - expected interface{} - }{ - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i, expected: int(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.QChar{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.QChar{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.QChar - dst interface{} - expected interface{} - }{ - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.QChar - dst interface{} - }{ - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui8}, - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui16}, - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui32}, - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui64}, - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui}, - {src: pgtype.QChar{Int: 0}, dst: &i16}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } +func TestQcharTranscode(t *testing.T) { + var tests []PgxTranscodeTestCase + for i := 0; i <= math.MaxUint8; i++ { + tests = append(tests, PgxTranscodeTestCase{rune(i), new(rune), isExpectedEq(rune(i))}) + tests = append(tests, PgxTranscodeTestCase{byte(i), new(byte), isExpectedEq(byte(i))}) + } + tests = append(tests, PgxTranscodeTestCase{nil, new(*rune), isExpectedEq((*rune)(nil))}) + tests = append(tests, PgxTranscodeTestCase{nil, new(*byte), isExpectedEq((*byte)(nil))}) + + testPgxCodec(t, `"char"`, tests) }