From 1cef9075d94cd69fbbe8038dc8ed26ef8dfe9842 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 19:45:44 -0600 Subject: [PATCH] Simply typed nil and driver.Valuer handling * Convert typed nils to untyped nils at beginning of encoding process. * Restore v4 json/jsonb null behavior * Add anynil internal package --- conn.go | 11 +++++-- go_stdlib.go | 61 --------------------------------------- internal/anynil/anynil.go | 36 +++++++++++++++++++++++ messages.go | 19 ------------ pgtype/json_test.go | 4 +-- pgtype/jsonb_test.go | 4 +-- values.go | 25 +++++++++++----- 7 files changed, 66 insertions(+), 94 deletions(-) delete mode 100644 go_stdlib.go create mode 100644 internal/anynil/anynil.go delete mode 100644 messages.go diff --git a/conn.go b/conn.go index 36994698..20968892 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/sanitize" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn/stmtcache" @@ -478,7 +479,9 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, argu c.eqb.Reset() - args, err := convertDriverValuers(arguments) + anynil.NormalizeSlice(arguments) + + args, err := evaluateDriverValuers(arguments) if err != nil { return err } @@ -671,7 +674,8 @@ optionLoop: rows.sql = sd.SQL - args, err = convertDriverValuers(args) + anynil.NormalizeSlice(args) + args, err = evaluateDriverValuers(args) if err != nil { rows.fatal(err) return rows, rows.err @@ -831,7 +835,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} } - args, err := convertDriverValuers(bi.arguments) + anynil.NormalizeSlice(bi.arguments) + args, err := evaluateDriverValuers(bi.arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } diff --git a/go_stdlib.go b/go_stdlib.go deleted file mode 100644 index 9372f9ef..00000000 --- a/go_stdlib.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgx - -import ( - "database/sql/driver" - "reflect" -) - -// This file contains code copied from the Go standard library due to the -// required function not being public. - -// Copyright (c) 2009 The Go Authors. All rights reserved. - -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: - -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. - -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -// From database/sql/convert.go - -var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - -// callValuerValue returns vr.Value(), with one exception: -// If vr.Value is an auto-generated method on a pointer type and the -// pointer is nil, it would panic at runtime in the panicwrap -// method. Treat it like nil instead. -// Issue 8415. -// -// This is so people can implement driver.Value on value types and -// still use nil pointers to those types to mean nil/NULL, just like -// string/*string. -// -// This function is mirrored in the database/sql/driver package. -func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { - if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && - rv.IsNil() && - rv.Type().Elem().Implements(valuerReflectType) { - return nil, nil - } - return vr.Value() -} diff --git a/internal/anynil/anynil.go b/internal/anynil/anynil.go new file mode 100644 index 00000000..57a45b95 --- /dev/null +++ b/internal/anynil/anynil.go @@ -0,0 +1,36 @@ +package anynil + +import "reflect" + +// Is returns true if value is any type of nil. e.g. nil or []byte(nil). +func Is(value interface{}) bool { + if value == nil { + return true + } + + refVal := reflect.ValueOf(value) + switch refVal.Kind() { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + return refVal.IsNil() + default: + return false + } +} + +// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified. +func Normalize(v interface{}) interface{} { + if Is(v) { + return nil + } + return v +} + +// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is +// mutated in place. +func NormalizeSlice(s []interface{}) { + for i := range s { + if Is(s[i]) { + s[i] = nil + } + } +} diff --git a/messages.go b/messages.go deleted file mode 100644 index 01ece44e..00000000 --- a/messages.go +++ /dev/null @@ -1,19 +0,0 @@ -package pgx - -import ( - "database/sql/driver" -) - -func convertDriverValuers(args []interface{}) ([]interface{}, error) { - for i, arg := range args { - switch arg := arg.(type) { - case driver.Valuer: - v, err := callValuerValue(arg) - if err != nil { - return nil, err - } - args[i] = v - } - } - return args, nil -} diff --git a/pgtype/json_test.go b/pgtype/json_test.go index a1dd63fb..39658bfa 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -48,8 +48,8 @@ func TestJSONCodec(t *testing.T) { {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, - {map[string]interface{}(nil), new(string), isExpectedEq(`null`)}, - {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte("null"))}, + {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index fa5ea20e..c26499c6 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -21,8 +21,8 @@ func TestJSONBTranscode(t *testing.T) { {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, - {map[string]interface{}(nil), new(string), isExpectedEq(`null`)}, - {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte("null"))}, + {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) diff --git a/values.go b/values.go index 7d1933b1..fe7f6444 100644 --- a/values.go +++ b/values.go @@ -7,6 +7,7 @@ import ( "reflect" "time" + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" ) @@ -25,18 +26,13 @@ func (e SerializationError) Error() string { } func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) { - if arg == nil { - return nil, nil - } - - refVal := reflect.ValueOf(arg) - if refVal.Kind() == reflect.Ptr && refVal.IsNil() { + if anynil.Is(arg) { return nil, nil } switch arg := arg.(type) { case driver.Valuer: - return callValuerValue(arg) + return arg.Value() case float32: return float64(arg), nil case float64: @@ -90,6 +86,7 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) return string(buf), nil } + refVal := reflect.ValueOf(arg) if refVal.Kind() == reflect.Ptr { arg = refVal.Elem().Interface() return convertSimpleArgument(m, arg) @@ -182,3 +179,17 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { return nil, false } + +func evaluateDriverValuers(args []interface{}) ([]interface{}, error) { + for i, arg := range args { + switch arg := arg.(type) { + case driver.Valuer: + v, err := arg.Value() + if err != nil { + return nil, err + } + args[i] = v + } + } + return args, nil +}