Fix pgtype types that can Set database/sql/driver.driver.Valuer

Bug was chooseParameterFormatCode would see that type could handle
binary format so binary format would be chosen. But
encodePreparedStatementArgument would see driver.Valuer first and
would encode with that -- which is text mode. So the server would
receive a text format value when expecting a binary format value.

Discovered while investigating 
pull/317/head
Jack Christensen 2017-09-01 15:57:53 -05:00
parent e2695be13b
commit 9c8ef1acdd
2 changed files with 51 additions and 7 deletions

View File

@ -11,6 +11,8 @@ import (
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
satori "github.com/jackc/pgx/pgtype/ext/satori-uuid"
"github.com/satori/go.uuid"
"github.com/shopspring/decimal"
)
@ -855,7 +857,7 @@ func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) {
context.Background(),
sql,
&pgx.QueryExOptions{
ParameterOIDs: paramOIDs,
ParameterOIDs: paramOIDs,
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
},
queryArgs...,
@ -1000,6 +1002,36 @@ func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
ensureConnValid(t, conn)
}
func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
conn.ConnInfo.RegisterDataType(pgtype.DataType{
Value: &satori.UUID{},
Name: "uuid",
OID: 2950,
})
expected, err := uuid.FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
if err != nil {
t.Fatal(err)
}
var u2 uuid.UUID
err = conn.QueryRow("select $1::uuid", expected).Scan(&u2)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}
if expected != u2 {
t.Errorf("Expected u2 to be %v, but it was %v", expected, u2)
}
ensureConnValid(t, conn)
}
func TestConnQueryDatabaseSQLNullX(t *testing.T) {
t.Parallel()

View File

@ -128,12 +128,6 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
return buf, nil
case driver.Valuer:
v, err := arg.Value()
if err != nil {
return nil, err
}
return encodePreparedStatementArgument(ci, buf, oid, v)
case string:
buf = pgio.AppendInt32(buf, int32(len(arg)))
buf = append(buf, arg...)
@ -154,6 +148,16 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := arg.Value()
if err != nil {
return nil, err
}
return encodePreparedStatementArgument(ci, buf, oid, v)
}
}
return nil, err
}
@ -170,6 +174,14 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype
return buf, nil
}
if arg, ok := arg.(driver.Valuer); ok {
v, err := arg.Value()
if err != nil {
return nil, err
}
return encodePreparedStatementArgument(ci, buf, oid, v)
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return encodePreparedStatementArgument(ci, buf, oid, strippedArg)
}