Move hstore to pgtype

Also implement binary format
v3-numeric-wip
Jack Christensen 2017-03-12 17:06:06 -05:00
parent 3391818847
commit 7bb1f3677d
6 changed files with 603 additions and 579 deletions

98
conn.go
View File

@ -267,47 +267,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.doneChan = make(chan struct{}) c.doneChan = make(chan struct{})
c.closedChan = make(chan error) c.closedChan = make(chan error)
c.oidPgtypeValues = map[Oid]pgtype.Value{
AclitemArrayOid: &pgtype.AclitemArray{},
AclitemOid: &pgtype.Aclitem{},
BoolArrayOid: &pgtype.BoolArray{},
BoolOid: &pgtype.Bool{},
ByteaArrayOid: &pgtype.ByteaArray{},
ByteaOid: &pgtype.Bytea{},
CharOid: &pgtype.QChar{},
CidOid: &pgtype.Cid{},
CidrArrayOid: &pgtype.CidrArray{},
CidrOid: &pgtype.Inet{},
DateArrayOid: &pgtype.DateArray{},
DateOid: &pgtype.Date{},
Float4ArrayOid: &pgtype.Float4Array{},
Float4Oid: &pgtype.Float4{},
Float8ArrayOid: &pgtype.Float8Array{},
Float8Oid: &pgtype.Float8{},
InetArrayOid: &pgtype.InetArray{},
InetOid: &pgtype.Inet{},
Int2ArrayOid: &pgtype.Int2Array{},
Int2Oid: &pgtype.Int2{},
Int4ArrayOid: &pgtype.Int4Array{},
Int4Oid: &pgtype.Int4{},
Int8ArrayOid: &pgtype.Int8Array{},
Int8Oid: &pgtype.Int8{},
JsonbOid: &pgtype.Jsonb{},
JsonOid: &pgtype.Json{},
NameOid: &pgtype.Name{},
OidOid: &pgtype.Oid{},
TextArrayOid: &pgtype.TextArray{},
TextOid: &pgtype.Text{},
TidOid: &pgtype.Tid{},
TimestampArrayOid: &pgtype.TimestampArray{},
TimestampOid: &pgtype.Timestamp{},
TimestampTzArrayOid: &pgtype.TimestamptzArray{},
TimestampTzOid: &pgtype.Timestamptz{},
VarcharArrayOid: &pgtype.VarcharArray{},
VarcharOid: &pgtype.Text{},
XidOid: &pgtype.Xid{},
}
if tlsConfig != nil { if tlsConfig != nil {
if c.shouldLog(LogLevelDebug) { if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Starting TLS handshake") c.log(LogLevelDebug, "Starting TLS handshake")
@ -317,6 +276,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
} }
} }
c.loadStaticOidPgtypeValues()
c.mr.cr = chunkreader.NewChunkReader(c.conn) c.mr.cr = chunkreader.NewChunkReader(c.conn)
msg := newStartupMessage() msg := newStartupMessage()
@ -376,6 +337,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
return err return err
} }
} }
c.loadDynamicOidPgtypeValues()
return nil return nil
default: default:
@ -416,6 +378,60 @@ where (
return rows.Err() return rows.Err()
} }
func (c *Conn) loadStaticOidPgtypeValues() {
c.oidPgtypeValues = map[Oid]pgtype.Value{
AclitemArrayOid: &pgtype.AclitemArray{},
AclitemOid: &pgtype.Aclitem{},
BoolArrayOid: &pgtype.BoolArray{},
BoolOid: &pgtype.Bool{},
ByteaArrayOid: &pgtype.ByteaArray{},
ByteaOid: &pgtype.Bytea{},
CharOid: &pgtype.QChar{},
CidOid: &pgtype.Cid{},
CidrArrayOid: &pgtype.CidrArray{},
CidrOid: &pgtype.Inet{},
DateArrayOid: &pgtype.DateArray{},
DateOid: &pgtype.Date{},
Float4ArrayOid: &pgtype.Float4Array{},
Float4Oid: &pgtype.Float4{},
Float8ArrayOid: &pgtype.Float8Array{},
Float8Oid: &pgtype.Float8{},
InetArrayOid: &pgtype.InetArray{},
InetOid: &pgtype.Inet{},
Int2ArrayOid: &pgtype.Int2Array{},
Int2Oid: &pgtype.Int2{},
Int4ArrayOid: &pgtype.Int4Array{},
Int4Oid: &pgtype.Int4{},
Int8ArrayOid: &pgtype.Int8Array{},
Int8Oid: &pgtype.Int8{},
JsonbOid: &pgtype.Jsonb{},
JsonOid: &pgtype.Json{},
NameOid: &pgtype.Name{},
OidOid: &pgtype.Oid{},
TextArrayOid: &pgtype.TextArray{},
TextOid: &pgtype.Text{},
TidOid: &pgtype.Tid{},
TimestampArrayOid: &pgtype.TimestampArray{},
TimestampOid: &pgtype.Timestamp{},
TimestampTzArrayOid: &pgtype.TimestamptzArray{},
TimestampTzOid: &pgtype.Timestamptz{},
VarcharArrayOid: &pgtype.VarcharArray{},
VarcharOid: &pgtype.Text{},
XidOid: &pgtype.Xid{},
}
}
func (c *Conn) loadDynamicOidPgtypeValues() {
nameOids := make(map[string]Oid, len(c.PgTypes))
for k, v := range c.PgTypes {
nameOids[v.Name] = k
}
if oid, ok := nameOids["hstore"]; ok {
c.oidPgtypeValues[oid] = &pgtype.Hstore{}
}
}
// PID returns the backend PID for this connection. // PID returns the backend PID for this connection.
func (c *Conn) PID() int32 { func (c *Conn) PID() int32 {
return c.pid return c.pid

222
hstore.go
View File

@ -1,222 +0,0 @@
package pgx
import (
"bytes"
"errors"
"fmt"
"unicode"
"unicode/utf8"
)
const (
hsPre = iota
hsKey
hsSep
hsVal
hsNul
hsNext
)
type hstoreParser struct {
str string
pos int
}
func newHSP(in string) *hstoreParser {
return &hstoreParser{
pos: 0,
str: in,
}
}
func (p *hstoreParser) Consume() (r rune, end bool) {
if p.pos >= len(p.str) {
end = true
return
}
r, w := utf8.DecodeRuneInString(p.str[p.pos:])
p.pos += w
return
}
func (p *hstoreParser) Peek() (r rune, end bool) {
if p.pos >= len(p.str) {
end = true
return
}
r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
return
}
func parseHstoreToMap(s string) (m map[string]string, err error) {
keys, values, err := ParseHstore(s)
if err != nil {
return
}
m = make(map[string]string, len(keys))
for i, key := range keys {
if !values[i].Valid {
err = fmt.Errorf("key '%s' has NULL value", key)
m = nil
return
}
m[key] = values[i].String
}
return
}
func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) {
keys, values, err := ParseHstore(s)
if err != nil {
return
}
store = make(map[string]NullString, len(keys))
for i, key := range keys {
store[key] = values[i]
}
return
}
// ParseHstore parses the string representation of an hstore column (the same
// you would get from an ordinary SELECT) into two slices of keys and values. it
// is used internally in the default parsing of hstores, but is exported for use
// in handling custom data structures backed by an hstore column without the
// overhead of creating a map[string]string
func ParseHstore(s string) (k []string, v []NullString, err error) {
if s == "" {
return
}
buf := bytes.Buffer{}
keys := []string{}
values := []NullString{}
p := newHSP(s)
r, end := p.Consume()
state := hsPre
for !end {
switch state {
case hsPre:
if r == '"' {
state = hsKey
} else {
err = errors.New("String does not begin with \"")
}
case hsKey:
switch r {
case '"': //End of the key
if buf.Len() == 0 {
err = errors.New("Empty Key is invalid")
} else {
keys = append(keys, buf.String())
buf = bytes.Buffer{}
state = hsSep
}
case '\\': //Potential escaped character
n, end := p.Consume()
switch {
case end:
err = errors.New("Found EOS in key, expecting character or \"")
case n == '"', n == '\\':
buf.WriteRune(n)
default:
buf.WriteRune(r)
buf.WriteRune(n)
}
default: //Any other character
buf.WriteRune(r)
}
case hsSep:
if r == '=' {
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after '=', expecting '>'")
case r == '>':
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
case r == '"':
state = hsVal
case r == 'N':
state = hsNul
default:
err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r)
}
default:
err = fmt.Errorf("Invalid character after '=', expecting '>'")
}
} else {
err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r)
}
case hsVal:
switch r {
case '"': //End of the value
values = append(values, NullString{String: buf.String(), Valid: true})
buf = bytes.Buffer{}
state = hsNext
case '\\': //Potential escaped character
n, end := p.Consume()
switch {
case end:
err = errors.New("Found EOS in key, expecting character or \"")
case n == '"', n == '\\':
buf.WriteRune(n)
default:
buf.WriteRune(r)
buf.WriteRune(n)
}
default: //Any other character
buf.WriteRune(r)
}
case hsNul:
nulBuf := make([]rune, 3)
nulBuf[0] = r
for i := 1; i < 3; i++ {
r, end = p.Consume()
if end {
err = errors.New("Found EOS in NULL value")
return
}
nulBuf[i] = r
}
if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
values = append(values, NullString{String: "", Valid: false})
state = hsNext
} else {
err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
}
case hsNext:
if r == ',' {
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after ',', expcting space")
case (unicode.IsSpace(r)):
r, end = p.Consume()
state = hsKey
default:
err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r)
}
} else {
err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r)
}
}
if err != nil {
return
}
r, end = p.Consume()
}
if state != hsNext {
err = errors.New("Improperly formatted hstore")
return
}
k = keys
v = values
return
}

