Merge pull request #2019 from jackc/fix-encode-driver-valuer-on-pointer

Fix encode driver.Valuer on pointer
pull/2035/head
Jack Christensen 2024-05-25 11:20:25 -05:00 committed by GitHub
commit b4911f1da7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 241 additions and 133 deletions

View File

@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.
## Supported Go and PostgreSQL Versions
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.20 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
## Version Policy

View File

@ -10,7 +10,6 @@ import (
"strings"
"time"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/sanitize"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn"
@ -755,7 +754,6 @@ optionLoop:
}
c.eqb.reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args)
var err error

View File

@ -1,10 +1,8 @@
package pgx
import (
"database/sql/driver"
"fmt"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
@ -23,10 +21,15 @@ type ExtendedQueryBuilder struct {
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
eqb.reset()
anynil.NormalizeSlice(args)
if sd == nil {
return eqb.appendParamsForQueryExecModeExec(m, args)
for i := range args {
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
return nil
}
if len(sd.ParamOIDs) != len(args) {
@ -113,10 +116,6 @@ func (eqb *ExtendedQueryBuilder) reset() {
}
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
if anynil.Is(arg) {
return nil, nil
}
if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128)
}
@ -145,74 +144,3 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui
return m.FormatCodeForOID(oid)
}
// appendParamsForQueryExecModeExec appends the args to eqb.
//
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
// type conversion it takes the date directly and ignores time zone (i.e. it works).
//
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
// no way to safely use binary or to specify the parameter OIDs.
func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
for _, arg := range args {
if arg == nil {
err := eqb.appendParam(m, 0, TextFormatCode, arg)
if err != nil {
return err
}
} else {
dt, ok := m.TypeForValue(arg)
if !ok {
var tv pgtype.TextValuer
if tv, ok = arg.(pgtype.TextValuer); ok {
t, err := tv.TextValue()
if err != nil {
return err
}
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = t
}
}
}
if !ok {
var dv driver.Valuer
if dv, ok = arg.(driver.Valuer); ok {
v, err := dv.Value()
if err != nil {
return err
}
dt, ok = m.TypeForValue(v)
if ok {
arg = v
}
}
}
if !ok {
var str fmt.Stringer
if str, ok = arg.(fmt.Stringer); ok {
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = str.String()
}
}
}
if !ok {
return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
}
err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
if err != nil {
return err
}
}
}
return nil
}

2
go.mod
View File

@ -1,6 +1,6 @@
module github.com/jackc/pgx/v5
go 1.19
go 1.20
require (
github.com/jackc/pgpassfile v1.0.0

View File

@ -1,36 +0,0 @@
package anynil
import "reflect"
// Is returns true if value is any type of nil. e.g. nil or []byte(nil).
func Is(value any) bool {
if value == nil {
return true
}
refVal := reflect.ValueOf(value)
switch refVal.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return refVal.IsNil()
default:
return false
}
}
// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified.
func Normalize(v any) any {
if Is(v) {
return nil
}
return v
}
// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is
// mutated in place.
func NormalizeSlice(s []any) {
for i := range s {
if Is(s[i]) {
s[i] = nil
}
}
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"reflect"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -230,7 +229,7 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan
// target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the
// scan of the elements.
if anynil.Is(target) {
if isNil, _ := isNilDriverValuer(target); isNil {
arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter)
}

View File

@ -139,6 +139,16 @@ Compatibility with database/sql
pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer
interfaces.
Encoding Typed Nils
pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec
system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil).
However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore,
driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See
https://github.com/golang/go/issues/8415 and
https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870.
Child Records
pgtype's support for arrays and composite records can be used to load records and their children in a single query. See

View File

@ -1912,8 +1912,17 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error)
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
// written.
func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) {
if value == nil {
return nil, nil
if isNil, callNilDriverValuer := isNilDriverValuer(value); isNil {
if callNilDriverValuer {
newBuf, err = (&encodePlanDriverValuer{m: m, oid: oid, formatCode: formatCode}).Encode(value, buf)
if err != nil {
return nil, newEncodeError(value, m, oid, formatCode, err)
}
return newBuf, nil
} else {
return nil, nil
}
}
plan := m.PlanEncode(oid, formatCode, value)
@ -1968,3 +1977,55 @@ func (w *sqlScannerWrapper) Scan(src any) error {
return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v)
}
// canBeNil returns true if value can be nil.
func canBeNil(value any) bool {
refVal := reflect.ValueOf(value)
kind := refVal.Kind()
switch kind {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return true
default:
return false
}
}
// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil
// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual
// type. Yuck.
//
// This can be simplified in Go 1.22 with reflect.TypeFor.
//
// var valuerReflectType = reflect.TypeFor[driver.Valuer]()
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement
// driver.Valuer if it is only implemented by T.
func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) {
if value == nil {
return true, false
}
refVal := reflect.ValueOf(value)
kind := refVal.Kind()
switch kind {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
if !refVal.IsNil() {
return false, false
}
if _, ok := value.(driver.Valuer); ok {
if kind == reflect.Ptr {
// The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on *T
// by checking if it is not implemented on *T.
return true, !refVal.Type().Elem().Implements(valuerReflectType)
} else {
return true, true
}
}
return true, false
default:
return false, false
}
}

View File

@ -4,6 +4,8 @@ import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"os"
@ -1171,6 +1173,161 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
ensureConnValid(t, conn)
}
type nilPointerAsEmptyJSONObject struct {
ID string
Name string
}
func (v *nilPointerAsEmptyJSONObject) Value() (driver.Value, error) {
if v == nil {
return "{}", nil
}
return json.Marshal(v)
}
// https://github.com/jackc/pgx/issues/1566
func TestConnQueryDatabaseSQLDriverValuerCalledOnNilPointerImplementers(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table t(v json not null)")
var v *nilPointerAsEmptyJSONObject
commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())
var s string
err = conn.QueryRow(context.Background(), "select v from t").Scan(&s)
require.NoError(t, err)
require.Equal(t, "{}", s)
_, err = conn.Exec(context.Background(), `delete from t`)
require.NoError(t, err)
v = &nilPointerAsEmptyJSONObject{ID: "1", Name: "foo"}
commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())
var v2 *nilPointerAsEmptyJSONObject
err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2)
require.NoError(t, err)
require.Equal(t, v, v2)
ensureConnValid(t, conn)
}
type nilSliceAsEmptySlice []byte
func (j nilSliceAsEmptySlice) Value() (driver.Value, error) {
if len(j) == 0 {
return []byte("[]"), nil
}
return []byte(j), nil
}
func (j *nilSliceAsEmptySlice) UnmarshalJSON(data []byte) error {
*j = bytes.Clone(data)
return nil
}
// https://github.com/jackc/pgx/issues/1860
func TestConnQueryDatabaseSQLDriverValuerCalledOnNilSliceImplementers(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table t(v json not null)")
var v nilSliceAsEmptySlice
commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())
var s string
err = conn.QueryRow(context.Background(), "select v from t").Scan(&s)
require.NoError(t, err)
require.Equal(t, "[]", s)
_, err = conn.Exec(context.Background(), `delete from t`)
require.NoError(t, err)
v = nilSliceAsEmptySlice(`{"name": "foo"}`)
commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())
var v2 nilSliceAsEmptySlice
err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2)
require.NoError(t, err)
require.Equal(t, v, v2)
ensureConnValid(t, conn)
}
type nilMapAsEmptyObject map[string]any
func (j nilMapAsEmptyObject) Value() (driver.Value, error) {
if j == nil {
return []byte("{}"), nil
}
return json.Marshal(j)
}
func (j *nilMapAsEmptyObject) UnmarshalJSON(data []byte) error {
var m map[string]any
err := json.Unmarshal(data, &m)
if err != nil {
return err
}
*j = m
return nil
}
// https://github.com/jackc/pgx/pull/2019#discussion_r1605806751
func TestConnQueryDatabaseSQLDriverValuerCalledOnNilMapImplementers(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table t(v json not null)")
var v nilMapAsEmptyObject
commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())
var s string
err = conn.QueryRow(context.Background(), "select v from t").Scan(&s)
require.NoError(t, err)
require.Equal(t, "{}", s)
_, err = conn.Exec(context.Background(), `delete from t`)
require.NoError(t, err)
v = nilMapAsEmptyObject{"name": "foo"}
commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())
var v2 nilMapAsEmptyObject
err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2)
require.NoError(t, err)
require.Equal(t, v, v2)
ensureConnValid(t, conn)
}
func TestConnQueryDatabaseSQLDriverScannerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) {
t.Parallel()

View File

@ -3,7 +3,6 @@ package pgx
import (
"errors"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgtype"
)
@ -15,10 +14,6 @@ const (
)
func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) {
if anynil.Is(arg) {
return nil, nil
}
buf, err := m.Encode(0, TextFormatCode, arg, []byte{})
if err != nil {
return nil, err
@ -30,10 +25,6 @@ func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) {
}
func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) {
if anynil.Is(arg) {
return pgio.AppendInt32(buf, -1), nil
}
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)