From 5deea5b9712e2afd8764d4dd84c62f9c36e221b5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Sep 2018 21:37:16 -0500 Subject: [PATCH] Convert driver.Valuer's earlier in bind path fixes #449 --- messages.go | 23 +++++++++++++++++++++++ values.go | 18 ------------------ 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/messages.go b/messages.go index 97e89295..01b799b2 100644 --- a/messages.go +++ b/messages.go @@ -1,6 +1,7 @@ package pgx import ( + "database/sql/driver" "math" "reflect" "time" @@ -162,6 +163,12 @@ func appendBind( buf = append(buf, preparedStatement...) buf = append(buf, 0) + var err error + arguments, err = convertDriverValuers(arguments) + if err != nil { + return nil, err + } + buf = pgio.AppendInt16(buf, int16(len(parameterOIDs))) for i, oid := range parameterOIDs { buf = pgio.AppendInt16(buf, chooseParameterFormatCode(connInfo, oid, arguments[i])) @@ -185,6 +192,22 @@ func appendBind( return buf, nil } +func convertDriverValuers(args []interface{}) ([]interface{}, error) { + for i, arg := range args { + switch arg := arg.(type) { + case pgtype.BinaryEncoder: + case pgtype.TextEncoder: + case driver.Valuer: + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + args[i] = v + } + } + return args, nil +} + // appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. func appendExecute(buf []byte, portal string, maxRows uint32) []byte { buf = append(buf, 'E') diff --git a/values.go b/values.go index 696d5764..0c571d74 100644 --- a/values.go +++ b/values.go @@ -200,14 +200,6 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype return buf, nil } - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err - } - return encodePreparedStatementArgument(ci, buf, oid, v) - } - if strippedArg, ok := stripNamedType(&refVal); ok { return encodePreparedStatementArgument(ci, buf, oid, strippedArg) } @@ -227,16 +219,6 @@ func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.OID, arg interfac if dt, ok := ci.DataTypeForOID(oid); ok { if _, ok := dt.Value.(pgtype.BinaryEncoder); ok { - if arg, ok := arg.(driver.Valuer); ok { - if err := dt.Value.Set(arg); err != nil { - if value, err := callValuerValue(arg); err == nil { - if _, ok := value.(string); ok { - return TextFormatCode - } - } - } - } - return BinaryFormatCode } }