diff --git a/conn.go b/conn.go index 78bdcedc..7243a4d1 100644 --- a/conn.go +++ b/conn.go @@ -18,10 +18,17 @@ import ( "regexp" "strconv" "strings" - "sync" + "sync/atomic" "time" ) +const ( + connStatusUninitialized = iota + connStatusClosed + connStatusIdle + connStatusBusy +) + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) @@ -80,12 +87,10 @@ type Conn struct { fp *fastpath pgsqlAfInet *byte pgsqlAfInet6 *byte - busy bool poolResetCount int preallocatedRows []Rows - closingLock sync.Mutex - alive bool + status int32 // One of connStatus* constants causeOfDeath error // context support @@ -252,14 +257,14 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl defer func() { if c != nil && err != nil { c.conn.Close() - c.alive = false + atomic.StoreInt32(&c.status, connStatusClosed) } }() c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) - c.alive = true + atomic.StoreInt32(&c.status, connStatusIdle) c.lastActivityTime = time.Now() c.doneChan = make(chan struct{}) c.closedChan = make(chan error) @@ -399,11 +404,14 @@ func (c *Conn) loadInetConstants() error { // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { - c.closingLock.Lock() - defer c.closingLock.Unlock() - - if !c.alive { - return nil + for { + status := atomic.LoadInt32(&c.status) + if status < connStatusIdle { + return nil + } + if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) { + break + } } _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) @@ -893,10 +901,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } func (c *Conn) IsAlive() bool { - c.closingLock.Lock() - alive := c.alive - c.closingLock.Unlock() - return alive + return atomic.LoadInt32(&c.status) >= connStatusIdle } func (c *Conn) CauseOfDeath() error { @@ -1071,12 +1076,9 @@ func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { - c.closingLock.Lock() - if !c.alive { - c.closingLock.Unlock() + if atomic.LoadInt32(&c.status) < connStatusIdle { return 0, nil, ErrDeadConn } - c.closingLock.Unlock() t, err = c.mr.rxMsg() if err != nil { @@ -1261,25 +1263,23 @@ func (c *Conn) txPasswordMessage(password string) (err error) { } func (c *Conn) die(err error) { - c.alive = false + atomic.StoreInt32(&c.status, connStatusClosed) c.causeOfDeath = err c.conn.Close() } func (c *Conn) lock() error { - if c.busy { - return ErrConnBusy + if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) { + return nil } - c.busy = true - return nil + return ErrConnBusy } func (c *Conn) unlock() error { - if !c.busy { - return errors.New("unlock conn that is not busy") + if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) { + return nil } - c.busy = false - return nil + return errors.New("unlock conn that is not busy") } func (c *Conn) shouldLog(lvl int) bool {