View File

@ -1,181 +0,0 @@
package pgx_test
import (
"github.com/jackc/pgx"
"testing"
)
func TestHstoreTranscode(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
type test struct {
hstore pgx.Hstore
description string
}
tests := []test{
{pgx.Hstore{}, "empty"},
{pgx.Hstore{"foo": "bar"}, "single key/value"},
{pgx.Hstore{"foo": "bar", "baz": "quz"}, "multiple key/values"},
{pgx.Hstore{"NULL": "bar"}, `string "NULL" key`},
{pgx.Hstore{"foo": "NULL"}, `string "NULL" value`},
}
specialStringTests := []struct {
input string
description string
}{
{`"`, `double quote (")`},
{`'`, `single quote (')`},
{`\`, `backslash (\)`},
{`\\`, `multiple backslashes (\\)`},
{`=>`, `separator (=>)`},
{` `, `space`},
{`\ / / \\ => " ' " '`, `multiple special characters`},
}
for _, sst := range specialStringTests {
tests = append(tests, test{pgx.Hstore{sst.input + "foo": "bar"}, "key with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.Hstore{"foo" + sst.input + "foo": "bar"}, "key with " + sst.description + " in middle"})
tests = append(tests, test{pgx.Hstore{"foo" + sst.input: "bar"}, "key with " + sst.description + " at end"})
tests = append(tests, test{pgx.Hstore{sst.input: "bar"}, "key is " + sst.description})
tests = append(tests, test{pgx.Hstore{"foo": sst.input + "bar"}, "value with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input + "bar"}, "value with " + sst.description + " in middle"})
tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input}, "value with " + sst.description + " at end"})
tests = append(tests, test{pgx.Hstore{"foo": sst.input}, "value is " + sst.description})
}
for _, tt := range tests {
var result pgx.Hstore
err := conn.QueryRow("select $1::hstore", tt.hstore).Scan(&result)
if err != nil {
t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err)
}
for key, inValue := range tt.hstore {
outValue, ok := result[key]
if ok {
if inValue != outValue {
t.Errorf(`%s: Key %s mismatch - expected %s, received %s`, tt.description, key, inValue, outValue)
}
} else {
t.Errorf(`%s: Missing key %s`, tt.description, key)
}
}
ensureConnValid(t, conn)
}
}
func TestNullHstoreTranscode(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
type test struct {
nullHstore pgx.NullHstore
description string
}
tests := []test{
{pgx.NullHstore{}, "null"},
{pgx.NullHstore{Valid: true}, "empty"},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}},
Valid: true},
"single key/value"},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}},
Valid: true},
"multiple key/values"},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}},
Valid: true},
`string "NULL" key`},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}},
Valid: true},
`string "NULL" value`},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}},
Valid: true},
`NULL value`},
}
specialStringTests := []struct {
input string
description string
}{
{`"`, `double quote (")`},
{`'`, `single quote (')`},
{`\`, `backslash (\)`},
{`\\`, `multiple backslashes (\\)`},
{`=>`, `separator (=>)`},
{` `, `space`},
{`\ / / \\ => " ' " '`, `multiple special characters`},
}
for _, sst := range specialStringTests {
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}},
Valid: true},
"key with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}},
Valid: true},
"key with " + sst.description + " in middle"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}},
Valid: true},
"key with " + sst.description + " at end"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}},
Valid: true},
"key is " + sst.description})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}},
Valid: true},
"value with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}},
Valid: true},
"value with " + sst.description + " in middle"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}},
Valid: true},
"value with " + sst.description + " at end"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}},
Valid: true},
"value is " + sst.description})
}
for _, tt := range tests {
var result pgx.NullHstore
err := conn.QueryRow("select $1::hstore", tt.nullHstore).Scan(&result)
if err != nil {
t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err)
}
if result.Valid != tt.nullHstore.Valid {
t.Errorf(`%s: Valid mismatch - expected %v, received %v`, tt.description, tt.nullHstore.Valid, result.Valid)
}
for key, inValue := range tt.nullHstore.Hstore {
outValue, ok := result.Hstore[key]
if ok {
if inValue != outValue {
t.Errorf(`%s: Key %s mismatch - expected %v, received %v`, tt.description, key, inValue, outValue)
}
} else {
t.Errorf(`%s: Missing key %s`, tt.description, key)
}
}
ensureConnValid(t, conn)
}
}

