Resolve race on conn.Close/die

Use sync.Mutex instead of atomic operations for clarity.
batch-wip
Jack Christensen 2017-05-21 19:35:37 -05:00
parent 8a7165dd98
commit 749fdfe7d5
1 changed files with 47 additions and 21 deletions

68
conn.go
View File

@ -17,6 +17,7 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@ -102,7 +103,8 @@ type Conn struct {
poolResetCount int
preallocatedRows []Rows
status int32 // One of connStatus* constants
mux sync.Mutex
status byte // One of connStatus* constants
causeOfDeath error
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
@ -267,20 +269,25 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
defer func() {
if c != nil && err != nil {
c.conn.Close()
atomic.StoreInt32(&c.status, connStatusClosed)
c.mux.Lock()
c.status = connStatusClosed
c.mux.Unlock()
}
}()
c.RuntimeParams = make(map[string]string)
c.preparedStatements = make(map[string]*PreparedStatement)
c.channels = make(map[string]struct{})
atomic.StoreInt32(&c.status, connStatusIdle)
c.lastActivityTime = time.Now()
c.cancelQueryCompleted = make(chan struct{}, 1)
c.doneChan = make(chan struct{})
c.closedChan = make(chan error)
c.wbuf = make([]byte, 0, 1024)
c.mux.Lock()
c.status = connStatusIdle
c.mux.Unlock()
if tlsConfig != nil {
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "starting TLS handshake", nil)
@ -401,19 +408,17 @@ func (c *Conn) PID() uint32 {
// Close closes a connection. It is safe to call Close on a already closed
// connection.
func (c *Conn) Close() (err error) {
for {
status := atomic.LoadInt32(&c.status)
if status < connStatusIdle {
return nil
}
if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) {
break
}
c.mux.Lock()
defer c.mux.Unlock()
if c.status < connStatusIdle {
return nil
}
c.status = connStatusClosed
defer func() {
c.conn.Close()
c.die(errors.New("Closed"))
c.causeOfDeath = errors.New("Closed")
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "closed connection", nil)
}
@ -989,10 +994,14 @@ func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notificat
}
func (c *Conn) IsAlive() bool {
return atomic.LoadInt32(&c.status) >= connStatusIdle
c.mux.Lock()
defer c.mux.Unlock()
return c.status >= connStatusIdle
}
func (c *Conn) CauseOfDeath() error {
c.mux.Lock()
defer c.mux.Unlock()
return c.causeOfDeath
}
@ -1131,7 +1140,7 @@ func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
}
func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) {
if atomic.LoadInt32(&c.status) < connStatusIdle {
if !c.IsAlive() {
return nil, ErrDeadConn
}
@ -1283,23 +1292,40 @@ func (c *Conn) txPasswordMessage(password string) (err error) {
}
func (c *Conn) die(err error) {
atomic.StoreInt32(&c.status, connStatusClosed)
c.mux.Lock()
defer c.mux.Unlock()
if c.status == connStatusClosed {
return
}
c.status = connStatusClosed
c.causeOfDeath = err
c.conn.Close()
}
func (c *Conn) lock() error {
if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) {
return nil
c.mux.Lock()
defer c.mux.Unlock()
if c.status != connStatusIdle {
return ErrConnBusy
}
return ErrConnBusy
c.status = connStatusBusy
return nil
}
func (c *Conn) unlock() error {
if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) {
return nil
c.mux.Lock()
defer c.mux.Unlock()
if c.status != connStatusBusy {
return errors.New("unlock conn that is not busy")
}
return errors.New("unlock conn that is not busy")
c.status = connStatusIdle
return nil
}
func (c *Conn) shouldLog(lvl int) bool {