diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 9661f99e..e22a0de8 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -586,39 +586,6 @@ func (pgConn *PgConn) flush() error { return err } -// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from -// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to -// call multiple times. -func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { - if ctx.Done() != nil { - deadlineWasSet := false - doneChan := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - conn.SetDeadline(deadlineTime) - deadlineWasSet = true - <-doneChan - // TODO - case <-doneChan: - } - }() - - finished := false - return func() { - if !finished { - doneChan <- struct{}{} - if deadlineWasSet { - conn.SetDeadline(time.Time{}) - } - finished = true - } - } - } - - return func() {} -} - // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -670,6 +637,54 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { return true } +// startOperation gets the connection ready for a new operation. It should be called at the beginning of every public +// method that communicates with the server. The returned cleanup function must be called if err == nil or a goroutine may +// be leaked. The cleanup function is safe to call multiple times. +func (pgConn *PgConn) startOperation(ctx context.Context) (cleanup func(), err error) { + cleanup = contextDoneToConnDeadline(ctx, pgConn.conn) + + err = pgConn.ensureReadyForQuery() + if err != nil { + cleanup() + return cleanup, preferContextOverNetTimeoutError(ctx, err) + } + + return cleanup, nil +} + +// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from +// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to +// call multiple times. +func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { + if ctx.Done() != nil { + deadlineWasSet := false + doneChan := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + conn.SetDeadline(deadlineTime) + deadlineWasSet = true + <-doneChan + // TODO + case <-doneChan: + } + }() + + finished := false + return func() { + if !finished { + doneChan <- struct{}{} + if deadlineWasSet { + conn.SetDeadline(time.Time{}) + } + finished = true + } + } + } + + return func() {} +} + // ensureReadyForQuery reads until pendingReadyForQueryCount == 0. func (pgConn *PgConn) ensureReadyForQuery() error { for pgConn.pendingReadyForQueryCount > 0 { @@ -706,13 +721,11 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.SendExec(sql) err = pgConn.flush() @@ -762,13 +775,11 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) err = pgConn.flush() @@ -788,13 +799,11 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) err = pgConn.flush() @@ -840,13 +849,11 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf)