From fa1c81fec4413a97bb267b85c19293cff10d5841 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 16:13:05 -0600 Subject: [PATCH] Move ACLItem to pgtype --- aclitem_parse_test.go | 126 ---------------- conn.go | 2 + pgtype/aclitem.go | 104 ++++++++++++++ pgtype/aclitem_test.go | 97 +++++++++++++ pgtype/aclitemarray.go | 186 ++++++++++++++++++++++++ pgtype/aclitemarray_test.go | 151 +++++++++++++++++++ pgtype/pgtype.go | 4 +- pgtype/typed_array_gen.sh | 1 + values.go | 280 +----------------------------------- values_test.go | 51 ------- 10 files changed, 548 insertions(+), 454 deletions(-) delete mode 100644 aclitem_parse_test.go create mode 100644 pgtype/aclitem.go create mode 100644 pgtype/aclitem_test.go create mode 100644 pgtype/aclitemarray.go create mode 100644 pgtype/aclitemarray_test.go diff --git a/aclitem_parse_test.go b/aclitem_parse_test.go deleted file mode 100644 index 5c7c748f..00000000 --- a/aclitem_parse_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package pgx - -import ( - "reflect" - "testing" -) - -func TestEscapeAclItem(t *testing.T) { - tests := []struct { - input string - expected string - }{ - { - "foo", - "foo", - }, - { - `foo, "\}`, - `foo\, \"\\\}`, - }, - } - - for i, tt := range tests { - actual, err := escapeAclItem(tt.input) - - if err != nil { - t.Errorf("%d. Unexpected error %v", i, err) - } - - if actual != tt.expected { - t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual) - } - } -} - -func TestParseAclItemArray(t *testing.T) { - tests := []struct { - input string - expected []AclItem - errMsg string - }{ - { - "", - []AclItem{}, - "", - }, - { - "one", - []AclItem{"one"}, - "", - }, - { - `"one"`, - []AclItem{"one"}, - "", - }, - { - "one,two,three", - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one","two","three"`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one",two,"three"`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `one,two,"three"`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one","two",three`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one","t w o",three`, - []AclItem{"one", "t w o", "three"}, - "", - }, - { - `"one","t, w o\"\}\\",three`, - []AclItem{"one", `t, w o"}\`, "three"}, - "", - }, - { - `"one","two",three"`, - []AclItem{"one", "two", `three"`}, - "", - }, - { - `"one","two,"three"`, - nil, - "unexpected rune after quoted value", - }, - { - `"one","two","three`, - nil, - "unexpected end of quoted value", - }, - } - - for i, tt := range tests { - actual, err := parseAclItemArray(tt.input) - - if err != nil { - if tt.errMsg == "" { - t.Errorf("%d. Unexpected error %v", i, err) - } else if err.Error() != tt.errMsg { - t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error()) - } - } else if tt.errMsg != "" { - t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg) - } - - if !reflect.DeepEqual(actual, tt.expected) { - t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual) - } - } -} diff --git a/conn.go b/conn.go index e340f1c6..f55dd82a 100644 --- a/conn.go +++ b/conn.go @@ -268,6 +268,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.closedChan = make(chan error) c.oidPgtypeValues = map[OID]pgtype.Value{ + ACLItemOID: &pgtype.ACLItem{}, + ACLItemArrayOID: &pgtype.ACLItemArray{}, BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, ByteaArrayOID: &pgtype.ByteaArray{}, diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go new file mode 100644 index 00000000..bd7b7d45 --- /dev/null +++ b/pgtype/aclitem.go @@ -0,0 +1,104 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" +) + +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +type ACLItem struct { + String string + Status Status +} + +func (dst *ACLItem) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case ACLItem: + *dst = value + case string: + *dst = ACLItem{String: value, Status: Present} + case *string: + if value == nil { + *dst = ACLItem{Status: Null} + } else { + *dst = ACLItem{String: *value, Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (src *ACLItem) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.String + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + case reflect.String: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + el.SetString(src.String) + return nil + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *ACLItem) DecodeText(src []byte) error { + if src == nil { + *dst = ACLItem{Status: Null} + return nil + } + + *dst = ACLItem{String: string(src), Status: Present} + return nil +} + +func (src ACLItem) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.String) + return false, err +} diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go new file mode 100644 index 00000000..0b2b6cfa --- /dev/null +++ b/pgtype/aclitem_test.go @@ -0,0 +1,97 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestACLItemTranscode(t *testing.T) { + testSuccessfulTranscode(t, "aclitem", []interface{}{ + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + }) +} + +func TestACLItemConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItem + }{ + {source: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.ACLItem + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestACLItemAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItem + dst interface{} + }{ + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/aclitemarray.go b/pgtype/aclitemarray.go new file mode 100644 index 00000000..d69cd83c --- /dev/null +++ b/pgtype/aclitemarray.go @@ -0,0 +1,186 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type ACLItemArray struct { + Elements []ACLItem + Dimensions []ArrayDimension + Status Status +} + +func (dst *ACLItemArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case ACLItemArray: + *dst = value + + case []string: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (src *ACLItemArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *ACLItemArray) DecodeText(src []byte) error { + if src == nil { + *dst = ACLItemArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []ACLItem + + if len(uta.Elements) > 0 { + elements = make([]ACLItem, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem ACLItem + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (src *ACLItemArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} diff --git a/pgtype/aclitemarray_test.go b/pgtype/aclitemarray_test.go new file mode 100644 index 00000000..8c01ac66 --- /dev/null +++ b/pgtype/aclitemarray_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestACLItemArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "aclitem[]", []interface{}{ + &pgtype.ACLItemArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{Status: pgtype.Null}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestACLItemArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItemArray + }{ + { + source: []string{"=r/postgres"}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.ACLItemArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ACLItemArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestACLItemArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.ACLItemArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItemArray + dst interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d6cd53c1..d72217ac 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -35,8 +35,8 @@ const ( Int8ArrayOID = 1016 Float4ArrayOID = 1021 Float8ArrayOID = 1022 - AclItemOID = 1033 - AclItemArrayOID = 1034 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 InetArrayOID = 1041 VarcharOID = 1043 DateOID = 1082 diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index c63414c8..876f8a3c 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -10,3 +10,4 @@ erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]fl erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > byteaarray.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_oid=ACLItemOID text_null=NULL typed_array.go.erb > aclitemarray.go diff --git a/values.go b/values.go index 80f4ee52..abe12d98 100644 --- a/values.go +++ b/values.go @@ -48,8 +48,8 @@ const ( Int8ArrayOID = 1016 Float4ArrayOID = 1021 Float8ArrayOID = 1022 - AclItemOID = 1033 - AclItemArrayOID = 1034 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 InetArrayOID = 1041 VarcharOID = 1043 DateOID = 1082 @@ -316,58 +316,6 @@ func (s NullString) Encode(w *WriteBuf, oid OID) error { return encodeString(w, oid, s.String) } -// AclItem is used for PostgreSQL's aclitem data type. A sample aclitem -// might look like this: -// -// postgres=arwdDxt/postgres -// -// Note, however, that because the user/role name part of an aclitem is -// an identifier, it follows all the usual formatting rules for SQL -// identifiers: if it contains spaces and other special characters, -// it should appear in double-quotes: -// -// postgres=arwdDxt/"role with spaces" -// -type AclItem string - -// NullAclItem represents a pgx.AclItem that may be null. NullAclItem implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullAclItem struct { - AclItem AclItem - Valid bool // Valid is true if AclItem is not NULL -} - -func (n *NullAclItem) Scan(vr *ValueReader) error { - if vr.Type().DataType != AclItemOID { - return SerializationError(fmt.Sprintf("NullAclItem.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.AclItem, n.Valid = "", false - return nil - } - - n.Valid = true - n.AclItem = AclItem(decodeText(vr)) - return vr.Err() -} - -// Particularly important to return TextFormatCode, seeing as Postgres -// only ever sends aclitem as text, not binary. -func (n NullAclItem) FormatCode() int16 { return TextFormatCode } - -func (n NullAclItem) Encode(w *WriteBuf, oid OID) error { - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeString(w, oid, string(n.AclItem)) -} - // NullInt16 represents a smallint that may be null. NullInt16 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan for prepared and unprepared queries. @@ -865,8 +813,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return Encode(wbuf, oid, v) case string: return encodeString(wbuf, oid, arg) - case []AclItem: - return encodeAclItemSlice(wbuf, oid, arg) case []byte: return encodeByteSlice(wbuf, oid, arg) } @@ -909,17 +855,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return nil } - switch arg := arg.(type) { - case AclItem: - // The aclitem data type goes over the wire using the same format as string, - // so just cast to string and use encodeString - return encodeString(wbuf, oid, string(arg)) - default: - if strippedArg, ok := stripNamedType(&refVal); ok { - return Encode(wbuf, oid, strippedArg) - } - return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + if strippedArg, ok := stripNamedType(&refVal); ok { + return Encode(wbuf, oid, strippedArg) } + return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } func stripNamedType(val *reflect.Value) (interface{}, bool) { @@ -981,15 +920,10 @@ func decodeByOID(vr *ValueReader) (interface{}, error) { // decoding to the built-in functionality. func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { - case *AclItem: - // aclitem goes over the wire just like text - *v = AclItem(decodeText(vr)) case *Tid: *v = decodeTid(vr) case *string: *v = decodeText(vr) - case *[]AclItem: - *v = decodeAclItemArray(vr) case *[]interface{}: *v = decodeRecord(vr) default: @@ -1675,207 +1609,3 @@ func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { return length, nil } - -// escapeAclItem escapes an AclItem before it is added to -// its aclitem[] string representation. The PostgreSQL aclitem -// datatype itself can need escapes because it follows the -// formatting rules of SQL identifiers. Think of this function -// as escaping the escapes, so that PostgreSQL's array parser -// will do the right thing. -func escapeAclItem(acl string) (string, error) { - var escapedAclItem bytes.Buffer - reader := strings.NewReader(acl) - for { - rn, _, err := reader.ReadRune() - if err != nil { - if err == io.EOF { - // Here, EOF is an expected end state, not an error. - return escapedAclItem.String(), nil - } - // This error was not expected - return "", err - } - if needsEscape(rn) { - escapedAclItem.WriteRune('\\') - } - escapedAclItem.WriteRune(rn) - } -} - -// needsEscape determines whether or not a rune needs escaping -// before being placed in the textual representation of an -// aclitem[] array. -func needsEscape(rn rune) bool { - return rn == '\\' || rn == ',' || rn == '"' || rn == '}' -} - -// encodeAclItemSlice encodes a slice of AclItems in -// their textual represention for PostgreSQL. -func encodeAclItemSlice(w *WriteBuf, oid OID, aclitems []AclItem) error { - strs := make([]string, len(aclitems)) - var escapedAclItem string - var err error - for i := range strs { - escapedAclItem, err = escapeAclItem(string(aclitems[i])) - if err != nil { - return err - } - strs[i] = string(escapedAclItem) - } - - var buf bytes.Buffer - buf.WriteRune('{') - buf.WriteString(strings.Join(strs, ",")) - buf.WriteRune('}') - str := buf.String() - w.WriteInt32(int32(len(str))) - w.WriteBytes([]byte(str)) - return nil -} - -// parseAclItemArray parses the textual representation -// of the aclitem[] type. The textual representation is chosen because -// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin). -// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO -// for formatting notes. -func parseAclItemArray(arr string) ([]AclItem, error) { - reader := strings.NewReader(arr) - // Difficult to guess a performant initial capacity for a slice of - // aclitems, but let's go with 5. - aclItems := make([]AclItem, 0, 5) - // A single value - aclItem := AclItem("") - for { - // Grab the first/next/last rune to see if we are dealing with a - // quoted value, an unquoted value, or the end of the string. - rn, _, err := reader.ReadRune() - if err != nil { - if err == io.EOF { - // Here, EOF is an expected end state, not an error. - return aclItems, nil - } - // This error was not expected - return nil, err - } - - if rn == '"' { - // Discard the opening quote of the quoted value. - aclItem, err = parseQuotedAclItem(reader) - } else { - // We have just read the first rune of an unquoted (bare) value; - // put it back so that ParseBareValue can read it. - err := reader.UnreadRune() - if err != nil { - return nil, err - } - aclItem, err = parseBareAclItem(reader) - } - - if err != nil { - if err == io.EOF { - // Here, EOF is an expected end state, not an error.. - aclItems = append(aclItems, aclItem) - return aclItems, nil - } - // This error was not expected. - return nil, err - } - aclItems = append(aclItems, aclItem) - } -} - -// parseBareAclItem parses a bare (unquoted) aclitem from reader -func parseBareAclItem(reader *strings.Reader) (AclItem, error) { - var aclItem bytes.Buffer - for { - rn, _, err := reader.ReadRune() - if err != nil { - // Return the read value in case the error is a harmless io.EOF. - // (io.EOF marks the end of a bare aclitem at the end of a string) - return AclItem(aclItem.String()), err - } - if rn == ',' { - // A comma marks the end of a bare aclitem. - return AclItem(aclItem.String()), nil - } else { - aclItem.WriteRune(rn) - } - } -} - -// parseQuotedAclItem parses an aclitem which is in double quotes from reader -func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) { - var aclItem bytes.Buffer - for { - rn, escaped, err := readPossiblyEscapedRune(reader) - if err != nil { - if err == io.EOF { - // Even when it is the last value, the final rune of - // a quoted aclitem should be the final closing quote, not io.EOF. - return AclItem(""), fmt.Errorf("unexpected end of quoted value") - } - // Return the read aclitem in case the error is a harmless io.EOF, - // which will be determined by the caller. - return AclItem(aclItem.String()), err - } - if !escaped && rn == '"' { - // An unescaped double quote marks the end of a quoted value. - // The next rune should either be a comma or the end of the string. - rn, _, err := reader.ReadRune() - if err != nil { - // Return the read value in case the error is a harmless io.EOF, - // which will be determined by the caller. - return AclItem(aclItem.String()), err - } - if rn != ',' { - return AclItem(""), fmt.Errorf("unexpected rune after quoted value") - } - return AclItem(aclItem.String()), nil - } - aclItem.WriteRune(rn) - } -} - -// Returns the next rune from r, unless it is a backslash; -// in that case, it returns the rune after the backslash. The second -// return value tells us whether or not the rune was -// preceeded by a backslash (escaped). -func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) { - rn, _, err := reader.ReadRune() - if err != nil { - return 0, false, err - } - if rn == '\\' { - // Discard the backslash and read the next rune. - rn, _, err = reader.ReadRune() - if err != nil { - return 0, false, err - } - return rn, true, nil - } - return rn, false, nil -} - -func decodeAclItemArray(vr *ValueReader) []AclItem { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into []AclItem")) - return nil - } - - str := vr.ReadString(vr.Len()) - - // Short-circuit empty array. - if str == "{}" { - return []AclItem{} - } - - // Remove the '{' at the front and the '}' at the end, - // so that parseAclItemArray doesn't have to deal with them. - str = str[1 : len(str)-1] - aclItems, err := parseAclItemArray(str) - if err != nil { - vr.Fatal(ProtocolError(err.Error())) - return nil - } - return aclItems -} diff --git a/values_test.go b/values_test.go index 4c02ac0a..9cf2b219 100644 --- a/values_test.go +++ b/values_test.go @@ -568,7 +568,6 @@ func TestNullX(t *testing.T) { s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 - a pgx.NullAclItem tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 @@ -591,10 +590,6 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, - // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, @@ -631,52 +626,6 @@ func TestNullX(t *testing.T) { } } -func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) { - if !reflect.DeepEqual(query, scan) { - t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan) - } -} - -func TestAclArrayDecoding(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - sql := "select $1::aclitem[]" - var scan []pgx.AclItem - - tests := []struct { - query []pgx.AclItem - }{ - { - []pgx.AclItem{}, - }, - { - []pgx.AclItem{"=r/postgres"}, - }, - { - []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, - }, - { - []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`}, - }, - } - for i, tt := range tests { - err := conn.QueryRow(sql, tt.query).Scan(&scan) - if err != nil { - // t.Errorf(`%d. error reading array: %v`, i, err) - t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query) - if pgerr, ok := err.(pgx.PgError); ok { - t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) - } - continue - } - assertAclItemSlicesEqual(t, tt.query, scan) - ensureConnValid(t, conn) - } -} - func TestArrayDecoding(t *testing.T) { t.Parallel()