mirror of https://github.com/jackc/pgx.git
Add support for custom JSON marshal and unmarshal.
The Codec interface is now implemented by *pgtype.JSONCodec and *pgtype.JSONBCodec instead of pgtype.JSONCodec and pgtype.JSONBCodec, respectively. This is technically a breaking change, but it is extremely unlikely that anyone is depending on this, and if there is downstream breakage it is trivial to fix. Fixes #2005.pull/2026/head
parent
e1b90cf620
commit
732889728f
|
@ -8,17 +8,20 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JSONCodec struct{}
|
type JSONCodec struct {
|
||||||
|
Marshal func(v any) ([]byte, error)
|
||||||
|
Unmarshal func(data []byte, v any) error
|
||||||
|
}
|
||||||
|
|
||||||
func (JSONCodec) FormatSupported(format int16) bool {
|
func (*JSONCodec) FormatSupported(format int16) bool {
|
||||||
return format == TextFormatCode || format == BinaryFormatCode
|
return format == TextFormatCode || format == BinaryFormatCode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (JSONCodec) PreferredFormat() int16 {
|
func (*JSONCodec) PreferredFormat() int16 {
|
||||||
return TextFormatCode
|
return TextFormatCode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||||
switch value.(type) {
|
switch value.(type) {
|
||||||
case string:
|
case string:
|
||||||
return encodePlanJSONCodecEitherFormatString{}
|
return encodePlanJSONCodecEitherFormatString{}
|
||||||
|
@ -44,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
|
||||||
//
|
//
|
||||||
// https://github.com/jackc/pgx/issues/1681
|
// https://github.com/jackc/pgx/issues/1681
|
||||||
case json.Marshaler:
|
case json.Marshaler:
|
||||||
return encodePlanJSONCodecEitherFormatMarshal{}
|
return &encodePlanJSONCodecEitherFormatMarshal{
|
||||||
|
marshal: c.Marshal,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
|
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
|
||||||
|
@ -61,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return encodePlanJSONCodecEitherFormatMarshal{}
|
return &encodePlanJSONCodecEitherFormatMarshal{
|
||||||
|
marshal: c.Marshal,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type encodePlanJSONCodecEitherFormatString struct{}
|
type encodePlanJSONCodecEitherFormatString struct{}
|
||||||
|
@ -96,10 +103,12 @@ func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byt
|
||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type encodePlanJSONCodecEitherFormatMarshal struct{}
|
type encodePlanJSONCodecEitherFormatMarshal struct {
|
||||||
|
marshal func(v any) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
jsonBytes, err := json.Marshal(value)
|
jsonBytes, err := e.marshal(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -108,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new
|
||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
case *string:
|
case *string:
|
||||||
return scanPlanAnyToString{}
|
return scanPlanAnyToString{}
|
||||||
|
@ -141,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
|
||||||
return &scanPlanSQLScanner{formatCode: format}
|
return &scanPlanSQLScanner{formatCode: format}
|
||||||
}
|
}
|
||||||
|
|
||||||
return scanPlanJSONToJSONUnmarshal{}
|
return &scanPlanJSONToJSONUnmarshal{
|
||||||
|
unmarshal: c.Unmarshal,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type scanPlanAnyToString struct{}
|
type scanPlanAnyToString struct{}
|
||||||
|
@ -173,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error {
|
||||||
return scanner.ScanBytes(src)
|
return scanner.ScanBytes(src)
|
||||||
}
|
}
|
||||||
|
|
||||||
type scanPlanJSONToJSONUnmarshal struct{}
|
type scanPlanJSONToJSONUnmarshal struct {
|
||||||
|
unmarshal func(data []byte, v any) error
|
||||||
|
}
|
||||||
|
|
||||||
func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
dstValue := reflect.ValueOf(dst)
|
dstValue := reflect.ValueOf(dst)
|
||||||
if dstValue.Kind() == reflect.Ptr {
|
if dstValue.Kind() == reflect.Ptr {
|
||||||
|
@ -193,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
||||||
elem := reflect.ValueOf(dst).Elem()
|
elem := reflect.ValueOf(dst).Elem()
|
||||||
elem.Set(reflect.Zero(elem.Type()))
|
elem.Set(reflect.Zero(elem.Type()))
|
||||||
|
|
||||||
return json.Unmarshal(src, dst)
|
return s.unmarshal(src, dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -206,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
|
||||||
return dstBuf, nil
|
return dstBuf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var dst any
|
var dst any
|
||||||
err := json.Unmarshal(src, &dst)
|
err := c.Unmarshal(src, &dst)
|
||||||
return dst, err
|
return dst, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,9 +6,11 @@ import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
pgx "github.com/jackc/pgx/v5"
|
pgx "github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/jackc/pgx/v5/pgxtest"
|
"github.com/jackc/pgx/v5/pgxtest"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -224,3 +226,28 @@ func TestJSONCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
|
||||||
require.Equal(t, `{"custom":"thing"}`, jsonStr)
|
require.Equal(t, `{"custom":"thing"}`, jsonStr)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||||
|
skipCockroachDB(t, "CockroachDB treats json as jsonb. This causes it to format differently than PostgreSQL.")
|
||||||
|
|
||||||
|
connTestRunner := defaultConnTestRunner
|
||||||
|
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
conn.TypeMap().RegisterType(&pgtype.Type{
|
||||||
|
Name: "json", OID: pgtype.JSONOID, Codec: &pgtype.JSONCodec{
|
||||||
|
Marshal: func(v any) ([]byte, error) {
|
||||||
|
return []byte(`{"custom":"value"}`), nil
|
||||||
|
},
|
||||||
|
Unmarshal: func(data []byte, v any) error {
|
||||||
|
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
|
||||||
|
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||||
|
// There is no space between "custom" and "value" in json type.
|
||||||
|
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom":"value"}`)},
|
||||||
|
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
|
||||||
|
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -2,29 +2,31 @@ package pgtype
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JSONBCodec struct{}
|
type JSONBCodec struct {
|
||||||
|
Marshal func(v any) ([]byte, error)
|
||||||
|
Unmarshal func(data []byte, v any) error
|
||||||
|
}
|
||||||
|
|
||||||
func (JSONBCodec) FormatSupported(format int16) bool {
|
func (*JSONBCodec) FormatSupported(format int16) bool {
|
||||||
return format == TextFormatCode || format == BinaryFormatCode
|
return format == TextFormatCode || format == BinaryFormatCode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (JSONBCodec) PreferredFormat() int16 {
|
func (*JSONBCodec) PreferredFormat() int16 {
|
||||||
return TextFormatCode
|
return TextFormatCode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value)
|
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value)
|
||||||
if plan != nil {
|
if plan != nil {
|
||||||
return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan}
|
return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan}
|
||||||
}
|
}
|
||||||
case TextFormatCode:
|
case TextFormatCode:
|
||||||
return JSONCodec{}.PlanEncode(m, oid, format, value)
|
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne
|
||||||
return plan.textPlan.Encode(value, buf)
|
return plan.textPlan.Encode(value, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target)
|
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target)
|
||||||
if plan != nil {
|
if plan != nil {
|
||||||
return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan}
|
return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan}
|
||||||
}
|
}
|
||||||
case TextFormatCode:
|
case TextFormatCode:
|
||||||
return JSONCodec{}.PlanScan(m, oid, format, target)
|
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error {
|
||||||
return plan.textPlan.Scan(src[1:], dst)
|
return plan.textPlan.Scan(src[1:], dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a
|
||||||
}
|
}
|
||||||
|
|
||||||
var dst any
|
var dst any
|
||||||
err := json.Unmarshal(src, &dst)
|
err := c.Unmarshal(src, &dst)
|
||||||
return dst, err
|
return dst, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,9 +2,12 @@ package pgtype_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
pgx "github.com/jackc/pgx/v5"
|
pgx "github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/jackc/pgx/v5/pgxtest"
|
"github.com/jackc/pgx/v5/pgxtest"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -80,3 +83,26 @@ func TestJSONBCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
|
||||||
require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string.
|
require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string.
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONBCodecCustomMarshal(t *testing.T) {
|
||||||
|
connTestRunner := defaultConnTestRunner
|
||||||
|
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
conn.TypeMap().RegisterType(&pgtype.Type{
|
||||||
|
Name: "jsonb", OID: pgtype.JSONBOID, Codec: &pgtype.JSONBCodec{
|
||||||
|
Marshal: func(v any) ([]byte, error) {
|
||||||
|
return []byte(`{"custom":"value"}`), nil
|
||||||
|
},
|
||||||
|
Unmarshal: func(data []byte, v any) error {
|
||||||
|
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
|
||||||
|
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{
|
||||||
|
// There is space between "custom" and "value" in jsonb type.
|
||||||
|
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom": "value"}`)},
|
||||||
|
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
|
||||||
|
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -65,8 +65,8 @@ func initDefaultMap() {
|
||||||
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
|
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
|
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
|
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}})
|
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
|
||||||
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}})
|
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
|
||||||
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
|
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
|
||||||
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
|
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
|
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
|
||||||
|
|
Loading…
Reference in New Issue