438
pgtype/hstore.go Normal file
View File

@ -0,0 +1,438 @@
package pgtype
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
"unicode"
"unicode/utf8"
"github.com/jackc/pgx/pgio"
)
// Hstore represents an hstore column that can be null or have null values
// associated with its keys.
type Hstore struct {
Map map[string]Text
Status Status
}
func (dst *Hstore) Set(src interface{}) error {
switch value := src.(type) {
case map[string]string:
m := make(map[string]Text, len(value))
for k, v := range value {
m[k] = Text{String: v, Status: Present}
}
*dst = Hstore{Map: m, Status: Present}
default:
return fmt.Errorf("cannot convert %v to Tid", src)
}
return nil
}
func (dst *Hstore) Get() interface{} {
switch dst.Status {
case Present:
return dst.Map
case Null:
return nil
default:
return dst.Status
}
}
func (src *Hstore) AssignTo(dst interface{}) error {
switch v := dst.(type) {
case *map[string]string:
switch src.Status {
case Present:
*v = make(map[string]string, len(src.Map))
for k, val := range src.Map {
if val.Status != Present {
return fmt.Errorf("cannot decode %v into %T", src, dst)
}
(*v)[k] = val.String
}
case Null:
*v = nil
default:
return fmt.Errorf("cannot decode %v into %T", src, dst)
}
default:
return fmt.Errorf("cannot decode %v into %T", src, dst)
}
return nil
}
func (dst *Hstore) DecodeText(src []byte) error {
if src == nil {
*dst = Hstore{Status: Null}
return nil
}
keys, values, err := parseHstore(string(src))
if err != nil {
return err
}
m := make(map[string]Text, len(keys))
for i := range keys {
m[keys[i]] = values[i]
}
*dst = Hstore{Map: m, Status: Present}
return nil
}
func (dst *Hstore) DecodeBinary(src []byte) error {
if src == nil {
*dst = Hstore{Status: Null}
return nil
}
rp := 0
if len(src[rp:]) < 4 {
return fmt.Errorf("hstore incomplete %v", src)
}
pairCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
m := make(map[string]Text, pairCount)
for i := 0; i < pairCount; i++ {
if len(src[rp:]) < 4 {
return fmt.Errorf("hstore incomplete %v", src)
}
keyLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
if len(src[rp:]) < keyLen {
return fmt.Errorf("hstore incomplete %v", src)
}
key := string(src[rp : rp+keyLen])
rp += keyLen
if len(src[rp:]) < 4 {
return fmt.Errorf("hstore incomplete %v", src)
}
valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var valueBuf []byte
if valueLen >= 0 {
valueBuf = src[rp : rp+valueLen]
}
rp += valueLen
var value Text
err := value.DecodeBinary(valueBuf)
if err != nil {
return err
}
m[key] = value
}
*dst = Hstore{Map: m, Status: Present}
return nil
}
func (src Hstore) EncodeText(w io.Writer) (bool, error) {
switch src.Status {
case Null:
return true, nil
case Undefined:
return false, errUndefined
}
firstPair := true
for k, v := range src.Map {
if firstPair {
firstPair = false
} else {
err := pgio.WriteByte(w, ',')
if err != nil {
return false, err
}
}
_, err := io.WriteString(w, quoteHstoreElementIfNeeded(k))
if err != nil {
return false, err
}
_, err = io.WriteString(w, "=>")
if err != nil {
return false, err
}
elemBuf := &bytes.Buffer{}
null, err := v.EncodeText(elemBuf)
if err != nil {
return false, err
}
if null {
_, err = io.WriteString(w, "NULL")
if err != nil {
return false, err
}
} else {
_, err := io.WriteString(w, quoteHstoreElementIfNeeded(elemBuf.String()))
if err != nil {
return false, err
}
}
}
return false, nil
}
func (src Hstore) EncodeBinary(w io.Writer) (bool, error) {
switch src.Status {
case Null:
return true, nil
case Undefined:
return false, errUndefined
}
_, err := pgio.WriteInt32(w, int32(len(src.Map)))
if err != nil {
return false, err
}
elemBuf := &bytes.Buffer{}
for k, v := range src.Map {
_, err := pgio.WriteInt32(w, int32(len(k)))
if err != nil {
return false, err
}
_, err = io.WriteString(w, k)
if err != nil {
return false, err
}
null, err := v.EncodeText(elemBuf)
if err != nil {
return false, err
}
if null {
_, err := pgio.WriteInt32(w, -1)
if err != nil {
return false, err
}
} else {
_, err := pgio.WriteInt32(w, int32(elemBuf.Len()))
if err != nil {
return false, err
}
_, err = elemBuf.WriteTo(w)
if err != nil {
return false, err
}
}
}
return false, err
}
var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
func quoteHstoreElement(src string) string {
return `"` + quoteArrayReplacer.Replace(src) + `"`
}
func quoteHstoreElementIfNeeded(src string) string {
if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) {
return quoteArrayElement(src)
}
return src
}
const (
hsPre = iota
hsKey
hsSep
hsVal
hsNul
hsNext
)
type hstoreParser struct {
str string
pos int
}
func newHSP(in string) *hstoreParser {
return &hstoreParser{
pos: 0,
str: in,
}
}
func (p *hstoreParser) Consume() (r rune, end bool) {
if p.pos >= len(p.str) {
end = true
return
}
r, w := utf8.DecodeRuneInString(p.str[p.pos:])
p.pos += w
return
}
func (p *hstoreParser) Peek() (r rune, end bool) {
if p.pos >= len(p.str) {
end = true
return
}
r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
return
}
// parseHstore parses the string representation of an hstore column (the same
// you would get from an ordinary SELECT) into two slices of keys and values. it
// is used internally in the default parsing of hstores.
func parseHstore(s string) (k []string, v []Text, err error) {
if s == "" {
return
}
buf := bytes.Buffer{}
keys := []string{}
values := []Text{}
p := newHSP(s)
r, end := p.Consume()
state := hsPre
for !end {
switch state {
case hsPre:
if r == '"' {
state = hsKey
} else {
err = errors.New("String does not begin with \"")
}
case hsKey:
switch r {
case '"': //End of the key
if buf.Len() == 0 {
err = errors.New("Empty Key is invalid")
} else {
keys = append(keys, buf.String())
buf = bytes.Buffer{}
state = hsSep
}
case '\\': //Potential escaped character
n, end := p.Consume()
switch {
case end:
err = errors.New("Found EOS in key, expecting character or \"")
case n == '"', n == '\\':
buf.WriteRune(n)
default:
buf.WriteRune(r)
buf.WriteRune(n)
}
default: //Any other character
buf.WriteRune(r)
}
case hsSep:
if r == '=' {
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after '=', expecting '>'")
case r == '>':
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
case r == '"':
state = hsVal
case r == 'N':
state = hsNul
default:
err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r)
}
default:
err = fmt.Errorf("Invalid character after '=', expecting '>'")
}
} else {
err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r)
}
case hsVal:
switch r {
case '"': //End of the value
values = append(values, Text{String: buf.String(), Status: Present})
buf = bytes.Buffer{}
state = hsNext
case '\\': //Potential escaped character
n, end := p.Consume()
switch {
case end:
err = errors.New("Found EOS in key, expecting character or \"")
case n == '"', n == '\\':
buf.WriteRune(n)
default:
buf.WriteRune(r)
buf.WriteRune(n)
}
default: //Any other character
buf.WriteRune(r)
}
case hsNul:
nulBuf := make([]rune, 3)
nulBuf[0] = r
for i := 1; i < 3; i++ {
r, end = p.Consume()
if end {
err = errors.New("Found EOS in NULL value")
return
}
nulBuf[i] = r
}
if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
values = append(values, Text{Status: Null})
state = hsNext
} else {
err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
}
case hsNext:
if r == ',' {
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after ',', expcting space")
case (unicode.IsSpace(r)):
r, end = p.Consume()
state = hsKey
default:
err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r)
}
} else {
err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r)
}
}
if err != nil {
return
}
r, end = p.Consume()
}
if state != hsNext {
err = errors.New("Improperly formatted hstore")
return
}
k = keys
v = values
return
}

