diff --git a/stdlib/sql.go b/stdlib/sql.go index e3307e96..5e8a2ca4 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -507,26 +507,29 @@ func (r *Rows) Next(dest []driver.Value) error { r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions)) for i, fd := range fieldDescriptions { + dataTypeOID := fd.DataTypeOID + format := fd.Format + switch fd.DataTypeOID { case pgtype.BoolOID: var d bool - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } case pgtype.ByteaOID: var d []byte - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } case pgtype.CIDOID: var d pgtype.CID - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } @@ -534,9 +537,9 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.DateOID: var d pgtype.Date - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } @@ -544,44 +547,44 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.Float4OID: var d float32 - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return float64(d), err } case pgtype.Float8OID: var d float64 - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } case pgtype.Int2OID: var d int16 - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return int64(d), err } case pgtype.Int4OID: var d int32 - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return int64(d), err } case pgtype.Int8OID: var d int64 - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } case pgtype.JSONOID: var d pgtype.JSON - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } @@ -589,9 +592,9 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.JSONBOID: var d pgtype.JSONB - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } @@ -599,9 +602,9 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.OIDOID: var d pgtype.OIDValue - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } @@ -609,9 +612,9 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.TimestampOID: var d pgtype.Timestamp - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } @@ -619,9 +622,9 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.TimestamptzOID: var d pgtype.Timestamptz - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } @@ -629,9 +632,9 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.XIDOID: var d pgtype.XID - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } @@ -639,9 +642,9 @@ func (r *Rows) Next(dest []driver.Value) error { } default: var d string - scanPlan := ci.PlanScan(fd.DataTypeOID, fd.Format, &d) + scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, fd.DataTypeOID, fd.Format, src, &d) + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 3f6958f2..87d82943 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -261,6 +261,22 @@ func TestConnQuery(t *testing.T) { }) } +// https://github.com/jackc/pgx/issues/781 +func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + var s string + var b bool + + rows, err := db.Query("select true, 'foo'") + require.NoError(t, err) + + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&b, &s)) + assert.Equal(t, true, b) + assert.Equal(t, "foo", s) + }) +} + func TestConnQueryNull(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select $1::int", nil)