diff --git a/pgtype/json.go b/pgtype/json.go index 580e8505..510b638e 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -3,187 +3,129 @@ package pgtype import ( "database/sql/driver" "encoding/json" - "errors" - "fmt" + "reflect" ) -type JSON struct { - Bytes []byte - Valid bool +type JSONCodec struct{} + +func (JSONCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst *JSON) Set(src interface{}) error { - if src == nil { - *dst = JSON{} - return nil - } +func (JSONCodec) PreferredFormat() int16 { + return TextFormatCode +} - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case string: - *dst = JSON{Bytes: []byte(value), Valid: true} - case *string: - if value == nil { - *dst = JSON{} - } else { - *dst = JSON{Bytes: []byte(*value), Valid: true} - } +func (JSONCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch value.(type) { case []byte: - if value == nil { - *dst = JSON{} - } else { - *dst = JSON{Bytes: value, Valid: true} - } - // Encode* methods are defined on *JSON. If JSON is passed directly then the - // struct itself would be encoded instead of Bytes. This is clearly a footgun - // so detect and return an error. See https://github.com/jackc/pgx/issues/350. - case JSON: - return errors.New("use pointer to pgtype.JSON instead of value") - // Same as above but for JSONB (because they share implementation) - case JSONB: - return errors.New("use pointer to pgtype.JSONB instead of value") - + return encodePlanJSONCodecEitherFormatByteSlice{} default: - buf, err := json.Marshal(value) - if err != nil { - return err - } - *dst = JSON{Bytes: buf, Valid: true} + return encodePlanJSONCodecEitherFormatMarshal{} } - - return nil } -func (dst JSON) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanJSONCodecEitherFormatByteSlice struct{} + +func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + jsonBytes := value.([]byte) + if jsonBytes == nil { + return nil, nil } - var i interface{} - err := json.Unmarshal(dst.Bytes, &i) + buf = append(buf, jsonBytes...) + return buf, nil +} + +type encodePlanJSONCodecEitherFormatMarshal struct{} + +func (encodePlanJSONCodecEitherFormatMarshal) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + jsonBytes, err := json.Marshal(value) if err != nil { - return dst + return nil, err } - return i + + buf = append(buf, jsonBytes...) + return buf, nil } -func (src *JSON) AssignTo(dst interface{}) error { - switch v := dst.(type) { +func (JSONCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch target.(type) { case *string: - if src.Valid { - *v = string(src.Bytes) - } else { - return fmt.Errorf("cannot assign non-valid to %T", dst) - } - case **string: - if src.Valid { - s := string(src.Bytes) - *v = &s - return nil - } else { - *v = nil - return nil - } + return scanPlanAnyToString{} case *[]byte: - if !src.Valid { - *v = nil - } else { - buf := make([]byte, len(src.Bytes)) - copy(buf, src.Bytes) - *v = buf - } + return scanPlanJSONToByteSlice{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} default: - data := src.Bytes - if data == nil || !src.Valid { - data = []byte("null") + return scanPlanJSONToJSONUnmarshal{} + } + +} + +type scanPlanAnyToString struct{} + +func (scanPlanAnyToString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + p := dst.(*string) + *p = string(src) + return nil +} + +type scanPlanJSONToByteSlice struct{} + +func (scanPlanJSONToByteSlice) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanJSONToBytesScanner struct{} + +func (scanPlanJSONToBytesScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BytesScanner) + return scanner.ScanBytes(src) +} + +type scanPlanJSONToJSONUnmarshal struct{} + +func (scanPlanJSONToJSONUnmarshal) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() == reflect.Ptr { + el := dstValue.Elem() + switch el.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + el.Set(reflect.Zero(el.Type())) + return nil + } } - - return json.Unmarshal(data, dst) } - return nil + return json.Unmarshal(src, dst) } -func (JSON) PreferredResultFormat() int16 { - return TextFormatCode -} - -func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { +func (c JSONCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { - *dst = JSON{} - return nil - } - - *dst = JSON{Bytes: src, Valid: true} - return nil -} - -func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { - return dst.DecodeText(ci, src) -} - -func (JSON) PreferredParamFormat() int16 { - return TextFormatCode -} - -func (src JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - return append(buf, src.Bytes...), nil + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil } -func (src JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return src.EncodeText(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *JSON) Scan(src interface{}) error { +func (c JSONCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = JSON{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src JSON) Value() (driver.Value, error) { - if !src.Valid { return nil, nil } - return src.Bytes, nil -} - -func (src JSON) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - return src.Bytes, nil -} - -func (dst *JSON) UnmarshalJSON(b []byte) error { - if b == nil || string(b) == "null" { - *dst = JSON{} - } else { - *dst = JSON{Bytes: b, Valid: true} - } - return nil + var dst interface{} + err := json.Unmarshal(src, &dst) + return dst, err } diff --git a/pgtype/json_test.go b/pgtype/json_test.go index cb5162d3..156217ac 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -1,177 +1,52 @@ package pgtype_test import ( - "bytes" - "reflect" "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestJSONTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "json", []interface{}{ - &pgtype.JSON{Bytes: []byte("{}"), Valid: true}, - &pgtype.JSON{Bytes: []byte("null"), Valid: true}, - &pgtype.JSON{Bytes: []byte("42"), Valid: true}, - &pgtype.JSON{Bytes: []byte(`"hello"`), Valid: true}, - &pgtype.JSON{}, - }) -} +func isExpectedEqMap(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + aa := a.(map[string]interface{}) + bb := v.(map[string]interface{}) -func TestJSONSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.JSON - }{ - {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Valid: true}}, - {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Valid: true}}, - {source: ([]byte)(nil), result: pgtype.JSON{}}, - {source: (*string)(nil), result: pgtype.JSON{}}, - {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Valid: true}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Valid: true}}, - } - - for i, tt := range successfulTests { - var d pgtype.JSON - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) + if (aa == nil) != (bb == nil) { + return false } - if !reflect.DeepEqual(d, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + if aa == nil { + return true } + + if len(aa) != len(bb) { + return false + } + + for k := range aa { + if aa[k] != bb[k] { + return false + } + } + + return true } } -func TestJSONAssignTo(t *testing.T) { - var s string - var ps *string - var b []byte - - rawStringTests := []struct { - src pgtype.JSON - dst *string - expected string - }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Valid: true}, dst: &s, expected: "{}"}, - } - - for i, tt := range rawStringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - rawBytesTests := []struct { - src pgtype.JSON - dst *[]byte - expected []byte - }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Valid: true}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSON{}, dst: &b, expected: (([]byte)(nil))}, - } - - for i, tt := range rawBytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if bytes.Compare(tt.expected, *tt.dst) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - var mapDst map[string]interface{} - type structDst struct { +func TestJSONCodec(t *testing.T) { + type jsonStruct struct { Name string `json:"name"` Age int `json:"age"` } - var strDst structDst - unmarshalTests := []struct { - src pgtype.JSON - dst interface{} - expected interface{} - }{ - {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Valid: true}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Valid: true}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, - } - for i, tt := range unmarshalTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.JSON - dst **string - expected *string - }{ - {src: pgtype.JSON{}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } -} - -func TestJSONMarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.JSON - result string - }{ - {source: pgtype.JSON{}, result: "null"}, - {source: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Valid: true}, result: "{\"a\": 1}"}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestJSONUnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.JSON - }{ - {source: "null", result: pgtype.JSON{}}, - {source: "{\"a\": 1}", result: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Valid: true}}, - } - for i, tt := range successfulTests { - var r pgtype.JSON - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r.Bytes) != string(tt.result.Bytes) || r.Valid != tt.result.Valid { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } + testPgxCodec(t, "json", []PgxTranscodeTestCase{ + {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, + {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, + {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, + {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, + {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, + {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, + {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 38d56499..6e329150 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -2,35 +2,64 @@ package pgtype import ( "database/sql/driver" + "encoding/json" "fmt" ) -type JSONB JSON +type JSONBCodec struct{} -func (dst *JSONB) Set(src interface{}) error { - return (*JSON)(dst).Set(src) +func (JSONBCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst JSONB) Get() interface{} { - return (JSON)(dst).Get() -} - -func (src *JSONB) AssignTo(dst interface{}) error { - return (*JSON)(src).AssignTo(dst) -} - -func (JSONB) PreferredResultFormat() int16 { +func (JSONBCodec) PreferredFormat() int16 { return TextFormatCode } -func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { - return (*JSON)(dst).DecodeText(ci, src) +func (JSONBCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + plan := JSONCodec{}.PlanEncode(ci, oid, TextFormatCode, value) + if plan != nil { + return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} + } + case TextFormatCode: + return JSONCodec{}.PlanEncode(ci, oid, format, value) + } + + return nil } -func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { +type encodePlanJSONBCodecBinaryWrapper struct { + textPlan EncodePlan +} + +func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + buf = append(buf, 1) + return plan.textPlan.Encode(value, buf) +} + +func (JSONBCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + plan := JSONCodec{}.PlanScan(ci, oid, TextFormatCode, target, actualTarget) + if plan != nil { + return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} + } + case TextFormatCode: + return JSONCodec{}.PlanScan(ci, oid, format, target, actualTarget) + } + + return nil +} + +type scanPlanJSONBCodecBinaryUnwrapper struct { + textPlan ScanPlan +} + +func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - *dst = JSONB{} - return nil + return plan.textPlan.Scan(ci, oid, formatCode, src, dst) } if len(src) == 0 { @@ -41,42 +70,58 @@ func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("unknown jsonb version number %d", src[0]) } - *dst = JSONB{Bytes: src[1:], Valid: true} - return nil - + return plan.textPlan.Scan(ci, oid, formatCode, src[1:], dst) } -func (JSONB) PreferredParamFormat() int16 { - return TextFormatCode -} - -func (src JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (JSON)(src).EncodeText(ci, buf) -} - -func (src JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { +func (c JSONBCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { return nil, nil } - buf = append(buf, 1) - return append(buf, src.Bytes...), nil + switch format { + case BinaryFormatCode: + if len(src) == 0 { + return nil, fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return nil, fmt.Errorf("unknown jsonb version number %d", src[0]) + } + + dstBuf := make([]byte, len(src)-1) + copy(dstBuf, src[1:]) + return dstBuf, nil + case TextFormatCode: + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } } -// Scan implements the database/sql Scanner interface. -func (dst *JSONB) Scan(src interface{}) error { - return (*JSON)(dst).Scan(src) -} +func (c JSONBCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } -// Value implements the database/sql/driver Valuer interface. -func (src JSONB) Value() (driver.Value, error) { - return (JSON)(src).Value() -} + switch format { + case BinaryFormatCode: + if len(src) == 0 { + return nil, fmt.Errorf("jsonb too short") + } -func (src JSONB) MarshalJSON() ([]byte, error) { - return (JSON)(src).MarshalJSON() -} + if src[0] != 1 { + return nil, fmt.Errorf("unknown jsonb version number %d", src[0]) + } -func (dst *JSONB) UnmarshalJSON(b []byte) error { - return (*JSON)(dst).UnmarshalJSON(b) + src = src[1:] + case TextFormatCode: + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } + + var dst interface{} + err := json.Unmarshal(src, &dst) + return dst, err } diff --git a/pgtype/jsonb_array.go b/pgtype/jsonb_array.go deleted file mode 100644 index 81ed9f29..00000000 --- a/pgtype/jsonb_array.go +++ /dev/null @@ -1,504 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type JSONBArray struct { - Elements []JSONB - Dimensions []ArrayDimension - Valid bool -} - -func (dst *JSONBArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = JSONBArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []string: - if value == nil { - *dst = JSONBArray{} - } else if len(value) == 0 { - *dst = JSONBArray{Valid: true} - } else { - elements := make([]JSONB, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = JSONBArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case [][]byte: - if value == nil { - *dst = JSONBArray{} - } else if len(value) == 0 { - *dst = JSONBArray{Valid: true} - } else { - elements := make([]JSONB, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = JSONBArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []JSONB: - if value == nil { - *dst = JSONBArray{} - } else if len(value) == 0 { - *dst = JSONBArray{Valid: true} - } else { - *dst = JSONBArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = JSONBArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for JSONBArray", src) - } - if elementsLength == 0 { - *dst = JSONBArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to JSONBArray", src) - } - - *dst = JSONBArray{ - Elements: make([]JSONB, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]JSONB, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to JSONBArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in JSONBArray", err) - } - index++ - - return index, nil -} - -func (dst JSONBArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *JSONBArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from JSONBArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from JSONBArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = JSONBArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []JSONB - - if len(uta.Elements) > 0 { - elements = make([]JSONB, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem JSONB - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = JSONBArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = JSONBArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = JSONBArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]JSONB, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = JSONBArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("jsonb"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "jsonb") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *JSONBArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src JSONBArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/jsonb_array_test.go b/pgtype/jsonb_array_test.go deleted file mode 100644 index 0fc4d40e..00000000 --- a/pgtype/jsonb_array_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestJSONBArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "jsonb[]", []interface{}{ - &pgtype.JSONBArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.JSONBArray{ - Elements: []pgtype.JSONB{ - {Bytes: []byte(`"foo"`), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.JSONBArray{}, - &pgtype.JSONBArray{ - Elements: []pgtype.JSONB{ - {Bytes: []byte(`"foo"`), Valid: true}, - {Bytes: []byte("null"), Valid: true}, - {Bytes: []byte("42"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Valid: true, - }, - }) -} diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 3a0d62c2..282caeb1 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -1,142 +1,25 @@ package pgtype_test import ( - "bytes" - "reflect" "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestJSONBTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - if _, ok := conn.ConnInfo().DataTypeForName("jsonb"); !ok { - t.Skip("Skipping due to no jsonb type") - } - - testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ - &pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, - &pgtype.JSONB{Bytes: []byte("null"), Valid: true}, - &pgtype.JSONB{Bytes: []byte("42"), Valid: true}, - &pgtype.JSONB{Bytes: []byte(`"hello"`), Valid: true}, - &pgtype.JSONB{}, - }) -} - -func TestJSONBSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.JSONB - }{ - {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}}, - {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}}, - {source: ([]byte)(nil), result: pgtype.JSONB{}}, - {source: (*string)(nil), result: pgtype.JSONB{}}, - {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Valid: true}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Valid: true}}, - } - - for i, tt := range successfulTests { - var d pgtype.JSONB - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(d, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestJSONBAssignTo(t *testing.T) { - var s string - var ps *string - var b []byte - - rawStringTests := []struct { - src pgtype.JSONB - dst *string - expected string - }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, dst: &s, expected: "{}"}, - } - - for i, tt := range rawStringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - rawBytesTests := []struct { - src pgtype.JSONB - dst *[]byte - expected []byte - }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSONB{}, dst: &b, expected: (([]byte)(nil))}, - } - - for i, tt := range rawBytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if bytes.Compare(tt.expected, *tt.dst) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - var mapDst map[string]interface{} - type structDst struct { + type jsonStruct struct { Name string `json:"name"` Age int `json:"age"` } - var strDst structDst - unmarshalTests := []struct { - src pgtype.JSONB - dst interface{} - expected interface{} - }{ - {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Valid: true}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Valid: true}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, - } - for i, tt := range unmarshalTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.JSONB - dst **string - expected *string - }{ - {src: pgtype.JSONB{}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } + testPgxCodec(t, "jsonb", []PgxTranscodeTestCase{ + {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, + {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, + {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, + {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, + {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, + {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, + {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index a6d77356..744671ab 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -27,6 +27,7 @@ const ( XIDOID = 28 CIDOID = 29 JSONOID = 114 + JSONArrayOID = 199 PointOID = 600 LsegOID = 601 PathOID = 602 @@ -289,6 +290,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) + ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementCodec: JSONBCodec{}, ElementOID: JSONBOID}}) + ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementCodec: JSONCodec{}, ElementOID: JSONOID}}) ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: BitOID}}) ci.RegisterDataType(DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: VarbitOID}}) @@ -316,9 +319,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) ci.RegisterDataType(DataType{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) - ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) - ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) + ci.RegisterDataType(DataType{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) + ci.RegisterDataType(DataType{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) ci.RegisterDataType(DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) ci.RegisterDataType(DataType{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) diff --git a/stdlib/sql.go b/stdlib/sql.go index cbb8544e..40693ded 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -670,25 +670,15 @@ func (r *Rows) Next(dest []driver.Value) error { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } - case pgtype.JSONOID: - var d pgtype.JSON + case pgtype.JSONOID, pgtype.JSONBOID: + var d []byte scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } - return d.Value() - } - case pgtype.JSONBOID: - var d pgtype.JSONB - scanPlan := ci.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) - if err != nil { - return nil, err - } - return d.Value() + return d, nil } case pgtype.TimestampOID: var d pgtype.Timestamp diff --git a/values.go b/values.go index a60d4129..b5ce4f7c 100644 --- a/values.go +++ b/values.go @@ -35,32 +35,6 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e } switch arg := arg.(type) { - - // https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface - // []byte to database/sql instead of string. But that caused problems with the - // simple protocol because the driver.Valuer case got taken before the - // pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual - // case because of https://github.com/jackc/pgx/issues/339. So instead we - // special case JSON and JSONB. - case *pgtype.JSON: - buf, err := arg.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - case *pgtype.JSONB: - buf, err := arg.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - case driver.Valuer: return callValuerValue(arg) case pgtype.TextEncoder: