mirror of https://github.com/jackc/pgx.git
Merge pull request #2019 from jackc/fix-encode-driver-valuer-on-pointer
Fix encode driver.Valuer on pointerpull/2035/head
commit
b4911f1da7
|
@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.
|
|||
|
||||
## Supported Go and PostgreSQL Versions
|
||||
|
||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.20 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||
|
||||
## Version Policy
|
||||
|
||||
|
|
2
conn.go
2
conn.go
|
@ -10,7 +10,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/anynil"
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
"github.com/jackc/pgx/v5/internal/stmtcache"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
@ -755,7 +754,6 @@ optionLoop:
|
|||
}
|
||||
|
||||
c.eqb.reset()
|
||||
anynil.NormalizeSlice(args)
|
||||
rows := c.getRows(ctx, sql, args)
|
||||
|
||||
var err error
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/anynil"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
@ -23,10 +21,15 @@ type ExtendedQueryBuilder struct {
|
|||
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
|
||||
eqb.reset()
|
||||
|
||||
anynil.NormalizeSlice(args)
|
||||
|
||||
if sd == nil {
|
||||
return eqb.appendParamsForQueryExecModeExec(m, args)
|
||||
for i := range args {
|
||||
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(sd.ParamOIDs) != len(args) {
|
||||
|
@ -113,10 +116,6 @@ func (eqb *ExtendedQueryBuilder) reset() {
|
|||
}
|
||||
|
||||
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
|
||||
if anynil.Is(arg) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if eqb.paramValueBytes == nil {
|
||||
eqb.paramValueBytes = make([]byte, 0, 128)
|
||||
}
|
||||
|
@ -145,74 +144,3 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui
|
|||
|
||||
return m.FormatCodeForOID(oid)
|
||||
}
|
||||
|
||||
// appendParamsForQueryExecModeExec appends the args to eqb.
|
||||
//
|
||||
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
|
||||
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
|
||||
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
|
||||
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
|
||||
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
|
||||
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
|
||||
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
|
||||
// type conversion it takes the date directly and ignores time zone (i.e. it works).
|
||||
//
|
||||
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
|
||||
// no way to safely use binary or to specify the parameter OIDs.
|
||||
func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
|
||||
for _, arg := range args {
|
||||
if arg == nil {
|
||||
err := eqb.appendParam(m, 0, TextFormatCode, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
dt, ok := m.TypeForValue(arg)
|
||||
if !ok {
|
||||
var tv pgtype.TextValuer
|
||||
if tv, ok = arg.(pgtype.TextValuer); ok {
|
||||
t, err := tv.TextValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dt, ok = m.TypeForOID(pgtype.TextOID)
|
||||
if ok {
|
||||
arg = t
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
var dv driver.Valuer
|
||||
if dv, ok = arg.(driver.Valuer); ok {
|
||||
v, err := dv.Value()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dt, ok = m.TypeForValue(v)
|
||||
if ok {
|
||||
arg = v
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
var str fmt.Stringer
|
||||
if str, ok = arg.(fmt.Stringer); ok {
|
||||
dt, ok = m.TypeForOID(pgtype.TextOID)
|
||||
if ok {
|
||||
arg = str.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
|
||||
}
|
||||
err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
2
go.mod
2
go.mod
|
@ -1,6 +1,6 @@
|
|||
module github.com/jackc/pgx/v5
|
||||
|
||||
go 1.19
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/jackc/pgpassfile v1.0.0
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
package anynil
|
||||
|
||||
import "reflect"
|
||||
|
||||
// Is returns true if value is any type of nil. e.g. nil or []byte(nil).
|
||||
func Is(value any) bool {
|
||||
if value == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(value)
|
||||
switch refVal.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
|
||||
return refVal.IsNil()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified.
|
||||
func Normalize(v any) any {
|
||||
if Is(v) {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is
|
||||
// mutated in place.
|
||||
func NormalizeSlice(s []any) {
|
||||
for i := range s {
|
||||
if Is(s[i]) {
|
||||
s[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/anynil"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
|
@ -230,7 +229,7 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan
|
|||
|
||||
// target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the
|
||||
// scan of the elements.
|
||||
if anynil.Is(target) {
|
||||
if isNil, _ := isNilDriverValuer(target); isNil {
|
||||
arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter)
|
||||
}
|
||||
|
||||
|
|
|
@ -139,6 +139,16 @@ Compatibility with database/sql
|
|||
pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer
|
||||
interfaces.
|
||||
|
||||
Encoding Typed Nils
|
||||
|
||||
pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec
|
||||
system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil).
|
||||
|
||||
However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore,
|
||||
driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See
|
||||
https://github.com/golang/go/issues/8415 and
|
||||
https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870.
|
||||
|
||||
Child Records
|
||||
|
||||
pgtype's support for arrays and composite records can be used to load records and their children in a single query. See
|
||||
|
|
|
@ -1912,8 +1912,17 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error)
|
|||
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
|
||||
// written.
|
||||
func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
if isNil, callNilDriverValuer := isNilDriverValuer(value); isNil {
|
||||
if callNilDriverValuer {
|
||||
newBuf, err = (&encodePlanDriverValuer{m: m, oid: oid, formatCode: formatCode}).Encode(value, buf)
|
||||
if err != nil {
|
||||
return nil, newEncodeError(value, m, oid, formatCode, err)
|
||||
}
|
||||
|
||||
return newBuf, nil
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
plan := m.PlanEncode(oid, formatCode, value)
|
||||
|
@ -1968,3 +1977,55 @@ func (w *sqlScannerWrapper) Scan(src any) error {
|
|||
|
||||
return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v)
|
||||
}
|
||||
|
||||
// canBeNil returns true if value can be nil.
|
||||
func canBeNil(value any) bool {
|
||||
refVal := reflect.ValueOf(value)
|
||||
kind := refVal.Kind()
|
||||
switch kind {
|
||||
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil
|
||||
// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual
|
||||
// type. Yuck.
|
||||
//
|
||||
// This can be simplified in Go 1.22 with reflect.TypeFor.
|
||||
//
|
||||
// var valuerReflectType = reflect.TypeFor[driver.Valuer]()
|
||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
|
||||
// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement
|
||||
// driver.Valuer if it is only implemented by T.
|
||||
func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) {
|
||||
if value == nil {
|
||||
return true, false
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(value)
|
||||
kind := refVal.Kind()
|
||||
switch kind {
|
||||
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
|
||||
if !refVal.IsNil() {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if _, ok := value.(driver.Valuer); ok {
|
||||
if kind == reflect.Ptr {
|
||||
// The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on *T
|
||||
// by checking if it is not implemented on *T.
|
||||
return true, !refVal.Type().Elem().Implements(valuerReflectType)
|
||||
} else {
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
|
||||
return true, false
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
|
157
query_test.go
157
query_test.go
|
@ -4,6 +4,8 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -1171,6 +1173,161 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type nilPointerAsEmptyJSONObject struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (v *nilPointerAsEmptyJSONObject) Value() (driver.Value, error) {
|
||||
if v == nil {
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1566
|
||||
func TestConnQueryDatabaseSQLDriverValuerCalledOnNilPointerImplementers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table t(v json not null)")
|
||||
|
||||
var v *nilPointerAsEmptyJSONObject
|
||||
commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "INSERT 0 1", commandTag.String())
|
||||
|
||||
var s string
|
||||
err = conn.QueryRow(context.Background(), "select v from t").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "{}", s)
|
||||
|
||||
_, err = conn.Exec(context.Background(), `delete from t`)
|
||||
require.NoError(t, err)
|
||||
|
||||
v = &nilPointerAsEmptyJSONObject{ID: "1", Name: "foo"}
|
||||
commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "INSERT 0 1", commandTag.String())
|
||||
|
||||
var v2 *nilPointerAsEmptyJSONObject
|
||||
err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, v, v2)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type nilSliceAsEmptySlice []byte
|
||||
|
||||
func (j nilSliceAsEmptySlice) Value() (driver.Value, error) {
|
||||
if len(j) == 0 {
|
||||
return []byte("[]"), nil
|
||||
}
|
||||
|
||||
return []byte(j), nil
|
||||
}
|
||||
|
||||
func (j *nilSliceAsEmptySlice) UnmarshalJSON(data []byte) error {
|
||||
*j = bytes.Clone(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1860
|
||||
func TestConnQueryDatabaseSQLDriverValuerCalledOnNilSliceImplementers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table t(v json not null)")
|
||||
|
||||
var v nilSliceAsEmptySlice
|
||||
commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "INSERT 0 1", commandTag.String())
|
||||
|
||||
var s string
|
||||
err = conn.QueryRow(context.Background(), "select v from t").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "[]", s)
|
||||
|
||||
_, err = conn.Exec(context.Background(), `delete from t`)
|
||||
require.NoError(t, err)
|
||||
|
||||
v = nilSliceAsEmptySlice(`{"name": "foo"}`)
|
||||
commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "INSERT 0 1", commandTag.String())
|
||||
|
||||
var v2 nilSliceAsEmptySlice
|
||||
err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, v, v2)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type nilMapAsEmptyObject map[string]any
|
||||
|
||||
func (j nilMapAsEmptyObject) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (j *nilMapAsEmptyObject) UnmarshalJSON(data []byte) error {
|
||||
var m map[string]any
|
||||
err := json.Unmarshal(data, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*j = m
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/pull/2019#discussion_r1605806751
|
||||
func TestConnQueryDatabaseSQLDriverValuerCalledOnNilMapImplementers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table t(v json not null)")
|
||||
|
||||
var v nilMapAsEmptyObject
|
||||
commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "INSERT 0 1", commandTag.String())
|
||||
|
||||
var s string
|
||||
err = conn.QueryRow(context.Background(), "select v from t").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "{}", s)
|
||||
|
||||
_, err = conn.Exec(context.Background(), `delete from t`)
|
||||
require.NoError(t, err)
|
||||
|
||||
v = nilMapAsEmptyObject{"name": "foo"}
|
||||
commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "INSERT 0 1", commandTag.String())
|
||||
|
||||
var v2 nilMapAsEmptyObject
|
||||
err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, v, v2)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryDatabaseSQLDriverScannerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ package pgx
|
|||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/anynil"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
@ -15,10 +14,6 @@ const (
|
|||
)
|
||||
|
||||
func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) {
|
||||
if anynil.Is(arg) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf, err := m.Encode(0, TextFormatCode, arg, []byte{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -30,10 +25,6 @@ func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) {
|
|||
}
|
||||
|
||||
func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) {
|
||||
if anynil.Is(arg) {
|
||||
return pgio.AppendInt32(buf, -1), nil
|
||||
}
|
||||
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)
|
||||
|
|
Loading…
Reference in New Issue