Add QueryContext

context
Jack Christensen 2017-02-04 15:40:58 -06:00
parent 78adfb13d7
commit 3e13b333d9
2 changed files with 159 additions and 0 deletions

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"golang.org/x/net/context"
"time"
)
@ -49,6 +50,9 @@ type Rows struct {
afterClose func(*Rows)
unlockConn bool
closed bool
ctx context.Context
doneChan chan struct{}
}
func (rows *Rows) FieldDescriptions() []FieldDescription {
@ -120,6 +124,15 @@ func (rows *Rows) Close() {
return
}
rows.readUntilReadyForQuery()
if rows.ctx != nil {
select {
case <-rows.ctx.Done():
rows.err = rows.ctx.Err()
case rows.doneChan <- struct{}{}:
}
}
rows.close()
}
@ -492,3 +505,38 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
rows, _ := c.Query(sql, args...)
return (*Row)(rows)
}
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
doneChan := make(chan struct{})
go func() {
select {
case <-ctx.Done():
c.cancelQuery()
c.Close()
case <-doneChan:
}
}()
rows, err := c.Query(sql, args...)
if err != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
case doneChan <- struct{}{}:
return nil, err
}
}
rows.ctx = ctx
rows.doneChan = doneChan
return rows, nil
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"database/sql"
"fmt"
"golang.org/x/net/context"
"strings"
"testing"
"time"
@ -1412,3 +1413,113 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryContext(ctx, "select 42::integer")
if err != nil {
t.Fatal(err)
}
var result, rowCount int
for rows.Next() {
err = rows.Scan(&result)
if err != nil {
t.Fatal(err)
}
rowCount++
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
if rowCount != 1 {
t.Fatalf("Expected 1 row, got %d", rowCount)
}
if result != 42 {
t.Fatalf("Expected result 42, got %d", result)
}
ensureConnValid(t, conn)
}
func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n")
if err != nil {
t.Fatal(err)
}
var result, rowCount int
for rows.Next() {
err = rows.Scan(&result)
if err != nil {
t.Fatal(err)
}
rowCount++
}
if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" {
t.Fatalf("Expected division by zero error, but got %v", rows.Err())
}
if rowCount != 9 {
t.Fatalf("Expected 9 rows, got %d", rowCount)
}
if result != 10 {
t.Fatalf("Expected result 10, got %d", result)
}
ensureConnValid(t, conn)
}
func TestQueryContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(500 * time.Millisecond)
cancelFunc()
}()
rows, err := conn.QueryContext(ctx, "select pg_sleep(5)")
if err != nil {
t.Fatal(err)
}
for rows.Next() {
t.Fatal("No rows should ever be ready -- context cancel apparently did not happen")
}
if rows.Err() != context.Canceled {
t.Fatal("Expected context.Canceled error, got %v", rows.Err())
}
checkConn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, checkConn)
var found bool
err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found)
if err != pgx.ErrNoRows {
t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't")
}
}