Add QueryRowContext

context
Jack Christensen 2017-02-04 15:57:06 -06:00
parent 3e13b333d9
commit 24193ee322
2 changed files with 74 additions and 9 deletions

View File

@ -507,12 +507,6 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
}
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
doneChan := make(chan struct{})
go func() {
@ -529,9 +523,9 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
if err != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
return rows, ctx.Err()
case doneChan <- struct{}{}:
return nil, err
return rows, err
}
}
@ -540,3 +534,8 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
return rows, nil
}
func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
rows, _ := c.QueryContext(ctx, sql, args...)
return (*Row)(rows)
}

View File

@ -1521,5 +1521,71 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) {
if err != pgx.ErrNoRows {
t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't")
}
}
func TestQueryRowContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
var result int
err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result)
if err != nil {
t.Fatal(err)
}
if result != 42 {
t.Fatalf("Expected result 42, got %d", result)
}
ensureConnValid(t, conn)
}
func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
var result int
err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result)
if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" {
t.Fatalf("Expected division by zero error, but got %v", err)
}
ensureConnValid(t, conn)
}
func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(500 * time.Millisecond)
cancelFunc()
}()
var result []byte
err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result)
if err != context.Canceled {
t.Fatal("Expected context.Canceled error, got %v", err)
}
checkConn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, checkConn)
var found bool
err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found)
if err != pgx.ErrNoRows {
t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't")
}
}