Add context to potentially blocking methods

This commit is contained in:
Jack Christensen 2018-12-31 17:17:11 -06:00
parent d229219039
commit 084423ae69
5 changed files with 196 additions and 51 deletions

View File

@ -494,7 +494,7 @@ func (c *Conn) Close() error {
} }
c.status = connStatusClosed c.status = connStatusClosed
err := c.pgConn.Close() err := c.pgConn.Close(context.TODO())
c.causeOfDeath = errors.New("Closed") c.causeOfDeath = errors.New("Closed")
if c.shouldLog(LogLevelInfo) { if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "closed connection", nil) c.log(LogLevelInfo, "closed connection", nil)

View File

@ -1,6 +1,7 @@
package pgconn package pgconn
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
@ -20,7 +21,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type AfterConnectFunc func(pgconn *PgConn) error type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server. // Config is the settings used to establish a connection to a PostgreSQL server.
type Config struct { type Config struct {
@ -466,8 +467,8 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible
// target_session_attrs=read-write. // target_session_attrs=read-write.
func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec("show transaction_read_only") result, err := pgConn.Exec(ctx, "show transaction_read_only")
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,7 +1,9 @@
package pgconn_test package pgconn_test
import ( import (
"context"
"testing" "testing"
"time"
"github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgconn"
@ -9,5 +11,7 @@ import (
) )
func closeConn(t testing.TB, conn *pgconn.PgConn) { func closeConn(t testing.TB, conn *pgconn.PgConn) {
require.Nil(t, conn.Close()) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
require.Nil(t, conn.Close(ctx))
} }

View File

@ -12,6 +12,7 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgproto3"
@ -19,6 +20,8 @@ import (
const batchBufferSize = 4096 const batchBufferSize = 4096
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
// PgError represents an error reported by the PostgreSQL server. See // PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description. // detailed field description.
@ -185,7 +188,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
if config.AfterConnectFunc != nil { if config.AfterConnectFunc != nil {
err := config.AfterConnectFunc(pgConn) err := config.AfterConnectFunc(ctx, pgConn)
if err != nil { if err != nil {
pgConn.NetConn.Close() pgConn.NetConn.Close()
return nil, fmt.Errorf("AfterConnectFunc: %v", err) return nil, fmt.Errorf("AfterConnectFunc: %v", err)
@ -296,24 +299,28 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
return msg, nil return msg, nil
} }
// Close closes a connection. It is safe to call Close on a already closed // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by
// connection. // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
func (pgConn *PgConn) Close() error { // underlying net.Conn.Close() will always be called regardless of any other errors.
func (pgConn *PgConn) Close(ctx context.Context) error {
if pgConn.closed { if pgConn.closed {
return nil return nil
} }
pgConn.closed = true pgConn.closed = true
defer pgConn.NetConn.Close()
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
defer cleanupContext()
_, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil { if err != nil {
pgConn.NetConn.Close() return preferContextOverNetTimeoutError(ctx, err)
return err
} }
_, err = pgConn.NetConn.Read(make([]byte, 1)) _, err = pgConn.NetConn.Read(make([]byte, 1))
if err != io.EOF { if err != io.EOF {
pgConn.NetConn.Close() return preferContextOverNetTimeoutError(ctx, err)
return err
} }
return pgConn.NetConn.Close() return pgConn.NetConn.Close()
@ -365,30 +372,38 @@ type PgResultReader struct {
err error err error
complete bool complete bool
preloadedRowValues bool preloadedRowValues bool
ctx context.Context
cleanupContext func()
} }
// GetResult returns a PgResultReader for the next result. If all results are // GetResult returns a PgResultReader for the next result. If all results are
// consumed it returns nil. If an error occurs it will be reported on the // consumed it returns nil. If an error occurs it will be reported on the
// returned PgResultReader. // returned PgResultReader.
func (pgConn *PgConn) GetResult() *PgResultReader { func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader {
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
for pgConn.pendingReadyForQueryCount > 0 { for pgConn.pendingReadyForQueryCount > 0 {
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
return &PgResultReader{pgConn: pgConn, err: err, complete: true} cleanupContext()
return &PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true}
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.RowDescription: case *pgproto3.RowDescription:
return &PgResultReader{pgConn: pgConn, fieldDescriptions: msg.Fields} return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields}
case *pgproto3.DataRow: case *pgproto3.DataRow:
return &PgResultReader{pgConn: pgConn, rowValues: msg.Values, preloadedRowValues: true} return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true}
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
return &PgResultReader{pgConn: pgConn, commandTag: CommandTag(msg.CommandTag), complete: true} cleanupContext()
return &PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true}
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
return &PgResultReader{pgConn: pgConn, err: errorResponseToPgError(msg), complete: true} cleanupContext()
return &PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true}
} }
} }
cleanupContext()
return nil return nil
} }
@ -406,6 +421,8 @@ func (rr *PgResultReader) NextRow() bool {
for { for {
msg, err := rr.pgConn.ReceiveMessage() msg, err := rr.pgConn.ReceiveMessage()
if err != nil { if err != nil {
rr.err = preferContextOverNetTimeoutError(rr.ctx, err)
rr.close()
return false return false
} }
@ -416,13 +433,12 @@ func (rr *PgResultReader) NextRow() bool {
rr.rowValues = msg.Values rr.rowValues = msg.Values
return true return true
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
rr.rowValues = nil
rr.commandTag = CommandTag(msg.CommandTag) rr.commandTag = CommandTag(msg.CommandTag)
rr.complete = true rr.close()
return false return false
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
rr.err = errorResponseToPgError(msg) rr.err = errorResponseToPgError(msg)
rr.complete = true rr.close()
return false return false
} }
} }
@ -441,46 +457,137 @@ func (rr *PgResultReader) Close() (CommandTag, error) {
if rr.complete { if rr.complete {
return rr.commandTag, rr.err return rr.commandTag, rr.err
} }
defer rr.close()
rr.rowValues = nil
for { for {
msg, err := rr.pgConn.ReceiveMessage() msg, err := rr.pgConn.ReceiveMessage()
if err != nil { if err != nil {
rr.err = err rr.err = preferContextOverNetTimeoutError(rr.ctx, err)
rr.complete = true
return rr.commandTag, rr.err return rr.commandTag, rr.err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
rr.commandTag = CommandTag(msg.CommandTag) rr.commandTag = CommandTag(msg.CommandTag)
rr.complete = true
return rr.commandTag, rr.err return rr.commandTag, rr.err
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
rr.err = errorResponseToPgError(msg) rr.err = errorResponseToPgError(msg)
rr.complete = true
return rr.commandTag, rr.err return rr.commandTag, rr.err
} }
} }
} }
func (rr *PgResultReader) close() {
if rr.complete {
return
}
rr.cleanupContext()
rr.rowValues = nil
rr.complete = true
}
// Flush sends the enqueued execs to the server. // Flush sends the enqueued execs to the server.
func (pgConn *PgConn) Flush() error { func (pgConn *PgConn) Flush(ctx context.Context) error {
defer pgConn.resetBatch() defer pgConn.resetBatch()
cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn)
defer cleanup()
n, err := pgConn.NetConn.Write(pgConn.batchBuf) n, err := pgConn.NetConn.Write(pgConn.batchBuf)
if err != nil { if err != nil {
if n > 0 { if n > 0 {
// TODO - kill connection - we sent a partial message // Close connection because cannot recover from partially sent message.
pgConn.NetConn.Close()
pgConn.closed = true
} }
return err return preferContextOverNetTimeoutError(ctx, err)
} }
pgConn.pendingReadyForQueryCount += pgConn.batchCount pgConn.pendingReadyForQueryCount += pgConn.batchCount
return nil return nil
} }
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to
// call multiple times.
func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) {
if ctx.Done() != nil {
deadlineWasSet := false
doneChan := make(chan struct{})
go func() {
select {
case <-ctx.Done():
conn.SetDeadline(deadlineTime)
deadlineWasSet = true
<-doneChan
// TODO
case <-doneChan:
}
}()
finished := false
return func() {
if !finished {
doneChan <- struct{}{}
if deadlineWasSet {
conn.SetDeadline(time.Time{})
}
finished = true
}
}
}
return func() {}
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
}
return err
}
// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. If recovery is
// successful true is returned. If recovery is not successful the connection is closed and false it returned. Recovery
// should usually be possible except in the case of a partial write. This must be called after any context cancellation.
//
// As RecoverFromTimeout may need to read and ignored data already sent from the server, it potentially can block
// indefinitely. Use ctx to guard against this.
func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
if pgConn.closed {
return false
}
pgConn.resetBatch()
pgConn.NetConn.SetDeadline(time.Time{})
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
defer cleanupContext()
for pgConn.pendingReadyForQueryCount > 0 {
_, err := pgConn.ReceiveMessage()
if err != nil {
preferContextOverNetTimeoutError(ctx, err)
pgConn.Close(context.Background())
return false
}
}
result, err := pgConn.Exec(
context.Background(), // do not use ctx again because deadline goroutine already started above
"select 'RecoverFromTimeout'",
)
if err != nil || len(result.Rows) != 1 || len(result.Rows[0]) != 1 || string(result.Rows[0][0]) != "RecoverFromTimeout" {
pgConn.Close(context.Background())
return false
}
return true
}
func (pgConn *PgConn) resetBatch() { func (pgConn *PgConn) resetBatch() {
pgConn.batchCount = 0 pgConn.batchCount = 0
if len(pgConn.batchBuf) > batchBufferSize { if len(pgConn.batchBuf) > batchBufferSize {
@ -500,7 +607,7 @@ type PgResult struct {
// transactions unless a transaction is already in progress or sql contains transaction control statements. // transactions unless a transaction is already in progress or sql contains transaction control statements.
// //
// Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). // Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec).
func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) {
if pgConn.batchCount != 0 { if pgConn.batchCount != 0 {
return nil, errors.New("unflushed previous sends") return nil, errors.New("unflushed previous sends")
} }
@ -509,14 +616,14 @@ func (pgConn *PgConn) Exec(sql string) (*PgResult, error) {
} }
pgConn.SendExec(sql) pgConn.SendExec(sql)
err := pgConn.Flush() err := pgConn.Flush(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var result *PgResult var result *PgResult
for resultReader := pgConn.GetResult(); resultReader != nil; resultReader = pgConn.GetResult() { for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) {
rows := [][][]byte{} rows := [][][]byte{}
for resultReader.NextRow() { for resultReader.NextRow() {
row := make([][]byte, len(resultReader.Values())) row := make([][]byte, len(resultReader.Values()))

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"os" "os"
"testing" "testing"
"time"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgconn"
@ -36,8 +37,7 @@ func TestConnect(t *testing.T) {
conn, err := pgconn.Connect(context.Background(), connString) conn, err := pgconn.Connect(context.Background(), connString)
require.Nil(t, err) require.Nil(t, err)
err = conn.Close() closeConn(t, conn)
require.Nil(t, err)
}) })
} }
} }
@ -57,8 +57,7 @@ func TestConnectTLS(t *testing.T) {
t.Error("not a TLS connection") t.Error("not a TLS connection")
} }
err = conn.Close() closeConn(t, conn)
require.Nil(t, err)
} }
func TestConnectInvalidUser(t *testing.T) { func TestConnectInvalidUser(t *testing.T) {
@ -74,7 +73,7 @@ func TestConnectInvalidUser(t *testing.T) {
conn, err := pgconn.ConnectConfig(context.Background(), config) conn, err := pgconn.ConnectConfig(context.Background(), config)
if err == nil { if err == nil {
conn.Close() conn.Close(context.Background())
t.Fatal("expected err but got none") t.Fatal("expected err but got none")
} }
pgErr, ok := err.(pgx.PgError) pgErr, ok := err.(pgx.PgError)
@ -92,7 +91,7 @@ func TestConnectWithConnectionRefused(t *testing.T) {
// Presumably nothing is listening on 127.0.0.1:1 // Presumably nothing is listening on 127.0.0.1:1
conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1")
if err == nil { if err == nil {
conn.Close() conn.Close(context.Background())
t.Fatal("Expected error establishing connection to bad port") t.Fatal("Expected error establishing connection to bad port")
} }
} }
@ -110,7 +109,7 @@ func TestConnectCustomDialer(t *testing.T) {
conn, err := pgconn.ConnectConfig(context.Background(), config) conn, err := pgconn.ConnectConfig(context.Background(), config)
require.Nil(t, err) require.Nil(t, err)
require.True(t, dialed) require.True(t, dialed)
conn.Close() closeConn(t, conn)
} }
func TestConnectWithRuntimeParams(t *testing.T) { func TestConnectWithRuntimeParams(t *testing.T) {
@ -126,12 +125,12 @@ func TestConnectWithRuntimeParams(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
defer closeConn(t, conn) defer closeConn(t, conn)
result, err := conn.Exec("show application_name") result, err := conn.Exec(context.Background(), "show application_name")
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "pgxtest", string(result.Rows[0][0])) assert.Equal(t, "pgxtest", string(result.Rows[0][0]))
result, err = conn.Exec("show search_path") result, err = conn.Exec(context.Background(), "show search_path")
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "myschema", string(result.Rows[0][0])) assert.Equal(t, "myschema", string(result.Rows[0][0]))
@ -179,7 +178,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) {
} }
acceptConnCount := 0 acceptConnCount := 0
config.AfterConnectFunc = func(conn *pgconn.PgConn) error { config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error {
acceptConnCount += 1 acceptConnCount += 1
if acceptConnCount < 2 { if acceptConnCount < 2 {
return errors.New("reject first conn") return errors.New("reject first conn")
@ -214,38 +213,38 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
conn, err := pgconn.ConnectConfig(context.Background(), config) conn, err := pgconn.ConnectConfig(context.Background(), config)
if !assert.NotNil(t, err) { if !assert.NotNil(t, err) {
conn.Close() conn.Close(context.Background())
} }
} }
func TestExec(t *testing.T) { func TestConnExec(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err) require.Nil(t, err)
defer closeConn(t, pgConn) defer closeConn(t, pgConn)
result, err := pgConn.Exec("select current_database()") result, err := pgConn.Exec(context.Background(), "select current_database()")
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0]))
} }
func TestExecMultipleQueries(t *testing.T) { func TestConnExecMultipleQueries(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err) require.Nil(t, err)
defer closeConn(t, pgConn) defer closeConn(t, pgConn)
result, err := pgConn.Exec("select current_database(); select 1") result, err := pgConn.Exec(context.Background(), "select current_database(); select 1")
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "1", string(result.Rows[0][0])) assert.Equal(t, "1", string(result.Rows[0][0]))
} }
func TestExecMultipleQueriesError(t *testing.T) { func TestConnExecMultipleQueriesError(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err) require.Nil(t, err)
defer closeConn(t, pgConn) defer closeConn(t, pgConn)
result, err := pgConn.Exec("select 1; select 1/0; select 1") result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1")
require.NotNil(t, err) require.NotNil(t, err)
require.Nil(t, result) require.Nil(t, result)
if pgErr, ok := err.(pgconn.PgError); ok { if pgErr, ok := err.(pgconn.PgError); ok {
@ -254,3 +253,37 @@ func TestExecMultipleQueriesError(t *testing.T) {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
} }
func TestConnExecContextCanceled(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)")
require.Nil(t, result)
assert.Equal(t, context.DeadlineExceeded, err)
}
func TestConnRecoverFromTimeout(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)")
cancel()
require.Nil(t, result)
assert.Equal(t, context.DeadlineExceeded, err)
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
if assert.True(t, pgConn.RecoverFromTimeout(ctx)) {
result, err := pgConn.Exec(ctx, "select 1")
require.Nil(t, err)
assert.Len(t, result.Rows, 1)
assert.Len(t, result.Rows[0], 1)
assert.Equal(t, "1", string(result.Rows[0][0]))
}
cancel()
}