108
pgtype/hstore_test.go Normal file
View File

@ -0,0 +1,108 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
)
func TestHstoreTranscode(t *testing.T) {
text := func(s string) pgtype.Text {
return pgtype.Text{String: s, Status: pgtype.Present}
}
values := []interface{}{
pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present},
pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present},
pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present},
pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present},
pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present},
pgtype.Hstore{Status: pgtype.Null},
}
specialStrings := []string{
`"`,
`'`,
`\`,
`\\`,
`=>`,
` `,
`\ / / \\ => " ' " '`,
}
for _, s := range specialStrings {
// Special key values
values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning
values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle
values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end
values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key
// Special value values
values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning
values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle
values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end
values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key
}
testSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool {
a := ai.(pgtype.Hstore)
b := bi.(pgtype.Hstore)
if len(a.Map) != len(b.Map) || a.Status != b.Status {
return false
}
for k := range a.Map {
if a.Map[k] != b.Map[k] {
return false
}
}
return true
})
}
func TestHstoreSet(t *testing.T) {
successfulTests := []struct {
src map[string]string
result pgtype.Hstore
}{
{src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}},
}
for i, tt := range successfulTests {
var dst pgtype.Hstore
err := dst.Set(tt.src)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !reflect.DeepEqual(dst, tt.result) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst)
}
}
}
func TestHstoreAssignTo(t *testing.T) {
var m map[string]string
simpleTests := []struct {
src pgtype.Hstore
dst *map[string]string
expected map[string]string
}{
{src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}},
{src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !reflect.DeepEqual(*tt.dst, tt.expected) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst)
}
}
}

