diff --git a/values.go b/values.go index ff2dfbfb..fe8f82fa 100644 --- a/values.go +++ b/values.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "io" "math" "net" "reflect" @@ -2999,12 +3000,42 @@ func decodeTextArray(vr *ValueReader) []string { return a } +func EscapeAclItem(acl string) (string, error) { + var buf bytes.Buffer + r := strings.NewReader(acl) + for { + rn, _, err := r.ReadRune() + if err != nil { + if err == io.EOF { + // This error was expected and is OK + return buf.String(), nil + } + // This error was not expected + return "", err + } + if NeedsEscape(rn) { + buf.WriteRune('\\') + } + buf.WriteRune(rn) + } +} + +func NeedsEscape(rn rune) bool { + return rn == '\\' || rn == ',' || rn == '"' || rn == '}' +} + // XXX: encodeAclItemSlice; using text encoding, not binary func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { // cast aclitems into strings so we can use strings.Join strs := make([]string, len(aclitems)) + var escaped string + var err error for i := range strs { - strs[i] = string(aclitems[i]) + escaped, err = EscapeAclItem(string(aclitems[i])) + if err != nil { + return err + } + strs[i] = string(escaped) } str := strings.Join(strs, ",") @@ -3014,6 +3045,121 @@ func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { return nil } +func ParseAclItemArray(arr string) ([]string, error) { + r := strings.NewReader(arr) + // Difficult to guess a performant initial capacity for a slice of + // values, but let's go with 5. + vals := make([]string, 0, 5) + // A single value + vlu := "" + for { + // Grab the first/next/last rune to see if we are dealing with a + // quoted value, an unquoted value, or the end of the string. + rn, _, err := r.ReadRune() + if err != nil { + if err == io.EOF { + // This error was expected and is OK + return vals, nil + } + // This error was not expected + return nil, err + } + + if rn == '"' { + // Discard the opening quote of the quoted value. + vlu, err = ParseQuotedAclItem(r) + } else { + // We have just read the first rune of an unquoted (bare) value; + // put it back so that ParseBareValue can read it. + err := r.UnreadRune() + if err != nil { + // This error was not expected. + return nil, err + } + vlu, err = ParseBareAclItem(r) + } + + if err != nil { + if err == io.EOF { + // This error was expected and is OK. + vals = append(vals, vlu) + return vals, nil + } + // This error was not expected. + return nil, err + } + vals = append(vals, vlu) + } +} + +func ParseBareAclItem(r *strings.Reader) (string, error) { + var buf bytes.Buffer + for { + rn, _, err := r.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF. + // (io.EOF marks the end of a bare value at the end of a string) + return buf.String(), err + } + if rn == ',' { + // A comma marks the end of a bare value. + return buf.String(), nil + } else { + buf.WriteRune(rn) + } + } +} + +func ParseQuotedAclItem(r *strings.Reader) (string, error) { + var buf bytes.Buffer + for { + rn, escaped, err := ReadPossiblyEscapedRune(r) + if err != nil { + if err == io.EOF { + // Even when it is the last value, the final rune of + // a quoted value should be the final closing quote, not io.EOF. + return "", fmt.Errorf("unexpected end of quoted value") + } + // Return the read value in case the error is a harmless io.EOF. + return buf.String(), err + } + if !escaped && rn == '"' { + // An unescaped double quote marks the end of a quoted value. + // The next rune should either be a comma or the end of the string. + rn, _, err := r.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF. + return buf.String(), err + } + if rn != ',' { + return "", fmt.Errorf("unexpected rune after quoted value") + } + return buf.String(), nil + } + buf.WriteRune(rn) + } +} + +// Returns the next rune from r, unless it is a backslash; +// in that case, it returns the rune after the backslash. The second +// return value tells us whether or not the rune was +// preceeded by a backslash (escaped). +func ReadPossiblyEscapedRune(r *strings.Reader) (rune, bool, error) { + rn, _, err := r.ReadRune() + if err != nil { + return 0, false, err + } + if rn == '\\' { + // Discard the backslash and read the next rune. + rn, _, err = r.ReadRune() + if err != nil { + return 0, false, err + } + return rn, true, nil + } + return rn, false, nil +} + // XXX: decodeAclItemArray; using text encoding, not binary func decodeAclItemArray(vr *ValueReader) []AclItem { if vr.Len() == -1 { @@ -3030,7 +3176,9 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { // remove the '{' at the front and the '}' at the end str = str[1 : len(str)-1] - strs := strings.Split(str, ",") + strs, _ := ParseAclItemArray(str) + // XXX: what do I do with the error here? + // XXX strs := strings.Split(str, ",") // cast strings into AclItems before returning aclitems := make([]AclItem, len(strs)) diff --git a/values_test.go b/values_test.go index 01e5114b..8c5d1032 100644 --- a/values_test.go +++ b/values_test.go @@ -672,13 +672,14 @@ func TestAclArrayDecoding(t *testing.T) { []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, }, { - []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/\" tricky\, ' \} \"\" \\ test user \"`}, + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`}, }, } for i, tt := range tests { err := conn.QueryRow(sql, tt.query).Scan(&scan) if err != nil { - t.Errorf(`%d. error reading array: %v`, i, err) + // t.Errorf(`%d. error reading array: %v`, i, err) + t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query) if pgerr, ok := err.(pgx.PgError); ok { t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) }