diff --git a/conn.go b/conn.go index bd098646..20844e57 100644 --- a/conn.go +++ b/conn.go @@ -795,9 +795,11 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared buf = append(buf, 'S') buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(buf) + n, err := c.conn.Write(buf) if err != nil { - c.die(err) + if fatalWriteErr(n, err) { + c.die(err) + } return nil, err } c.readyForQuery = false @@ -1085,8 +1087,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} buf = append(buf, 'S') buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(buf) - if err != nil { + n, err := c.conn.Write(buf) + if err != nil && fatalWriteErr(n, err) { c.die(err) } c.readyForQuery = false @@ -1094,6 +1096,17 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} return err } +// fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal +func fatalWriteErr(bytesWritten int, err error) bool { + // Partial writes break the connection + if bytesWritten > 0 { + return true + } + + netErr, is := err.(net.Error) + return !(is && netErr.Timeout()) +} + // Exec executes sql. sql can be either a prepared statement name or an SQL string. // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { diff --git a/query.go b/query.go index 681c133b..44bf004a 100644 --- a/query.go +++ b/query.go @@ -398,16 +398,15 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, err = c.initContext(ctx) if err != nil { rows.fatal(err) - return rows, err + return rows, rows.err } err = c.sendPreparedQuery(ps, args...) if err != nil { rows.fatal(err) - err = c.termContext(err) } - return rows, err + return rows, rows.err } func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) {