diff --git a/connection.go b/connection.go index f30d3644..4dba463d 100644 --- a/connection.go +++ b/connection.go @@ -106,7 +106,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { } func (c *Connection) Close() (err error) { - return c.txMsg('X', c.getBuf(0)) + return c.txMsg('X', c.getBuf()) } func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error) (err error) { @@ -212,8 +212,7 @@ func (c *Connection) SelectValues(sql string) (values []interface{}, err error) } func (c *Connection) sendSimpleQuery(sql string) (err error) { - bufSize := len(sql) + 1 // sql, null string terminator (1) - buf := c.getBuf(bufSize) + buf := c.getBuf() _, err = buf.WriteString(sql) if err != nil { @@ -293,7 +292,7 @@ func (c *Connection) rxMsg() (t byte, r *MessageReader, err error) { } func (c *Connection) rxMsgHeader() (t byte, bodySize int32, err error) { - buf := c.getBuf(5) + buf := c.getBuf() if _, err = io.CopyN(buf, c.conn, 5); err != nil { return 0, 0, err } @@ -308,7 +307,7 @@ func (c *Connection) rxMsgHeader() (t byte, bodySize int32, err error) { } func (c *Connection) rxMsgBody(bodySize int32) (buf *bytes.Buffer, err error) { - buf = c.getBuf(int(bodySize)) + buf = c.getBuf() _, err = io.CopyN(buf, c.conn, int64(bodySize)) return } @@ -421,8 +420,7 @@ func (c *Connection) txMsg(identifier byte, buf *bytes.Buffer) (err error) { } func (c *Connection) txPasswordMessage(password string) (err error) { - bufSize := len(password) + 1 // password, null string terminator (1) - buf := c.getBuf(bufSize) + buf := c.getBuf() _, err = buf.WriteString(password) if err != nil { @@ -436,12 +434,13 @@ func (c *Connection) txPasswordMessage(password string) (err error) { return } -// Gets a buffer of up to n bytes. If it is a large request it will return a new buffer, if it is small enough it will return the shared connection buffer -func (c *Connection) getBuf(n int) *bytes.Buffer { - if n < sharedBufferSize { - c.buf.Reset() - return c.buf - } else { - return bytes.NewBuffer(make([]byte, n)) +// Gets the shared connection buffer. Since bytes.Buffer never releases memory from +// its internal byte array, check on the size and create a new bytes.Buffer so the +// old one can get GC'ed +func (c *Connection) getBuf() *bytes.Buffer { + c.buf.Reset() + if cap(c.buf.Bytes()) > sharedBufferSize { + c.buf = bytes.NewBuffer(make([]byte, sharedBufferSize)) } + return c.buf }