From a6863a7dd2f9efa3dc2a13f396c49cc7da6f7c13 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jan 2022 17:47:37 -0600 Subject: [PATCH] Convert Hstore to Codec --- pgtype/builtin_wrappers.go | 37 +++ pgtype/hstore.go | 357 ++++++++++++++------------- pgtype/hstore_array.go | 476 ------------------------------------ pgtype/hstore_array_test.go | 436 --------------------------------- pgtype/hstore_test.go | 334 ++++++++++++------------- pgtype/pgtype.go | 48 ++++ 6 files changed, 432 insertions(+), 1256 deletions(-) delete mode 100644 pgtype/hstore_array.go delete mode 100644 pgtype/hstore_array_test.go diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index df955f18..15d4e083 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -381,3 +381,40 @@ func (w netIPWrapper) InetValue() (Inet, error) { mask := net.CIDRMask(bitCount, bitCount) return Inet{IPNet: &net.IPNet{Mask: mask, IP: net.IP(w)}, Valid: true}, nil } + +type mapStringToPointerStringWrapper map[string]*string + +func (w *mapStringToPointerStringWrapper) ScanHstore(v Hstore) error { + *w = mapStringToPointerStringWrapper(v) + return nil +} + +func (w mapStringToPointerStringWrapper) HstoreValue() (Hstore, error) { + return Hstore(w), nil +} + +type mapStringToStringWrapper map[string]string + +func (w *mapStringToStringWrapper) ScanHstore(v Hstore) error { + *w = make(mapStringToStringWrapper, len(v)) + for k, v := range v { + if v == nil { + return fmt.Errorf("cannot scan NULL to string") + } + (*w)[k] = *v + } + return nil +} + +func (w mapStringToStringWrapper) HstoreValue() (Hstore, error) { + if w == nil { + return nil, nil + } + + hstore := make(Hstore, len(w)) + for k, v := range w { + s := v + hstore[k] = &s + } + return hstore, nil +} diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 69c8a07b..6ff8164c 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -13,114 +13,168 @@ import ( "github.com/jackc/pgio" ) +type HstoreScanner interface { + ScanHstore(v Hstore) error +} + +type HstoreValuer interface { + HstoreValue() (Hstore, error) +} + // Hstore represents an hstore column that can be null or have null values // associated with its keys. -type Hstore struct { - Map map[string]Text - Valid bool -} - -func (dst *Hstore) Set(src interface{}) error { - if src == nil { - *dst = Hstore{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case map[string]string: - m := make(map[string]Text, len(value)) - for k, v := range value { - m[k] = Text{String: v, Valid: true} - } - *dst = Hstore{Map: m, Valid: true} - case map[string]*string: - m := make(map[string]Text, len(value)) - for k, v := range value { - if v == nil { - m[k] = Text{} - } else { - m[k] = Text{String: *v, Valid: true} - } - } - *dst = Hstore{Map: m, Valid: true} - default: - return fmt.Errorf("cannot convert %v to Hstore", src) - } +type Hstore map[string]*string +func (h *Hstore) ScanHstore(v Hstore) error { + *h = v return nil } -func (dst Hstore) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Map +func (h Hstore) HstoreValue() (Hstore, error) { + return h, nil } -func (src *Hstore) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *map[string]string: - *v = make(map[string]string, len(src.Map)) - for k, val := range src.Map { - if !val.Valid { - return fmt.Errorf("cannot decode %#v into %T", src, dst) - } - (*v)[k] = val.String - } - return nil - case *map[string]*string: - *v = make(map[string]*string, len(src.Map)) - for k, val := range src.Map { - if val.Valid { - (*v)[k] = &val.String - } else { - (*v)[k] = nil - } - } - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the database/sql Scanner interface. +func (h *Hstore) Scan(src interface{}) error { if src == nil { - *dst = Hstore{} + *h = nil return nil } - keys, values, err := parseHstore(string(src)) + switch src := src.(type) { + case string: + return scanPlanTextAnyToHstoreScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), h) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (h Hstore) Value() (driver.Value, error) { + if h == nil { + return nil, nil + } + + buf, err := HstoreCodec{}.PlanEncode(nil, 0, TextFormatCode, h).Encode(h, nil) if err != nil { - return err + return nil, err + } + return string(buf), err +} + +type HstoreCodec struct{} + +func (HstoreCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (HstoreCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (HstoreCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(HstoreValuer); !ok { + return nil } - m := make(map[string]Text, len(keys)) - for i := range keys { - m[keys[i]] = values[i] + switch format { + case BinaryFormatCode: + return encodePlanHstoreCodecBinary{} + case TextFormatCode: + return encodePlanHstoreCodecText{} } - *dst = Hstore{Map: m, Valid: true} return nil } -func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { +type encodePlanHstoreCodecBinary struct{} + +func (encodePlanHstoreCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + hstore, err := value.(HstoreValuer).HstoreValue() + if err != nil { + return nil, err + } + + if hstore == nil { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(hstore))) + + for k, v := range hstore { + buf = pgio.AppendInt32(buf, int32(len(k))) + buf = append(buf, k...) + + if v == nil { + buf = pgio.AppendInt32(buf, -1) + } else { + buf = pgio.AppendInt32(buf, int32(len(*v))) + buf = append(buf, (*v)...) + } + } + + return buf, nil +} + +type encodePlanHstoreCodecText struct{} + +func (encodePlanHstoreCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + hstore, err := value.(HstoreValuer).HstoreValue() + if err != nil { + return nil, err + } + + if hstore == nil { + return nil, nil + } + + firstPair := true + + for k, v := range hstore { + if firstPair { + firstPair = false + } else { + buf = append(buf, ',') + } + + buf = append(buf, quoteHstoreElementIfNeeded(k)...) + buf = append(buf, "=>"...) + + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, quoteHstoreElementIfNeeded(*v)...) + } + } + + return buf, nil +} + +func (HstoreCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case HstoreScanner: + return scanPlanBinaryHstoreToHstoreScanner{} + } + case TextFormatCode: + switch target.(type) { + case HstoreScanner: + return scanPlanTextAnyToHstoreScanner{} + } + } + + return nil +} + +type scanPlanBinaryHstoreToHstoreScanner struct{} + +func (scanPlanBinaryHstoreToHstoreScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(HstoreScanner) + if src == nil { - *dst = Hstore{} - return nil + return scanner.ScanHstore(Hstore{}) } rp := 0 @@ -131,7 +185,7 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 - m := make(map[string]Text, pairCount) + hstore := make(Hstore, pairCount) for i := 0; i < pairCount; i++ { if len(src[rp:]) < 4 { @@ -163,73 +217,58 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { if err != nil { return err } - m[key] = value + + if value.Valid { + hstore[key] = &value.String + } else { + hstore[key] = nil + } } - *dst = Hstore{Map: m, Valid: true} - - return nil + return scanner.ScanHstore(hstore) } -func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { +type scanPlanTextAnyToHstoreScanner struct{} + +func (scanPlanTextAnyToHstoreScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(HstoreScanner) + + if src == nil { + return scanner.ScanHstore(Hstore{}) + } + + keys, values, err := parseHstore(string(src)) + if err != nil { + return err + } + + m := make(Hstore, len(keys)) + for i := range keys { + if values[i].Valid { + m[keys[i]] = &values[i].String + } else { + m[keys[i]] = nil + } + } + + return scanner.ScanHstore(m) +} + +func (c HstoreCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c HstoreCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { return nil, nil } - firstPair := true - - inElemBuf := make([]byte, 0, 32) - for k, v := range src.Map { - if firstPair { - firstPair = false - } else { - buf = append(buf, ',') - } - - buf = append(buf, quoteHstoreElementIfNeeded(k)...) - buf = append(buf, "=>"...) - - elemBuf, err := ci.Encode(TextOID, TextFormatCode, v, inElemBuf) - if err != nil { - return nil, err - } - - if elemBuf == nil { - buf = append(buf, "NULL"...) - } else { - buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...) - } + var hstore Hstore + err := codecScan(c, ci, oid, format, src, &hstore) + if err != nil { + return nil, err } - - return buf, nil -} - -func (src Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendInt32(buf, int32(len(src.Map))) - - var err error - for k, v := range src.Map { - buf = pgio.AppendInt32(buf, int32(len(k))) - buf = append(buf, k...) - - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := ci.Encode(TextOID, BinaryFormatCode, v, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, err + return hstore, nil } var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) @@ -420,27 +459,3 @@ func parseHstore(s string) (k []string, v []Text, err error) { v = values return } - -// Scan implements the database/sql Scanner interface. -func (dst *Hstore) Scan(src interface{}) error { - if src == nil { - *dst = Hstore{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Hstore) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go deleted file mode 100644 index 0ca5d4fb..00000000 --- a/pgtype/hstore_array.go +++ /dev/null @@ -1,476 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type HstoreArray struct { - Elements []Hstore - Dimensions []ArrayDimension - Valid bool -} - -func (dst *HstoreArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = HstoreArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []map[string]string: - if value == nil { - *dst = HstoreArray{} - } else if len(value) == 0 { - *dst = HstoreArray{Valid: true} - } else { - elements := make([]Hstore, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = HstoreArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Hstore: - if value == nil { - *dst = HstoreArray{} - } else if len(value) == 0 { - *dst = HstoreArray{Valid: true} - } else { - *dst = HstoreArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = HstoreArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for HstoreArray", src) - } - if elementsLength == 0 { - *dst = HstoreArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to HstoreArray", src) - } - - *dst = HstoreArray{ - Elements: make([]Hstore, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Hstore, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to HstoreArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in HstoreArray", err) - } - index++ - - return index, nil -} - -func (dst HstoreArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *HstoreArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]map[string]string: - *v = make([]map[string]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from HstoreArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from HstoreArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = HstoreArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Hstore - - if len(uta.Elements) > 0 { - elements = make([]Hstore, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Hstore - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = HstoreArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Hstore, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("hstore"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *HstoreArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src HstoreArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go deleted file mode 100644 index 7912b626..00000000 --- a/pgtype/hstore_array_test.go +++ /dev/null @@ -1,436 +0,0 @@ -package pgtype_test - -import ( - "context" - "reflect" - "testing" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestHstoreArrayTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - var hstoreOID uint32 - err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID) - if err != nil { - t.Fatalf("did not find hstore OID, %v", err) - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) - - var hstoreArrayOID uint32 - err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) - if err != nil { - t.Fatalf("did not find _hstore OID, %v", err) - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) - - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Valid: true} - } - - values := []pgtype.Hstore{ - {Map: map[string]pgtype.Text{}, Valid: true}, - {Map: map[string]pgtype.Text{"foo": text("bar")}, Valid: true}, - {Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Valid: true}, - {Map: map[string]pgtype.Text{"NULL": text("bar")}, Valid: true}, - {Map: map[string]pgtype.Text{"foo": text("NULL")}, Valid: true}, - {}, - } - - specialStrings := []string{ - `"`, - `'`, - `\`, - `\\`, - `=>`, - ` `, - `\ / / \\ => " ' " '`, - } - for _, s := range specialStrings { - // Special key values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Valid: true}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Valid: true}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Valid: true}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Valid: true}) // is key - - // Special value values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Valid: true}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Valid: true}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Valid: true}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Valid: true}) // is key - } - - src := &pgtype.HstoreArray{ - Elements: values, - Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, - Valid: true, - } - - _, err = conn.Prepare(context.Background(), "test", "select $1::hstore[]") - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for _, fc := range formats { - queryResultFormats := pgx.QueryResultFormats{fc.formatCode} - vEncoder := testutil.ForceEncoder(src, fc.formatCode) - if vEncoder == nil { - t.Logf("%#v does not implement %v", src, fc.name) - continue - } - - var result pgtype.HstoreArray - err := conn.QueryRow(context.Background(), "test", queryResultFormats, vEncoder).Scan(&result) - if err != nil { - t.Errorf("%v: %v", fc.name, err) - continue - } - - if result.Valid != src.Valid { - t.Errorf("%v: expected Valid %v, got %v", fc.formatCode, src.Valid, result.Valid) - continue - } - - if len(result.Elements) != len(src.Elements) { - t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) - continue - } - - for i := range result.Elements { - a := src.Elements[i] - b := result.Elements[i] - - if a.Valid != b.Valid { - t.Errorf("%v element idx %d: expected Valid %v, got %v", fc.formatCode, i, a.Valid, b.Valid) - } - - if len(a.Map) != len(b.Map) { - t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) - } - - for k := range a.Map { - if a.Map[k] != b.Map[k] { - t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) - } - } - } - } -} - -func TestHstoreArraySet(t *testing.T) { - successfulTests := []struct { - src interface{} - result pgtype.HstoreArray - }{ - { - src: []map[string]string{{"foo": "bar"}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - }, - { - src: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true, - }, - }, - { - src: [][][][]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true, - }, - }, - { - src: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true, - }, - }, - { - src: [2][1][1][3]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true, - }, - }, - } - - for i, tt := range successfulTests { - var dst pgtype.HstoreArray - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) - } - } -} - -func TestHstoreArrayAssignTo(t *testing.T) { - var hstoreSlice []map[string]string - var hstoreSliceDim2 [][]map[string]string - var hstoreSliceDim4 [][][][]map[string]string - var hstoreArrayDim2 [2][1]map[string]string - var hstoreArrayDim4 [2][1][1][3]map[string]string - - simpleTests := []struct { - src pgtype.HstoreArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &hstoreSlice, - expected: []map[string]string{{"foo": "bar"}}}, - { - src: pgtype.HstoreArray{}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), - }, - { - src: pgtype.HstoreArray{Valid: true}, dst: &hstoreSlice, expected: []map[string]string{}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &hstoreSliceDim2, - expected: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true, - }, - dst: &hstoreSliceDim4, - expected: [][][][]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &hstoreArrayDim2, - expected: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true, - }, - dst: &hstoreArrayDim4, - expected: [2][1][1][3]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index dd80f0c5..edd94db7 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -1,31 +1,124 @@ package pgtype_test import ( - "reflect" + "context" "testing" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestHstoreTranscode(t *testing.T) { - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Valid: true} +func isExpectedEqMapStringString(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + am := a.(map[string]string) + vm := v.(map[string]string) + + if len(am) != len(vm) { + return false + } + + for k, v := range am { + if vm[k] != v { + return false + } + } + + return true + } +} + +func isExpectedEqMapStringPointerString(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + am := a.(map[string]*string) + vm := v.(map[string]*string) + + if len(am) != len(vm) { + return false + } + + for k, v := range am { + if (vm[k] == nil) != (v == nil) { + return false + } + + if v != nil && *vm[k] != *v { + return false + } + } + + return true + } +} + +func TestHstoreCodec(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var hstoreOID uint32 + err := conn.QueryRow(context.Background(), `select oid from pg_type where typname = 'hstore'`).Scan(&hstoreOID) + if err != nil { + t.Skipf("Skipping: cannot find hstore OID") } - values := []interface{}{ - &pgtype.Hstore{Map: map[string]pgtype.Text{}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(""), "bar": text(""), "baz": text("123")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"": text("bar")}, Valid: true}, - &pgtype.Hstore{ - Map: map[string]pgtype.Text{"a": text("a"), "b": {}, "c": text("c"), "d": {}, "e": text("e")}, - Valid: true, + conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + fs := func(s string) *string { + return &s + } + + tests := []PgxTranscodeTestCase{ + { + map[string]string{}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{}), }, - &pgtype.Hstore{}, + { + map[string]string{"foo": "", "bar": "", "baz": "123"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "", "bar": "", "baz": "123"}), + }, + { + map[string]string{"NULL": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"NULL": "bar"}), + }, + { + map[string]string{"bar": "NULL"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"bar": "NULL"}), + }, + { + map[string]string{"": "foo"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"": "foo"}), + }, + { + map[string]*string{}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{}), + }, + { + map[string]*string{"foo": fs("bar"), "baq": fs("quz")}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{"foo": fs("bar"), "baq": fs("quz")}), + }, + { + map[string]*string{"foo": nil, "baq": fs("quz")}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{"foo": nil, "baq": fs("quz")}), + }, + {nil, new(*map[string]string), isExpectedEq((*map[string]string)(nil))}, + {nil, new(*map[string]*string), isExpectedEq((*map[string]*string)(nil))}, + {nil, new(*pgtype.Hstore), isExpectedEq((*pgtype.Hstore)(nil))}, } specialStrings := []string{ @@ -39,166 +132,61 @@ func TestHstoreTranscode(t *testing.T) { } for _, s := range specialStrings { // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Valid: true}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Valid: true}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Valid: true}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Valid: true}) // is key + + // at beginning + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{s + "foo": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{s + "foo": "bar"}), + }) + // in middle + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo" + s + "bar": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo" + s + "bar": "bar"}), + }) + // at end + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo" + s: "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo" + s: "bar"}), + }) + // is key + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{s: "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{s: "bar"}), + }) // Special value values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Valid: true}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Valid: true}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Valid: true}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Valid: true}) // is key + + // at beginning + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo": s + "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": s + "bar"}), + }) + // in middle + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo": "foo" + s + "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "foo" + s + "bar"}), + }) + // at end + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo": "foo" + s}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "foo" + s}), + }) + // is key + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo": s}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": s}), + }) } - testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { - a := ai.(pgtype.Hstore) - b := bi.(pgtype.Hstore) - - if len(a.Map) != len(b.Map) || a.Valid != b.Valid { - return false - } - - for k := range a.Map { - if a.Map[k] != b.Map[k] { - return false - } - } - - return true - }) -} - -func TestHstoreTranscodeNullable(t *testing.T) { - text := func(s string, valid bool) pgtype.Text { - return pgtype.Text{String: s, Valid: valid} - } - - values := []interface{}{ - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("", false)}, Valid: true}, - } - - specialStrings := []string{ - `"`, - `'`, - `\`, - `\\`, - `=>`, - ` `, - `\ / / \\ => " ' " '`, - } - for _, s := range specialStrings { - // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("", false)}, Valid: true}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("", false)}, Valid: true}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("", false)}, Valid: true}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("", false)}, Valid: true}) // is key - } - - testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { - a := ai.(pgtype.Hstore) - b := bi.(pgtype.Hstore) - - if len(a.Map) != len(b.Map) || a.Valid != b.Valid { - return false - } - - for k := range a.Map { - if a.Map[k] != b.Map[k] { - return false - } - } - - return true - }) -} - -func TestHstoreSet(t *testing.T) { - successfulTests := []struct { - src map[string]string - result pgtype.Hstore - }{ - {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, Valid: true}}, - } - - for i, tt := range successfulTests { - var dst pgtype.Hstore - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) - } - } -} - -func TestHstoreSetNullable(t *testing.T) { - successfulTests := []struct { - src map[string]*string - result pgtype.Hstore - }{ - {src: map[string]*string{"foo": nil}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {}}, Valid: true}}, - } - - for i, tt := range successfulTests { - var dst pgtype.Hstore - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) - } - } -} - -func TestHstoreAssignTo(t *testing.T) { - var m map[string]string - - simpleTests := []struct { - src pgtype.Hstore - dst *map[string]string - expected map[string]string - }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, Valid: true}, dst: &m, expected: map[string]string{"foo": "bar"}}, - {src: pgtype.Hstore{}, dst: &m, expected: ((map[string]string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } -} - -func TestHstoreAssignToNullable(t *testing.T) { - var m map[string]*string - - simpleTests := []struct { - src pgtype.Hstore - dst *map[string]*string - expected map[string]*string - }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {}}, Valid: true}, dst: &m, expected: map[string]*string{"foo": nil}}, - {src: pgtype.Hstore{}, dst: &m, expected: ((map[string]*string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } + for _, format := range formats { + testPgxCodecFormat(t, "hstore", tests, conn, format.name, format.code) } } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index cb351fbd..5d0ed882 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -856,6 +856,10 @@ func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(dst), true case *net.IP: return &wrapNetIPScanPlan{}, (*netIPWrapper)(dst), true + case *map[string]*string: + return &wrapMapStringToPointerStringScanPlan{}, (*mapStringToPointerStringWrapper)(dst), true + case *map[string]string: + return &wrapMapStringToStringScanPlan{}, (*mapStringToStringWrapper)(dst), true } return nil, nil, false @@ -1021,6 +1025,26 @@ func (plan *wrapNetIPScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, return plan.next.Scan(ci, oid, formatCode, src, (*netIPWrapper)(dst.(*net.IP))) } +type wrapMapStringToPointerStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapMapStringToPointerStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapMapStringToPointerStringScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*mapStringToPointerStringWrapper)(dst.(*map[string]*string))) +} + +type wrapMapStringToStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapMapStringToStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapMapStringToStringScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*mapStringToStringWrapper)(dst.(*map[string]string))) +} + type pointerEmptyInterfaceScanPlan struct { codec Codec } @@ -1362,6 +1386,10 @@ func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNext return &wrapNetIPNetEncodePlan{}, netIPNetWrapper(value), true case net.IP: return &wrapNetIPEncodePlan{}, netIPWrapper(value), true + case map[string]*string: + return &wrapMapStringToPointerStringEncodePlan{}, mapStringToPointerStringWrapper(value), true + case map[string]string: + return &wrapMapStringToStringEncodePlan{}, mapStringToStringWrapper(value), true } return nil, nil, false @@ -1527,6 +1555,26 @@ func (plan *wrapNetIPEncodePlan) Encode(value interface{}, buf []byte) (newBuf [ return plan.next.Encode(netIPWrapper(value.(net.IP)), buf) } +type wrapMapStringToPointerStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMapStringToPointerStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMapStringToPointerStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(mapStringToPointerStringWrapper(value.(map[string]*string)), buf) +} + +type wrapMapStringToStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMapStringToStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMapStringToStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(mapStringToStringWrapper(value.(map[string]string)), buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written.