From 2bf5a614018d35b5db4259b5541497ab35175882 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Tue, 4 Jul 2023 12:53:41 +0200 Subject: [PATCH] fix: Do not use infinite timers --- pgconn/pgconn.go | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 1fb65320..46920cb4 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -263,7 +263,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba } func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, - ignoreNotPreferredErr bool) (*PgConn, error) { + ignoreNotPreferredErr bool, +) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config pgConn.cleanupDone = make(chan struct{}) @@ -298,6 +299,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.status = connStatusConnecting pgConn.bgReader = bgreader.New(pgConn.conn) pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.slowWriteTimer.Stop() pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ @@ -476,7 +478,8 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa err = &pgconnError{ msg: "receive message failed", err: normalizeTimeoutError(ctx, err), - safeToRetry: true} + safeToRetry: true, + } } return msg, err } @@ -1336,7 +1339,6 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := mrr.pgConn.receiveMessage() - if err != nil { mrr.pgConn.contextWatcher.Unwatch() mrr.err = normalizeTimeoutError(mrr.ctx, err) @@ -1647,8 +1649,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) - pgConn.exitPotentialWriteReadDeadlock() if err != nil { multiResult.closed = true multiResult.err = err @@ -1719,20 +1721,22 @@ func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { // // In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is // ineffective. - pgConn.slowWriteTimer.Reset(15 * time.Millisecond) + if pgConn.slowWriteTimer.Reset(15 * time.Millisecond) { + panic("BUG: slow write timer already active") + } } // exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { - if !pgConn.slowWriteTimer.Reset(time.Duration(math.MaxInt64)) { - pgConn.slowWriteTimer.Stop() - } + // The state of the timer is not relevant upon exiting the potential slow write. It may both + // fire (due to a slow write), or not fire (due to a fast write). + _ = pgConn.slowWriteTimer.Stop() } func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() err := pgConn.frontend.Flush() - pgConn.exitPotentialWriteReadDeadlock() return err } @@ -1796,6 +1800,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { pgConn.contextWatcher = newContextWatcher(pgConn.conn) pgConn.bgReader = bgreader.New(pgConn.conn) pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.slowWriteTimer.Stop() return pgConn, nil } @@ -1997,7 +2002,6 @@ func (p *Pipeline) GetResults() (results any, err error) { } } - } func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {