From 3873a83a0aa679e912f437551567d53184bad847 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 26 Jul 2013 14:48:41 -0500 Subject: [PATCH] Buffer writes to the PostgreSQL socket Avoid sending a packet for each write --- connection.go | 56 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/connection.go b/connection.go index 251fb294..0e6a78ce 100644 --- a/connection.go +++ b/connection.go @@ -4,6 +4,7 @@ package pgx import ( + "bufio" "bytes" "crypto/md5" "encoding/binary" @@ -30,6 +31,7 @@ type ConnectionParameters struct { // goroutines. type Connection struct { conn net.Conn // the underlying TCP or unix domain socket connection + writer *bufio.Writer // buffered writer to avoid sending tiny packets buf *bytes.Buffer // work buffer to avoid constant alloc and dealloc Pid int32 // backend pid SecretKey int32 // key to use to send a cancel query message to the server @@ -106,7 +108,13 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { return nil, err } } + defer func() { + if c != nil && err != nil { + c.conn.Close() + } + }() + c.writer = bufio.NewWriter(c.conn) c.buf = bytes.NewBuffer(make([]byte, 0, sharedBufferSize)) c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*preparedStatement) @@ -116,7 +124,9 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { if c.parameters.Database != "" { msg.options["database"] = c.parameters.Database } - c.txStartupMessage(msg) + if err = c.txStartupMessage(msg); err != nil { + return + } for { var t byte @@ -145,7 +155,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { } func (c *Connection) Close() (err error) { - err = c.txMsg('X', c.getBuf()) + err = c.txMsg('X', c.getBuf(), true) c.die(errors.New("Closed")) return err } @@ -296,7 +306,7 @@ func (c *Connection) Prepare(name, sql string) (err error) { if w.Err != nil { return w.Err } - err = c.txMsg('P', buf) + err = c.txMsg('P', buf, false) if err != nil { return } @@ -310,13 +320,13 @@ func (c *Connection) Prepare(name, sql string) (err error) { return w.Err } - err = c.txMsg('D', buf) + err = c.txMsg('D', buf, false) if err != nil { return } // sync - err = c.txMsg('S', c.getBuf()) + err = c.txMsg('S', c.getBuf(), true) if err != nil { return err } @@ -433,7 +443,7 @@ func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err return } - return c.txMsg('Q', buf) + return c.txMsg('Q', buf, true) } func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...interface{}) (err error) { @@ -477,7 +487,7 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter return w.Err } - err = c.txMsg('B', buf) + err = c.txMsg('B', buf, false) if err != nil { return err } @@ -492,13 +502,13 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter return w.Err } - err = c.txMsg('E', buf) + err = c.txMsg('E', buf, false) if err != nil { return err } // sync - err = c.txMsg('S', c.getBuf()) + err = c.txMsg('S', c.getBuf(), true) if err != nil { return err } @@ -653,11 +663,11 @@ func (c *Connection) rxAuthenticationX(r *MessageReader) (err error) { switch code { case 0: // AuthenticationOk case 3: // AuthenticationCleartextPassword - c.txPasswordMessage(c.parameters.Password) + err = c.txPasswordMessage(c.parameters.Password) case 5: // AuthenticationMD5Password salt := r.ReadString(4) digestedPassword := "md5" + hexMD5(hexMD5(c.parameters.Password+c.parameters.User)+salt) - c.txPasswordMessage(digestedPassword) + err = c.txPasswordMessage(digestedPassword) default: err = errors.New("Received unknown authentication message") } @@ -752,28 +762,40 @@ func (c *Connection) rxNotificationResponse(r *MessageReader) (err error) { } func (c *Connection) txStartupMessage(msg *startupMessage) (err error) { - _, err = c.conn.Write(msg.Bytes()) + _, err = c.writer.Write(msg.Bytes()) + if err != nil { + return + } + err = c.writer.Flush() return } -func (c *Connection) txMsg(identifier byte, buf *bytes.Buffer) (err error) { +func (c *Connection) txMsg(identifier byte, buf *bytes.Buffer, flush bool) (err error) { defer func() { if err != nil { c.die(err) } }() - err = binary.Write(c.conn, binary.BigEndian, identifier) + err = binary.Write(c.writer, binary.BigEndian, identifier) if err != nil { return } - err = binary.Write(c.conn, binary.BigEndian, int32(buf.Len()+4)) + err = binary.Write(c.writer, binary.BigEndian, int32(buf.Len()+4)) if err != nil { return } - _, err = buf.WriteTo(c.conn) + _, err = buf.WriteTo(c.writer) + if err != nil { + return + } + + if flush { + err = c.writer.Flush() + } + return } @@ -788,7 +810,7 @@ func (c *Connection) txPasswordMessage(password string) (err error) { if err != nil { return } - err = c.txMsg('p', buf) + err = c.txMsg('p', buf, true) return }