Add support for custom JSON marshal and unmarshal.

The Codec interface is now implemented by *pgtype.JSONCodec
and *pgtype.JSONBCodec instead of pgtype.JSONCodec and
pgtype.JSONBCodec, respectively. This is technically a breaking
change, but it is extremely unlikely that anyone is depending on this,
and if there is downstream breakage it is trivial to fix.

Fixes #2005.
pull/2026/head
Mitar 2024-05-13 17:36:42 +02:00 committed by Jack Christensen
parent e1b90cf620
commit 732889728f
5 changed files with 100 additions and 32 deletions

View File

@ -8,17 +8,20 @@ import (
"reflect"
)
type JSONCodec struct{}
type JSONCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}
func (JSONCodec) FormatSupported(format int16) bool {
func (*JSONCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (JSONCodec) PreferredFormat() int16 {
func (*JSONCodec) PreferredFormat() int16 {
return TextFormatCode
}
func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch value.(type) {
case string:
return encodePlanJSONCodecEitherFormatString{}
@ -44,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
//
// https://github.com/jackc/pgx/issues/1681
case json.Marshaler:
return encodePlanJSONCodecEitherFormatMarshal{}
return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
}
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
@ -61,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
}
}
return encodePlanJSONCodecEitherFormatMarshal{}
return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
}
type encodePlanJSONCodecEitherFormatString struct{}
@ -96,10 +103,12 @@ func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byt
return buf, nil
}
type encodePlanJSONCodecEitherFormatMarshal struct{}
type encodePlanJSONCodecEitherFormatMarshal struct {
marshal func(v any) ([]byte, error)
}
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := json.Marshal(value)
func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := e.marshal(value)
if err != nil {
return nil, err
}
@ -108,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new
return buf, nil
}
func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch target.(type) {
case *string:
return scanPlanAnyToString{}
@ -141,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
return &scanPlanSQLScanner{formatCode: format}
}
return scanPlanJSONToJSONUnmarshal{}
return &scanPlanJSONToJSONUnmarshal{
unmarshal: c.Unmarshal,
}
}
type scanPlanAnyToString struct{}
@ -173,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error {
return scanner.ScanBytes(src)
}
type scanPlanJSONToJSONUnmarshal struct{}
type scanPlanJSONToJSONUnmarshal struct {
unmarshal func(data []byte, v any) error
}
func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
if src == nil {
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() == reflect.Ptr {
@ -193,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
elem := reflect.ValueOf(dst).Elem()
elem.Set(reflect.Zero(elem.Type()))
return json.Unmarshal(src, dst)
return s.unmarshal(src, dst)
}
func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
@ -206,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
return dstBuf, nil
}
func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
var dst any
err := json.Unmarshal(src, &dst)
err := c.Unmarshal(src, &dst)
return dst, err
}

View File

@ -6,9 +6,11 @@ import (
"database/sql/driver"
"encoding/json"
"errors"
"reflect"
"testing"
pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/require"
)
@ -224,3 +226,28 @@ func TestJSONCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
require.Equal(t, `{"custom":"thing"}`, jsonStr)
})
}
func TestJSONCodecCustomMarshal(t *testing.T) {
skipCockroachDB(t, "CockroachDB treats json as jsonb. This causes it to format differently than PostgreSQL.")
connTestRunner := defaultConnTestRunner
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "json", OID: pgtype.JSONOID, Codec: &pgtype.JSONCodec{
Marshal: func(v any) ([]byte, error) {
return []byte(`{"custom":"value"}`), nil
},
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
}
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
// There is no space between "custom" and "value" in json type.
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom":"value"}`)},
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
}},
})
}

View File

@ -2,29 +2,31 @@ package pgtype
import (
"database/sql/driver"
"encoding/json"
"fmt"
)
type JSONBCodec struct{}
type JSONBCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}
func (JSONBCodec) FormatSupported(format int16) bool {
func (*JSONBCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (JSONBCodec) PreferredFormat() int16 {
func (*JSONBCodec) PreferredFormat() int16 {
return TextFormatCode
}
func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch format {
case BinaryFormatCode:
plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value)
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value)
if plan != nil {
return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan}
}
case TextFormatCode:
return JSONCodec{}.PlanEncode(m, oid, format, value)
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value)
}
return nil
@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne
return plan.textPlan.Encode(value, buf)
}
func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target)
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target)
if plan != nil {
return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan}
}
case TextFormatCode:
return JSONCodec{}.PlanScan(m, oid, format, target)
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target)
}
return nil
@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error {
return plan.textPlan.Scan(src[1:], dst)
}
func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
}
}
func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a
}
var dst any
err := json.Unmarshal(src, &dst)
err := c.Unmarshal(src, &dst)
return dst, err
}

View File

@ -2,9 +2,12 @@ package pgtype_test
import (
"context"
"encoding/json"
"reflect"
"testing"
pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/require"
)
@ -80,3 +83,26 @@ func TestJSONBCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string.
})
}
func TestJSONBCodecCustomMarshal(t *testing.T) {
connTestRunner := defaultConnTestRunner
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "jsonb", OID: pgtype.JSONBOID, Codec: &pgtype.JSONBCodec{
Marshal: func(v any) ([]byte, error) {
return []byte(`{"custom":"value"}`), nil
},
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
}
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{
// There is space between "custom" and "value" in jsonb type.
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom": "value"}`)},
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
}},
})
}

View File

@ -65,8 +65,8 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})