diff --git a/conn.go b/conn.go index 210f57e3..da854592 100644 --- a/conn.go +++ b/conn.go @@ -494,7 +494,7 @@ func (c *Conn) Close() error { } c.status = connStatusClosed - err := c.pgConn.Close() + err := c.pgConn.Close(context.TODO()) c.causeOfDeath = errors.New("Closed") if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "closed connection", nil) diff --git a/pgconn/config.go b/pgconn/config.go index 4d8bee4c..d8872f66 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -1,6 +1,7 @@ package pgconn import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -20,7 +21,7 @@ import ( "github.com/pkg/errors" ) -type AfterConnectFunc func(pgconn *PgConn) error +type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -466,8 +467,8 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. -func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { - result, err := pgConn.Exec("show transaction_read_only") +func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { + result, err := pgConn.Exec(ctx, "show transaction_read_only") if err != nil { return err } diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go index e6a7c73b..8e7ca92f 100644 --- a/pgconn/helper_test.go +++ b/pgconn/helper_test.go @@ -1,7 +1,9 @@ package pgconn_test import ( + "context" "testing" + "time" "github.com/jackc/pgx/pgconn" @@ -9,5 +11,7 @@ import ( ) func closeConn(t testing.TB, conn *pgconn.PgConn) { - require.Nil(t, conn.Close()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.Nil(t, conn.Close(ctx)) } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index c243d2f6..311b06a3 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -12,6 +12,7 @@ import ( "net" "strconv" "strings" + "time" "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" @@ -19,6 +20,8 @@ import ( const batchBufferSize = 4096 +var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. @@ -185,7 +188,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } case *pgproto3.ReadyForQuery: if config.AfterConnectFunc != nil { - err := config.AfterConnectFunc(pgConn) + err := config.AfterConnectFunc(ctx, pgConn) if err != nil { pgConn.NetConn.Close() return nil, fmt.Errorf("AfterConnectFunc: %v", err) @@ -296,24 +299,28 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// Close closes a connection. It is safe to call Close on a already closed -// connection. -func (pgConn *PgConn) Close() error { +// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by +// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The +// underlying net.Conn.Close() will always be called regardless of any other errors. +func (pgConn *PgConn) Close(ctx context.Context) error { if pgConn.closed { return nil } pgConn.closed = true + defer pgConn.NetConn.Close() + + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + defer cleanupContext() + _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - pgConn.NetConn.Close() - return err + return preferContextOverNetTimeoutError(ctx, err) } _, err = pgConn.NetConn.Read(make([]byte, 1)) if err != io.EOF { - pgConn.NetConn.Close() - return err + return preferContextOverNetTimeoutError(ctx, err) } return pgConn.NetConn.Close() @@ -365,30 +372,38 @@ type PgResultReader struct { err error complete bool preloadedRowValues bool + ctx context.Context + cleanupContext func() } // GetResult returns a PgResultReader for the next result. If all results are // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. -func (pgConn *PgConn) GetResult() *PgResultReader { +func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() if err != nil { - return &PgResultReader{pgConn: pgConn, err: err, complete: true} + cleanupContext() + return &PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} } switch msg := msg.(type) { case *pgproto3.RowDescription: - return &PgResultReader{pgConn: pgConn, fieldDescriptions: msg.Fields} + return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} case *pgproto3.DataRow: - return &PgResultReader{pgConn: pgConn, rowValues: msg.Values, preloadedRowValues: true} + return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} case *pgproto3.CommandComplete: - return &PgResultReader{pgConn: pgConn, commandTag: CommandTag(msg.CommandTag), complete: true} + cleanupContext() + return &PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} case *pgproto3.ErrorResponse: - return &PgResultReader{pgConn: pgConn, err: errorResponseToPgError(msg), complete: true} + cleanupContext() + return &PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} } } + cleanupContext() return nil } @@ -406,6 +421,8 @@ func (rr *PgResultReader) NextRow() bool { for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { + rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.close() return false } @@ -416,13 +433,12 @@ func (rr *PgResultReader) NextRow() bool { rr.rowValues = msg.Values return true case *pgproto3.CommandComplete: - rr.rowValues = nil rr.commandTag = CommandTag(msg.CommandTag) - rr.complete = true + rr.close() return false case *pgproto3.ErrorResponse: rr.err = errorResponseToPgError(msg) - rr.complete = true + rr.close() return false } } @@ -441,46 +457,137 @@ func (rr *PgResultReader) Close() (CommandTag, error) { if rr.complete { return rr.commandTag, rr.err } - - rr.rowValues = nil + defer rr.close() for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { - rr.err = err - rr.complete = true + rr.err = preferContextOverNetTimeoutError(rr.ctx, err) return rr.commandTag, rr.err } switch msg := msg.(type) { case *pgproto3.CommandComplete: rr.commandTag = CommandTag(msg.CommandTag) - rr.complete = true return rr.commandTag, rr.err case *pgproto3.ErrorResponse: rr.err = errorResponseToPgError(msg) - rr.complete = true return rr.commandTag, rr.err } } } +func (rr *PgResultReader) close() { + if rr.complete { + return + } + + rr.cleanupContext() + rr.rowValues = nil + rr.complete = true +} + // Flush sends the enqueued execs to the server. -func (pgConn *PgConn) Flush() error { +func (pgConn *PgConn) Flush(ctx context.Context) error { defer pgConn.resetBatch() + cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn) + defer cleanup() + n, err := pgConn.NetConn.Write(pgConn.batchBuf) if err != nil { if n > 0 { - // TODO - kill connection - we sent a partial message + // Close connection because cannot recover from partially sent message. + pgConn.NetConn.Close() + pgConn.closed = true } - return err + return preferContextOverNetTimeoutError(ctx, err) } pgConn.pendingReadyForQueryCount += pgConn.batchCount return nil } +// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from +// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to +// call multiple times. +func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { + if ctx.Done() != nil { + deadlineWasSet := false + doneChan := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + conn.SetDeadline(deadlineTime) + deadlineWasSet = true + <-doneChan + // TODO + case <-doneChan: + } + }() + + finished := false + return func() { + if !finished { + doneChan <- struct{}{} + if deadlineWasSet { + conn.SetDeadline(time.Time{}) + } + finished = true + } + } + } + + return func() {} +} + +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return ctx.Err() + } + return err +} + +// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. If recovery is +// successful true is returned. If recovery is not successful the connection is closed and false it returned. Recovery +// should usually be possible except in the case of a partial write. This must be called after any context cancellation. +// +// As RecoverFromTimeout may need to read and ignored data already sent from the server, it potentially can block +// indefinitely. Use ctx to guard against this. +func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { + if pgConn.closed { + return false + } + pgConn.resetBatch() + + pgConn.NetConn.SetDeadline(time.Time{}) + + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + defer cleanupContext() + + for pgConn.pendingReadyForQueryCount > 0 { + _, err := pgConn.ReceiveMessage() + if err != nil { + preferContextOverNetTimeoutError(ctx, err) + pgConn.Close(context.Background()) + return false + } + } + + result, err := pgConn.Exec( + context.Background(), // do not use ctx again because deadline goroutine already started above + "select 'RecoverFromTimeout'", + ) + if err != nil || len(result.Rows) != 1 || len(result.Rows[0]) != 1 || string(result.Rows[0][0]) != "RecoverFromTimeout" { + pgConn.Close(context.Background()) + return false + } + + return true +} + func (pgConn *PgConn) resetBatch() { pgConn.batchCount = 0 if len(pgConn.batchBuf) > batchBufferSize { @@ -500,7 +607,7 @@ type PgResult struct { // transactions unless a transaction is already in progress or sql contains transaction control statements. // // Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { +func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } @@ -509,14 +616,14 @@ func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { } pgConn.SendExec(sql) - err := pgConn.Flush() + err := pgConn.Flush(ctx) if err != nil { return nil, err } var result *PgResult - for resultReader := pgConn.GetResult(); resultReader != nil; resultReader = pgConn.GetResult() { + for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) { rows := [][][]byte{} for resultReader.NextRow() { row := make([][]byte, len(resultReader.Values())) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index f3f22d42..98fd198e 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -6,6 +6,7 @@ import ( "net" "os" "testing" + "time" "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" @@ -36,8 +37,7 @@ func TestConnect(t *testing.T) { conn, err := pgconn.Connect(context.Background(), connString) require.Nil(t, err) - err = conn.Close() - require.Nil(t, err) + closeConn(t, conn) }) } } @@ -57,8 +57,7 @@ func TestConnectTLS(t *testing.T) { t.Error("not a TLS connection") } - err = conn.Close() - require.Nil(t, err) + closeConn(t, conn) } func TestConnectInvalidUser(t *testing.T) { @@ -74,7 +73,7 @@ func TestConnectInvalidUser(t *testing.T) { conn, err := pgconn.ConnectConfig(context.Background(), config) if err == nil { - conn.Close() + conn.Close(context.Background()) t.Fatal("expected err but got none") } pgErr, ok := err.(pgx.PgError) @@ -92,7 +91,7 @@ func TestConnectWithConnectionRefused(t *testing.T) { // Presumably nothing is listening on 127.0.0.1:1 conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") if err == nil { - conn.Close() + conn.Close(context.Background()) t.Fatal("Expected error establishing connection to bad port") } } @@ -110,7 +109,7 @@ func TestConnectCustomDialer(t *testing.T) { conn, err := pgconn.ConnectConfig(context.Background(), config) require.Nil(t, err) require.True(t, dialed) - conn.Close() + closeConn(t, conn) } func TestConnectWithRuntimeParams(t *testing.T) { @@ -126,12 +125,12 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - result, err := conn.Exec("show application_name") + result, err := conn.Exec(context.Background(), "show application_name") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result, err = conn.Exec("show search_path") + result, err = conn.Exec(context.Background(), "show search_path") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "myschema", string(result.Rows[0][0])) @@ -179,7 +178,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { } acceptConnCount := 0 - config.AfterConnectFunc = func(conn *pgconn.PgConn) error { + config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { acceptConnCount += 1 if acceptConnCount < 2 { return errors.New("reject first conn") @@ -214,38 +213,38 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { conn, err := pgconn.ConnectConfig(context.Background(), config) if !assert.NotNil(t, err) { - conn.Close() + conn.Close(context.Background()) } } -func TestExec(t *testing.T) { +func TestConnExec(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec("select current_database()") + result, err := pgConn.Exec(context.Background(), "select current_database()") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) } -func TestExecMultipleQueries(t *testing.T) { +func TestConnExecMultipleQueries(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec("select current_database(); select 1") + result, err := pgConn.Exec(context.Background(), "select current_database(); select 1") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) } -func TestExecMultipleQueriesError(t *testing.T) { +func TestConnExecMultipleQueriesError(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec("select 1; select 1/0; select 1") + result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1") require.NotNil(t, err) require.Nil(t, result) if pgErr, ok := err.(pgconn.PgError); ok { @@ -254,3 +253,37 @@ func TestExecMultipleQueriesError(t *testing.T) { t.Errorf("unexpected error: %v", err) } } + +func TestConnExecContextCanceled(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") + require.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) +} + +func TestConnRecoverFromTimeout(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") + cancel() + require.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + if assert.True(t, pgConn.RecoverFromTimeout(ctx)) { + result, err := pgConn.Exec(ctx, "select 1") + require.Nil(t, err) + assert.Len(t, result.Rows, 1) + assert.Len(t, result.Rows[0], 1) + assert.Equal(t, "1", string(result.Rows[0][0])) + } + cancel() +}