Fix stdlib decoding error with certain order and combination of fields

fixes #781
pull/784/head
Jack Christensen 2020-06-29 09:38:53 -05:00
parent 34cbb61138
commit bf47a3d0a4
2 changed files with 51 additions and 32 deletions

View File

@ -507,26 +507,29 @@ func (r *Rows) Next(dest []driver.Value) error {
r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions)) r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions))
for i, fd := range fieldDescriptions { for i, fd := range fieldDescriptions {
dataTypeOID := fd.DataTypeOID
format := fd.Format
switch fd.DataTypeOID { switch fd.DataTypeOID {
case pgtype.BoolOID: case pgtype.BoolOID:
var d bool var d bool
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
return d, err return d, err
} }
case pgtype.ByteaOID: case pgtype.ByteaOID:
var d []byte var d []byte
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
return d, err return d, err
} }
case pgtype.CIDOID: case pgtype.CIDOID:
var d pgtype.CID var d pgtype.CID
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -534,9 +537,9 @@ func (r *Rows) Next(dest []driver.Value) error {
} }
case pgtype.DateOID: case pgtype.DateOID:
var d pgtype.Date var d pgtype.Date
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -544,44 +547,44 @@ func (r *Rows) Next(dest []driver.Value) error {
} }
case pgtype.Float4OID: case pgtype.Float4OID:
var d float32 var d float32
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
return float64(d), err return float64(d), err
} }
case pgtype.Float8OID: case pgtype.Float8OID:
var d float64 var d float64
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
return d, err return d, err
} }
case pgtype.Int2OID: case pgtype.Int2OID:
var d int16 var d int16
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
return int64(d), err return int64(d), err
} }
case pgtype.Int4OID: case pgtype.Int4OID:
var d int32 var d int32
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
return int64(d), err return int64(d), err
} }
case pgtype.Int8OID: case pgtype.Int8OID:
var d int64 var d int64
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
return d, err return d, err
} }
case pgtype.JSONOID: case pgtype.JSONOID:
var d pgtype.JSON var d pgtype.JSON
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -589,9 +592,9 @@ func (r *Rows) Next(dest []driver.Value) error {
} }
case pgtype.JSONBOID: case pgtype.JSONBOID:
var d pgtype.JSONB var d pgtype.JSONB
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -599,9 +602,9 @@ func (r *Rows) Next(dest []driver.Value) error {
} }
case pgtype.OIDOID: case pgtype.OIDOID:
var d pgtype.OIDValue var d pgtype.OIDValue
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -609,9 +612,9 @@ func (r *Rows) Next(dest []driver.Value) error {
} }
case pgtype.TimestampOID: case pgtype.TimestampOID:
var d pgtype.Timestamp var d pgtype.Timestamp
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -619,9 +622,9 @@ func (r *Rows) Next(dest []driver.Value) error {
} }
case pgtype.TimestamptzOID: case pgtype.TimestamptzOID:
var d pgtype.Timestamptz var d pgtype.Timestamptz
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -629,9 +632,9 @@ func (r *Rows) Next(dest []driver.Value) error {
} }
case pgtype.XIDOID: case pgtype.XIDOID:
var d pgtype.XID var d pgtype.XID
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -639,9 +642,9 @@ func (r *Rows) Next(dest []driver.Value) error {
} }
default: default:
var d string var d string
scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) scanPlan := ci.PlanScan(dataTypeOID, format, &d)
r.valueFuncs[i] = func(src []byte) (driver.Value, error) { r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
return d, err return d, err
} }
} }

View File

@ -261,6 +261,22 @@ func TestConnQuery(t *testing.T) {
}) })
} }
// https://github.com/jackc/pgx/issues/781
func TestConnQueryDifferentScanPlansIssue781(t *testing.T) {
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
var s string
var b bool
rows, err := db.Query("select true, 'foo'")
require.NoError(t, err)
require.True(t, rows.Next())
require.NoError(t, rows.Scan(&b, &s))
assert.Equal(t, true, b)
assert.Equal(t, "foo", s)
})
}
func TestConnQueryNull(t *testing.T) { func TestConnQueryNull(t *testing.T) {
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
rows, err := db.Query("select $1::int", nil) rows, err := db.Query("select $1::int", nil)