Add support for custom JSON marshal and unmarshal.

The Codec interface is now implemented by *pgtype.JSONCodec
and *pgtype.JSONBCodec instead of pgtype.JSONCodec and
pgtype.JSONBCodec, respectively. This is technically a breaking
change, but it is extremely unlikely that anyone is depending on this,
and if there is downstream breakage it is trivial to fix.

Fixes #2005.
pull/2026/head
Mitar 2024-05-13 17:36:42 +02:00 committed by Jack Christensen
parent e1b90cf620
commit 732889728f
5 changed files with 100 additions and 32 deletions

View File

@ -8,17 +8,20 @@ import (
"reflect" "reflect"
) )
type JSONCodec struct{} type JSONCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}
func (JSONCodec) FormatSupported(format int16) bool { func (*JSONCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode return format == TextFormatCode || format == BinaryFormatCode
} }
func (JSONCodec) PreferredFormat() int16 { func (*JSONCodec) PreferredFormat() int16 {
return TextFormatCode return TextFormatCode
} }
func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch value.(type) { switch value.(type) {
case string: case string:
return encodePlanJSONCodecEitherFormatString{} return encodePlanJSONCodecEitherFormatString{}
@ -44,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
// //
// https://github.com/jackc/pgx/issues/1681 // https://github.com/jackc/pgx/issues/1681
case json.Marshaler: case json.Marshaler:
return encodePlanJSONCodecEitherFormatMarshal{} return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
} }
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
@ -61,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
} }
} }
return encodePlanJSONCodecEitherFormatMarshal{} return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
} }
type encodePlanJSONCodecEitherFormatString struct{} type encodePlanJSONCodecEitherFormatString struct{}
@ -96,10 +103,12 @@ func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byt
return buf, nil return buf, nil
} }
type encodePlanJSONCodecEitherFormatMarshal struct{} type encodePlanJSONCodecEitherFormatMarshal struct {
marshal func(v any) ([]byte, error)
}
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := json.Marshal(value) jsonBytes, err := e.marshal(value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -108,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new
return buf, nil return buf, nil
} }
func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch target.(type) { switch target.(type) {
case *string: case *string:
return scanPlanAnyToString{} return scanPlanAnyToString{}
@ -141,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
return &scanPlanSQLScanner{formatCode: format} return &scanPlanSQLScanner{formatCode: format}
} }
return scanPlanJSONToJSONUnmarshal{} return &scanPlanJSONToJSONUnmarshal{
unmarshal: c.Unmarshal,
}
} }
type scanPlanAnyToString struct{} type scanPlanAnyToString struct{}
@ -173,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error {
return scanner.ScanBytes(src) return scanner.ScanBytes(src)
} }
type scanPlanJSONToJSONUnmarshal struct{} type scanPlanJSONToJSONUnmarshal struct {
unmarshal func(data []byte, v any) error
}
func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
if src == nil { if src == nil {
dstValue := reflect.ValueOf(dst) dstValue := reflect.ValueOf(dst)
if dstValue.Kind() == reflect.Ptr { if dstValue.Kind() == reflect.Ptr {
@ -193,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
elem := reflect.ValueOf(dst).Elem() elem := reflect.ValueOf(dst).Elem()
elem.Set(reflect.Zero(elem.Type())) elem.Set(reflect.Zero(elem.Type()))
return json.Unmarshal(src, dst) return s.unmarshal(src, dst)
} }
func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
@ -206,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
return dstBuf, nil return dstBuf, nil
} }
func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
var dst any var dst any
err := json.Unmarshal(src, &dst) err := c.Unmarshal(src, &dst)
return dst, err return dst, err
} }

View File

@ -6,9 +6,11 @@ import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"reflect"
"testing" "testing"
pgx "github.com/jackc/pgx/v5" pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest" "github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -224,3 +226,28 @@ func TestJSONCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
require.Equal(t, `{"custom":"thing"}`, jsonStr) require.Equal(t, `{"custom":"thing"}`, jsonStr)
}) })
} }
func TestJSONCodecCustomMarshal(t *testing.T) {
skipCockroachDB(t, "CockroachDB treats json as jsonb. This causes it to format differently than PostgreSQL.")
connTestRunner := defaultConnTestRunner
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "json", OID: pgtype.JSONOID, Codec: &pgtype.JSONCodec{
Marshal: func(v any) ([]byte, error) {
return []byte(`{"custom":"value"}`), nil
},
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
}
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
// There is no space between "custom" and "value" in json type.
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom":"value"}`)},
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
}},
})
}

View File

@ -2,29 +2,31 @@ package pgtype
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json"
"fmt" "fmt"
) )
type JSONBCodec struct{} type JSONBCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}
func (JSONBCodec) FormatSupported(format int16) bool { func (*JSONBCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode return format == TextFormatCode || format == BinaryFormatCode
} }
func (JSONBCodec) PreferredFormat() int16 { func (*JSONBCodec) PreferredFormat() int16 {
return TextFormatCode return TextFormatCode
} }
func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value) plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value)
if plan != nil { if plan != nil {
return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan}
} }
case TextFormatCode: case TextFormatCode:
return JSONCodec{}.PlanEncode(m, oid, format, value) return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value)
} }
return nil return nil
@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne
return plan.textPlan.Encode(value, buf) return plan.textPlan.Encode(value, buf)
} }
func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target) plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target)
if plan != nil { if plan != nil {
return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan}
} }
case TextFormatCode: case TextFormatCode:
return JSONCodec{}.PlanScan(m, oid, format, target) return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target)
} }
return nil return nil
@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error {
return plan.textPlan.Scan(src[1:], dst) return plan.textPlan.Scan(src[1:], dst)
} }
func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
} }
} }
func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a
} }
var dst any var dst any
err := json.Unmarshal(src, &dst) err := c.Unmarshal(src, &dst)
return dst, err return dst, err
} }

View File

@ -2,9 +2,12 @@ package pgtype_test
import ( import (
"context" "context"
"encoding/json"
"reflect"
"testing" "testing"
pgx "github.com/jackc/pgx/v5" pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest" "github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -80,3 +83,26 @@ func TestJSONBCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string. require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string.
}) })
} }
func TestJSONBCodecCustomMarshal(t *testing.T) {
connTestRunner := defaultConnTestRunner
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "jsonb", OID: pgtype.JSONBOID, Codec: &pgtype.JSONBCodec{
Marshal: func(v any) ([]byte, error) {
return []byte(`{"custom":"value"}`), nil
},
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
}
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{
// There is space between "custom" and "value" in jsonb type.
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom": "value"}`)},
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
}},
})
}

View File

@ -65,8 +65,8 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})