From 90de4adfa79163f1afc2f2388b0300cf89b8bfe9 Mon Sep 17 00:00:00 2001 From: Iurii Krasnoshchok Date: Sat, 16 Dec 2017 02:20:34 +0100 Subject: [PATCH] Add support for bpchar type --- messages.go | 4 +- pgtype/bpchar.go | 68 ++++++++ pgtype/bpchar_array.go | 300 ++++++++++++++++++++++++++++++++++++ pgtype/bpchar_array_test.go | 55 +++++++ pgtype/bpchar_test.go | 51 ++++++ pgtype/pgtype.go | 4 + pgtype/typed_array_gen.sh | 1 + 7 files changed, 481 insertions(+), 2 deletions(-) create mode 100644 pgtype/bpchar.go create mode 100644 pgtype/bpchar_array.go create mode 100644 pgtype/bpchar_array_test.go create mode 100644 pgtype/bpchar_test.go diff --git a/messages.go b/messages.go index 5ffa5c06..97e89295 100644 --- a/messages.go +++ b/messages.go @@ -31,7 +31,7 @@ func (fd FieldDescription) Length() (int64, bool) { switch fd.DataType { case pgtype.TextOID, pgtype.ByteaOID: return math.MaxInt64, true - case pgtype.VarcharOID: + case pgtype.VarcharOID, pgtype.BPCharArrayOID: return int64(fd.Modifier - varHeaderSize), true default: return 0, false @@ -58,7 +58,7 @@ func (fd FieldDescription) Type() reflect.Type { return reflect.TypeOf(int32(0)) case pgtype.Int2OID: return reflect.TypeOf(int16(0)) - case pgtype.VarcharOID, pgtype.TextOID: + case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID: return reflect.TypeOf("") case pgtype.BoolOID: return reflect.TypeOf(false) diff --git a/pgtype/bpchar.go b/pgtype/bpchar.go new file mode 100644 index 00000000..21263184 --- /dev/null +++ b/pgtype/bpchar.go @@ -0,0 +1,68 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// BPChar is fixed-length, blank padded char type +// character(n), char(n) +type BPChar Text + +// Set converts from src to dst. +func (dst *BPChar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +// Get returns underlying value +func (dst *BPChar) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. +func (src *BPChar) AssignTo(dst interface{}) error { + if src.Status == Present { + switch v := dst.(type) { + case *rune: + runes := []rune(src.String) + if len(runes) == 1 { + *v = runes[0] + return nil + } + } + } + return (*Text)(src).AssignTo(dst) +} + +func (dst *BPChar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src *BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) +} + +func (src *BPChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *BPChar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *BPChar) Value() (driver.Value, error) { + return (*Text)(src).Value() +} + +func (src *BPChar) MarshalJSON() ([]byte, error) { + return (*Text)(src).MarshalJSON() +} + +func (dst *BPChar) UnmarshalJSON(b []byte) error { + return (*Text)(dst).UnmarshalJSON(b) +} diff --git a/pgtype/bpchar_array.go b/pgtype/bpchar_array.go new file mode 100644 index 00000000..1e6220f7 --- /dev/null +++ b/pgtype/bpchar_array.go @@ -0,0 +1,300 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type BPCharArray struct { + Elements []BPChar + Dimensions []ArrayDimension + Status Status +} + +func (dst *BPCharArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = BPCharArray{Status: Null} + return nil + } + + switch value := src.(type) { + + case []string: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + elements := make([]BPChar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BPCharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to BPCharArray", value) + } + + return nil +} + +func (dst *BPCharArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *BPCharArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BPCharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []BPChar + + if len(uta.Elements) > 0 { + elements = make([]BPChar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem BPChar + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = BPCharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BPCharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = BPCharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]BPChar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = BPCharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bpchar"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "bpchar") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *BPCharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *BPCharArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/bpchar_array_test.go b/pgtype/bpchar_array_test.go new file mode 100644 index 00000000..e4f2e7eb --- /dev/null +++ b/pgtype/bpchar_array_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestBPCharArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "char(8)[]", []interface{}{ + &pgtype.BPCharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: "foo ", Status: pgtype.Present}, + pgtype.BPChar{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BPCharArray{Status: pgtype.Null}, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: "bar ", Status: pgtype.Present}, + pgtype.BPChar{String: "NuLL ", Status: pgtype.Present}, + pgtype.BPChar{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.BPChar{String: "1 ", Status: pgtype.Present}, + pgtype.BPChar{String: "1 ", Status: pgtype.Present}, + pgtype.BPChar{String: "null ", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + Status: pgtype.Present, + }, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: " bar ", Status: pgtype.Present}, + pgtype.BPChar{String: " baz ", Status: pgtype.Present}, + pgtype.BPChar{String: " quz ", Status: pgtype.Present}, + pgtype.BPChar{String: "foo ", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} diff --git a/pgtype/bpchar_test.go b/pgtype/bpchar_test.go new file mode 100644 index 00000000..c076ca1b --- /dev/null +++ b/pgtype/bpchar_test.go @@ -0,0 +1,51 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestChar3Transcode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "char(3)", []interface{}{ + &pgtype.BPChar{String: "a ", Status: pgtype.Present}, + &pgtype.BPChar{String: " a ", Status: pgtype.Present}, + &pgtype.BPChar{String: "嗨 ", Status: pgtype.Present}, + &pgtype.BPChar{String: " ", Status: pgtype.Present}, + &pgtype.BPChar{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.BPChar) + b := bb.(pgtype.BPChar) + + return a.Status == b.Status && a.String == b.String + }) +} + +func TestBPCharAssignTo(t *testing.T) { + var ( + str string + run rune + ) + simpleTests := []struct { + src pgtype.BPChar + dst interface{} + expected interface{} + }{ + {src: pgtype.BPChar{String: "simple", Status: pgtype.Present}, dst: &str, expected: "simple"}, + {src: pgtype.BPChar{String: "嗨", Status: pgtype.Present}, dst: &run, expected: '嗨'}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index f7a1a300..2643314e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -32,6 +32,7 @@ const ( Int4ArrayOID = 1007 TextArrayOID = 1009 ByteaArrayOID = 1001 + BPCharArrayOID = 1014 VarcharArrayOID = 1015 Int8ArrayOID = 1016 Float4ArrayOID = 1021 @@ -39,6 +40,7 @@ const ( ACLItemOID = 1033 ACLItemArrayOID = 1034 InetArrayOID = 1041 + BPCharOID = 1042 VarcharOID = 1043 DateOID = 1082 TimestampOID = 1114 @@ -211,6 +213,7 @@ func init() { nameValues = map[string]Value{ "_aclitem": &ACLItemArray{}, "_bool": &BoolArray{}, + "_bpchar": &BPCharArray{}, "_bytea": &ByteaArray{}, "_cidr": &CIDRArray{}, "_date": &DateArray{}, @@ -230,6 +233,7 @@ func init() { "bit": &Bit{}, "bool": &Bool{}, "box": &Box{}, + "bpchar": &BPChar{}, "bytea": &Bytea{}, "char": &QChar{}, "cid": &CID{}, diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 2a1eab99..4a8211bc 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -11,6 +11,7 @@ erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.I erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string element_type_name=bpchar text_null='NULL' binary_format=true typed_array.go.erb > bpchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go