From d9522a474187818767db0f5336ff3b63bc300e43 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 20 Jun 2014 16:33:51 -0500 Subject: [PATCH] Inform database/sql when connections die --- conn.go | 8 ++++---- stdlib/sql.go | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 7821fc67..7c6f261e 100644 --- a/conn.go +++ b/conn.go @@ -116,6 +116,7 @@ func (e ProtocolError) Error() string { } var NotificationTimeoutError = errors.New("Notification Timeout") +var DeadConnError = errors.New("Connection is dead") // Connect establishes a connection with a PostgreSQL server using config. One // of config.Socket or config.Host must be specified. config.User @@ -966,8 +967,7 @@ func (c *Conn) rxMsg() (t byte, r *MessageReader, err error) { func (c *Conn) rxMsgHeader() (t byte, bodySize int32, err error) { if !c.alive { - err = errors.New("Connection is dead") - return + return 0, 0, DeadConnError } defer func() { @@ -987,7 +987,7 @@ func (c *Conn) rxMsgHeader() (t byte, bodySize int32, err error) { func (c *Conn) rxMsgBody(bodySize int32) (*bytes.Buffer, error) { if !c.alive { - return nil, errors.New("Connection is dead") + return nil, DeadConnError } buf := c.getBuf() @@ -1135,7 +1135,7 @@ func (c *Conn) txStartupMessage(msg *startupMessage) (err error) { func (c *Conn) txMsg(identifier byte, buf *bytes.Buffer, flush bool) (err error) { if !c.alive { - return errors.New("Connection is dead") + return DeadConnError } defer func() { diff --git a/stdlib/sql.go b/stdlib/sql.go index b276861a..14f26443 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -36,6 +36,10 @@ type Conn struct { } func (c *Conn) Prepare(query string) (driver.Stmt, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + name := fmt.Sprintf("pgx_%d", c.psCount) c.psCount++ @@ -52,6 +56,10 @@ func (c *Conn) Close() error { } func (c *Conn) Begin() (driver.Tx, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + _, err := c.conn.Execute("begin") if err != nil { return nil, err @@ -74,12 +82,20 @@ func (s *Stmt) NumInput() int { } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { + if !s.conn.IsAlive() { + return nil, driver.ErrBadConn + } + args := valueToInterface(argsV) commandTag, err := s.conn.Execute(s.ps.Name, args...) return driver.RowsAffected(commandTag.RowsAffected()), err } func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { + if !s.conn.IsAlive() { + return nil, driver.ErrBadConn + } + args := valueToInterface(argsV) rowCount := 0