From 458dd24a9fd8067f86c5ab765b21a52374ce49b3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 May 2017 21:26:45 -0500 Subject: [PATCH] Remove unneeded WriteBuf --- conn.go | 118 +++++++++++++++++++++++++++++++------------------ copy_from.go | 57 ++++++++++++++---------- fastpath.go | 29 +++++++----- messages.go | 87 ------------------------------------ replication.go | 23 ++++++---- values.go | 70 ++++++++++++++--------------- 6 files changed, 173 insertions(+), 211 deletions(-) diff --git a/conn.go b/conn.go index bca9f6d8..a1781be2 100644 --- a/conn.go +++ b/conn.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -86,8 +87,7 @@ func (cc *ConnConfig) networkAddress() (network, address string) { type Conn struct { conn net.Conn // the underlying TCP or unix domain socket connection lastActivityTime time.Time // the last time the connection was used - wbuf [1024]byte - writeBuf WriteBuf + wbuf []byte pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server @@ -279,6 +279,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.cancelQueryCompleted = make(chan struct{}, 1) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) + c.wbuf = make([]byte, 0, 1024) if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { @@ -707,32 +708,42 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } // parse - wbuf := newWriteBuf(c, 'P') - wbuf.WriteCString(name) - wbuf.WriteCString(sql) + buf := c.wbuf + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, name...) + buf = append(buf, 0) + buf = append(buf, sql...) + buf = append(buf, 0) if opts != nil { if len(opts.ParameterOids) > 65535 { return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) } - wbuf.WriteInt16(int16(len(opts.ParameterOids))) + buf = pgio.AppendInt16(buf, int16(len(opts.ParameterOids))) for _, oid := range opts.ParameterOids { - wbuf.WriteInt32(int32(oid)) + buf = pgio.AppendInt32(buf, int32(oid)) } } else { - wbuf.WriteInt16(0) + buf = pgio.AppendInt16(buf, 0) } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) // describe - wbuf.startMsg('D') - wbuf.WriteByte('S') - wbuf.WriteCString(name) + buf = append(buf, 'D') + sp = len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 'S') + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) // sync - wbuf.startMsg('S') - wbuf.closeMsg() + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) if err != nil { c.die(err) return nil, err @@ -813,15 +824,20 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { delete(c.preparedStatements, name) // close - wbuf := newWriteBuf(c, 'C') - wbuf.WriteByte('S') - wbuf.WriteCString(name) + buf := c.wbuf + buf = append(buf, 'C') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 'S') + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) // flush - wbuf.startMsg('H') - wbuf.closeMsg() + buf = append(buf, 'H') + buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) if err != nil { c.die(err) return err @@ -943,11 +959,15 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { } if len(args) == 0 { - wbuf := newWriteBuf(c, 'Q') - wbuf.WriteCString(sql) - wbuf.closeMsg() + buf := c.wbuf + buf = append(buf, 'Q') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, sql...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err := c.conn.Write(wbuf.buf) + _, err := c.conn.Write(buf) if err != nil { c.die(err) return err @@ -975,37 +995,45 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} } // bind - wbuf := newWriteBuf(c, 'B') - wbuf.WriteByte(0) - wbuf.WriteCString(ps.Name) + buf := c.wbuf + buf = append(buf, 'B') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 0) + buf = append(buf, ps.Name...) + buf = append(buf, 0) - wbuf.WriteInt16(int16(len(ps.ParameterOids))) + buf = pgio.AppendInt16(buf, int16(len(ps.ParameterOids))) for i, oid := range ps.ParameterOids { - wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) + buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) } - wbuf.WriteInt16(int16(len(arguments))) + buf = pgio.AppendInt16(buf, int16(len(arguments))) for i, oid := range ps.ParameterOids { - if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil { + var err error + buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i]) + if err != nil { return err } } - wbuf.WriteInt16(int16(len(ps.FieldDescriptions))) + buf = pgio.AppendInt16(buf, int16(len(ps.FieldDescriptions))) for _, fd := range ps.FieldDescriptions { - wbuf.WriteInt16(fd.FormatCode) + buf = pgio.AppendInt16(buf, fd.FormatCode) } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) // execute - wbuf.startMsg('E') - wbuf.WriteByte(0) - wbuf.WriteInt32(0) + buf = append(buf, 'E') + buf = pgio.AppendInt32(buf, 9) + buf = append(buf, 0) + buf = pgio.AppendInt32(buf, 0) // sync - wbuf.startMsg('S') - wbuf.closeMsg() + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) if err != nil { c.die(err) } @@ -1180,11 +1208,15 @@ func (c *Conn) txStartupMessage(msg *startupMessage) error { } func (c *Conn) txPasswordMessage(password string) (err error) { - wbuf := newWriteBuf(c, 'p') - wbuf.WriteCString(password) - wbuf.closeMsg() + buf := c.wbuf + buf = append(buf, 'p') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, password...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) return err } diff --git a/copy_from.go b/copy_from.go index 7d8dead1..f3c77109 100644 --- a/copy_from.go +++ b/copy_from.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) @@ -89,14 +90,14 @@ func (ct *copyFrom) waitForReaderDone() error { func (ct *copyFrom) run() (int, error) { quotedTableName := ct.tableName.Sanitize() - buf := &bytes.Buffer{} + cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { if i != 0 { - buf.WriteString(", ") + cbuf.WriteString(", ") } - buf.WriteString(quoteIdentifier(cn)) + cbuf.WriteString(quoteIdentifier(cn)) } - quotedColumnNames := buf.String() + quotedColumnNames := cbuf.String() ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) if err != nil { @@ -116,11 +117,14 @@ func (ct *copyFrom) run() (int, error) { go ct.readUntilReadyForQuery() defer ct.waitForReaderDone() - wbuf := newWriteBuf(ct.conn, copyData) + buf := ct.conn.wbuf + buf = append(buf, copyData) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) - wbuf.WriteInt32(0) - wbuf.WriteInt32(0) + buf = append(buf, "PGCOPY\n\377\r\n\000"...) + buf = pgio.AppendInt32(buf, 0) + buf = pgio.AppendInt32(buf, 0) var sentCount int @@ -131,18 +135,16 @@ func (ct *copyFrom) run() (int, error) { default: } - if len(wbuf.buf) > 65536 { - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) + if len(buf) > 65536 { + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + _, err = ct.conn.conn.Write(buf) if err != nil { ct.conn.die(err) return 0, err } // Directly manipulate wbuf to reset to reuse the same buffer - wbuf.buf = wbuf.buf[0:5] - wbuf.buf[0] = copyData - wbuf.sizeIdx = 1 + buf = buf[0:5] } sentCount++ @@ -157,9 +159,9 @@ func (ct *copyFrom) run() (int, error) { return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) } - wbuf.WriteInt16(int16(len(ct.columnNames))) + buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val) + buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val) if err != nil { ct.cancelCopyIn() return 0, err @@ -173,11 +175,13 @@ func (ct *copyFrom) run() (int, error) { return 0, ct.rowSrc.Err() } - wbuf.WriteInt16(-1) // terminate the copy stream + buf = pgio.AppendInt16(buf, -1) // terminate the copy stream + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - wbuf.startMsg(copyDone) - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) + buf = append(buf, copyDone) + buf = pgio.AppendInt32(buf, 4) + + _, err = ct.conn.conn.Write(buf) if err != nil { ct.conn.die(err) return 0, err @@ -210,10 +214,15 @@ func (c *Conn) readUntilCopyInResponse() error { } func (ct *copyFrom) cancelCopyIn() error { - wbuf := newWriteBuf(ct.conn, copyFail) - wbuf.WriteCString("client error: abort") - wbuf.closeMsg() - _, err := ct.conn.conn.Write(wbuf.buf) + buf := ct.conn.wbuf + buf = append(buf, copyFail) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, "client error: abort"...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + _, err := ct.conn.conn.Write(buf) if err != nil { ct.conn.die(err) return err diff --git a/fastpath.go b/fastpath.go index 75681c9c..776be177 100644 --- a/fastpath.go +++ b/fastpath.go @@ -3,6 +3,7 @@ package pgx import ( "encoding/binary" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -55,19 +56,23 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) { return nil, err } - wbuf := newWriteBuf(f.cn, 'F') // function call - wbuf.WriteInt32(int32(oid)) // function object id - wbuf.WriteInt16(1) // # of argument format codes - wbuf.WriteInt16(1) // format code: binary - wbuf.WriteInt16(int16(len(args))) // # of arguments - for _, arg := range args { - wbuf.WriteInt32(int32(len(arg))) // length of argument - wbuf.WriteBytes(arg) // argument value - } - wbuf.WriteInt16(1) // response format code (binary) - wbuf.closeMsg() + buf := f.cn.wbuf + buf = append(buf, 'F') // function call + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - if _, err := f.cn.conn.Write(wbuf.buf); err != nil { + buf = pgio.AppendInt32(buf, int32(oid)) // function object id + buf = pgio.AppendInt16(buf, 1) // # of argument format codes + buf = pgio.AppendInt16(buf, 1) // format code: binary + buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments + for _, arg := range args { + buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument + buf = append(buf, arg...) // argument value + } + buf = pgio.AppendInt16(buf, 1) // response format code (binary) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + if _, err := f.cn.conn.Write(buf); err != nil { return nil, err } diff --git a/messages.go b/messages.go index 0f17a6d2..8e406602 100644 --- a/messages.go +++ b/messages.go @@ -92,90 +92,3 @@ type PgError struct { func (pe PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } - -func newWriteBuf(c *Conn, t byte) *WriteBuf { - buf := append(c.wbuf[0:0], t, 0, 0, 0, 0) - c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c} - return &c.writeBuf -} - -// WriteBuf is used build messages to send to the PostgreSQL server. It is used -// by the Encoder interface when implementing custom encoders. -type WriteBuf struct { - buf []byte - convBuf [8]byte - sizeIdx int - conn *Conn -} - -func (wb *WriteBuf) startMsg(t byte) { - wb.closeMsg() - wb.buf = append(wb.buf, t, 0, 0, 0, 0) - wb.sizeIdx = len(wb.buf) - 4 -} - -func (wb *WriteBuf) closeMsg() { - binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx)) -} - -func (wb *WriteBuf) reserveSize() int { - sizePosition := len(wb.buf) - wb.buf = append(wb.buf, 0, 0, 0, 0) - return sizePosition -} - -func (wb *WriteBuf) setComputedSize(sizePosition int) { - binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(len(wb.buf)-sizePosition-4)) -} - -func (wb *WriteBuf) setSize(sizePosition int, size int32) { - binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(size)) -} - -func (wb *WriteBuf) WriteByte(b byte) { - wb.buf = append(wb.buf, b) -} - -func (wb *WriteBuf) WriteCString(s string) { - wb.buf = append(wb.buf, []byte(s)...) - wb.buf = append(wb.buf, 0) -} - -func (wb *WriteBuf) WriteInt16(n int16) { - wb.WriteUint16(uint16(n)) -} - -func (wb *WriteBuf) WriteUint16(n uint16) (int, error) { - binary.BigEndian.PutUint16(wb.convBuf[:2], n) - wb.buf = append(wb.buf, wb.convBuf[:2]...) - return 2, nil -} - -func (wb *WriteBuf) WriteInt32(n int32) { - wb.WriteUint32(uint32(n)) -} - -func (wb *WriteBuf) WriteUint32(n uint32) (int, error) { - binary.BigEndian.PutUint32(wb.convBuf[:4], n) - wb.buf = append(wb.buf, wb.convBuf[:4]...) - return 4, nil -} - -func (wb *WriteBuf) WriteInt64(n int64) { - wb.WriteUint64(uint64(n)) -} - -func (wb *WriteBuf) WriteUint64(n uint64) (int, error) { - binary.BigEndian.PutUint64(wb.convBuf[:8], n) - wb.buf = append(wb.buf, wb.convBuf[:8]...) - return 8, nil -} - -func (wb *WriteBuf) WriteBytes(b []byte) { - wb.buf = append(wb.buf, b...) -} - -func (wb *WriteBuf) Write(b []byte) (int, error) { - wb.buf = append(wb.buf, b...) - return len(b), nil -} diff --git a/replication.go b/replication.go index 594944e0..1260d3e7 100644 --- a/replication.go +++ b/replication.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) @@ -175,17 +176,21 @@ type ReplicationConn struct { // message to the server, as well as carries the WAL position of the // client, which then updates the server's replication slot position. func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { - writeBuf := newWriteBuf(rc.c, copyData) - writeBuf.WriteByte(standbyStatusUpdate) - writeBuf.WriteInt64(int64(k.WalWritePosition)) - writeBuf.WriteInt64(int64(k.WalFlushPosition)) - writeBuf.WriteInt64(int64(k.WalApplyPosition)) - writeBuf.WriteInt64(int64(k.ClientTime)) - writeBuf.WriteByte(k.ReplyRequested) + buf := rc.c.wbuf + buf = append(buf, copyData) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - writeBuf.closeMsg() + buf = append(buf, standbyStatusUpdate) + buf = pgio.AppendInt64(buf, int64(k.WalWritePosition)) + buf = pgio.AppendInt64(buf, int64(k.WalFlushPosition)) + buf = pgio.AppendInt64(buf, int64(k.WalApplyPosition)) + buf = pgio.AppendInt64(buf, int64(k.ClientTime)) + buf = append(buf, k.ReplyRequested) - _, err = rc.c.conn.Write(writeBuf.buf) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + _, err = rc.c.conn.Write(buf) if err != nil { rc.c.die(err) } diff --git a/values.go b/values.go index b1928b86..ca5db50b 100644 --- a/values.go +++ b/values.go @@ -97,84 +97,82 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) } -func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { +func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.Oid, arg interface{}) ([]byte, error) { if arg == nil { - wbuf.WriteInt32(-1) - return nil + return pgio.AppendInt32(buf, -1), nil } switch arg := arg.(type) { case pgtype.BinaryEncoder: - sp := len(wbuf.buf) - wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) - argBuf, err := arg.EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeBinary(ci, buf) if err != nil { - return err + return nil, err } if argBuf != nil { - wbuf.buf = argBuf - pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return nil + return buf, nil case pgtype.TextEncoder: - sp := len(wbuf.buf) - wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) - argBuf, err := arg.EncodeText(wbuf.conn.ConnInfo, wbuf.buf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeText(ci, buf) if err != nil { - return err + return nil, err } if argBuf != nil { - wbuf.buf = argBuf - pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return nil + return buf, nil case driver.Valuer: v, err := arg.Value() if err != nil { - return err + return nil, err } - return encodePreparedStatementArgument(wbuf, oid, v) + return encodePreparedStatementArgument(ci, buf, oid, v) case string: - wbuf.WriteInt32(int32(len(arg))) - wbuf.WriteBytes([]byte(arg)) - return nil + buf = pgio.AppendInt32(buf, int32(len(arg))) + buf = append(buf, arg...) + return buf, nil } refVal := reflect.ValueOf(arg) if refVal.Kind() == reflect.Ptr { if refVal.IsNil() { - wbuf.WriteInt32(-1) - return nil + return pgio.AppendInt32(buf, -1), nil } arg = refVal.Elem().Interface() - return encodePreparedStatementArgument(wbuf, oid, arg) + return encodePreparedStatementArgument(ci, buf, oid, arg) } - if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok { + if dt, ok := ci.DataTypeForOid(oid); ok { value := dt.Value err := value.Set(arg) if err != nil { - return err + return nil, err } - sp := len(wbuf.buf) - wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) - argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) if err != nil { - return err + return nil, err } if argBuf != nil { - wbuf.buf = argBuf - pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return nil + return buf, nil } if strippedArg, ok := stripNamedType(&refVal); ok { - return encodePreparedStatementArgument(wbuf, oid, strippedArg) + return encodePreparedStatementArgument(ci, buf, oid, strippedArg) } - return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } // chooseParameterFormatCode determines the correct format code for an