diff --git a/query.go b/query.go index 811e95b1..e37e6120 100644 --- a/query.go +++ b/query.go @@ -364,19 +364,21 @@ type QueryExOptions struct { } func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { + rows = c.getRows(sql, args) + err = c.waitForPreviousCancelQuery(ctx) if err != nil { - return nil, err + rows.fatal(err) + return rows, err } if err := c.ensureConnectionReadyForQuery(); err != nil { - return nil, err + rows.fatal(err) + return rows, err } c.lastActivityTime = time.Now() - rows = c.getRows(sql, args) - if err := c.lock(); err != nil { rows.fatal(err) return rows, err @@ -413,14 +415,14 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, if err != nil && fatalWriteErr(n, err) { rows.fatal(err) c.die(err) - return nil, err + return rows, err } c.pendingReadyForQueryCount++ fieldDescriptions, err := c.readUntilRowDescription() if err != nil { rows.fatal(err) - return nil, err + return rows, err } if len(options.ResultFormatCodes) == 0 { diff --git a/query_test.go b/query_test.go index 9379bd23..371c3ec4 100644 --- a/query_test.go +++ b/query_test.go @@ -835,6 +835,43 @@ func TestQueryRowErrors(t *testing.T) { } } +func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := ` + with t as ( + select 1::int8 as some_int, 'foo'::text as some_text + ) + select some_int from t where some_text = $1` + paramOIDs := []pgtype.OID{pgtype.TextArrayOID} + queryArgs := []interface{}{"bar"} + expectedErr := "operator does not exist: text = text[] (SQLSTATE 42883)" + var result int64 + + err := conn.QueryRowEx( + context.Background(), + sql, + &pgx.QueryExOptions{ + ParameterOIDs: paramOIDs, + ResultFormatCodes: []int16{pgx.BinaryFormatCode}, + }, + queryArgs..., + ).Scan(&result) + + if err == nil { + t.Errorf("Unexpected success (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", sql, paramOIDs, queryArgs) + } + if err != nil && !strings.Contains(err.Error(), expectedErr) { + t.Errorf("Expected error to contain %s, but got %v (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", + expectedErr, err, sql, paramOIDs, queryArgs) + } + + ensureConnValid(t, conn) +} + func TestQueryRowNoResults(t *testing.T) { t.Parallel()