mirror of https://github.com/jackc/pgx.git
Remove unneeded WriteBuf
parent
6e64a0c867
commit
458dd24a9f
118
conn.go
118
conn.go
|
@ -20,6 +20,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
@ -86,8 +87,7 @@ func (cc *ConnConfig) networkAddress() (network, address string) {
|
|||
type Conn struct {
|
||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||
lastActivityTime time.Time // the last time the connection was used
|
||||
wbuf [1024]byte
|
||||
writeBuf WriteBuf
|
||||
wbuf []byte
|
||||
pid uint32 // backend pid
|
||||
secretKey uint32 // key to use to send a cancel query message to the server
|
||||
RuntimeParams map[string]string // parameters that have been reported by the server
|
||||
|
@ -279,6 +279,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
c.cancelQueryCompleted = make(chan struct{}, 1)
|
||||
c.doneChan = make(chan struct{})
|
||||
c.closedChan = make(chan error)
|
||||
c.wbuf = make([]byte, 0, 1024)
|
||||
|
||||
if tlsConfig != nil {
|
||||
if c.shouldLog(LogLevelDebug) {
|
||||
|
@ -707,32 +708,42 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
}
|
||||
|
||||
// parse
|
||||
wbuf := newWriteBuf(c, 'P')
|
||||
wbuf.WriteCString(name)
|
||||
wbuf.WriteCString(sql)
|
||||
buf := c.wbuf
|
||||
buf = append(buf, 'P')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, sql...)
|
||||
buf = append(buf, 0)
|
||||
|
||||
if opts != nil {
|
||||
if len(opts.ParameterOids) > 65535 {
|
||||
return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))
|
||||
}
|
||||
wbuf.WriteInt16(int16(len(opts.ParameterOids)))
|
||||
buf = pgio.AppendInt16(buf, int16(len(opts.ParameterOids)))
|
||||
for _, oid := range opts.ParameterOids {
|
||||
wbuf.WriteInt32(int32(oid))
|
||||
buf = pgio.AppendInt32(buf, int32(oid))
|
||||
}
|
||||
} else {
|
||||
wbuf.WriteInt16(0)
|
||||
buf = pgio.AppendInt16(buf, 0)
|
||||
}
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
// describe
|
||||
wbuf.startMsg('D')
|
||||
wbuf.WriteByte('S')
|
||||
wbuf.WriteCString(name)
|
||||
buf = append(buf, 'D')
|
||||
sp = len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, 'S')
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
// sync
|
||||
wbuf.startMsg('S')
|
||||
wbuf.closeMsg()
|
||||
buf = append(buf, 'S')
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
_, err = c.conn.Write(wbuf.buf)
|
||||
_, err = c.conn.Write(buf)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return nil, err
|
||||
|
@ -813,15 +824,20 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
|
|||
delete(c.preparedStatements, name)
|
||||
|
||||
// close
|
||||
wbuf := newWriteBuf(c, 'C')
|
||||
wbuf.WriteByte('S')
|
||||
wbuf.WriteCString(name)
|
||||
buf := c.wbuf
|
||||
buf = append(buf, 'C')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, 'S')
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
// flush
|
||||
wbuf.startMsg('H')
|
||||
wbuf.closeMsg()
|
||||
buf = append(buf, 'H')
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
_, err = c.conn.Write(wbuf.buf)
|
||||
_, err = c.conn.Write(buf)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return err
|
||||
|
@ -943,11 +959,15 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
|||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
wbuf := newWriteBuf(c, 'Q')
|
||||
wbuf.WriteCString(sql)
|
||||
wbuf.closeMsg()
|
||||
buf := c.wbuf
|
||||
buf = append(buf, 'Q')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, sql...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err := c.conn.Write(wbuf.buf)
|
||||
_, err := c.conn.Write(buf)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return err
|
||||
|
@ -975,37 +995,45 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
}
|
||||
|
||||
// bind
|
||||
wbuf := newWriteBuf(c, 'B')
|
||||
wbuf.WriteByte(0)
|
||||
wbuf.WriteCString(ps.Name)
|
||||
buf := c.wbuf
|
||||
buf = append(buf, 'B')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, ps.Name...)
|
||||
buf = append(buf, 0)
|
||||
|
||||
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
||||
buf = pgio.AppendInt16(buf, int16(len(ps.ParameterOids)))
|
||||
for i, oid := range ps.ParameterOids {
|
||||
wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
|
||||
buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(int16(len(arguments)))
|
||||
buf = pgio.AppendInt16(buf, int16(len(arguments)))
|
||||
for i, oid := range ps.ParameterOids {
|
||||
if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil {
|
||||
var err error
|
||||
buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(int16(len(ps.FieldDescriptions)))
|
||||
buf = pgio.AppendInt16(buf, int16(len(ps.FieldDescriptions)))
|
||||
for _, fd := range ps.FieldDescriptions {
|
||||
wbuf.WriteInt16(fd.FormatCode)
|
||||
buf = pgio.AppendInt16(buf, fd.FormatCode)
|
||||
}
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
// execute
|
||||
wbuf.startMsg('E')
|
||||
wbuf.WriteByte(0)
|
||||
wbuf.WriteInt32(0)
|
||||
buf = append(buf, 'E')
|
||||
buf = pgio.AppendInt32(buf, 9)
|
||||
buf = append(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
// sync
|
||||
wbuf.startMsg('S')
|
||||
wbuf.closeMsg()
|
||||
buf = append(buf, 'S')
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
_, err = c.conn.Write(wbuf.buf)
|
||||
_, err = c.conn.Write(buf)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
}
|
||||
|
@ -1180,11 +1208,15 @@ func (c *Conn) txStartupMessage(msg *startupMessage) error {
|
|||
}
|
||||
|
||||
func (c *Conn) txPasswordMessage(password string) (err error) {
|
||||
wbuf := newWriteBuf(c, 'p')
|
||||
wbuf.WriteCString(password)
|
||||
wbuf.closeMsg()
|
||||
buf := c.wbuf
|
||||
buf = append(buf, 'p')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, password...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err = c.conn.Write(wbuf.buf)
|
||||
_, err = c.conn.Write(buf)
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
57
copy_from.go
57
copy_from.go
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
|
@ -89,14 +90,14 @@ func (ct *copyFrom) waitForReaderDone() error {
|
|||
|
||||
func (ct *copyFrom) run() (int, error) {
|
||||
quotedTableName := ct.tableName.Sanitize()
|
||||
buf := &bytes.Buffer{}
|
||||
cbuf := &bytes.Buffer{}
|
||||
for i, cn := range ct.columnNames {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
cbuf.WriteString(", ")
|
||||
}
|
||||
buf.WriteString(quoteIdentifier(cn))
|
||||
cbuf.WriteString(quoteIdentifier(cn))
|
||||
}
|
||||
quotedColumnNames := buf.String()
|
||||
quotedColumnNames := cbuf.String()
|
||||
|
||||
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
||||
if err != nil {
|
||||
|
@ -116,11 +117,14 @@ func (ct *copyFrom) run() (int, error) {
|
|||
go ct.readUntilReadyForQuery()
|
||||
defer ct.waitForReaderDone()
|
||||
|
||||
wbuf := newWriteBuf(ct.conn, copyData)
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
|
||||
wbuf.WriteInt32(0)
|
||||
wbuf.WriteInt32(0)
|
||||
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
var sentCount int
|
||||
|
||||
|
@ -131,18 +135,16 @@ func (ct *copyFrom) run() (int, error) {
|
|||
default:
|
||||
}
|
||||
|
||||
if len(wbuf.buf) > 65536 {
|
||||
wbuf.closeMsg()
|
||||
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||
if len(buf) > 65536 {
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
_, err = ct.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Directly manipulate wbuf to reset to reuse the same buffer
|
||||
wbuf.buf = wbuf.buf[0:5]
|
||||
wbuf.buf[0] = copyData
|
||||
wbuf.sizeIdx = 1
|
||||
buf = buf[0:5]
|
||||
}
|
||||
|
||||
sentCount++
|
||||
|
@ -157,9 +159,9 @@ func (ct *copyFrom) run() (int, error) {
|
|||
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(int16(len(ct.columnNames)))
|
||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val)
|
||||
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
|
||||
if err != nil {
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
|
@ -173,11 +175,13 @@ func (ct *copyFrom) run() (int, error) {
|
|||
return 0, ct.rowSrc.Err()
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(-1) // terminate the copy stream
|
||||
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
wbuf.startMsg(copyDone)
|
||||
wbuf.closeMsg()
|
||||
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||
buf = append(buf, copyDone)
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
_, err = ct.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
|
@ -210,10 +214,15 @@ func (c *Conn) readUntilCopyInResponse() error {
|
|||
}
|
||||
|
||||
func (ct *copyFrom) cancelCopyIn() error {
|
||||
wbuf := newWriteBuf(ct.conn, copyFail)
|
||||
wbuf.WriteCString("client error: abort")
|
||||
wbuf.closeMsg()
|
||||
_, err := ct.conn.conn.Write(wbuf.buf)
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyFail)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, "client error: abort"...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err := ct.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return err
|
||||
|
|
29
fastpath.go
29
fastpath.go
|
@ -3,6 +3,7 @@ package pgx
|
|||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
@ -55,19 +56,23 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
wbuf := newWriteBuf(f.cn, 'F') // function call
|
||||
wbuf.WriteInt32(int32(oid)) // function object id
|
||||
wbuf.WriteInt16(1) // # of argument format codes
|
||||
wbuf.WriteInt16(1) // format code: binary
|
||||
wbuf.WriteInt16(int16(len(args))) // # of arguments
|
||||
for _, arg := range args {
|
||||
wbuf.WriteInt32(int32(len(arg))) // length of argument
|
||||
wbuf.WriteBytes(arg) // argument value
|
||||
}
|
||||
wbuf.WriteInt16(1) // response format code (binary)
|
||||
wbuf.closeMsg()
|
||||
buf := f.cn.wbuf
|
||||
buf = append(buf, 'F') // function call
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
if _, err := f.cn.conn.Write(wbuf.buf); err != nil {
|
||||
buf = pgio.AppendInt32(buf, int32(oid)) // function object id
|
||||
buf = pgio.AppendInt16(buf, 1) // # of argument format codes
|
||||
buf = pgio.AppendInt16(buf, 1) // format code: binary
|
||||
buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments
|
||||
for _, arg := range args {
|
||||
buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument
|
||||
buf = append(buf, arg...) // argument value
|
||||
}
|
||||
buf = pgio.AppendInt16(buf, 1) // response format code (binary)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
if _, err := f.cn.conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
87
messages.go
87
messages.go
|
@ -92,90 +92,3 @@ type PgError struct {
|
|||
func (pe PgError) Error() string {
|
||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
func newWriteBuf(c *Conn, t byte) *WriteBuf {
|
||||
buf := append(c.wbuf[0:0], t, 0, 0, 0, 0)
|
||||
c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c}
|
||||
return &c.writeBuf
|
||||
}
|
||||
|
||||
// WriteBuf is used build messages to send to the PostgreSQL server. It is used
|
||||
// by the Encoder interface when implementing custom encoders.
|
||||
type WriteBuf struct {
|
||||
buf []byte
|
||||
convBuf [8]byte
|
||||
sizeIdx int
|
||||
conn *Conn
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) startMsg(t byte) {
|
||||
wb.closeMsg()
|
||||
wb.buf = append(wb.buf, t, 0, 0, 0, 0)
|
||||
wb.sizeIdx = len(wb.buf) - 4
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) closeMsg() {
|
||||
binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx))
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) reserveSize() int {
|
||||
sizePosition := len(wb.buf)
|
||||
wb.buf = append(wb.buf, 0, 0, 0, 0)
|
||||
return sizePosition
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) setComputedSize(sizePosition int) {
|
||||
binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(len(wb.buf)-sizePosition-4))
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) setSize(sizePosition int, size int32) {
|
||||
binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(size))
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteByte(b byte) {
|
||||
wb.buf = append(wb.buf, b)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteCString(s string) {
|
||||
wb.buf = append(wb.buf, []byte(s)...)
|
||||
wb.buf = append(wb.buf, 0)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteInt16(n int16) {
|
||||
wb.WriteUint16(uint16(n))
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteUint16(n uint16) (int, error) {
|
||||
binary.BigEndian.PutUint16(wb.convBuf[:2], n)
|
||||
wb.buf = append(wb.buf, wb.convBuf[:2]...)
|
||||
return 2, nil
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteInt32(n int32) {
|
||||
wb.WriteUint32(uint32(n))
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteUint32(n uint32) (int, error) {
|
||||
binary.BigEndian.PutUint32(wb.convBuf[:4], n)
|
||||
wb.buf = append(wb.buf, wb.convBuf[:4]...)
|
||||
return 4, nil
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteInt64(n int64) {
|
||||
wb.WriteUint64(uint64(n))
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteUint64(n uint64) (int, error) {
|
||||
binary.BigEndian.PutUint64(wb.convBuf[:8], n)
|
||||
wb.buf = append(wb.buf, wb.convBuf[:8]...)
|
||||
return 8, nil
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteBytes(b []byte) {
|
||||
wb.buf = append(wb.buf, b...)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) Write(b []byte) (int, error) {
|
||||
wb.buf = append(wb.buf, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
|
@ -175,17 +176,21 @@ type ReplicationConn struct {
|
|||
// message to the server, as well as carries the WAL position of the
|
||||
// client, which then updates the server's replication slot position.
|
||||
func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
|
||||
writeBuf := newWriteBuf(rc.c, copyData)
|
||||
writeBuf.WriteByte(standbyStatusUpdate)
|
||||
writeBuf.WriteInt64(int64(k.WalWritePosition))
|
||||
writeBuf.WriteInt64(int64(k.WalFlushPosition))
|
||||
writeBuf.WriteInt64(int64(k.WalApplyPosition))
|
||||
writeBuf.WriteInt64(int64(k.ClientTime))
|
||||
writeBuf.WriteByte(k.ReplyRequested)
|
||||
buf := rc.c.wbuf
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
writeBuf.closeMsg()
|
||||
buf = append(buf, standbyStatusUpdate)
|
||||
buf = pgio.AppendInt64(buf, int64(k.WalWritePosition))
|
||||
buf = pgio.AppendInt64(buf, int64(k.WalFlushPosition))
|
||||
buf = pgio.AppendInt64(buf, int64(k.WalApplyPosition))
|
||||
buf = pgio.AppendInt64(buf, int64(k.ClientTime))
|
||||
buf = append(buf, k.ReplyRequested)
|
||||
|
||||
_, err = rc.c.conn.Write(writeBuf.buf)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err = rc.c.conn.Write(buf)
|
||||
if err != nil {
|
||||
rc.c.die(err)
|
||||
}
|
||||
|
|
70
values.go
70
values.go
|
@ -97,84 +97,82 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e
|
|||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
|
||||
}
|
||||
|
||||
func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||
func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.Oid, arg interface{}) ([]byte, error) {
|
||||
if arg == nil {
|
||||
wbuf.WriteInt32(-1)
|
||||
return nil
|
||||
return pgio.AppendInt32(buf, -1), nil
|
||||
}
|
||||
|
||||
switch arg := arg.(type) {
|
||||
case pgtype.BinaryEncoder:
|
||||
sp := len(wbuf.buf)
|
||||
wbuf.buf = pgio.AppendInt32(wbuf.buf, -1)
|
||||
argBuf, err := arg.EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
argBuf, err := arg.EncodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if argBuf != nil {
|
||||
wbuf.buf = argBuf
|
||||
pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4))
|
||||
buf = argBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
return nil
|
||||
return buf, nil
|
||||
case pgtype.TextEncoder:
|
||||
sp := len(wbuf.buf)
|
||||
wbuf.buf = pgio.AppendInt32(wbuf.buf, -1)
|
||||
argBuf, err := arg.EncodeText(wbuf.conn.ConnInfo, wbuf.buf)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
argBuf, err := arg.EncodeText(ci, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if argBuf != nil {
|
||||
wbuf.buf = argBuf
|
||||
pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4))
|
||||
buf = argBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
return nil
|
||||
return buf, nil
|
||||
case driver.Valuer:
|
||||
v, err := arg.Value()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return encodePreparedStatementArgument(wbuf, oid, v)
|
||||
return encodePreparedStatementArgument(ci, buf, oid, v)
|
||||
case string:
|
||||
wbuf.WriteInt32(int32(len(arg)))
|
||||
wbuf.WriteBytes([]byte(arg))
|
||||
return nil
|
||||
buf = pgio.AppendInt32(buf, int32(len(arg)))
|
||||
buf = append(buf, arg...)
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(arg)
|
||||
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
if refVal.IsNil() {
|
||||
wbuf.WriteInt32(-1)
|
||||
return nil
|
||||
return pgio.AppendInt32(buf, -1), nil
|
||||
}
|
||||
arg = refVal.Elem().Interface()
|
||||
return encodePreparedStatementArgument(wbuf, oid, arg)
|
||||
return encodePreparedStatementArgument(ci, buf, oid, arg)
|
||||
}
|
||||
|
||||
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok {
|
||||
if dt, ok := ci.DataTypeForOid(oid); ok {
|
||||
value := dt.Value
|
||||
err := value.Set(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sp := len(wbuf.buf)
|
||||
wbuf.buf = pgio.AppendInt32(wbuf.buf, -1)
|
||||
argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if argBuf != nil {
|
||||
wbuf.buf = argBuf
|
||||
pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4))
|
||||
buf = argBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
return nil
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
||||
return encodePreparedStatementArgument(wbuf, oid, strippedArg)
|
||||
return encodePreparedStatementArgument(ci, buf, oid, strippedArg)
|
||||
}
|
||||
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
||||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
||||
}
|
||||
|
||||
// chooseParameterFormatCode determines the correct format code for an
|
||||
|
|
Loading…
Reference in New Issue