mirror of
https://github.com/jackc/pgx.git
synced 2025-04-27 13:14:32 +00:00
Before this change, the Hstore text protocol did not quote keys or values containing non-space whitespace ("\r\n\v\t"). This causes inserts with these values to fail with errors like: ERROR: Syntax error near "r" at position 17 (SQLSTATE XX000) The previous version also quoted curly braces ("{}"), but they don't seem to require quoting. It is possible that it would be easier to just always quote the values, which is what Postgres does when encoding its text protocol, but this is a smaller change.
471 lines
9.8 KiB
Go
471 lines
9.8 KiB
Go
package pgtype
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"unicode"
|
|
"unicode/utf8"
|
|
|
|
"github.com/jackc/pgx/v5/internal/pgio"
|
|
)
|
|
|
|
type HstoreScanner interface {
|
|
ScanHstore(v Hstore) error
|
|
}
|
|
|
|
type HstoreValuer interface {
|
|
HstoreValue() (Hstore, error)
|
|
}
|
|
|
|
// Hstore represents an hstore column that can be null or have null values
|
|
// associated with its keys.
|
|
type Hstore map[string]*string
|
|
|
|
func (h *Hstore) ScanHstore(v Hstore) error {
|
|
*h = v
|
|
return nil
|
|
}
|
|
|
|
func (h Hstore) HstoreValue() (Hstore, error) {
|
|
return h, nil
|
|
}
|
|
|
|
// Scan implements the database/sql Scanner interface.
|
|
func (h *Hstore) Scan(src any) error {
|
|
if src == nil {
|
|
*h = nil
|
|
return nil
|
|
}
|
|
|
|
switch src := src.(type) {
|
|
case string:
|
|
return scanPlanTextAnyToHstoreScanner{}.Scan([]byte(src), h)
|
|
}
|
|
|
|
return fmt.Errorf("cannot scan %T", src)
|
|
}
|
|
|
|
// Value implements the database/sql/driver Valuer interface.
|
|
func (h Hstore) Value() (driver.Value, error) {
|
|
if h == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
buf, err := HstoreCodec{}.PlanEncode(nil, 0, TextFormatCode, h).Encode(h, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return string(buf), err
|
|
}
|
|
|
|
type HstoreCodec struct{}
|
|
|
|
func (HstoreCodec) FormatSupported(format int16) bool {
|
|
return format == TextFormatCode || format == BinaryFormatCode
|
|
}
|
|
|
|
func (HstoreCodec) PreferredFormat() int16 {
|
|
return BinaryFormatCode
|
|
}
|
|
|
|
func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
|
if _, ok := value.(HstoreValuer); !ok {
|
|
return nil
|
|
}
|
|
|
|
switch format {
|
|
case BinaryFormatCode:
|
|
return encodePlanHstoreCodecBinary{}
|
|
case TextFormatCode:
|
|
return encodePlanHstoreCodecText{}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type encodePlanHstoreCodecBinary struct{}
|
|
|
|
func (encodePlanHstoreCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
|
hstore, err := value.(HstoreValuer).HstoreValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if hstore == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
buf = pgio.AppendInt32(buf, int32(len(hstore)))
|
|
|
|
for k, v := range hstore {
|
|
buf = pgio.AppendInt32(buf, int32(len(k)))
|
|
buf = append(buf, k...)
|
|
|
|
if v == nil {
|
|
buf = pgio.AppendInt32(buf, -1)
|
|
} else {
|
|
buf = pgio.AppendInt32(buf, int32(len(*v)))
|
|
buf = append(buf, (*v)...)
|
|
}
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
type encodePlanHstoreCodecText struct{}
|
|
|
|
func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
|
hstore, err := value.(HstoreValuer).HstoreValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if hstore == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
firstPair := true
|
|
|
|
for k, v := range hstore {
|
|
if firstPair {
|
|
firstPair = false
|
|
} else {
|
|
buf = append(buf, ',')
|
|
}
|
|
|
|
buf = append(buf, quoteHstoreElementIfNeeded(k)...)
|
|
buf = append(buf, "=>"...)
|
|
|
|
if v == nil {
|
|
buf = append(buf, "NULL"...)
|
|
} else {
|
|
buf = append(buf, quoteHstoreElementIfNeeded(*v)...)
|
|
}
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
|
|
|
switch format {
|
|
case BinaryFormatCode:
|
|
switch target.(type) {
|
|
case HstoreScanner:
|
|
return scanPlanBinaryHstoreToHstoreScanner{}
|
|
}
|
|
case TextFormatCode:
|
|
switch target.(type) {
|
|
case HstoreScanner:
|
|
return scanPlanTextAnyToHstoreScanner{}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type scanPlanBinaryHstoreToHstoreScanner struct{}
|
|
|
|
func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {
|
|
scanner := (dst).(HstoreScanner)
|
|
|
|
if src == nil {
|
|
return scanner.ScanHstore(Hstore(nil))
|
|
}
|
|
|
|
rp := 0
|
|
|
|
if len(src[rp:]) < 4 {
|
|
return fmt.Errorf("hstore incomplete %v", src)
|
|
}
|
|
pairCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
|
rp += 4
|
|
|
|
hstore := make(Hstore, pairCount)
|
|
|
|
for i := 0; i < pairCount; i++ {
|
|
if len(src[rp:]) < 4 {
|
|
return fmt.Errorf("hstore incomplete %v", src)
|
|
}
|
|
keyLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
|
rp += 4
|
|
|
|
if len(src[rp:]) < keyLen {
|
|
return fmt.Errorf("hstore incomplete %v", src)
|
|
}
|
|
key := string(src[rp : rp+keyLen])
|
|
rp += keyLen
|
|
|
|
if len(src[rp:]) < 4 {
|
|
return fmt.Errorf("hstore incomplete %v", src)
|
|
}
|
|
valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
|
rp += 4
|
|
|
|
var valueBuf []byte
|
|
if valueLen >= 0 {
|
|
valueBuf = src[rp : rp+valueLen]
|
|
rp += valueLen
|
|
}
|
|
|
|
var value Text
|
|
err := scanPlanTextAnyToTextScanner{}.Scan(valueBuf, &value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if value.Valid {
|
|
hstore[key] = &value.String
|
|
} else {
|
|
hstore[key] = nil
|
|
}
|
|
}
|
|
|
|
return scanner.ScanHstore(hstore)
|
|
}
|
|
|
|
type scanPlanTextAnyToHstoreScanner struct{}
|
|
|
|
func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error {
|
|
scanner := (dst).(HstoreScanner)
|
|
|
|
if src == nil {
|
|
return scanner.ScanHstore(Hstore(nil))
|
|
}
|
|
|
|
keys, values, err := parseHstore(string(src))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
m := make(Hstore, len(keys))
|
|
for i := range keys {
|
|
if values[i].Valid {
|
|
m[keys[i]] = &values[i].String
|
|
} else {
|
|
m[keys[i]] = nil
|
|
}
|
|
}
|
|
|
|
return scanner.ScanHstore(m)
|
|
}
|
|
|
|
func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
|
return codecDecodeToTextFormat(c, m, oid, format, src)
|
|
}
|
|
|
|
func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
|
if src == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
var hstore Hstore
|
|
err := codecScan(c, m, oid, format, src, &hstore)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return hstore, nil
|
|
}
|
|
|
|
func quoteHstoreElementIfNeeded(src string) string {
|
|
// Double-quote keys and values that include whitespace, commas, =s or >s. To include a double
|
|
// quote or a backslash in a key or value, escape it with a backslash.
|
|
// From: https://www.postgresql.org/docs/current/hstore.html
|
|
// whitespace appears to be defined as the isspace() C function: \t\n\v\f\r\n and space
|
|
const quoteRequiredChars = `,"\=> ` + "\t\n\v\f\r"
|
|
if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, quoteRequiredChars) {
|
|
return quoteArrayElement(src)
|
|
}
|
|
|
|
return src
|
|
}
|
|
|
|
const (
|
|
hsPre = iota
|
|
hsKey
|
|
hsSep
|
|
hsVal
|
|
hsNul
|
|
hsNext
|
|
)
|
|
|
|
type hstoreParser struct {
|
|
str string
|
|
pos int
|
|
}
|
|
|
|
func newHSP(in string) *hstoreParser {
|
|
return &hstoreParser{
|
|
pos: 0,
|
|
str: in,
|
|
}
|
|
}
|
|
|
|
func (p *hstoreParser) Consume() (r rune, end bool) {
|
|
if p.pos >= len(p.str) {
|
|
end = true
|
|
return
|
|
}
|
|
r, w := utf8.DecodeRuneInString(p.str[p.pos:])
|
|
p.pos += w
|
|
return
|
|
}
|
|
|
|
func (p *hstoreParser) Peek() (r rune, end bool) {
|
|
if p.pos >= len(p.str) {
|
|
end = true
|
|
return
|
|
}
|
|
r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
|
|
return
|
|
}
|
|
|
|
// parseHstore parses the string representation of an hstore column (the same
|
|
// you would get from an ordinary SELECT) into two slices of keys and values. it
|
|
// is used internally in the default parsing of hstores.
|
|
func parseHstore(s string) (k []string, v []Text, err error) {
|
|
if s == "" {
|
|
return
|
|
}
|
|
|
|
buf := bytes.Buffer{}
|
|
keys := []string{}
|
|
values := []Text{}
|
|
p := newHSP(s)
|
|
|
|
r, end := p.Consume()
|
|
state := hsPre
|
|
|
|
for !end {
|
|
switch state {
|
|
case hsPre:
|
|
if r == '"' {
|
|
state = hsKey
|
|
} else {
|
|
err = errors.New("String does not begin with \"")
|
|
}
|
|
case hsKey:
|
|
switch r {
|
|
case '"': //End of the key
|
|
keys = append(keys, buf.String())
|
|
buf = bytes.Buffer{}
|
|
state = hsSep
|
|
case '\\': //Potential escaped character
|
|
n, end := p.Consume()
|
|
switch {
|
|
case end:
|
|
err = errors.New("Found EOS in key, expecting character or \"")
|
|
case n == '"', n == '\\':
|
|
buf.WriteRune(n)
|
|
default:
|
|
buf.WriteRune(r)
|
|
buf.WriteRune(n)
|
|
}
|
|
default: //Any other character
|
|
buf.WriteRune(r)
|
|
}
|
|
case hsSep:
|
|
if r == '=' {
|
|
r, end = p.Consume()
|
|
switch {
|
|
case end:
|
|
err = errors.New("Found EOS after '=', expecting '>'")
|
|
case r == '>':
|
|
r, end = p.Consume()
|
|
switch {
|
|
case end:
|
|
err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
|
|
case r == '"':
|
|
state = hsVal
|
|
case r == 'N':
|
|
state = hsNul
|
|
default:
|
|
err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r)
|
|
}
|
|
default:
|
|
err = fmt.Errorf("Invalid character after '=', expecting '>'")
|
|
}
|
|
} else {
|
|
err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r)
|
|
}
|
|
case hsVal:
|
|
switch r {
|
|
case '"': //End of the value
|
|
values = append(values, Text{String: buf.String(), Valid: true})
|
|
buf = bytes.Buffer{}
|
|
state = hsNext
|
|
case '\\': //Potential escaped character
|
|
n, end := p.Consume()
|
|
switch {
|
|
case end:
|
|
err = errors.New("Found EOS in key, expecting character or \"")
|
|
case n == '"', n == '\\':
|
|
buf.WriteRune(n)
|
|
default:
|
|
buf.WriteRune(r)
|
|
buf.WriteRune(n)
|
|
}
|
|
default: //Any other character
|
|
buf.WriteRune(r)
|
|
}
|
|
case hsNul:
|
|
nulBuf := make([]rune, 3)
|
|
nulBuf[0] = r
|
|
for i := 1; i < 3; i++ {
|
|
r, end = p.Consume()
|
|
if end {
|
|
err = errors.New("Found EOS in NULL value")
|
|
return
|
|
}
|
|
nulBuf[i] = r
|
|
}
|
|
if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
|
|
values = append(values, Text{})
|
|
state = hsNext
|
|
} else {
|
|
err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
|
|
}
|
|
case hsNext:
|
|
if r == ',' {
|
|
r, end = p.Consume()
|
|
switch {
|
|
case end:
|
|
err = errors.New("Found EOS after ',', expecting space")
|
|
case (unicode.IsSpace(r)):
|
|
// after space is a doublequote to start the key
|
|
r, end = p.Consume()
|
|
if end {
|
|
err = errors.New("Found EOS after space, expecting \"")
|
|
return
|
|
}
|
|
if r != '"' {
|
|
err = fmt.Errorf("Invalid character '%c' after space, expecting \"", r)
|
|
return
|
|
}
|
|
state = hsKey
|
|
default:
|
|
err = fmt.Errorf("Invalid character '%c' after ',', expecting space", r)
|
|
}
|
|
} else {
|
|
err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r)
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
r, end = p.Consume()
|
|
}
|
|
if state != hsNext {
|
|
err = errors.New("Improperly formatted hstore")
|
|
return
|
|
}
|
|
k = keys
|
|
v = values
|
|
return
|
|
}
|