Add context to potentially blocking methods

pull/483/head
Jack Christensen 2018-12-31 17:17:11 -06:00
parent d229219039
commit 084423ae69
5 changed files with 196 additions and 51 deletions

View File

@ -494,7 +494,7 @@ func (c *Conn) Close() error {
}
c.status = connStatusClosed
err := c.pgConn.Close()
err := c.pgConn.Close(context.TODO())
c.causeOfDeath = errors.New("Closed")
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "closed connection", nil)

View File

@ -1,6 +1,7 @@
package pgconn
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
@ -20,7 +21,7 @@ import (
"github.com/pkg/errors"
)
type AfterConnectFunc func(pgconn *PgConn) error
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server.
type Config struct {
@ -466,8 +467,8 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error {
result, err := pgConn.Exec("show transaction_read_only")
func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "show transaction_read_only")
if err != nil {
return err
}

View File

@ -1,7 +1,9 @@
package pgconn_test
import (
"context"
"testing"
"time"
"github.com/jackc/pgx/pgconn"
@ -9,5 +11,7 @@ import (
)
func closeConn(t testing.TB, conn *pgconn.PgConn) {
require.Nil(t, conn.Close())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
require.Nil(t, conn.Close(ctx))
}

View File

@ -12,6 +12,7 @@ import (
"net"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
@ -19,6 +20,8 @@ import (
const batchBufferSize = 4096
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
@ -185,7 +188,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
}
case *pgproto3.ReadyForQuery:
if config.AfterConnectFunc != nil {
err := config.AfterConnectFunc(pgConn)
err := config.AfterConnectFunc(ctx, pgConn)
if err != nil {
pgConn.NetConn.Close()
return nil, fmt.Errorf("AfterConnectFunc: %v", err)
@ -296,24 +299,28 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
return msg, nil
}
// Close closes a connection. It is safe to call Close on a already closed
// connection.
func (pgConn *PgConn) Close() error {
// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by
// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
// underlying net.Conn.Close() will always be called regardless of any other errors.
func (pgConn *PgConn) Close(ctx context.Context) error {
if pgConn.closed {
return nil
}
pgConn.closed = true
defer pgConn.NetConn.Close()
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
defer cleanupContext()
_, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil {
pgConn.NetConn.Close()
return err
return preferContextOverNetTimeoutError(ctx, err)
}
_, err = pgConn.NetConn.Read(make([]byte, 1))
if err != io.EOF {
pgConn.NetConn.Close()
return err
return preferContextOverNetTimeoutError(ctx, err)
}
return pgConn.NetConn.Close()
@ -365,30 +372,38 @@ type PgResultReader struct {
err error
complete bool
preloadedRowValues bool
ctx context.Context
cleanupContext func()
}
// GetResult returns a PgResultReader for the next result. If all results are
// consumed it returns nil. If an error occurs it will be reported on the
// returned PgResultReader.
func (pgConn *PgConn) GetResult() *PgResultReader {
func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader {
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
for pgConn.pendingReadyForQueryCount > 0 {
msg, err := pgConn.ReceiveMessage()
if err != nil {
return &PgResultReader{pgConn: pgConn, err: err, complete: true}
cleanupContext()
return &PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true}
}
switch msg := msg.(type) {
case *pgproto3.RowDescription:
return &PgResultReader{pgConn: pgConn, fieldDescriptions: msg.Fields}
return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields}
case *pgproto3.DataRow:
return &PgResultReader{pgConn: pgConn, rowValues: msg.Values, preloadedRowValues: true}
return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true}
case *pgproto3.CommandComplete:
return &PgResultReader{pgConn: pgConn, commandTag: CommandTag(msg.CommandTag), complete: true}
cleanupContext()
return &PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true}
case *pgproto3.ErrorResponse:
return &PgResultReader{pgConn: pgConn, err: errorResponseToPgError(msg), complete: true}
cleanupContext()
return &PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true}
}
}
cleanupContext()
return nil
}
@ -406,6 +421,8 @@ func (rr *PgResultReader) NextRow() bool {
for {
msg, err := rr.pgConn.ReceiveMessage()
if err != nil {
rr.err = preferContextOverNetTimeoutError(rr.ctx, err)
rr.close()
return false
}
@ -416,13 +433,12 @@ func (rr *PgResultReader) NextRow() bool {
rr.rowValues = msg.Values
return true
case *pgproto3.CommandComplete:
rr.rowValues = nil
rr.commandTag = CommandTag(msg.CommandTag)
rr.complete = true
rr.close()
return false
case *pgproto3.ErrorResponse:
rr.err = errorResponseToPgError(msg)
rr.complete = true
rr.close()
return false
}
}
@ -441,46 +457,137 @@ func (rr *PgResultReader) Close() (CommandTag, error) {
if rr.complete {
return rr.commandTag, rr.err
}
rr.rowValues = nil
defer rr.close()
for {
msg, err := rr.pgConn.ReceiveMessage()
if err != nil {
rr.err = err
rr.complete = true
rr.err = preferContextOverNetTimeoutError(rr.ctx, err)
return rr.commandTag, rr.err
}
switch msg := msg.(type) {
case *pgproto3.CommandComplete:
rr.commandTag = CommandTag(msg.CommandTag)
rr.complete = true
return rr.commandTag, rr.err
case *pgproto3.ErrorResponse:
rr.err = errorResponseToPgError(msg)
rr.complete = true
return rr.commandTag, rr.err
}
}
}
func (rr *PgResultReader) close() {
if rr.complete {
return
}
rr.cleanupContext()
rr.rowValues = nil
rr.complete = true
}
// Flush sends the enqueued execs to the server.
func (pgConn *PgConn) Flush() error {
func (pgConn *PgConn) Flush(ctx context.Context) error {
defer pgConn.resetBatch()
cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn)
defer cleanup()
n, err := pgConn.NetConn.Write(pgConn.batchBuf)
if err != nil {
if n > 0 {
// TODO - kill connection - we sent a partial message
// Close connection because cannot recover from partially sent message.
pgConn.NetConn.Close()
pgConn.closed = true
}
return err
return preferContextOverNetTimeoutError(ctx, err)
}
pgConn.pendingReadyForQueryCount += pgConn.batchCount
return nil
}
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to
// call multiple times.
func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) {
if ctx.Done() != nil {
deadlineWasSet := false
doneChan := make(chan struct{})
go func() {
select {
case <-ctx.Done():
conn.SetDeadline(deadlineTime)
deadlineWasSet = true
<-doneChan
// TODO
case <-doneChan:
}
}()
finished := false
return func() {
if !finished {
doneChan <- struct{}{}
if deadlineWasSet {
conn.SetDeadline(time.Time{})
}
finished = true
}
}
}
return func() {}
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
}
return err
}
// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. If recovery is
// successful true is returned. If recovery is not successful the connection is closed and false it returned. Recovery
// should usually be possible except in the case of a partial write. This must be called after any context cancellation.
//
// As RecoverFromTimeout may need to read and ignored data already sent from the server, it potentially can block
// indefinitely. Use ctx to guard against this.
func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
if pgConn.closed {
return false
}
pgConn.resetBatch()
pgConn.NetConn.SetDeadline(time.Time{})
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
defer cleanupContext()
for pgConn.pendingReadyForQueryCount > 0 {
_, err := pgConn.ReceiveMessage()
if err != nil {
preferContextOverNetTimeoutError(ctx, err)
pgConn.Close(context.Background())
return false
}
}
result, err := pgConn.Exec(
context.Background(), // do not use ctx again because deadline goroutine already started above
"select 'RecoverFromTimeout'",
)
if err != nil || len(result.Rows) != 1 || len(result.Rows[0]) != 1 || string(result.Rows[0][0]) != "RecoverFromTimeout" {
pgConn.Close(context.Background())
return false
}
return true
}
func (pgConn *PgConn) resetBatch() {
pgConn.batchCount = 0
if len(pgConn.batchBuf) > batchBufferSize {
@ -500,7 +607,7 @@ type PgResult struct {
// transactions unless a transaction is already in progress or sql contains transaction control statements.
//
// Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec).
func (pgConn *PgConn) Exec(sql string) (*PgResult, error) {
func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) {
if pgConn.batchCount != 0 {
return nil, errors.New("unflushed previous sends")
}
@ -509,14 +616,14 @@ func (pgConn *PgConn) Exec(sql string) (*PgResult, error) {
}
pgConn.SendExec(sql)
err := pgConn.Flush()
err := pgConn.Flush(ctx)
if err != nil {
return nil, err
}
var result *PgResult
for resultReader := pgConn.GetResult(); resultReader != nil; resultReader = pgConn.GetResult() {
for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) {
rows := [][][]byte{}
for resultReader.NextRow() {
row := make([][]byte, len(resultReader.Values()))

View File

@ -6,6 +6,7 @@ import (
"net"
"os"
"testing"
"time"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgconn"
@ -36,8 +37,7 @@ func TestConnect(t *testing.T) {
conn, err := pgconn.Connect(context.Background(), connString)
require.Nil(t, err)
err = conn.Close()
require.Nil(t, err)
closeConn(t, conn)
})
}
}
@ -57,8 +57,7 @@ func TestConnectTLS(t *testing.T) {
t.Error("not a TLS connection")
}
err = conn.Close()
require.Nil(t, err)
closeConn(t, conn)
}
func TestConnectInvalidUser(t *testing.T) {
@ -74,7 +73,7 @@ func TestConnectInvalidUser(t *testing.T) {
conn, err := pgconn.ConnectConfig(context.Background(), config)
if err == nil {
conn.Close()
conn.Close(context.Background())
t.Fatal("expected err but got none")
}
pgErr, ok := err.(pgx.PgError)
@ -92,7 +91,7 @@ func TestConnectWithConnectionRefused(t *testing.T) {
// Presumably nothing is listening on 127.0.0.1:1
conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1")
if err == nil {
conn.Close()
conn.Close(context.Background())
t.Fatal("Expected error establishing connection to bad port")
}
}
@ -110,7 +109,7 @@ func TestConnectCustomDialer(t *testing.T) {
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.Nil(t, err)
require.True(t, dialed)
conn.Close()
closeConn(t, conn)
}
func TestConnectWithRuntimeParams(t *testing.T) {
@ -126,12 +125,12 @@ func TestConnectWithRuntimeParams(t *testing.T) {
require.Nil(t, err)
defer closeConn(t, conn)
result, err := conn.Exec("show application_name")
result, err := conn.Exec(context.Background(), "show application_name")
require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "pgxtest", string(result.Rows[0][0]))
result, err = conn.Exec("show search_path")
result, err = conn.Exec(context.Background(), "show search_path")
require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "myschema", string(result.Rows[0][0]))
@ -179,7 +178,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) {
}
acceptConnCount := 0
config.AfterConnectFunc = func(conn *pgconn.PgConn) error {
config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error {
acceptConnCount += 1
if acceptConnCount < 2 {
return errors.New("reject first conn")
@ -214,38 +213,38 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
conn, err := pgconn.ConnectConfig(context.Background(), config)
if !assert.NotNil(t, err) {
conn.Close()
conn.Close(context.Background())
}
}
func TestExec(t *testing.T) {
func TestConnExec(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
result, err := pgConn.Exec("select current_database()")
result, err := pgConn.Exec(context.Background(), "select current_database()")
require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0]))
}
func TestExecMultipleQueries(t *testing.T) {
func TestConnExecMultipleQueries(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
result, err := pgConn.Exec("select current_database(); select 1")
result, err := pgConn.Exec(context.Background(), "select current_database(); select 1")
require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "1", string(result.Rows[0][0]))
}
func TestExecMultipleQueriesError(t *testing.T) {
func TestConnExecMultipleQueriesError(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
result, err := pgConn.Exec("select 1; select 1/0; select 1")
result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1")
require.NotNil(t, err)
require.Nil(t, result)
if pgErr, ok := err.(pgconn.PgError); ok {
@ -254,3 +253,37 @@ func TestExecMultipleQueriesError(t *testing.T) {
t.Errorf("unexpected error: %v", err)
}
}
func TestConnExecContextCanceled(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)")
require.Nil(t, result)
assert.Equal(t, context.DeadlineExceeded, err)
}
func TestConnRecoverFromTimeout(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)")
cancel()
require.Nil(t, result)
assert.Equal(t, context.DeadlineExceeded, err)
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
if assert.True(t, pgConn.RecoverFromTimeout(ctx)) {
result, err := pgConn.Exec(ctx, "select 1")
require.Nil(t, err)
assert.Len(t, result.Rows, 1)
assert.Len(t, result.Rows[0], 1)
assert.Equal(t, "1", string(result.Rows[0][0]))
}
cancel()
}