diff --git a/query.go b/query.go index 19b867e2..121dcfe3 100644 --- a/query.go +++ b/query.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "golang.org/x/net/context" "time" ) @@ -49,6 +50,9 @@ type Rows struct { afterClose func(*Rows) unlockConn bool closed bool + + ctx context.Context + doneChan chan struct{} } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -120,6 +124,15 @@ func (rows *Rows) Close() { return } rows.readUntilReadyForQuery() + + if rows.ctx != nil { + select { + case <-rows.ctx.Done(): + rows.err = rows.ctx.Err() + case rows.doneChan <- struct{}{}: + } + } + rows.close() } @@ -492,3 +505,38 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { rows, _ := c.Query(sql, args...) return (*Row)(rows) } + +func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + doneChan := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + case <-doneChan: + } + }() + + rows, err := c.Query(sql, args...) + + if err != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case doneChan <- struct{}{}: + return nil, err + } + } + + rows.ctx = ctx + rows.doneChan = doneChan + + return rows, nil +} diff --git a/query_test.go b/query_test.go index f08887b5..ca05fb42 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,7 @@ import ( "bytes" "database/sql" "fmt" + "golang.org/x/net/context" "strings" "testing" "time" @@ -1412,3 +1413,113 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) { ensureConnValid(t, conn) } + +func TestQueryContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryContext(ctx, "select 42::integer") + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + if rowCount != 1 { + t.Fatalf("Expected 1 row, got %d", rowCount) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryContextErrorWhileReceivingRows(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n") + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", rows.Err()) + } + + if rowCount != 9 { + t.Fatalf("Expected 9 rows, got %d", rowCount) + } + if result != 10 { + t.Fatalf("Expected result 10, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + rows, err := conn.QueryContext(ctx, "select pg_sleep(5)") + if err != nil { + t.Fatal(err) + } + + for rows.Next() { + t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") + } + + if rows.Err() != context.Canceled { + t.Fatal("Expected context.Canceled error, got %v", rows.Err()) + } + + checkConn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, checkConn) + + var found bool + err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err != pgx.ErrNoRows { + t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") + } + +}