diff --git a/connection.go b/connection.go index d88c0c95..f55f5602 100644 --- a/connection.go +++ b/connection.go @@ -113,6 +113,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { defer func() { if c != nil && err != nil { c.conn.Close() + c.alive = false } }() @@ -121,6 +122,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize)) c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*preparedStatement) + c.alive = true msg := newStartupMessage() msg.options["user"] = c.parameters.User @@ -144,7 +146,6 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { } case readyForQuery: c.rxReadyForQuery(r) - c.alive = true return c, nil default: if err = c.processContextFreeMsg(t, r); err != nil { @@ -619,12 +620,6 @@ func (c *Connection) processContextFreeMsg(t byte, r *MessageReader) (err error) } func (c *Connection) rxMsg() (t byte, r *MessageReader, err error) { - defer func() { - if err != nil { - c.die(err) - } - }() - var bodySize int32 t, bodySize, err = c.rxMsgHeader() if err != nil { @@ -641,6 +636,17 @@ func (c *Connection) rxMsg() (t byte, r *MessageReader, err error) { } func (c *Connection) rxMsgHeader() (t byte, bodySize int32, err error) { + if !c.alive { + err = errors.New("Connection is dead") + return + } + + defer func() { + if err != nil { + c.die(err) + } + }() + buf := c.getBuf() if _, err = io.CopyN(buf, c.conn, 5); err != nil { return 0, 0, err @@ -656,6 +662,17 @@ func (c *Connection) rxMsgHeader() (t byte, bodySize int32, err error) { } func (c *Connection) rxMsgBody(bodySize int32) (buf *bytes.Buffer, err error) { + if !c.alive { + err = errors.New("Connection is dead") + return + } + + defer func() { + if err != nil { + c.die(err) + } + }() + buf = c.getBuf() _, err = io.CopyN(buf, c.conn, int64(bodySize)) return @@ -774,6 +791,10 @@ func (c *Connection) txStartupMessage(msg *startupMessage) (err error) { } func (c *Connection) txMsg(identifier byte, buf *bytes.Buffer, flush bool) (err error) { + if !c.alive { + return errors.New("Connection is dead") + } + defer func() { if err != nil { c.die(err)