diff --git a/conn.go b/conn.go index 4085722c..d541e942 100644 --- a/conn.go +++ b/conn.go @@ -267,47 +267,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.doneChan = make(chan struct{}) c.closedChan = make(chan error) - c.oidPgtypeValues = map[Oid]pgtype.Value{ - AclitemArrayOid: &pgtype.AclitemArray{}, - AclitemOid: &pgtype.Aclitem{}, - BoolArrayOid: &pgtype.BoolArray{}, - BoolOid: &pgtype.Bool{}, - ByteaArrayOid: &pgtype.ByteaArray{}, - ByteaOid: &pgtype.Bytea{}, - CharOid: &pgtype.QChar{}, - CidOid: &pgtype.Cid{}, - CidrArrayOid: &pgtype.CidrArray{}, - CidrOid: &pgtype.Inet{}, - DateArrayOid: &pgtype.DateArray{}, - DateOid: &pgtype.Date{}, - Float4ArrayOid: &pgtype.Float4Array{}, - Float4Oid: &pgtype.Float4{}, - Float8ArrayOid: &pgtype.Float8Array{}, - Float8Oid: &pgtype.Float8{}, - InetArrayOid: &pgtype.InetArray{}, - InetOid: &pgtype.Inet{}, - Int2ArrayOid: &pgtype.Int2Array{}, - Int2Oid: &pgtype.Int2{}, - Int4ArrayOid: &pgtype.Int4Array{}, - Int4Oid: &pgtype.Int4{}, - Int8ArrayOid: &pgtype.Int8Array{}, - Int8Oid: &pgtype.Int8{}, - JsonbOid: &pgtype.Jsonb{}, - JsonOid: &pgtype.Json{}, - NameOid: &pgtype.Name{}, - OidOid: &pgtype.Oid{}, - TextArrayOid: &pgtype.TextArray{}, - TextOid: &pgtype.Text{}, - TidOid: &pgtype.Tid{}, - TimestampArrayOid: &pgtype.TimestampArray{}, - TimestampOid: &pgtype.Timestamp{}, - TimestampTzArrayOid: &pgtype.TimestamptzArray{}, - TimestampTzOid: &pgtype.Timestamptz{}, - VarcharArrayOid: &pgtype.VarcharArray{}, - VarcharOid: &pgtype.Text{}, - XidOid: &pgtype.Xid{}, - } - if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { c.log(LogLevelDebug, "Starting TLS handshake") @@ -317,6 +276,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } + c.loadStaticOidPgtypeValues() + c.mr.cr = chunkreader.NewChunkReader(c.conn) msg := newStartupMessage() @@ -376,6 +337,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return err } } + c.loadDynamicOidPgtypeValues() return nil default: @@ -416,6 +378,60 @@ where ( return rows.Err() } +func (c *Conn) loadStaticOidPgtypeValues() { + c.oidPgtypeValues = map[Oid]pgtype.Value{ + AclitemArrayOid: &pgtype.AclitemArray{}, + AclitemOid: &pgtype.Aclitem{}, + BoolArrayOid: &pgtype.BoolArray{}, + BoolOid: &pgtype.Bool{}, + ByteaArrayOid: &pgtype.ByteaArray{}, + ByteaOid: &pgtype.Bytea{}, + CharOid: &pgtype.QChar{}, + CidOid: &pgtype.Cid{}, + CidrArrayOid: &pgtype.CidrArray{}, + CidrOid: &pgtype.Inet{}, + DateArrayOid: &pgtype.DateArray{}, + DateOid: &pgtype.Date{}, + Float4ArrayOid: &pgtype.Float4Array{}, + Float4Oid: &pgtype.Float4{}, + Float8ArrayOid: &pgtype.Float8Array{}, + Float8Oid: &pgtype.Float8{}, + InetArrayOid: &pgtype.InetArray{}, + InetOid: &pgtype.Inet{}, + Int2ArrayOid: &pgtype.Int2Array{}, + Int2Oid: &pgtype.Int2{}, + Int4ArrayOid: &pgtype.Int4Array{}, + Int4Oid: &pgtype.Int4{}, + Int8ArrayOid: &pgtype.Int8Array{}, + Int8Oid: &pgtype.Int8{}, + JsonbOid: &pgtype.Jsonb{}, + JsonOid: &pgtype.Json{}, + NameOid: &pgtype.Name{}, + OidOid: &pgtype.Oid{}, + TextArrayOid: &pgtype.TextArray{}, + TextOid: &pgtype.Text{}, + TidOid: &pgtype.Tid{}, + TimestampArrayOid: &pgtype.TimestampArray{}, + TimestampOid: &pgtype.Timestamp{}, + TimestampTzArrayOid: &pgtype.TimestamptzArray{}, + TimestampTzOid: &pgtype.Timestamptz{}, + VarcharArrayOid: &pgtype.VarcharArray{}, + VarcharOid: &pgtype.Text{}, + XidOid: &pgtype.Xid{}, + } +} + +func (c *Conn) loadDynamicOidPgtypeValues() { + nameOids := make(map[string]Oid, len(c.PgTypes)) + for k, v := range c.PgTypes { + nameOids[v.Name] = k + } + + if oid, ok := nameOids["hstore"]; ok { + c.oidPgtypeValues[oid] = &pgtype.Hstore{} + } +} + // PID returns the backend PID for this connection. func (c *Conn) PID() int32 { return c.pid diff --git a/hstore.go b/hstore.go deleted file mode 100644 index 0ab9f779..00000000 --- a/hstore.go +++ /dev/null @@ -1,222 +0,0 @@ -package pgx - -import ( - "bytes" - "errors" - "fmt" - "unicode" - "unicode/utf8" -) - -const ( - hsPre = iota - hsKey - hsSep - hsVal - hsNul - hsNext -) - -type hstoreParser struct { - str string - pos int -} - -func newHSP(in string) *hstoreParser { - return &hstoreParser{ - pos: 0, - str: in, - } -} - -func (p *hstoreParser) Consume() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return - } - r, w := utf8.DecodeRuneInString(p.str[p.pos:]) - p.pos += w - return -} - -func (p *hstoreParser) Peek() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return - } - r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) - return -} - -func parseHstoreToMap(s string) (m map[string]string, err error) { - keys, values, err := ParseHstore(s) - if err != nil { - return - } - m = make(map[string]string, len(keys)) - for i, key := range keys { - if !values[i].Valid { - err = fmt.Errorf("key '%s' has NULL value", key) - m = nil - return - } - m[key] = values[i].String - } - return -} - -func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) { - keys, values, err := ParseHstore(s) - if err != nil { - return - } - - store = make(map[string]NullString, len(keys)) - - for i, key := range keys { - store[key] = values[i] - } - return -} - -// ParseHstore parses the string representation of an hstore column (the same -// you would get from an ordinary SELECT) into two slices of keys and values. it -// is used internally in the default parsing of hstores, but is exported for use -// in handling custom data structures backed by an hstore column without the -// overhead of creating a map[string]string -func ParseHstore(s string) (k []string, v []NullString, err error) { - if s == "" { - return - } - - buf := bytes.Buffer{} - keys := []string{} - values := []NullString{} - p := newHSP(s) - - r, end := p.Consume() - state := hsPre - - for !end { - switch state { - case hsPre: - if r == '"' { - state = hsKey - } else { - err = errors.New("String does not begin with \"") - } - case hsKey: - switch r { - case '"': //End of the key - if buf.Len() == 0 { - err = errors.New("Empty Key is invalid") - } else { - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - } - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsSep: - if r == '=' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=', expecting '>'") - case r == '>': - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") - case r == '"': - state = hsVal - case r == 'N': - state = hsNul - default: - err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) - } - default: - err = fmt.Errorf("Invalid character after '=', expecting '>'") - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) - } - case hsVal: - switch r { - case '"': //End of the value - values = append(values, NullString{String: buf.String(), Valid: true}) - buf = bytes.Buffer{} - state = hsNext - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsNul: - nulBuf := make([]rune, 3) - nulBuf[0] = r - for i := 1; i < 3; i++ { - r, end = p.Consume() - if end { - err = errors.New("Found EOS in NULL value") - return - } - nulBuf[i] = r - } - if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { - values = append(values, NullString{String: "", Valid: false}) - state = hsNext - } else { - err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) - } - case hsNext: - if r == ',' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after ',', expcting space") - case (unicode.IsSpace(r)): - r, end = p.Consume() - state = hsKey - default: - err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) - } - } - - if err != nil { - return - } - r, end = p.Consume() - } - if state != hsNext { - err = errors.New("Improperly formatted hstore") - return - } - k = keys - v = values - return -} diff --git a/hstore_test.go b/hstore_test.go deleted file mode 100644 index c948f0cd..00000000 --- a/hstore_test.go +++ /dev/null @@ -1,181 +0,0 @@ -package pgx_test - -import ( - "github.com/jackc/pgx" - "testing" -) - -func TestHstoreTranscode(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type test struct { - hstore pgx.Hstore - description string - } - - tests := []test{ - {pgx.Hstore{}, "empty"}, - {pgx.Hstore{"foo": "bar"}, "single key/value"}, - {pgx.Hstore{"foo": "bar", "baz": "quz"}, "multiple key/values"}, - {pgx.Hstore{"NULL": "bar"}, `string "NULL" key`}, - {pgx.Hstore{"foo": "NULL"}, `string "NULL" value`}, - } - - specialStringTests := []struct { - input string - description string - }{ - {`"`, `double quote (")`}, - {`'`, `single quote (')`}, - {`\`, `backslash (\)`}, - {`\\`, `multiple backslashes (\\)`}, - {`=>`, `separator (=>)`}, - {` `, `space`}, - {`\ / / \\ => " ' " '`, `multiple special characters`}, - } - for _, sst := range specialStringTests { - tests = append(tests, test{pgx.Hstore{sst.input + "foo": "bar"}, "key with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.Hstore{"foo" + sst.input + "foo": "bar"}, "key with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.Hstore{"foo" + sst.input: "bar"}, "key with " + sst.description + " at end"}) - tests = append(tests, test{pgx.Hstore{sst.input: "bar"}, "key is " + sst.description}) - - tests = append(tests, test{pgx.Hstore{"foo": sst.input + "bar"}, "value with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input + "bar"}, "value with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input}, "value with " + sst.description + " at end"}) - tests = append(tests, test{pgx.Hstore{"foo": sst.input}, "value is " + sst.description}) - } - - for _, tt := range tests { - var result pgx.Hstore - err := conn.QueryRow("select $1::hstore", tt.hstore).Scan(&result) - if err != nil { - t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err) - } - - for key, inValue := range tt.hstore { - outValue, ok := result[key] - if ok { - if inValue != outValue { - t.Errorf(`%s: Key %s mismatch - expected %s, received %s`, tt.description, key, inValue, outValue) - } - } else { - t.Errorf(`%s: Missing key %s`, tt.description, key) - } - } - - ensureConnValid(t, conn) - } -} - -func TestNullHstoreTranscode(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type test struct { - nullHstore pgx.NullHstore - description string - } - - tests := []test{ - {pgx.NullHstore{}, "null"}, - {pgx.NullHstore{Valid: true}, "empty"}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}}, - Valid: true}, - "single key/value"}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}}, - Valid: true}, - "multiple key/values"}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}}, - Valid: true}, - `string "NULL" key`}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}}, - Valid: true}, - `string "NULL" value`}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}}, - Valid: true}, - `NULL value`}, - } - - specialStringTests := []struct { - input string - description string - }{ - {`"`, `double quote (")`}, - {`'`, `single quote (')`}, - {`\`, `backslash (\)`}, - {`\\`, `multiple backslashes (\\)`}, - {`=>`, `separator (=>)`}, - {` `, `space`}, - {`\ / / \\ => " ' " '`, `multiple special characters`}, - } - for _, sst := range specialStringTests { - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}}, - Valid: true}, - "key with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}}, - Valid: true}, - "key with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}}, - Valid: true}, - "key with " + sst.description + " at end"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}}, - Valid: true}, - "key is " + sst.description}) - - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}}, - Valid: true}, - "value with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}}, - Valid: true}, - "value with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}}, - Valid: true}, - "value with " + sst.description + " at end"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}}, - Valid: true}, - "value is " + sst.description}) - } - - for _, tt := range tests { - var result pgx.NullHstore - err := conn.QueryRow("select $1::hstore", tt.nullHstore).Scan(&result) - if err != nil { - t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err) - } - - if result.Valid != tt.nullHstore.Valid { - t.Errorf(`%s: Valid mismatch - expected %v, received %v`, tt.description, tt.nullHstore.Valid, result.Valid) - } - - for key, inValue := range tt.nullHstore.Hstore { - outValue, ok := result.Hstore[key] - if ok { - if inValue != outValue { - t.Errorf(`%s: Key %s mismatch - expected %v, received %v`, tt.description, key, inValue, outValue) - } - } else { - t.Errorf(`%s: Missing key %s`, tt.description, key) - } - } - - ensureConnValid(t, conn) - } -} diff --git a/pgtype/hstore.go b/pgtype/hstore.go new file mode 100644 index 00000000..11bfb9a7 --- /dev/null +++ b/pgtype/hstore.go @@ -0,0 +1,438 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "strings" + "unicode" + "unicode/utf8" + + "github.com/jackc/pgx/pgio" +) + +// Hstore represents an hstore column that can be null or have null values +// associated with its keys. +type Hstore struct { + Map map[string]Text + Status Status +} + +func (dst *Hstore) Set(src interface{}) error { + switch value := src.(type) { + case map[string]string: + m := make(map[string]Text, len(value)) + for k, v := range value { + m[k] = Text{String: v, Status: Present} + } + *dst = Hstore{Map: m, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Tid", src) + } + + return nil +} + +func (dst *Hstore) Get() interface{} { + switch dst.Status { + case Present: + return dst.Map + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Hstore) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *map[string]string: + switch src.Status { + case Present: + *v = make(map[string]string, len(src.Map)) + for k, val := range src.Map { + if val.Status != Present { + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + (*v)[k] = val.String + } + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Hstore) DecodeText(src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + keys, values, err := parseHstore(string(src)) + if err != nil { + return err + } + + m := make(map[string]Text, len(keys)) + for i := range keys { + m[keys[i]] = values[i] + } + + *dst = Hstore{Map: m, Status: Present} + return nil +} + +func (dst *Hstore) DecodeBinary(src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + m := make(map[string]Text, pairCount) + + for i := 0; i < pairCount; i++ { + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + if len(src[rp:]) < keyLen { + return fmt.Errorf("hstore incomplete %v", src) + } + key := string(src[rp : rp+keyLen]) + rp += keyLen + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var valueBuf []byte + if valueLen >= 0 { + valueBuf = src[rp : rp+valueLen] + } + rp += valueLen + + var value Text + err := value.DecodeBinary(valueBuf) + if err != nil { + return err + } + m[key] = value + } + + *dst = Hstore{Map: m, Status: Present} + + return nil +} + +func (src Hstore) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + firstPair := true + + for k, v := range src.Map { + if firstPair { + firstPair = false + } else { + err := pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + _, err := io.WriteString(w, quoteHstoreElementIfNeeded(k)) + if err != nil { + return false, err + } + + _, err = io.WriteString(w, "=>") + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + null, err := v.EncodeText(elemBuf) + if err != nil { + return false, err + } + + if null { + _, err = io.WriteString(w, "NULL") + if err != nil { + return false, err + } + } else { + _, err := io.WriteString(w, quoteHstoreElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + } + + return false, nil +} + +func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteInt32(w, int32(len(src.Map))) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + for k, v := range src.Map { + _, err := pgio.WriteInt32(w, int32(len(k))) + if err != nil { + return false, err + } + _, err = io.WriteString(w, k) + if err != nil { + return false, err + } + + null, err := v.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err := pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err := pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} + +var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteHstoreElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func quoteHstoreElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { + return quoteArrayElement(src) + } + return src +} + +const ( + hsPre = iota + hsKey + hsSep + hsVal + hsNul + hsNext +) + +type hstoreParser struct { + str string + pos int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + } +} + +func (p *hstoreParser) Consume() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, w := utf8.DecodeRuneInString(p.str[p.pos:]) + p.pos += w + return +} + +func (p *hstoreParser) Peek() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) + return +} + +// parseHstore parses the string representation of an hstore column (the same +// you would get from an ordinary SELECT) into two slices of keys and values. it +// is used internally in the default parsing of hstores. +func parseHstore(s string) (k []string, v []Text, err error) { + if s == "" { + return + } + + buf := bytes.Buffer{} + keys := []string{} + values := []Text{} + p := newHSP(s) + + r, end := p.Consume() + state := hsPre + + for !end { + switch state { + case hsPre: + if r == '"' { + state = hsKey + } else { + err = errors.New("String does not begin with \"") + } + case hsKey: + switch r { + case '"': //End of the key + if buf.Len() == 0 { + err = errors.New("Empty Key is invalid") + } else { + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep + } + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsSep: + if r == '=' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=', expecting '>'") + case r == '>': + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") + case r == '"': + state = hsVal + case r == 'N': + state = hsNul + default: + err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + } + default: + err = fmt.Errorf("Invalid character after '=', expecting '>'") + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) + } + case hsVal: + switch r { + case '"': //End of the value + values = append(values, Text{String: buf.String(), Status: Present}) + buf = bytes.Buffer{} + state = hsNext + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsNul: + nulBuf := make([]rune, 3) + nulBuf[0] = r + for i := 1; i < 3; i++ { + r, end = p.Consume() + if end { + err = errors.New("Found EOS in NULL value") + return + } + nulBuf[i] = r + } + if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { + values = append(values, Text{Status: Null}) + state = hsNext + } else { + err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + } + case hsNext: + if r == ',' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after ',', expcting space") + case (unicode.IsSpace(r)): + r, end = p.Consume() + state = hsKey + default: + err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + } + } + + if err != nil { + return + } + r, end = p.Consume() + } + if state != hsNext { + err = errors.New("Improperly formatted hstore") + return + } + k = keys + v = values + return +} diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go new file mode 100644 index 00000000..fbe8dee5 --- /dev/null +++ b/pgtype/hstore_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestHstoreTranscode(t *testing.T) { + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []interface{}{ + pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + pgtype.Hstore{Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + testSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + a := ai.(pgtype.Hstore) + b := bi.(pgtype.Hstore) + + if len(a.Map) != len(b.Map) || a.Status != b.Status { + return false + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + return false + } + } + + return true + }) +} + +func TestHstoreSet(t *testing.T) { + successfulTests := []struct { + src map[string]string + result pgtype.Hstore + }{ + {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var dst pgtype.Hstore + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreAssignTo(t *testing.T) { + var m map[string]string + + simpleTests := []struct { + src pgtype.Hstore + dst *map[string]string + expected map[string]string + }{ + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, + {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/values.go b/values.go index e976d0d3..e1c8f731 100644 --- a/values.go +++ b/values.go @@ -10,7 +10,6 @@ import ( "math" "reflect" "strconv" - "strings" "time" "github.com/jackc/pgx/pgio" @@ -577,140 +576,6 @@ func (n NullTime) Encode(w *WriteBuf, oid Oid) error { return encodeTime(w, oid, n.Time) } -// Hstore represents an hstore column. It does not support a null column or null -// key values (use NullHstore for this). Hstore implements the Scanner and -// Encoder interfaces so it may be used both as an argument to Query[Row] and a -// destination for Scan. -type Hstore map[string]string - -func (h *Hstore) Scan(vr *ValueReader) error { - //oid for hstore not standardized, so we check its type name - if vr.Type().DataTypeName != "hstore" { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into Hstore", vr.Type().DataTypeName))) - return nil - } - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null column into Hstore")) - return nil - } - - switch vr.Type().FormatCode { - case TextFormatCode: - m, err := parseHstoreToMap(vr.ReadString(vr.Len())) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) - return nil - } - hm := Hstore(m) - *h = hm - return nil - case BinaryFormatCode: - vr.Fatal(ProtocolError("Can't decode binary hstore")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } -} - -func (h Hstore) FormatCode() int16 { return TextFormatCode } - -func (h Hstore) Encode(w *WriteBuf, oid Oid) error { - var buf bytes.Buffer - - i := 0 - for k, v := range h { - i++ - ks := strings.Replace(k, `\`, `\\`, -1) - ks = strings.Replace(ks, `"`, `\"`, -1) - vs := strings.Replace(v, `\`, `\\`, -1) - vs = strings.Replace(vs, `"`, `\"`, -1) - buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) - if i < len(h) { - buf.WriteString(", ") - } - } - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - return nil -} - -// NullHstore represents an hstore column that can be null or have null values -// associated with its keys. NullHstore implements the Scanner and Encoder -// interfaces so it may be used both as an argument to Query[Row] and a -// destination for Scan. -// -// If Valid is false, then the value of the entire hstore column is NULL -// If any of the NullString values in Store has Valid set to false, the key -// appears in the hstore column, but its value is explicitly set to NULL. -type NullHstore struct { - Hstore map[string]NullString - Valid bool -} - -func (h *NullHstore) Scan(vr *ValueReader) error { - //oid for hstore not standardized, so we check its type name - if vr.Type().DataTypeName != "hstore" { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into NullHstore", vr.Type().DataTypeName))) - return nil - } - - if vr.Len() == -1 { - h.Valid = false - return nil - } - - switch vr.Type().FormatCode { - case TextFormatCode: - store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len())) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) - return nil - } - h.Valid = true - h.Hstore = store - return nil - case BinaryFormatCode: - vr.Fatal(ProtocolError("Can't decode binary hstore")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } -} - -func (h NullHstore) FormatCode() int16 { return TextFormatCode } - -func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { - var buf bytes.Buffer - - if !h.Valid { - w.WriteInt32(-1) - return nil - } - - i := 0 - for k, v := range h.Hstore { - i++ - ks := strings.Replace(k, `\`, `\\`, -1) - ks = strings.Replace(ks, `"`, `\"`, -1) - if v.Valid { - vs := strings.Replace(v.String, `\`, `\\`, -1) - vs = strings.Replace(vs, `"`, `\"`, -1) - buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) - } else { - buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks)) - } - if i < len(h.Hstore) { - buf.WriteString(", ") - } - } - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - return nil -} - // Encode encodes arg into wbuf as the type oid. This allows implementations // of the Encoder interface to delegate the actual work of encoding to the // built-in functionality.