From 90975ab5c274c25b11861ae873faa99ab176d24d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 10:01:07 -0500 Subject: [PATCH] Extract append message functions. In general, pgproto3 types should be used. But these functions may be easier to without incurring additional memory allocations. --- conn.go | 146 +++++++++------------------------------------------- messages.go | 109 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 122 deletions(-) diff --git a/conn.go b/conn.go index 2c4f4907..be64f104 100644 --- a/conn.go +++ b/conn.go @@ -763,41 +763,17 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared }() } - // parse - buf := c.wbuf - buf = append(buf, 'P') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, name...) - buf = append(buf, 0) - buf = append(buf, sql...) - buf = append(buf, 0) - - if opts != nil { - if len(opts.ParameterOids) > 65535 { - return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) - } - buf = pgio.AppendInt16(buf, int16(len(opts.ParameterOids))) - for _, oid := range opts.ParameterOids { - buf = pgio.AppendInt32(buf, int32(oid)) - } - } else { - buf = pgio.AppendInt16(buf, 0) + if opts == nil { + opts = &PrepareExOptions{} } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - // describe - buf = append(buf, 'D') - sp = len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, 'S') - buf = append(buf, name...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + if len(opts.ParameterOids) > 65535 { + return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) + } - // sync - buf = append(buf, 'S') - buf = pgio.AppendInt32(buf, 4) + buf := appendParse(c.wbuf, name, sql, opts.ParameterOids) + buf = appendDescribe(buf, 'S', name) + buf = appendSync(buf) n, err := c.conn.Write(buf) if err != nil { @@ -1021,13 +997,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { } if len(args) == 0 { - buf := c.wbuf - buf = append(buf, 'Q') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, sql...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + buf := appendQuery(c.wbuf, sql) _, err := c.conn.Write(buf) if err != nil { @@ -1056,44 +1026,17 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} return err } - // bind - buf := c.wbuf - buf = append(buf, 'B') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, 0) - buf = append(buf, ps.Name...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(ps.ParameterOids))) - for i, oid := range ps.ParameterOids { - buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) + resultFormatCodes := make([]int16, len(ps.FieldDescriptions)) + for i, fd := range ps.FieldDescriptions { + resultFormatCodes[i] = fd.FormatCode + } + buf, err := appendBind(c.wbuf, "", ps.Name, c.ConnInfo, ps.ParameterOids, arguments, resultFormatCodes) + if err != nil { + return err } - buf = pgio.AppendInt16(buf, int16(len(arguments))) - for i, oid := range ps.ParameterOids { - var err error - buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i]) - if err != nil { - return err - } - } - - buf = pgio.AppendInt16(buf, int16(len(ps.FieldDescriptions))) - for _, fd := range ps.FieldDescriptions { - buf = pgio.AppendInt16(buf, fd.FormatCode) - } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - // execute - buf = append(buf, 'E') - buf = pgio.AppendInt32(buf, 9) - buf = append(buf, 0) - buf = pgio.AppendInt32(buf, 0) - - // sync - buf = append(buf, 'S') - buf = pgio.AppendInt32(buf, 4) + buf = appendExecute(buf, "", 0) + buf = appendSync(buf) n, err := c.conn.Write(buf) if err != nil && fatalWriteErr(n, err) { @@ -1476,9 +1419,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, return "", err } - // sync - buf = append(buf, 'S') - buf = pgio.AppendInt32(buf, 4) + buf = appendSync(buf) n, err := c.conn.Write(buf) if err != nil && fatalWriteErr(n, err) { @@ -1539,51 +1480,12 @@ func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOpt return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids)) } - // parse - buf = append(buf, 'P') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, 0) - buf = append(buf, sql...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids))) - for _, oid := range options.ParameterOids { - buf = pgio.AppendUint32(buf, uint32(oid)) + buf = appendParse(buf, "", sql, options.ParameterOids) + buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOids, arguments, nil) + if err != nil { + return nil, err } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - // bind - buf = append(buf, 'B') - sp = len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, 0) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids))) - for i, oid := range options.ParameterOids { - buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) - } - - buf = pgio.AppendInt16(buf, int16(len(arguments))) - for i, oid := range options.ParameterOids { - var err error - buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i]) - if err != nil { - return nil, err - } - } - - // No result values for an exec - buf = pgio.AppendInt16(buf, 0) - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - // execute - buf = append(buf, 'E') - buf = pgio.AppendInt32(buf, 9) - buf = append(buf, 0) - buf = pgio.AppendInt32(buf, 0) + buf = appendExecute(buf, "", 0) return buf, nil } diff --git a/messages.go b/messages.go index f06f8b41..0bf501b4 100644 --- a/messages.go +++ b/messages.go @@ -1,6 +1,7 @@ package pgx import ( + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgtype" ) @@ -47,3 +48,111 @@ type PgError struct { func (pe PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } + +// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. +func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.Oid) []byte { + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, name...) + buf = append(buf, 0) + buf = append(buf, query...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(parameterOIDs))) + for _, oid := range parameterOIDs { + buf = pgio.AppendUint32(buf, uint32(oid)) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it. +func appendDescribe(buf []byte, objectType byte, name string) []byte { + buf = append(buf, 'D') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, objectType) + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. +func appendSync(buf []byte) []byte { + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) + + return buf +} + +// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it. +func appendBind( + buf []byte, + destinationPortal, + preparedStatement string, + connInfo *pgtype.ConnInfo, + parameterOIDs []pgtype.Oid, + arguments []interface{}, + resultFormatCodes []int16, +) ([]byte, error) { + buf = append(buf, 'B') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, destinationPortal...) + buf = append(buf, 0) + buf = append(buf, preparedStatement...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(parameterOIDs))) + for i, oid := range parameterOIDs { + buf = pgio.AppendInt16(buf, chooseParameterFormatCode(connInfo, oid, arguments[i])) + } + + buf = pgio.AppendInt16(buf, int16(len(arguments))) + for i, oid := range parameterOIDs { + var err error + buf, err = encodePreparedStatementArgument(connInfo, buf, oid, arguments[i]) + if err != nil { + return nil, err + } + } + + buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes))) + for _, fc := range resultFormatCodes { + buf = pgio.AppendInt16(buf, fc) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf, nil +} + +// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. +func appendExecute(buf []byte, portal string, maxRows uint32) []byte { + buf = append(buf, 'E') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf = append(buf, portal...) + buf = append(buf, 0) + buf = pgio.AppendUint32(buf, maxRows) + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it. +func appendQuery(buf []byte, query string) []byte { + buf = append(buf, 'Q') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, query...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +}