Remove unneeded WriteBuf

batch-wip
Jack Christensen 2017-05-02 21:26:45 -05:00
parent 6e64a0c867
commit 458dd24a9f
6 changed files with 173 additions and 211 deletions

118
conn.go
View File

@ -20,6 +20,7 @@ import (
"sync/atomic"
"time"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
@ -86,8 +87,7 @@ func (cc *ConnConfig) networkAddress() (network, address string) {
type Conn struct {
conn net.Conn // the underlying TCP or unix domain socket connection
lastActivityTime time.Time // the last time the connection was used
wbuf [1024]byte
writeBuf WriteBuf
wbuf []byte
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
RuntimeParams map[string]string // parameters that have been reported by the server
@ -279,6 +279,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.cancelQueryCompleted = make(chan struct{}, 1)
c.doneChan = make(chan struct{})
c.closedChan = make(chan error)
c.wbuf = make([]byte, 0, 1024)
if tlsConfig != nil {
if c.shouldLog(LogLevelDebug) {
@ -707,32 +708,42 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
}
// parse
wbuf := newWriteBuf(c, 'P')
wbuf.WriteCString(name)
wbuf.WriteCString(sql)
buf := c.wbuf
buf = append(buf, 'P')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, name...)
buf = append(buf, 0)
buf = append(buf, sql...)
buf = append(buf, 0)
if opts != nil {
if len(opts.ParameterOids) > 65535 {
return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))
}
wbuf.WriteInt16(int16(len(opts.ParameterOids)))
buf = pgio.AppendInt16(buf, int16(len(opts.ParameterOids)))
for _, oid := range opts.ParameterOids {
wbuf.WriteInt32(int32(oid))
buf = pgio.AppendInt32(buf, int32(oid))
}
} else {
wbuf.WriteInt16(0)
buf = pgio.AppendInt16(buf, 0)
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// describe
wbuf.startMsg('D')
wbuf.WriteByte('S')
wbuf.WriteCString(name)
buf = append(buf, 'D')
sp = len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 'S')
buf = append(buf, name...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// sync
wbuf.startMsg('S')
wbuf.closeMsg()
buf = append(buf, 'S')
buf = pgio.AppendInt32(buf, 4)
_, err = c.conn.Write(wbuf.buf)
_, err = c.conn.Write(buf)
if err != nil {
c.die(err)
return nil, err
@ -813,15 +824,20 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
delete(c.preparedStatements, name)
// close
wbuf := newWriteBuf(c, 'C')
wbuf.WriteByte('S')
wbuf.WriteCString(name)
buf := c.wbuf
buf = append(buf, 'C')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 'S')
buf = append(buf, name...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// flush
wbuf.startMsg('H')
wbuf.closeMsg()
buf = append(buf, 'H')
buf = pgio.AppendInt32(buf, 4)
_, err = c.conn.Write(wbuf.buf)
_, err = c.conn.Write(buf)
if err != nil {
c.die(err)
return err
@ -943,11 +959,15 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
}
if len(args) == 0 {
wbuf := newWriteBuf(c, 'Q')
wbuf.WriteCString(sql)
wbuf.closeMsg()
buf := c.wbuf
buf = append(buf, 'Q')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, sql...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err := c.conn.Write(wbuf.buf)
_, err := c.conn.Write(buf)
if err != nil {
c.die(err)
return err
@ -975,37 +995,45 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
}
// bind
wbuf := newWriteBuf(c, 'B')
wbuf.WriteByte(0)
wbuf.WriteCString(ps.Name)
buf := c.wbuf
buf = append(buf, 'B')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 0)
buf = append(buf, ps.Name...)
buf = append(buf, 0)
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
buf = pgio.AppendInt16(buf, int16(len(ps.ParameterOids)))
for i, oid := range ps.ParameterOids {
wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
}
wbuf.WriteInt16(int16(len(arguments)))
buf = pgio.AppendInt16(buf, int16(len(arguments)))
for i, oid := range ps.ParameterOids {
if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil {
var err error
buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i])
if err != nil {
return err
}
}
wbuf.WriteInt16(int16(len(ps.FieldDescriptions)))
buf = pgio.AppendInt16(buf, int16(len(ps.FieldDescriptions)))
for _, fd := range ps.FieldDescriptions {
wbuf.WriteInt16(fd.FormatCode)
buf = pgio.AppendInt16(buf, fd.FormatCode)
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// execute
wbuf.startMsg('E')
wbuf.WriteByte(0)
wbuf.WriteInt32(0)
buf = append(buf, 'E')
buf = pgio.AppendInt32(buf, 9)
buf = append(buf, 0)
buf = pgio.AppendInt32(buf, 0)
// sync
wbuf.startMsg('S')
wbuf.closeMsg()
buf = append(buf, 'S')
buf = pgio.AppendInt32(buf, 4)
_, err = c.conn.Write(wbuf.buf)
_, err = c.conn.Write(buf)
if err != nil {
c.die(err)
}
@ -1180,11 +1208,15 @@ func (c *Conn) txStartupMessage(msg *startupMessage) error {
}
func (c *Conn) txPasswordMessage(password string) (err error) {
wbuf := newWriteBuf(c, 'p')
wbuf.WriteCString(password)
wbuf.closeMsg()
buf := c.wbuf
buf = append(buf, 'p')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, password...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = c.conn.Write(wbuf.buf)
_, err = c.conn.Write(buf)
return err
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
)
@ -89,14 +90,14 @@ func (ct *copyFrom) waitForReaderDone() error {
func (ct *copyFrom) run() (int, error) {
quotedTableName := ct.tableName.Sanitize()
buf := &bytes.Buffer{}
cbuf := &bytes.Buffer{}
for i, cn := range ct.columnNames {
if i != 0 {
buf.WriteString(", ")
cbuf.WriteString(", ")
}
buf.WriteString(quoteIdentifier(cn))
cbuf.WriteString(quoteIdentifier(cn))
}
quotedColumnNames := buf.String()
quotedColumnNames := cbuf.String()
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
if err != nil {
@ -116,11 +117,14 @@ func (ct *copyFrom) run() (int, error) {
go ct.readUntilReadyForQuery()
defer ct.waitForReaderDone()
wbuf := newWriteBuf(ct.conn, copyData)
buf := ct.conn.wbuf
buf = append(buf, copyData)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
wbuf.WriteInt32(0)
wbuf.WriteInt32(0)
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
var sentCount int
@ -131,18 +135,16 @@ func (ct *copyFrom) run() (int, error) {
default:
}
if len(wbuf.buf) > 65536 {
wbuf.closeMsg()
_, err = ct.conn.conn.Write(wbuf.buf)
if len(buf) > 65536 {
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = ct.conn.conn.Write(buf)
if err != nil {
ct.conn.die(err)
return 0, err
}
// Directly manipulate wbuf to reset to reuse the same buffer
wbuf.buf = wbuf.buf[0:5]
wbuf.buf[0] = copyData
wbuf.sizeIdx = 1
buf = buf[0:5]
}
sentCount++
@ -157,9 +159,9 @@ func (ct *copyFrom) run() (int, error) {
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}
wbuf.WriteInt16(int16(len(ct.columnNames)))
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val)
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
ct.cancelCopyIn()
return 0, err
@ -173,11 +175,13 @@ func (ct *copyFrom) run() (int, error) {
return 0, ct.rowSrc.Err()
}
wbuf.WriteInt16(-1) // terminate the copy stream
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
wbuf.startMsg(copyDone)
wbuf.closeMsg()
_, err = ct.conn.conn.Write(wbuf.buf)
buf = append(buf, copyDone)
buf = pgio.AppendInt32(buf, 4)
_, err = ct.conn.conn.Write(buf)
if err != nil {
ct.conn.die(err)
return 0, err
@ -210,10 +214,15 @@ func (c *Conn) readUntilCopyInResponse() error {
}
func (ct *copyFrom) cancelCopyIn() error {
wbuf := newWriteBuf(ct.conn, copyFail)
wbuf.WriteCString("client error: abort")
wbuf.closeMsg()
_, err := ct.conn.conn.Write(wbuf.buf)
buf := ct.conn.wbuf
buf = append(buf, copyFail)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, "client error: abort"...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err := ct.conn.conn.Write(buf)
if err != nil {
ct.conn.die(err)
return err

View File

@ -3,6 +3,7 @@ package pgx
import (
"encoding/binary"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
@ -55,19 +56,23 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) {
return nil, err
}
wbuf := newWriteBuf(f.cn, 'F') // function call
wbuf.WriteInt32(int32(oid)) // function object id
wbuf.WriteInt16(1) // # of argument format codes
wbuf.WriteInt16(1) // format code: binary
wbuf.WriteInt16(int16(len(args))) // # of arguments
for _, arg := range args {
wbuf.WriteInt32(int32(len(arg))) // length of argument
wbuf.WriteBytes(arg) // argument value
}
wbuf.WriteInt16(1) // response format code (binary)
wbuf.closeMsg()
buf := f.cn.wbuf
buf = append(buf, 'F') // function call
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
if _, err := f.cn.conn.Write(wbuf.buf); err != nil {
buf = pgio.AppendInt32(buf, int32(oid)) // function object id
buf = pgio.AppendInt16(buf, 1) // # of argument format codes
buf = pgio.AppendInt16(buf, 1) // format code: binary
buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments
for _, arg := range args {
buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument
buf = append(buf, arg...) // argument value
}
buf = pgio.AppendInt16(buf, 1) // response format code (binary)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
if _, err := f.cn.conn.Write(buf); err != nil {
return nil, err
}

View File

@ -92,90 +92,3 @@ type PgError struct {
func (pe PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
func newWriteBuf(c *Conn, t byte) *WriteBuf {
buf := append(c.wbuf[0:0], t, 0, 0, 0, 0)
c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c}
return &c.writeBuf
}
// WriteBuf is used build messages to send to the PostgreSQL server. It is used
// by the Encoder interface when implementing custom encoders.
type WriteBuf struct {
buf []byte
convBuf [8]byte
sizeIdx int
conn *Conn
}
func (wb *WriteBuf) startMsg(t byte) {
wb.closeMsg()
wb.buf = append(wb.buf, t, 0, 0, 0, 0)
wb.sizeIdx = len(wb.buf) - 4
}
func (wb *WriteBuf) closeMsg() {
binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx))
}
func (wb *WriteBuf) reserveSize() int {
sizePosition := len(wb.buf)
wb.buf = append(wb.buf, 0, 0, 0, 0)
return sizePosition
}
func (wb *WriteBuf) setComputedSize(sizePosition int) {
binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(len(wb.buf)-sizePosition-4))
}
func (wb *WriteBuf) setSize(sizePosition int, size int32) {
binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(size))
}
func (wb *WriteBuf) WriteByte(b byte) {
wb.buf = append(wb.buf, b)
}
func (wb *WriteBuf) WriteCString(s string) {
wb.buf = append(wb.buf, []byte(s)...)
wb.buf = append(wb.buf, 0)
}
func (wb *WriteBuf) WriteInt16(n int16) {
wb.WriteUint16(uint16(n))
}
func (wb *WriteBuf) WriteUint16(n uint16) (int, error) {
binary.BigEndian.PutUint16(wb.convBuf[:2], n)
wb.buf = append(wb.buf, wb.convBuf[:2]...)
return 2, nil
}
func (wb *WriteBuf) WriteInt32(n int32) {
wb.WriteUint32(uint32(n))
}
func (wb *WriteBuf) WriteUint32(n uint32) (int, error) {
binary.BigEndian.PutUint32(wb.convBuf[:4], n)
wb.buf = append(wb.buf, wb.convBuf[:4]...)
return 4, nil
}
func (wb *WriteBuf) WriteInt64(n int64) {
wb.WriteUint64(uint64(n))
}
func (wb *WriteBuf) WriteUint64(n uint64) (int, error) {
binary.BigEndian.PutUint64(wb.convBuf[:8], n)
wb.buf = append(wb.buf, wb.convBuf[:8]...)
return 8, nil
}
func (wb *WriteBuf) WriteBytes(b []byte) {
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) Write(b []byte) (int, error) {
wb.buf = append(wb.buf, b...)
return len(b), nil
}

View File

@ -7,6 +7,7 @@ import (
"fmt"
"time"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
)
@ -175,17 +176,21 @@ type ReplicationConn struct {
// message to the server, as well as carries the WAL position of the
// client, which then updates the server's replication slot position.
func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
writeBuf := newWriteBuf(rc.c, copyData)
writeBuf.WriteByte(standbyStatusUpdate)
writeBuf.WriteInt64(int64(k.WalWritePosition))
writeBuf.WriteInt64(int64(k.WalFlushPosition))
writeBuf.WriteInt64(int64(k.WalApplyPosition))
writeBuf.WriteInt64(int64(k.ClientTime))
writeBuf.WriteByte(k.ReplyRequested)
buf := rc.c.wbuf
buf = append(buf, copyData)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
writeBuf.closeMsg()
buf = append(buf, standbyStatusUpdate)
buf = pgio.AppendInt64(buf, int64(k.WalWritePosition))
buf = pgio.AppendInt64(buf, int64(k.WalFlushPosition))
buf = pgio.AppendInt64(buf, int64(k.WalApplyPosition))
buf = pgio.AppendInt64(buf, int64(k.ClientTime))
buf = append(buf, k.ReplyRequested)
_, err = rc.c.conn.Write(writeBuf.buf)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = rc.c.conn.Write(buf)
if err != nil {
rc.c.die(err)
}

View File

@ -97,84 +97,82 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
}
func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.Oid, arg interface{}) ([]byte, error) {
if arg == nil {
wbuf.WriteInt32(-1)
return nil
return pgio.AppendInt32(buf, -1), nil
}
switch arg := arg.(type) {
case pgtype.BinaryEncoder:
sp := len(wbuf.buf)
wbuf.buf = pgio.AppendInt32(wbuf.buf, -1)
argBuf, err := arg.EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := arg.EncodeBinary(ci, buf)
if err != nil {
return err
return nil, err
}
if argBuf != nil {
wbuf.buf = argBuf
pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4))
buf = argBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
return nil
return buf, nil
case pgtype.TextEncoder:
sp := len(wbuf.buf)
wbuf.buf = pgio.AppendInt32(wbuf.buf, -1)
argBuf, err := arg.EncodeText(wbuf.conn.ConnInfo, wbuf.buf)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := arg.EncodeText(ci, buf)
if err != nil {
return err
return nil, err
}
if argBuf != nil {
wbuf.buf = argBuf
pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4))
buf = argBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
return nil
return buf, nil
case driver.Valuer:
v, err := arg.Value()
if err != nil {
return err
return nil, err
}
return encodePreparedStatementArgument(wbuf, oid, v)
return encodePreparedStatementArgument(ci, buf, oid, v)
case string:
wbuf.WriteInt32(int32(len(arg)))
wbuf.WriteBytes([]byte(arg))
return nil
buf = pgio.AppendInt32(buf, int32(len(arg)))
buf = append(buf, arg...)
return buf, nil
}
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
wbuf.WriteInt32(-1)
return nil
return pgio.AppendInt32(buf, -1), nil
}
arg = refVal.Elem().Interface()
return encodePreparedStatementArgument(wbuf, oid, arg)
return encodePreparedStatementArgument(ci, buf, oid, arg)
}
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok {
if dt, ok := ci.DataTypeForOid(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
return err
return nil, err
}
sp := len(wbuf.buf)
wbuf.buf = pgio.AppendInt32(wbuf.buf, -1)
argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
if err != nil {
return err
return nil, err
}
if argBuf != nil {
wbuf.buf = argBuf
pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4))
buf = argBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
return nil
return buf, nil
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return encodePreparedStatementArgument(wbuf, oid, strippedArg)
return encodePreparedStatementArgument(ci, buf, oid, strippedArg)
}
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
// chooseParameterFormatCode determines the correct format code for an