diff --git a/query.go b/query.go index 121dcfe3..fc3f405b 100644 --- a/query.go +++ b/query.go @@ -507,12 +507,6 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - doneChan := make(chan struct{}) go func() { @@ -529,9 +523,9 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} if err != nil { select { case <-ctx.Done(): - return nil, ctx.Err() + return rows, ctx.Err() case doneChan <- struct{}{}: - return nil, err + return rows, err } } @@ -540,3 +534,8 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} return rows, nil } + +func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { + rows, _ := c.QueryContext(ctx, sql, args...) + return (*Row)(rows) +} diff --git a/query_test.go b/query_test.go index ca05fb42..6909ba1e 100644 --- a/query_test.go +++ b/query_test.go @@ -1521,5 +1521,71 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { if err != pgx.ErrNoRows { t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") } - +} + +func TestQueryRowContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result) + if err != nil { + t.Fatal(err) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result) + if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + var result []byte + err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result) + if err != context.Canceled { + t.Fatal("Expected context.Canceled error, got %v", err) + } + + checkConn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, checkConn) + + var found bool + err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err != pgx.ErrNoRows { + t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") + } }