Extract startOperation

pull/483/head
Jack Christensen 2019-01-02 14:56:24 -06:00
parent ad7a822723
commit 8af697bacf
1 changed files with 60 additions and 53 deletions

View File

@ -586,39 +586,6 @@ func (pgConn *PgConn) flush() error {
return err
}
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to
// call multiple times.
func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) {
if ctx.Done() != nil {
deadlineWasSet := false
doneChan := make(chan struct{})
go func() {
select {
case <-ctx.Done():
conn.SetDeadline(deadlineTime)
deadlineWasSet = true
<-doneChan
// TODO
case <-doneChan:
}
}()
finished := false
return func() {
if !finished {
doneChan <- struct{}{}
if deadlineWasSet {
conn.SetDeadline(time.Time{})
}
finished = true
}
}
}
return func() {}
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
@ -670,6 +637,54 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
return true
}
// startOperation gets the connection ready for a new operation. It should be called at the beginning of every public
// method that communicates with the server. The returned cleanup function must be called if err == nil or a goroutine may
// be leaked. The cleanup function is safe to call multiple times.
func (pgConn *PgConn) startOperation(ctx context.Context) (cleanup func(), err error) {
cleanup = contextDoneToConnDeadline(ctx, pgConn.conn)
err = pgConn.ensureReadyForQuery()
if err != nil {
cleanup()
return cleanup, preferContextOverNetTimeoutError(ctx, err)
}
return cleanup, nil
}
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to
// call multiple times.
func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) {
if ctx.Done() != nil {
deadlineWasSet := false
doneChan := make(chan struct{})
go func() {
select {
case <-ctx.Done():
conn.SetDeadline(deadlineTime)
deadlineWasSet = true
<-doneChan
// TODO
case <-doneChan:
}
}()
finished := false
return func() {
if !finished {
doneChan <- struct{}{}
if deadlineWasSet {
conn.SetDeadline(time.Time{})
}
finished = true
}
}
}
return func() {}
}
// ensureReadyForQuery reads until pendingReadyForQueryCount == 0.
func (pgConn *PgConn) ensureReadyForQuery() error {
for pgConn.pendingReadyForQueryCount > 0 {
@ -706,13 +721,11 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) {
return nil, errors.New("unflushed previous sends")
}
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanup()
err := pgConn.ensureReadyForQuery()
cleanup, err := pgConn.startOperation(ctx)
if err != nil {
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, err
}
defer cleanup()
pgConn.SendExec(sql)
err = pgConn.flush()
@ -762,13 +775,11 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
return nil, errors.New("unflushed previous sends")
}
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanup()
err := pgConn.ensureReadyForQuery()
cleanup, err := pgConn.startOperation(ctx)
if err != nil {
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, err
}
defer cleanup()
pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats)
err = pgConn.flush()
@ -788,13 +799,11 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
return nil, errors.New("unflushed previous sends")
}
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanup()
err := pgConn.ensureReadyForQuery()
cleanup, err := pgConn.startOperation(ctx)
if err != nil {
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, err
}
defer cleanup()
pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats)
err = pgConn.flush()
@ -840,13 +849,11 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
return nil, errors.New("unflushed previous sends")
}
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanup()
err := pgConn.ensureReadyForQuery()
cleanup, err := pgConn.startOperation(ctx)
if err != nil {
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, err
}
defer cleanup()
pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf)
pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf)