From 9c8ef1acddff5f99e8cab446c0fd30a413ba69ab Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 1 Sep 2017 15:57:53 -0500 Subject: [PATCH] Fix pgtype types that can Set database/sql/driver.driver.Valuer Bug was chooseParameterFormatCode would see that type could handle binary format so binary format would be chosen. But encodePreparedStatementArgument would see driver.Valuer first and would encode with that -- which is text mode. So the server would receive a text format value when expecting a binary format value. Discovered while investigating #316 --- query_test.go | 34 +++++++++++++++++++++++++++++++++- values.go | 24 ++++++++++++++++++------ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/query_test.go b/query_test.go index 371c3ec4..fd29cc24 100644 --- a/query_test.go +++ b/query_test.go @@ -11,6 +11,8 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" + "github.com/satori/go.uuid" "github.com/shopspring/decimal" ) @@ -855,7 +857,7 @@ func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) { context.Background(), sql, &pgx.QueryExOptions{ - ParameterOIDs: paramOIDs, + ParameterOIDs: paramOIDs, ResultFormatCodes: []int16{pgx.BinaryFormatCode}, }, queryArgs..., @@ -1000,6 +1002,36 @@ func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { ensureConnValid(t, conn) } +func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &satori.UUID{}, + Name: "uuid", + OID: 2950, + }) + + expected, err := uuid.FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + if err != nil { + t.Fatal(err) + } + + var u2 uuid.UUID + err = conn.QueryRow("select $1::uuid", expected).Scan(&u2) + if err != nil { + t.Fatalf("Scan failed: %v", err) + } + + if expected != u2 { + t.Errorf("Expected u2 to be %v, but it was %v", expected, u2) + } + + ensureConnValid(t, conn) +} + func TestConnQueryDatabaseSQLNullX(t *testing.T) { t.Parallel() diff --git a/values.go b/values.go index 780fca68..efbf5573 100644 --- a/values.go +++ b/values.go @@ -128,12 +128,6 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } return buf, nil - case driver.Valuer: - v, err := arg.Value() - if err != nil { - return nil, err - } - return encodePreparedStatementArgument(ci, buf, oid, v) case string: buf = pgio.AppendInt32(buf, int32(len(arg))) buf = append(buf, arg...) @@ -154,6 +148,16 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype value := dt.Value err := value.Set(arg) if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := arg.Value() + if err != nil { + return nil, err + } + return encodePreparedStatementArgument(ci, buf, oid, v) + } + } + return nil, err } @@ -170,6 +174,14 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype return buf, nil } + if arg, ok := arg.(driver.Valuer); ok { + v, err := arg.Value() + if err != nil { + return nil, err + } + return encodePreparedStatementArgument(ci, buf, oid, v) + } + if strippedArg, ok := stripNamedType(&refVal); ok { return encodePreparedStatementArgument(ci, buf, oid, strippedArg) }