Always return non-nil *Rows from Query to fix QueryRow

Since QueryRow delegates to Query, it needs Query to always return
non-nil *Rows to prevent a nil pointer deference when the QueryRow
caller calls Scan(). This commit fixes the few returns in QueryEx that
return nil on errors rather than *Rows with its err field set.
pull/315/head
Kelsey Francis 2017-08-31 11:31:23 -07:00
parent 47c0e9cbac
commit fc18cc8d76
2 changed files with 45 additions and 6 deletions

View File

@ -364,19 +364,21 @@ type QueryExOptions struct {
}
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
rows = c.getRows(sql, args)
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
rows.fatal(err)
return rows, err
}
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
rows.fatal(err)
return rows, err
}
c.lastActivityTime = time.Now()
rows = c.getRows(sql, args)
if err := c.lock(); err != nil {
rows.fatal(err)
return rows, err
@ -413,14 +415,14 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
if err != nil && fatalWriteErr(n, err) {
rows.fatal(err)
c.die(err)
return nil, err
return rows, err
}
c.pendingReadyForQueryCount++
fieldDescriptions, err := c.readUntilRowDescription()
if err != nil {
rows.fatal(err)
return nil, err
return rows, err
}
if len(options.ResultFormatCodes) == 0 {

View File

@ -835,6 +835,43 @@ func TestQueryRowErrors(t *testing.T) {
}
}
func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := `
with t as (
select 1::int8 as some_int, 'foo'::text as some_text
)
select some_int from t where some_text = $1`
paramOIDs := []pgtype.OID{pgtype.TextArrayOID}
queryArgs := []interface{}{"bar"}
expectedErr := "operator does not exist: text = text[] (SQLSTATE 42883)"
var result int64
err := conn.QueryRowEx(
context.Background(),
sql,
&pgx.QueryExOptions{
ParameterOIDs: paramOIDs,
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
},
queryArgs...,
).Scan(&result)
if err == nil {
t.Errorf("Unexpected success (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", sql, paramOIDs, queryArgs)
}
if err != nil && !strings.Contains(err.Error(), expectedErr) {
t.Errorf("Expected error to contain %s, but got %v (sql -> %v, paramOIDs -> %v, queryArgs -> %v)",
expectedErr, err, sql, paramOIDs, queryArgs)
}
ensureConnValid(t, conn)
}
func TestQueryRowNoResults(t *testing.T) {
t.Parallel()