mirror of https://github.com/jackc/pgx.git
Add support for custom JSON marshal and unmarshal.
The Codec interface is now implemented by *pgtype.JSONCodec and *pgtype.JSONBCodec instead of pgtype.JSONCodec and pgtype.JSONBCodec, respectively. This is technically a breaking change, but it is extremely unlikely that anyone is depending on this, and if there is downstream breakage it is trivial to fix. Fixes #2005.pull/2026/head
parent
e1b90cf620
commit
732889728f
|
@ -8,17 +8,20 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
type JSONCodec struct{}
|
||||
type JSONCodec struct {
|
||||
Marshal func(v any) ([]byte, error)
|
||||
Unmarshal func(data []byte, v any) error
|
||||
}
|
||||
|
||||
func (JSONCodec) FormatSupported(format int16) bool {
|
||||
func (*JSONCodec) FormatSupported(format int16) bool {
|
||||
return format == TextFormatCode || format == BinaryFormatCode
|
||||
}
|
||||
|
||||
func (JSONCodec) PreferredFormat() int16 {
|
||||
func (*JSONCodec) PreferredFormat() int16 {
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||
func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||
switch value.(type) {
|
||||
case string:
|
||||
return encodePlanJSONCodecEitherFormatString{}
|
||||
|
@ -44,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
|
|||
//
|
||||
// https://github.com/jackc/pgx/issues/1681
|
||||
case json.Marshaler:
|
||||
return encodePlanJSONCodecEitherFormatMarshal{}
|
||||
return &encodePlanJSONCodecEitherFormatMarshal{
|
||||
marshal: c.Marshal,
|
||||
}
|
||||
}
|
||||
|
||||
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
|
||||
|
@ -61,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
|
|||
}
|
||||
}
|
||||
|
||||
return encodePlanJSONCodecEitherFormatMarshal{}
|
||||
return &encodePlanJSONCodecEitherFormatMarshal{
|
||||
marshal: c.Marshal,
|
||||
}
|
||||
}
|
||||
|
||||
type encodePlanJSONCodecEitherFormatString struct{}
|
||||
|
@ -96,10 +103,12 @@ func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byt
|
|||
return buf, nil
|
||||
}
|
||||
|
||||
type encodePlanJSONCodecEitherFormatMarshal struct{}
|
||||
type encodePlanJSONCodecEitherFormatMarshal struct {
|
||||
marshal func(v any) ([]byte, error)
|
||||
}
|
||||
|
||||
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
jsonBytes, err := e.marshal(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -108,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new
|
|||
return buf, nil
|
||||
}
|
||||
|
||||
func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
switch target.(type) {
|
||||
case *string:
|
||||
return scanPlanAnyToString{}
|
||||
|
@ -141,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
|
|||
return &scanPlanSQLScanner{formatCode: format}
|
||||
}
|
||||
|
||||
return scanPlanJSONToJSONUnmarshal{}
|
||||
return &scanPlanJSONToJSONUnmarshal{
|
||||
unmarshal: c.Unmarshal,
|
||||
}
|
||||
}
|
||||
|
||||
type scanPlanAnyToString struct{}
|
||||
|
@ -173,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error {
|
|||
return scanner.ScanBytes(src)
|
||||
}
|
||||
|
||||
type scanPlanJSONToJSONUnmarshal struct{}
|
||||
type scanPlanJSONToJSONUnmarshal struct {
|
||||
unmarshal func(data []byte, v any) error
|
||||
}
|
||||
|
||||
func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
||||
func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
||||
if src == nil {
|
||||
dstValue := reflect.ValueOf(dst)
|
||||
if dstValue.Kind() == reflect.Ptr {
|
||||
|
@ -193,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
|||
elem := reflect.ValueOf(dst).Elem()
|
||||
elem.Set(reflect.Zero(elem.Type()))
|
||||
|
||||
return json.Unmarshal(src, dst)
|
||||
return s.unmarshal(src, dst)
|
||||
}
|
||||
|
||||
func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -206,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
|
|||
return dstBuf, nil
|
||||
}
|
||||
|
||||
func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var dst any
|
||||
err := json.Unmarshal(src, &dst)
|
||||
err := c.Unmarshal(src, &dst)
|
||||
return dst, err
|
||||
}
|
||||
|
|
|
@ -6,9 +6,11 @@ import (
|
|||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -224,3 +226,28 @@ func TestJSONCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
|
|||
require.Equal(t, `{"custom":"thing"}`, jsonStr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||
skipCockroachDB(t, "CockroachDB treats json as jsonb. This causes it to format differently than PostgreSQL.")
|
||||
|
||||
connTestRunner := defaultConnTestRunner
|
||||
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
conn.TypeMap().RegisterType(&pgtype.Type{
|
||||
Name: "json", OID: pgtype.JSONOID, Codec: &pgtype.JSONCodec{
|
||||
Marshal: func(v any) ([]byte, error) {
|
||||
return []byte(`{"custom":"value"}`), nil
|
||||
},
|
||||
Unmarshal: func(data []byte, v any) error {
|
||||
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||
},
|
||||
}})
|
||||
}
|
||||
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||
// There is no space between "custom" and "value" in json type.
|
||||
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom":"value"}`)},
|
||||
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
|
||||
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,29 +2,31 @@ package pgtype
|
|||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type JSONBCodec struct{}
|
||||
type JSONBCodec struct {
|
||||
Marshal func(v any) ([]byte, error)
|
||||
Unmarshal func(data []byte, v any) error
|
||||
}
|
||||
|
||||
func (JSONBCodec) FormatSupported(format int16) bool {
|
||||
func (*JSONBCodec) FormatSupported(format int16) bool {
|
||||
return format == TextFormatCode || format == BinaryFormatCode
|
||||
}
|
||||
|
||||
func (JSONBCodec) PreferredFormat() int16 {
|
||||
func (*JSONBCodec) PreferredFormat() int16 {
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||
func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value)
|
||||
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value)
|
||||
if plan != nil {
|
||||
return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan}
|
||||
}
|
||||
case TextFormatCode:
|
||||
return JSONCodec{}.PlanEncode(m, oid, format, value)
|
||||
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne
|
|||
return plan.textPlan.Encode(value, buf)
|
||||
}
|
||||
|
||||
func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target)
|
||||
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target)
|
||||
if plan != nil {
|
||||
return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan}
|
||||
}
|
||||
case TextFormatCode:
|
||||
return JSONCodec{}.PlanScan(m, oid, format, target)
|
||||
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error {
|
|||
return plan.textPlan.Scan(src[1:], dst)
|
||||
}
|
||||
|
||||
func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
|
|||
}
|
||||
}
|
||||
|
||||
func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a
|
|||
}
|
||||
|
||||
var dst any
|
||||
err := json.Unmarshal(src, &dst)
|
||||
err := c.Unmarshal(src, &dst)
|
||||
return dst, err
|
||||
}
|
||||
|
|
|
@ -2,9 +2,12 @@ package pgtype_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -80,3 +83,26 @@ func TestJSONBCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
|
|||
require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string.
|
||||
})
|
||||
}
|
||||
|
||||
func TestJSONBCodecCustomMarshal(t *testing.T) {
|
||||
connTestRunner := defaultConnTestRunner
|
||||
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
conn.TypeMap().RegisterType(&pgtype.Type{
|
||||
Name: "jsonb", OID: pgtype.JSONBOID, Codec: &pgtype.JSONBCodec{
|
||||
Marshal: func(v any) ([]byte, error) {
|
||||
return []byte(`{"custom":"value"}`), nil
|
||||
},
|
||||
Unmarshal: func(data []byte, v any) error {
|
||||
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||
},
|
||||
}})
|
||||
}
|
||||
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{
|
||||
// There is space between "custom" and "value" in jsonb type.
|
||||
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom": "value"}`)},
|
||||
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
|
||||
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
|
|
@ -65,8 +65,8 @@ func initDefaultMap() {
|
|||
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
|
||||
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
|
||||
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
|
||||
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
|
||||
|
|
Loading…
Reference in New Issue