135
values.go
View File

@ -10,7 +10,6 @@ import (
"math" "math"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgio"
@ -577,140 +576,6 @@ func (n NullTime) Encode(w *WriteBuf, oid Oid) error {
return encodeTime(w, oid, n.Time) return encodeTime(w, oid, n.Time)
} }
// Hstore represents an hstore column. It does not support a null column or null
// key values (use NullHstore for this). Hstore implements the Scanner and
// Encoder interfaces so it may be used both as an argument to Query[Row] and a
// destination for Scan.
type Hstore map[string]string
func (h *Hstore) Scan(vr *ValueReader) error {
//oid for hstore not standardized, so we check its type name
if vr.Type().DataTypeName != "hstore" {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into Hstore", vr.Type().DataTypeName)))
return nil
}
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null column into Hstore"))
return nil
}
switch vr.Type().FormatCode {
case TextFormatCode:
m, err := parseHstoreToMap(vr.ReadString(vr.Len()))
if err != nil {
vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
return nil
}
hm := Hstore(m)
*h = hm
return nil
case BinaryFormatCode:
vr.Fatal(ProtocolError("Can't decode binary hstore"))
return nil
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
}
func (h Hstore) FormatCode() int16 { return TextFormatCode }
func (h Hstore) Encode(w *WriteBuf, oid Oid) error {
var buf bytes.Buffer
i := 0
for k, v := range h {
i++
ks := strings.Replace(k, `\`, `\\`, -1)
ks = strings.Replace(ks, `"`, `\"`, -1)
vs := strings.Replace(v, `\`, `\\`, -1)
vs = strings.Replace(vs, `"`, `\"`, -1)
buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs))
if i < len(h) {
buf.WriteString(", ")
}
}
w.WriteInt32(int32(buf.Len()))
w.WriteBytes(buf.Bytes())
return nil
}
// NullHstore represents an hstore column that can be null or have null values
// associated with its keys. NullHstore implements the Scanner and Encoder
// interfaces so it may be used both as an argument to Query[Row] and a
// destination for Scan.
//
// If Valid is false, then the value of the entire hstore column is NULL
// If any of the NullString values in Store has Valid set to false, the key
// appears in the hstore column, but its value is explicitly set to NULL.
type NullHstore struct {
Hstore map[string]NullString
Valid bool
}
func (h *NullHstore) Scan(vr *ValueReader) error {
//oid for hstore not standardized, so we check its type name
if vr.Type().DataTypeName != "hstore" {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into NullHstore", vr.Type().DataTypeName)))
return nil
}
if vr.Len() == -1 {
h.Valid = false
return nil
}
switch vr.Type().FormatCode {
case TextFormatCode:
store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len()))
if err != nil {
vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
return nil
}
h.Valid = true
h.Hstore = store
return nil
case BinaryFormatCode:
vr.Fatal(ProtocolError("Can't decode binary hstore"))
return nil
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
}
func (h NullHstore) FormatCode() int16 { return TextFormatCode }
func (h NullHstore) Encode(w *WriteBuf, oid Oid) error {
var buf bytes.Buffer
if !h.Valid {
w.WriteInt32(-1)
return nil
}
i := 0
for k, v := range h.Hstore {
i++
ks := strings.Replace(k, `\`, `\\`, -1)
ks = strings.Replace(ks, `"`, `\"`, -1)
if v.Valid {
vs := strings.Replace(v.String, `\`, `\\`, -1)
vs = strings.Replace(vs, `"`, `\"`, -1)
buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs))
} else {
buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks))
}
if i < len(h.Hstore) {
buf.WriteString(", ")
}
}
w.WriteInt32(int32(buf.Len()))
w.WriteBytes(buf.Bytes())
return nil
}
// Encode encodes arg into wbuf as the type oid. This allows implementations // Encode encodes arg into wbuf as the type oid. This allows implementations
// of the Encoder interface to delegate the actual work of encoding to the // of the Encoder interface to delegate the actual work of encoding to the
// built-in functionality. // built-in functionality.