pgx/pgtype/hstore.go
Evan Jones dc94db6b3d pgtype.Hstore: add a round-trip test for binary and text codecs
This ensures the output of Encode can pass through Scan and produce
the same input. This found two two minor problems with the text
codec. These are not bugs: These situations do not happen when using
pgx with Postgres. However, I think it is worth fixing to ensure the
code is internally consistent.

The problems with the text codec are:

* It did not correctly distinguish between nil and empty. This is not
  a problem with Postgres, since NULL values are marked separately,
  but the binary codec distinguishes between them, so it seems like
  the text codec should as well.
* It did not output spaces between keys. Postgres produces output in
  this format, and the parser now only strictly parses the Postgres
  format. This is not a bug, but seems like a good idea.
2023-06-29 17:25:47 -05:00

487 lines
12 KiB
Go

package pgtype
import (
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5/internal/pgio"
)
type HstoreScanner interface {
ScanHstore(v Hstore) error
}
type HstoreValuer interface {
HstoreValue() (Hstore, error)
}
// Hstore represents an hstore column that can be null or have null values
// associated with its keys.
type Hstore map[string]*string
func (h *Hstore) ScanHstore(v Hstore) error {
*h = v
return nil
}
func (h Hstore) HstoreValue() (Hstore, error) {
return h, nil
}
// Scan implements the database/sql Scanner interface.
func (h *Hstore) Scan(src any) error {
if src == nil {
*h = nil
return nil
}
switch src := src.(type) {
case string:
return scanPlanTextAnyToHstoreScanner{}.scanString(src, h)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (h Hstore) Value() (driver.Value, error) {
if h == nil {
return nil, nil
}
buf, err := HstoreCodec{}.PlanEncode(nil, 0, TextFormatCode, h).Encode(h, nil)
if err != nil {
return nil, err
}
return string(buf), err
}
type HstoreCodec struct{}
func (HstoreCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (HstoreCodec) PreferredFormat() int16 {
return BinaryFormatCode
}
func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(HstoreValuer); !ok {
return nil
}
switch format {
case BinaryFormatCode:
return encodePlanHstoreCodecBinary{}
case TextFormatCode:
return encodePlanHstoreCodecText{}
}
return nil
}
type encodePlanHstoreCodecBinary struct{}
func (encodePlanHstoreCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
hstore, err := value.(HstoreValuer).HstoreValue()
if err != nil {
return nil, err
}
if hstore == nil {
return nil, nil
}
buf = pgio.AppendInt32(buf, int32(len(hstore)))
for k, v := range hstore {
buf = pgio.AppendInt32(buf, int32(len(k)))
buf = append(buf, k...)
if v == nil {
buf = pgio.AppendInt32(buf, -1)
} else {
buf = pgio.AppendInt32(buf, int32(len(*v)))
buf = append(buf, (*v)...)
}
}
return buf, nil
}
type encodePlanHstoreCodecText struct{}
func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
hstore, err := value.(HstoreValuer).HstoreValue()
if err != nil {
return nil, err
}
if len(hstore) == 0 {
// distinguish between empty and nil: Not strictly required by Postgres, since its protocol
// explicitly marks NULL column values separately. However, the Binary codec does this, and
// this means we can "round trip" Encode and Scan without data loss.
// nil: []byte(nil); empty: []byte{}
if hstore == nil {
return nil, nil
}
return []byte{}, nil
}
firstPair := true
for k, v := range hstore {
if firstPair {
firstPair = false
} else {
buf = append(buf, ',', ' ')
}
// unconditionally quote hstore keys/values like Postgres does
// this avoids a Mac OS X Postgres hstore parsing bug:
// https://www.postgresql.org/message-id/CA%2BHWA9awUW0%2BRV_gO9r1ABZwGoZxPztcJxPy8vMFSTbTfi4jig%40mail.gmail.com
buf = append(buf, '"')
buf = append(buf, quoteArrayReplacer.Replace(k)...)
buf = append(buf, '"')
buf = append(buf, "=>"...)
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, '"')
buf = append(buf, quoteArrayReplacer.Replace(*v)...)
buf = append(buf, '"')
}
}
return buf, nil
}
func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
switch target.(type) {
case HstoreScanner:
return scanPlanBinaryHstoreToHstoreScanner{}
}
case TextFormatCode:
switch target.(type) {
case HstoreScanner:
return scanPlanTextAnyToHstoreScanner{}
}
}
return nil
}
type scanPlanBinaryHstoreToHstoreScanner struct{}
func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {
scanner := (dst).(HstoreScanner)
if src == nil {
return scanner.ScanHstore(Hstore(nil))
}
rp := 0
const uint32Len = 4
if len(src[rp:]) < uint32Len {
return fmt.Errorf("hstore incomplete %v", src)
}
pairCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += uint32Len
hstore := make(Hstore, pairCount)
// one allocation for all *string, rather than one per string, just like text parsing
valueStrings := make([]string, pairCount)
for i := 0; i < pairCount; i++ {
if len(src[rp:]) < uint32Len {
return fmt.Errorf("hstore incomplete %v", src)
}
keyLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += uint32Len
if len(src[rp:]) < keyLen {
return fmt.Errorf("hstore incomplete %v", src)
}
key := string(src[rp : rp+keyLen])
rp += keyLen
if len(src[rp:]) < uint32Len {
return fmt.Errorf("hstore incomplete %v", src)
}
valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
if valueLen >= 0 {
valueStrings[i] = string(src[rp : rp+valueLen])
rp += valueLen
hstore[key] = &valueStrings[i]
} else {
hstore[key] = nil
}
}
return scanner.ScanHstore(hstore)
}
type scanPlanTextAnyToHstoreScanner struct{}
func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error {
scanner := (dst).(HstoreScanner)
if src == nil {
return scanner.ScanHstore(Hstore(nil))
}
return s.scanString(string(src), scanner)
}
// scanString does not return nil hstore values because string cannot be nil.
func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error {
hstore, err := parseHstore(src)
if err != nil {
return err
}
return scanner.ScanHstore(hstore)
}
func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return codecDecodeToTextFormat(c, m, oid, format, src)
}
func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
var hstore Hstore
err := codecScan(c, m, oid, format, src, &hstore)
if err != nil {
return nil, err
}
return hstore, nil
}
type hstoreParser struct {
str string
pos int
nextBackslash int
}
func newHSP(in string) *hstoreParser {
return &hstoreParser{
pos: 0,
str: in,
nextBackslash: strings.IndexByte(in, '\\'),
}
}
func (p *hstoreParser) atEnd() bool {
return p.pos >= len(p.str)
}
// consume returns the next byte of the string, or end if the string is done.
func (p *hstoreParser) consume() (b byte, end bool) {
if p.pos >= len(p.str) {
return 0, true
}
b = p.str[p.pos]
p.pos++
return b, false
}
func unexpectedByteErr(actualB byte, expectedB byte) error {
return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB)
}
// consumeExpectedByte consumes expectedB from the string, or returns an error.
func (p *hstoreParser) consumeExpectedByte(expectedB byte) error {
nextB, end := p.consume()
if end {
return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB)
}
if nextB != expectedB {
return unexpectedByteErr(nextB, expectedB)
}
return nil
}
// consumeExpected2 consumes two expected bytes or returns an error.
// This was a bit faster than using a string argument (better inlining? Not sure).
func (p *hstoreParser) consumeExpected2(one byte, two byte) error {
if p.pos+2 > len(p.str) {
return errors.New("unexpected end of string")
}
if p.str[p.pos] != one {
return unexpectedByteErr(p.str[p.pos], one)
}
if p.str[p.pos+1] != two {
return unexpectedByteErr(p.str[p.pos+1], two)
}
p.pos += 2
return nil
}
var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`)
// consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been
// parsed already. This copies the string from the backing string so it can be garbage collected.
func (p *hstoreParser) consumeDoubleQuoted() (string, error) {
// fast path: assume most keys/values do not contain escapes
nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"')
if nextDoubleQuote == -1 {
return "", errEOSInQuoted
}
nextDoubleQuote += p.pos
if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote {
// clone the string from the source string to ensure it can be garbage collected separately
// TODO: use strings.Clone on Go 1.20; this could get optimized away
s := strings.Clone(p.str[p.pos:nextDoubleQuote])
p.pos = nextDoubleQuote + 1
return s, nil
}
// slow path: string contains escapes
s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash)
p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\')
if p.nextBackslash != -1 {
p.nextBackslash += p.pos
}
return s, err
}
// consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting
// at p.pos, and with the first backslash at firstBackslash. This copies the string so it can be
// garbage collected separately.
func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) {
// copy the prefix that does not contain backslashes
var builder strings.Builder
builder.WriteString(p.str[p.pos:firstBackslash])
// skip to the backslash
p.pos = firstBackslash
// copy bytes until the end, unescaping backslashes
for {
nextB, end := p.consume()
if end {
return "", errEOSInQuoted
} else if nextB == '"' {
break
} else if nextB == '\\' {
// escape: skip the backslash and copy the char
nextB, end = p.consume()
if end {
return "", errEOSInQuoted
}
if !(nextB == '\\' || nextB == '"') {
return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB)
}
builder.WriteByte(nextB)
} else {
// normal byte: copy it
builder.WriteByte(nextB)
}
}
return builder.String(), nil
}
// consumePairSeparator consumes the Hstore pair separator ", " or returns an error.
func (p *hstoreParser) consumePairSeparator() error {
return p.consumeExpected2(',', ' ')
}
// consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error.
func (p *hstoreParser) consumeKVSeparator() error {
return p.consumeExpected2('=', '>')
}
// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error.
func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) {
// peek at the next byte
if p.atEnd() {
return Text{}, errors.New("found end instead of value")
}
next := p.str[p.pos]
if next == 'N' {
// must be the exact string NULL: use consumeExpected2 twice
err := p.consumeExpected2('N', 'U')
if err != nil {
return Text{}, err
}
err = p.consumeExpected2('L', 'L')
if err != nil {
return Text{}, err
}
return Text{String: "", Valid: false}, nil
} else if next != '"' {
return Text{}, unexpectedByteErr(next, '"')
}
// skip the double quote
p.pos += 1
s, err := p.consumeDoubleQuoted()
if err != nil {
return Text{}, err
}
return Text{String: s, Valid: true}, nil
}
func parseHstore(s string) (Hstore, error) {
p := newHSP(s)
// This is an over-estimate of the number of key/value pairs. Use '>' because I am guessing it
// is less likely to occur in keys/values than '=' or ','.
numPairsEstimate := strings.Count(s, ">")
// makes one allocation of strings for the entire Hstore, rather than one allocation per value.
valueStrings := make([]string, 0, numPairsEstimate)
result := make(Hstore, numPairsEstimate)
first := true
for !p.atEnd() {
if !first {
err := p.consumePairSeparator()
if err != nil {
return nil, err
}
} else {
first = false
}
err := p.consumeExpectedByte('"')
if err != nil {
return nil, err
}
key, err := p.consumeDoubleQuoted()
if err != nil {
return nil, err
}
err = p.consumeKVSeparator()
if err != nil {
return nil, err
}
value, err := p.consumeDoubleQuotedOrNull()
if err != nil {
return nil, err
}
if value.Valid {
valueStrings = append(valueStrings, value.String)
result[key] = &valueStrings[len(valueStrings)-1]
} else {
result[key] = nil
}
}
return result, nil
}