mirror of https://github.com/jackc/pgx.git
Adding hstore support. map[string]string will encode to hstores and throw errors on hstores with NULL values, and there is now a NullHstore type that is basically map[string]NullString and will both accept and decode NULL values properly
parent
0441bcd8e4
commit
821605a8dd
2
conn.go
2
conn.go
|
@ -519,6 +519,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
err = arg.Encode(wbuf, oid)
|
||||
case string:
|
||||
err = encodeText(wbuf, arguments[i])
|
||||
case map[string]string:
|
||||
err = encodeHstore(wbuf, arguments[i])
|
||||
default:
|
||||
switch oid {
|
||||
case BoolOid:
|
||||
|
|
|
@ -0,0 +1,215 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
hsPre = iota
|
||||
hsKey
|
||||
hsSep
|
||||
hsVal
|
||||
hsNul
|
||||
hsNext
|
||||
hsEnd
|
||||
)
|
||||
|
||||
type hstoreParser struct {
|
||||
str string
|
||||
pos int
|
||||
}
|
||||
|
||||
func newHSP(in string) *hstoreParser {
|
||||
return &hstoreParser{
|
||||
pos: 0,
|
||||
str: in,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *hstoreParser) Consume() (r rune, end bool) {
|
||||
if p.pos >= len(p.str) {
|
||||
end = true
|
||||
return
|
||||
}
|
||||
r, w := utf8.DecodeRuneInString(p.str[p.pos:])
|
||||
p.pos += w
|
||||
return
|
||||
}
|
||||
|
||||
func (p *hstoreParser) Peek() (r rune, end bool) {
|
||||
if p.pos >= len(p.str) {
|
||||
end = true
|
||||
return
|
||||
}
|
||||
r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
|
||||
return
|
||||
}
|
||||
|
||||
func parseHstoreToMap(s string) (m map[string]string, err error) {
|
||||
keys, values, err := ParseHstore(s)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
m = make(map[string]string, len(keys))
|
||||
for i, key := range keys {
|
||||
if !values[i].Valid {
|
||||
err = fmt.Errorf("key '%s' has NULL value", key)
|
||||
m = nil
|
||||
return
|
||||
}
|
||||
m[key] = values[i].String
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) {
|
||||
keys, values, err := ParseHstore(s)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
store = make(map[string]NullString, len(keys))
|
||||
|
||||
for i, key := range keys {
|
||||
store[key] = values[i]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ParseHstore(s string) (k []string, v []NullString, err error) {
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
keys := []string{}
|
||||
values := []NullString{}
|
||||
p := newHSP(s)
|
||||
|
||||
r, end := p.Consume()
|
||||
state := hsPre
|
||||
|
||||
for !end {
|
||||
switch state {
|
||||
case hsPre:
|
||||
if r == '"' {
|
||||
state = hsKey
|
||||
} else {
|
||||
err = errors.New("String does not begin with \"")
|
||||
}
|
||||
case hsKey:
|
||||
switch r {
|
||||
case '"': //End of the key
|
||||
if buf.Len() == 0 {
|
||||
err = errors.New("Empty Key is invalid")
|
||||
} else {
|
||||
keys = append(keys, buf.String())
|
||||
buf = bytes.Buffer{}
|
||||
state = hsSep
|
||||
}
|
||||
case '\\': //Potential escaped character
|
||||
n, end := p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS in key, expecting character or \"")
|
||||
case n == '"', n == '\\':
|
||||
buf.WriteRune(n)
|
||||
default:
|
||||
buf.WriteRune(r)
|
||||
buf.WriteRune(n)
|
||||
}
|
||||
default: //Any other character
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case hsSep:
|
||||
if r == '=' {
|
||||
r, end = p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS after '=', expecting '>'")
|
||||
case r == '>':
|
||||
r, end = p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
|
||||
case r == '"':
|
||||
state = hsVal
|
||||
case r == 'N':
|
||||
state = hsNul
|
||||
default:
|
||||
err = fmt.Errorf("Invalid character '%s' after '=>', expecting '\"' or 'NULL'")
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("Invalid character after '=', expecting '>'")
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("Invalid character '%s' after value, expecting '='", r)
|
||||
}
|
||||
case hsVal:
|
||||
switch r {
|
||||
case '"': //End of the value
|
||||
values = append(values, NullString{String: buf.String(), Valid: true})
|
||||
buf = bytes.Buffer{}
|
||||
state = hsNext
|
||||
case '\\': //Potential escaped character
|
||||
n, end := p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS in key, expecting character or \"")
|
||||
case n == '"', n == '\\':
|
||||
buf.WriteRune(n)
|
||||
default:
|
||||
buf.WriteRune(r)
|
||||
buf.WriteRune(n)
|
||||
}
|
||||
default: //Any other character
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case hsNul:
|
||||
nulBuf := make([]rune, 3)
|
||||
nulBuf[0] = r
|
||||
for i := 1; i < 3; i++ {
|
||||
r, end = p.Consume()
|
||||
if end {
|
||||
err = errors.New("Found EOS in NULL value")
|
||||
return
|
||||
}
|
||||
nulBuf[i] = r
|
||||
}
|
||||
if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
|
||||
values = append(values, NullString{String: "", Valid: false})
|
||||
state = hsNext
|
||||
} else {
|
||||
err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
|
||||
}
|
||||
case hsNext:
|
||||
if r == ',' {
|
||||
r, end = p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS after ',', expcting space")
|
||||
case (unicode.IsSpace(r)):
|
||||
r, end = p.Consume()
|
||||
state = hsKey
|
||||
default:
|
||||
err = fmt.Errorf("Invalid character '%s' after ', ', expecting \"", r)
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("Invalid character '%s' after value, expecting ','", r)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, end = p.Consume()
|
||||
}
|
||||
if state != hsNext {
|
||||
err = errors.New("Improperly formatted hstore")
|
||||
return
|
||||
}
|
||||
k = keys
|
||||
v = values
|
||||
return
|
||||
}
|
3
query.go
3
query.go
|
@ -254,7 +254,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
default:
|
||||
rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType))
|
||||
}
|
||||
|
||||
case *map[string]string:
|
||||
*d = decodeHstore(vr)
|
||||
case Scanner:
|
||||
err = d.Scan(vr)
|
||||
if err != nil {
|
||||
|
|
131
values.go
131
values.go
|
@ -1,10 +1,12 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
@ -419,6 +421,81 @@ func (n NullTime) Encode(w *WriteBuf, oid Oid) error {
|
|||
return encodeTimestampTz(w, n.Time)
|
||||
}
|
||||
|
||||
// NullHstore represents an hstore that can be null or have null values
|
||||
// associated with its keys. NullHstore implements the Scanner and TextEncoder
|
||||
// interfaces so it may be used both as an argument to Query[Row] and a
|
||||
// destination for Scan for prepared and unprepared queries.
|
||||
//
|
||||
// If Valid is false, then the value of the entire hstore column is NULL
|
||||
// If any of the NullString values in Store has Valid set to false, the key
|
||||
// appears in the hstore column, but its value is explicitly set to NULL
|
||||
type NullHstore struct {
|
||||
Store map[string]NullString
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func (h *NullHstore) Scan(vr *ValueReader) error {
|
||||
//oid for hstore not standardized, so we check its type name
|
||||
if vr.Type().DataTypeName != "hstore" {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into map[string]string", vr.Type().DataTypeName)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
h.Valid = false
|
||||
return nil
|
||||
}
|
||||
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len()))
|
||||
if err != nil {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
|
||||
return nil
|
||||
}
|
||||
h.Valid = true
|
||||
h.Store = store
|
||||
return nil
|
||||
case BinaryFormatCode:
|
||||
vr.Fatal(ProtocolError("Can't decode binary hstore"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h NullHstore) FormatCode() int16 { return TextFormatCode }
|
||||
|
||||
func (h NullHstore) Encode(w *WriteBuf, oid Oid) error {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if !h.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
i := 0
|
||||
for k, v := range h.Store {
|
||||
i++
|
||||
ks := strings.Replace(k, `\`, `\\`, -1)
|
||||
ks = strings.Replace(ks, `"`, `\"`, -1)
|
||||
if v.Valid {
|
||||
vs := strings.Replace(v.String, `\`, `\\`, -1)
|
||||
vs = strings.Replace(vs, `"`, `\"`, -1)
|
||||
buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks))
|
||||
}
|
||||
if i < len(h.Store) {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
}
|
||||
w.WriteInt32(int32(buf.Len()))
|
||||
w.WriteBytes(buf.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeBool(vr *ValueReader) bool {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into bool"))
|
||||
|
@ -1045,6 +1122,60 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error {
|
|||
return encodeText(w, s)
|
||||
}
|
||||
|
||||
func decodeHstore(vr *ValueReader) map[string]string {
|
||||
//oid for hstore not standardized, so we check its type name
|
||||
if vr.Type().DataTypeName != "hstore" {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into map[string]string", vr.Type().DataTypeName)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into map[string]string"))
|
||||
return nil
|
||||
}
|
||||
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
m, err := parseHstoreToMap(vr.ReadString(vr.Len()))
|
||||
if err != nil {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
case BinaryFormatCode:
|
||||
vr.Fatal(ProtocolError("Can't decode binary hstore"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func encodeHstore(w *WriteBuf, value interface{}) error {
|
||||
var buf bytes.Buffer
|
||||
|
||||
h, ok := value.(map[string]string)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected map[string]string, received %T", value)
|
||||
}
|
||||
|
||||
i := 0
|
||||
for k, v := range h {
|
||||
i++
|
||||
ks := strings.Replace(k, `\`, `\\`, -1)
|
||||
ks = strings.Replace(ks, `"`, `\"`, -1)
|
||||
vs := strings.Replace(v, `\`, `\\`, -1)
|
||||
vs = strings.Replace(vs, `"`, `\"`, -1)
|
||||
buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs))
|
||||
if i < len(h) {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
}
|
||||
w.WriteInt32(int32(buf.Len()))
|
||||
w.WriteBytes(buf.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
|
||||
numDims := vr.ReadInt32()
|
||||
if numDims == 0 {
|
||||
|
|
Loading…
Reference in New Issue