From 17513d175aa26c6b27966a4a36975b307a9f8dc3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 16:49:58 -0600 Subject: [PATCH] Convert bit and varbit to Codec --- pgtype/bit.go | 45 --------- pgtype/bit_test.go | 25 ----- pgtype/bits.go | 208 ++++++++++++++++++++++++++++++++++++++++++ pgtype/bits_test.go | 65 +++++++++++++ pgtype/pgtype.go | 4 +- pgtype/varbit.go | 123 ------------------------- pgtype/varbit_test.go | 26 ------ 7 files changed, 275 insertions(+), 221 deletions(-) delete mode 100644 pgtype/bit.go delete mode 100644 pgtype/bit_test.go create mode 100644 pgtype/bits.go create mode 100644 pgtype/bits_test.go delete mode 100644 pgtype/varbit.go delete mode 100644 pgtype/varbit_test.go diff --git a/pgtype/bit.go b/pgtype/bit.go deleted file mode 100644 index c1709e6b..00000000 --- a/pgtype/bit.go +++ /dev/null @@ -1,45 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -type Bit Varbit - -func (dst *Bit) Set(src interface{}) error { - return (*Varbit)(dst).Set(src) -} - -func (dst Bit) Get() interface{} { - return (Varbit)(dst).Get() -} - -func (src *Bit) AssignTo(dst interface{}) error { - return (*Varbit)(src).AssignTo(dst) -} - -func (dst *Bit) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Varbit)(dst).DecodeBinary(ci, src) -} - -func (src Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Varbit)(src).EncodeBinary(ci, buf) -} - -func (dst *Bit) DecodeText(ci *ConnInfo, src []byte) error { - return (*Varbit)(dst).DecodeText(ci, src) -} - -func (src Bit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Varbit)(src).EncodeText(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Bit) Scan(src interface{}) error { - return (*Varbit)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Bit) Value() (driver.Value, error) { - return (Varbit)(src).Value() -} diff --git a/pgtype/bit_test.go b/pgtype/bit_test.go deleted file mode 100644 index 2f07c3c9..00000000 --- a/pgtype/bit_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestBitTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bit(40)", []interface{}{ - &pgtype.Varbit{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, - &pgtype.Varbit{}, - }) -} - -func TestBitNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select B'111111111'", - Value: &pgtype.Bit{Bytes: []byte{255, 128}, Len: 9, Valid: true}, - }, - }) -} diff --git a/pgtype/bits.go b/pgtype/bits.go new file mode 100644 index 00000000..9b499c35 --- /dev/null +++ b/pgtype/bits.go @@ -0,0 +1,208 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +type BitsScanner interface { + ScanBits(v Bits) error +} + +type BitsValuer interface { + BitsValue() (Bits, error) +} + +// Bits represents the PostgreSQL bit and varbit types. +type Bits struct { + Bytes []byte + Len int32 // Number of bits + Valid bool +} + +func (b *Bits) ScanBits(v Bits) error { + *b = v + return nil +} + +func (b Bits) BitsValue() (Bits, error) { + return b, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bits) Scan(src interface{}) error { + if src == nil { + *dst = Bits{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToBitsScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bits) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := BitsCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type BitsCodec struct{} + +func (BitsCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BitsCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BitsCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(BitsValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanBitsCodecBinary{} + case TextFormatCode: + return encodePlanBitsCodecText{} + } + + return nil +} + +type encodePlanBitsCodecBinary struct{} + +func (encodePlanBitsCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + bits, err := value.(BitsValuer).BitsValue() + if err != nil { + return nil, err + } + + if !bits.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, bits.Len) + return append(buf, bits.Bytes...), nil +} + +type encodePlanBitsCodecText struct{} + +func (encodePlanBitsCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + bits, err := value.(BitsValuer).BitsValue() + if err != nil { + return nil, err + } + + if !bits.Valid { + return nil, nil + } + + for i := int32(0); i < bits.Len; i++ { + byteIdx := i / 8 + bitMask := byte(128 >> byte(i%8)) + char := byte('0') + if bits.Bytes[byteIdx]&bitMask > 0 { + char = '1' + } + buf = append(buf, char) + } + + return buf, nil +} + +func (BitsCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case BitsScanner: + return scanPlanBinaryBitsToBitsScanner{} + } + case TextFormatCode: + switch target.(type) { + case BitsScanner: + return scanPlanTextAnyToBitsScanner{} + } + } + + return nil +} + +func (c BitsCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c BitsCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var box Bits + err := codecScan(c, ci, oid, format, src, &box) + if err != nil { + return nil, err + } + return box, nil +} + +type scanPlanBinaryBitsToBitsScanner struct{} + +func (scanPlanBinaryBitsToBitsScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BitsScanner) + + if src == nil { + return scanner.ScanBits(Bits{}) + } + + if len(src) < 4 { + return fmt.Errorf("invalid length for bit/varbit: %v", len(src)) + } + + bitLen := int32(binary.BigEndian.Uint32(src)) + rp := 4 + + return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true}) +} + +type scanPlanTextAnyToBitsScanner struct{} + +func (scanPlanTextAnyToBitsScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BitsScanner) + + if src == nil { + return scanner.ScanBits(Bits{}) + } + + bitLen := len(src) + byteLen := bitLen / 8 + if bitLen%8 > 0 { + byteLen++ + } + buf := make([]byte, byteLen) + + for i, b := range src { + if b == '1' { + byteIdx := i / 8 + bitIdx := uint(i % 8) + buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) + } + } + + return scanner.ScanBits(Bits{Bytes: buf, Len: int32(bitLen), Valid: true}) +} diff --git a/pgtype/bits_test.go b/pgtype/bits_test.go new file mode 100644 index 00000000..a585ef8b --- /dev/null +++ b/pgtype/bits_test.go @@ -0,0 +1,65 @@ +package pgtype_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" +) + +func isExpectedEqBits(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + ab := a.(pgtype.Bits) + vb := v.(pgtype.Bits) + return bytes.Compare(ab.Bytes, vb.Bytes) == 0 && ab.Len == vb.Len && ab.Valid == vb.Valid + } +} + +func TestBitsCodecBit(t *testing.T) { + testPgxCodec(t, "bit(40)", []PgxTranscodeTestCase{ + { + pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), + }, + {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + }) +} + +func TestBitsCodecVarbit(t *testing.T) { + testPgxCodec(t, "varbit", []PgxTranscodeTestCase{ + { + pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}), + }, + {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + }) +} + +func TestBitsNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select B'111111111'", + Value: &pgtype.Bits{Bytes: []byte{255, 128}, Len: 9, Valid: true}, + }, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d4001392..dc3fbedd 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -276,7 +276,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) + ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) ci.RegisterDataType(DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) @@ -321,7 +321,7 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) ci.RegisterDataType(DataType{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) - ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) + ci.RegisterDataType(DataType{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) diff --git a/pgtype/varbit.go b/pgtype/varbit.go deleted file mode 100644 index bc6fdac4..00000000 --- a/pgtype/varbit.go +++ /dev/null @@ -1,123 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - - "github.com/jackc/pgio" -) - -type Varbit struct { - Bytes []byte - Len int32 // Number of bits - Valid bool -} - -func (dst *Varbit) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Varbit", src) -} - -func (dst Varbit) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *Varbit) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Varbit{} - return nil - } - - bitLen := len(src) - byteLen := bitLen / 8 - if bitLen%8 > 0 { - byteLen++ - } - buf := make([]byte, byteLen) - - for i, b := range src { - if b == '1' { - byteIdx := i / 8 - bitIdx := uint(i % 8) - buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) - } - } - - *dst = Varbit{Bytes: buf, Len: int32(bitLen), Valid: true} - return nil -} - -func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Varbit{} - return nil - } - - if len(src) < 4 { - return fmt.Errorf("invalid length for varbit: %v", len(src)) - } - - bitLen := int32(binary.BigEndian.Uint32(src)) - rp := 4 - - *dst = Varbit{Bytes: src[rp:], Len: bitLen, Valid: true} - return nil -} - -func (src Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - for i := int32(0); i < src.Len; i++ { - byteIdx := i / 8 - bitMask := byte(128 >> byte(i%8)) - char := byte('0') - if src.Bytes[byteIdx]&bitMask > 0 { - char = '1' - } - buf = append(buf, char) - } - - return buf, nil -} - -func (src Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendInt32(buf, src.Len) - return append(buf, src.Bytes...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Varbit) Scan(src interface{}) error { - if src == nil { - *dst = Varbit{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Varbit) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/varbit_test.go b/pgtype/varbit_test.go deleted file mode 100644 index 031d5fa8..00000000 --- a/pgtype/varbit_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestVarbitTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ - &pgtype.Varbit{Bytes: []byte{}, Len: 0, Valid: true}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}, - &pgtype.Varbit{}, - }) -} - -func TestVarbitNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select B'111111111'", - Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Valid: true}, - }, - }) -}