Buffer writes to the PostgreSQL socket

Avoid sending a packet for each write
This commit is contained in:
Jack Christensen 2013-07-26 14:48:41 -05:00
parent ad34dc7264
commit 3873a83a0a

View File

@ -4,6 +4,7 @@
package pgx
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/binary"
@ -30,6 +31,7 @@ type ConnectionParameters struct {
// goroutines.
type Connection struct {
conn net.Conn // the underlying TCP or unix domain socket connection
writer *bufio.Writer // buffered writer to avoid sending tiny packets
buf *bytes.Buffer // work buffer to avoid constant alloc and dealloc
Pid int32 // backend pid
SecretKey int32 // key to use to send a cancel query message to the server
@ -106,7 +108,13 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) {
return nil, err
}
}
defer func() {
if c != nil && err != nil {
c.conn.Close()
}
}()
c.writer = bufio.NewWriter(c.conn)
c.buf = bytes.NewBuffer(make([]byte, 0, sharedBufferSize))
c.RuntimeParams = make(map[string]string)
c.preparedStatements = make(map[string]*preparedStatement)
@ -116,7 +124,9 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) {
if c.parameters.Database != "" {
msg.options["database"] = c.parameters.Database
}
c.txStartupMessage(msg)
if err = c.txStartupMessage(msg); err != nil {
return
}
for {
var t byte
@ -145,7 +155,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) {
}
func (c *Connection) Close() (err error) {
err = c.txMsg('X', c.getBuf())
err = c.txMsg('X', c.getBuf(), true)
c.die(errors.New("Closed"))
return err
}
@ -296,7 +306,7 @@ func (c *Connection) Prepare(name, sql string) (err error) {
if w.Err != nil {
return w.Err
}
err = c.txMsg('P', buf)
err = c.txMsg('P', buf, false)
if err != nil {
return
}
@ -310,13 +320,13 @@ func (c *Connection) Prepare(name, sql string) (err error) {
return w.Err
}
err = c.txMsg('D', buf)
err = c.txMsg('D', buf, false)
if err != nil {
return
}
// sync
err = c.txMsg('S', c.getBuf())
err = c.txMsg('S', c.getBuf(), true)
if err != nil {
return err
}
@ -433,7 +443,7 @@ func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err
return
}
return c.txMsg('Q', buf)
return c.txMsg('Q', buf, true)
}
func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...interface{}) (err error) {
@ -477,7 +487,7 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter
return w.Err
}
err = c.txMsg('B', buf)
err = c.txMsg('B', buf, false)
if err != nil {
return err
}
@ -492,13 +502,13 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter
return w.Err
}
err = c.txMsg('E', buf)
err = c.txMsg('E', buf, false)
if err != nil {
return err
}
// sync
err = c.txMsg('S', c.getBuf())
err = c.txMsg('S', c.getBuf(), true)
if err != nil {
return err
}
@ -653,11 +663,11 @@ func (c *Connection) rxAuthenticationX(r *MessageReader) (err error) {
switch code {
case 0: // AuthenticationOk
case 3: // AuthenticationCleartextPassword
c.txPasswordMessage(c.parameters.Password)
err = c.txPasswordMessage(c.parameters.Password)
case 5: // AuthenticationMD5Password
salt := r.ReadString(4)
digestedPassword := "md5" + hexMD5(hexMD5(c.parameters.Password+c.parameters.User)+salt)
c.txPasswordMessage(digestedPassword)
err = c.txPasswordMessage(digestedPassword)
default:
err = errors.New("Received unknown authentication message")
}
@ -752,28 +762,40 @@ func (c *Connection) rxNotificationResponse(r *MessageReader) (err error) {
}
func (c *Connection) txStartupMessage(msg *startupMessage) (err error) {
_, err = c.conn.Write(msg.Bytes())
_, err = c.writer.Write(msg.Bytes())
if err != nil {
return
}
err = c.writer.Flush()
return
}
func (c *Connection) txMsg(identifier byte, buf *bytes.Buffer) (err error) {
func (c *Connection) txMsg(identifier byte, buf *bytes.Buffer, flush bool) (err error) {
defer func() {
if err != nil {
c.die(err)
}
}()
err = binary.Write(c.conn, binary.BigEndian, identifier)
err = binary.Write(c.writer, binary.BigEndian, identifier)
if err != nil {
return
}
err = binary.Write(c.conn, binary.BigEndian, int32(buf.Len()+4))
err = binary.Write(c.writer, binary.BigEndian, int32(buf.Len()+4))
if err != nil {
return
}
_, err = buf.WriteTo(c.conn)
_, err = buf.WriteTo(c.writer)
if err != nil {
return
}
if flush {
err = c.writer.Flush()
}
return
}
@ -788,7 +810,7 @@ func (c *Connection) txPasswordMessage(password string) (err error) {
if err != nil {
return
}
err = c.txMsg('p', buf)
err = c.txMsg('p', buf, true)
return
}