mirror of https://github.com/jackc/pgx.git
Do not allow protocol messages larger than ~1GB
The PostgreSQL server will reject messages greater than ~1 GB anyway. However, worse than that is that a message that is larger than 4 GB could wrap the 32-bit integer message size and be interpreted by the server as multiple messages. This could allow a malicious client to inject arbitrary protocol messages. https://github.com/jackc/pgx/security/advisories/GHSA-mrww-27vc-gghvpull/1927/head
parent
c1b0a01ca7
commit
adbb38f298
|
@ -1674,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
|
|||
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
|
||||
type Batch struct {
|
||||
buf []byte
|
||||
err error
|
||||
}
|
||||
|
||||
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
|
||||
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
|
||||
batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
|
||||
}
|
||||
|
||||
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
|
||||
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
|
||||
batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
|
||||
batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
|
||||
batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
|
||||
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
|
||||
// multiple queries in a single round trip than using pipeline mode.
|
||||
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
|
||||
if batch.err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
err: batch.err,
|
||||
}
|
||||
}
|
||||
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
|
@ -1718,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
pgConn.contextWatcher.Watch(ctx)
|
||||
}
|
||||
|
||||
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
multiResult.closed = true
|
||||
multiResult.err = batch.err
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
pgConn.enterPotentialWriteReadDeadlock()
|
||||
defer pgConn.exitPotentialWriteReadDeadlock()
|
||||
|
|
|
@ -3363,9 +3363,9 @@ func TestSNISupport(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
|
||||
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
|
||||
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
|
||||
srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
|
||||
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
|
||||
srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))
|
||||
|
||||
serverSNINameChan <- sniHost
|
||||
}()
|
||||
|
@ -3472,3 +3472,10 @@ func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
|||
err = pipeline.Close()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func mustEncode(buf []byte, err error) []byte {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
|
|
@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 4)
|
||||
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||
|
|
|
@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||
dst = append(dst, a.Data...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||
|
|
|
@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 12)
|
||||
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
||||
dst = append(dst, src.Salt[:]...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||
|
||||
for _, s := range src.AuthMechanisms {
|
||||
|
@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
|||
}
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Unmarshaler.
|
||||
|
|
|
@ -16,7 +16,8 @@ type Backend struct {
|
|||
// before it is actually transmitted (i.e. before Flush).
|
||||
tracer *tracer
|
||||
|
||||
wbuf []byte
|
||||
wbuf []byte
|
||||
encodeError error
|
||||
|
||||
// Frontend message flyweights
|
||||
bind Bind
|
||||
|
@ -55,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
|
|||
return &Backend{cr: cr, w: w}
|
||||
}
|
||||
|
||||
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
|
||||
// called.
|
||||
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
|
||||
// encountered will be returned from Flush.
|
||||
func (b *Backend) Send(msg BackendMessage) {
|
||||
if b.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(b.wbuf)
|
||||
b.wbuf = msg.Encode(b.wbuf)
|
||||
newBuf, err := msg.Encode(b.wbuf)
|
||||
if err != nil {
|
||||
b.encodeError = err
|
||||
return
|
||||
}
|
||||
b.wbuf = newBuf
|
||||
|
||||
if b.tracer != nil {
|
||||
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
|
||||
}
|
||||
|
@ -67,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) {
|
|||
|
||||
// Flush writes any pending messages to the frontend (i.e. the client).
|
||||
func (b *Backend) Flush() error {
|
||||
if err := b.encodeError; err != nil {
|
||||
b.encodeError = nil
|
||||
b.wbuf = b.wbuf[:0]
|
||||
return &writeError{err: err, safeToRetry: true}
|
||||
}
|
||||
|
||||
n, err := b.w.Write(b.wbuf)
|
||||
|
||||
const maxLen = 1024
|
||||
|
|
|
@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, 12)
|
||||
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
|
|||
"username": "tester",
|
||||
},
|
||||
}
|
||||
dst := []byte{}
|
||||
dst = want.Encode(dst)
|
||||
dst, err := want.Encode([]byte{})
|
||||
require.NoError(t, err)
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push(dst)
|
||||
|
|
|
@ -108,10 +108,8 @@ func (dst *Bind) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Bind) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'B')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *Bind) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'B')
|
||||
|
||||
dst = append(dst, src.DestinationPortal...)
|
||||
dst = append(dst, 0)
|
||||
|
@ -139,9 +137,7 @@ func (src *Bind) Encode(dst []byte) []byte {
|
|||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '2', 0, 0, 0, 4)
|
||||
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, '2', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Maximum allowed size.
|
||||
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1 byte too big
|
||||
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
|
||||
require.Error(t, err)
|
||||
}
|
|
@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
|
||||
dst = pgio.AppendInt32(dst, 16)
|
||||
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -4,8 +4,6 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Close struct {
|
||||
|
@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Close) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *Close) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'C')
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '3', 0, 0, 0, 4)
|
||||
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, '3', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -3,8 +3,6 @@ package pgproto3
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CommandComplete struct {
|
||||
|
@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'C')
|
||||
dst = append(dst, src.CommandTag...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -44,19 +44,15 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'W')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'W')
|
||||
dst = append(dst, src.OverallFormat)
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
|
@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) {
|
|||
err := dstResp.Decode(srcBytes[5:])
|
||||
assert.NoError(t, err, "No errors on decode")
|
||||
dstBytes := []byte{}
|
||||
dstBytes = dstResp.Encode(dstBytes)
|
||||
dstBytes, err = dstResp.Encode(dstBytes)
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
|
||||
}
|
||||
|
|
|
@ -3,8 +3,6 @@ package pgproto3
|
|||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CopyData struct {
|
||||
|
@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'd')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
func (src *CopyData) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'd')
|
||||
dst = append(dst, src.Data...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyDone) Encode(dst []byte) []byte {
|
||||
return append(dst, 'c', 0, 0, 0, 4)
|
||||
func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'c', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -3,8 +3,6 @@ package pgproto3
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CopyFail struct {
|
||||
|
@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyFail) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'f')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'f')
|
||||
dst = append(dst, src.Message...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -44,10 +44,8 @@ func (dst *CopyInResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'G')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'G')
|
||||
|
||||
dst = append(dst, src.OverallFormat)
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
|
@ -55,9 +53,7 @@ func (src *CopyInResponse) Encode(dst []byte) []byte {
|
|||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -43,10 +43,8 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'H')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'H')
|
||||
|
||||
dst = append(dst, src.OverallFormat)
|
||||
|
||||
|
@ -55,9 +53,7 @@ func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
|||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -63,10 +63,8 @@ func (dst *DataRow) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *DataRow) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'D')
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
||||
for _, v := range src.Values {
|
||||
|
@ -79,9 +77,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
|
|||
dst = append(dst, v...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -4,8 +4,6 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Describe struct {
|
||||
|
@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Describe) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *Describe) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'D')
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, 'I', 0, 0, 0, 4)
|
||||
func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'I', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -2,7 +2,6 @@ package pgproto3
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
)
|
||||
|
@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, src.marshalBinary('E')...)
|
||||
func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'E')
|
||||
dst = src.appendFields(dst)
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte(typeByte)
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
func (src *ErrorResponse) appendFields(dst []byte) []byte {
|
||||
if src.Severity != "" {
|
||||
buf.WriteByte('S')
|
||||
buf.WriteString(src.Severity)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'S')
|
||||
dst = append(dst, src.Severity...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.SeverityUnlocalized != "" {
|
||||
buf.WriteByte('V')
|
||||
buf.WriteString(src.SeverityUnlocalized)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'V')
|
||||
dst = append(dst, src.SeverityUnlocalized...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Code != "" {
|
||||
buf.WriteByte('C')
|
||||
buf.WriteString(src.Code)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'C')
|
||||
dst = append(dst, src.Code...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Message != "" {
|
||||
buf.WriteByte('M')
|
||||
buf.WriteString(src.Message)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'M')
|
||||
dst = append(dst, src.Message...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Detail != "" {
|
||||
buf.WriteByte('D')
|
||||
buf.WriteString(src.Detail)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'D')
|
||||
dst = append(dst, src.Detail...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Hint != "" {
|
||||
buf.WriteByte('H')
|
||||
buf.WriteString(src.Hint)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'H')
|
||||
dst = append(dst, src.Hint...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Position != 0 {
|
||||
buf.WriteByte('P')
|
||||
buf.WriteString(strconv.Itoa(int(src.Position)))
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'P')
|
||||
dst = append(dst, strconv.Itoa(int(src.Position))...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.InternalPosition != 0 {
|
||||
buf.WriteByte('p')
|
||||
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'p')
|
||||
dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.InternalQuery != "" {
|
||||
buf.WriteByte('q')
|
||||
buf.WriteString(src.InternalQuery)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'q')
|
||||
dst = append(dst, src.InternalQuery...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Where != "" {
|
||||
buf.WriteByte('W')
|
||||
buf.WriteString(src.Where)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'W')
|
||||
dst = append(dst, src.Where...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.SchemaName != "" {
|
||||
buf.WriteByte('s')
|
||||
buf.WriteString(src.SchemaName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 's')
|
||||
dst = append(dst, src.SchemaName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.TableName != "" {
|
||||
buf.WriteByte('t')
|
||||
buf.WriteString(src.TableName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 't')
|
||||
dst = append(dst, src.TableName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.ColumnName != "" {
|
||||
buf.WriteByte('c')
|
||||
buf.WriteString(src.ColumnName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'c')
|
||||
dst = append(dst, src.ColumnName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.DataTypeName != "" {
|
||||
buf.WriteByte('d')
|
||||
buf.WriteString(src.DataTypeName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'd')
|
||||
dst = append(dst, src.DataTypeName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.ConstraintName != "" {
|
||||
buf.WriteByte('n')
|
||||
buf.WriteString(src.ConstraintName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'n')
|
||||
dst = append(dst, src.ConstraintName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.File != "" {
|
||||
buf.WriteByte('F')
|
||||
buf.WriteString(src.File)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'F')
|
||||
dst = append(dst, src.File...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Line != 0 {
|
||||
buf.WriteByte('L')
|
||||
buf.WriteString(strconv.Itoa(int(src.Line)))
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'L')
|
||||
dst = append(dst, strconv.Itoa(int(src.Line))...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Routine != "" {
|
||||
buf.WriteByte('R')
|
||||
buf.WriteString(src.Routine)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'R')
|
||||
dst = append(dst, src.Routine...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
|
||||
for k, v := range src.UnknownFields {
|
||||
buf.WriteByte(k)
|
||||
buf.WriteString(v)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, k)
|
||||
dst = append(dst, v...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 0)
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes()
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error {
|
|||
return fmt.Errorf("error generating query response: %w", err)
|
||||
}
|
||||
|
||||
buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||
buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||
{
|
||||
Name: []byte("fortune"),
|
||||
TableOID: 0,
|
||||
|
@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error {
|
|||
TypeModifier: -1,
|
||||
Format: 0,
|
||||
},
|
||||
}}).Encode(nil)
|
||||
buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)
|
||||
buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
|
||||
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
||||
}}).Encode(nil))
|
||||
buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf))
|
||||
buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf))
|
||||
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
|
||||
_, err = p.conn.Write(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing query response: %w", err)
|
||||
|
@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error {
|
|||
|
||||
switch startupMessage.(type) {
|
||||
case *pgproto3.StartupMessage:
|
||||
buf := (&pgproto3.AuthenticationOk{}).Encode(nil)
|
||||
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
||||
buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))
|
||||
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
|
||||
_, err = p.conn.Write(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending ready for query: %w", err)
|
||||
|
@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error {
|
|||
func (p *PgFortuneBackend) Close() error {
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
func mustEncode(buf []byte, err error) []byte {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
|
|
@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Execute) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'E')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *Execute) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'E')
|
||||
dst = append(dst, src.Portal...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint32(dst, src.MaxRows)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Flush) Encode(dst []byte) []byte {
|
||||
return append(dst, 'H', 0, 0, 0, 4)
|
||||
func (src *Flush) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'H', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -18,7 +18,8 @@ type Frontend struct {
|
|||
// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
|
||||
tracer *tracer
|
||||
|
||||
wbuf []byte
|
||||
wbuf []byte
|
||||
encodeError error
|
||||
|
||||
// Backend message flyweights
|
||||
authenticationOk AuthenticationOk
|
||||
|
@ -64,16 +65,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
|
|||
return &Frontend{cr: cr, w: w}
|
||||
}
|
||||
|
||||
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
|
||||
// called.
|
||||
// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error
|
||||
// encountered will be returned from Flush.
|
||||
//
|
||||
// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
|
||||
// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
|
||||
// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
|
||||
// behind an interface.
|
||||
func (f *Frontend) Send(msg FrontendMessage) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
|
@ -81,6 +92,12 @@ func (f *Frontend) Send(msg FrontendMessage) {
|
|||
|
||||
// Flush writes any pending messages to the backend (i.e. the server).
|
||||
func (f *Frontend) Flush() error {
|
||||
if err := f.encodeError; err != nil {
|
||||
f.encodeError = nil
|
||||
f.wbuf = f.wbuf[:0]
|
||||
return &writeError{err: err, safeToRetry: true}
|
||||
}
|
||||
|
||||
if len(f.wbuf) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
@ -116,71 +133,141 @@ func (f *Frontend) Untrace() {
|
|||
f.tracer = nil
|
||||
}
|
||||
|
||||
// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendBind(msg *Bind) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendParse(msg *Parse) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendClose(msg *Close) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is
|
||||
// called. Any error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendDescribe(msg *Describe) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendExecute sends an Execute message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called.
|
||||
// Any error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendExecute(msg *Execute) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendSync(msg *Sync) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendQuery(msg *Query) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
|
|
|
@ -71,10 +71,8 @@ func (dst *FunctionCall) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *FunctionCall) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'F')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
|
||||
func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'F')
|
||||
dst = pgio.AppendUint32(dst, src.Function)
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
|
||||
for _, argFormatCode := range src.ArgFormatCodes {
|
||||
|
@ -90,6 +88,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte {
|
|||
}
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
|
|
@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'V')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'V')
|
||||
|
||||
if src.Result == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
|||
dst = append(dst, src.Result...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"encoding/binary"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFunctionCall_EncodeDecode(t *testing.T) {
|
||||
|
@ -30,7 +32,8 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
|
|||
Arguments: tt.fields.Arguments,
|
||||
ResultFormatCode: tt.fields.ResultFormatCode,
|
||||
}
|
||||
encoded := src.Encode([]byte{})
|
||||
encoded, err := src.Encode([]byte{})
|
||||
require.NoError(t, err)
|
||||
dst := &FunctionCall{}
|
||||
// Check the header
|
||||
msgTypeCode := encoded[0]
|
||||
|
@ -44,7 +47,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
|
|||
t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded))
|
||||
}
|
||||
// Check decoding works as expected
|
||||
err := dst.Decode(encoded[5:])
|
||||
err = dst.Decode(encoded[5:])
|
||||
if err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
|
|
@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *GSSEncRequest) Encode(dst []byte) []byte {
|
||||
func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) {
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendInt32(dst, gssEncReqNumber)
|
||||
return dst
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -2,8 +2,6 @@ package pgproto3
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type GSSResponse struct {
|
||||
|
@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (g *GSSResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
|
||||
func (g *GSSResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'p')
|
||||
dst = append(dst, g.Data...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NoData) Encode(dst []byte) []byte {
|
||||
return append(dst, 'n', 0, 0, 0, 4)
|
||||
func (src *NoData) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'n', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NoticeResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
|
||||
func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'N')
|
||||
dst = (*ErrorResponse)(src).appendFields(dst)
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
|
|
@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NotificationResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'A')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'A')
|
||||
dst = pgio.AppendUint32(dst, src.PID)
|
||||
dst = append(dst, src.Channel...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Payload...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -39,19 +39,15 @@ func (dst *ParameterDescription) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ParameterDescription) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 't')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 't')
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -3,8 +3,6 @@ package pgproto3
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type ParameterStatus struct {
|
||||
|
@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ParameterStatus) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'S')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'S')
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Value...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -52,10 +52,8 @@ func (dst *Parse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Parse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'P')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *Parse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'P')
|
||||
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
@ -67,9 +65,7 @@ func (src *Parse) Encode(dst []byte) []byte {
|
|||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ParseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '1', 0, 0, 0, 4)
|
||||
func (src *ParseComplete) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, '1', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -3,8 +3,6 @@ package pgproto3
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type PasswordMessage struct {
|
||||
|
@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *PasswordMessage) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
|
||||
|
||||
func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'p')
|
||||
dst = append(dst, src.Password...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -4,8 +4,14 @@ import (
|
|||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL
|
||||
// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff.
|
||||
const maxMessageBodyLen = (0x3fffffff - 1)
|
||||
|
||||
// Message is the interface implemented by an object that can decode and encode
|
||||
// a particular PostgreSQL message.
|
||||
type Message interface {
|
||||
|
@ -14,7 +20,7 @@ type Message interface {
|
|||
Decode(data []byte) error
|
||||
|
||||
// Encode appends itself to dst and returns the new buffer.
|
||||
Encode(dst []byte) []byte
|
||||
Encode(dst []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// FrontendMessage is a message sent by the frontend (i.e. the client).
|
||||
|
@ -92,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
|
|||
}
|
||||
return nil, errors.New("unknown protocol representation")
|
||||
}
|
||||
|
||||
// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to
|
||||
// dst. It returns the new buffer and the position of the message length placeholder.
|
||||
func beginMessage(dst []byte, t byte) ([]byte, int) {
|
||||
dst = append(dst, t)
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
return dst, sp
|
||||
}
|
||||
|
||||
// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to
|
||||
// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer.
|
||||
func finishMessage(dst []byte, sp int) ([]byte, error) {
|
||||
messageBodyLen := len(dst[sp:])
|
||||
if messageBodyLen > maxMessageBodyLen {
|
||||
return nil, errors.New("message body too large")
|
||||
}
|
||||
pgio.SetInt32(dst[sp:], int32(messageBodyLen))
|
||||
return dst, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
package pgproto3
|
||||
|
||||
const MaxMessageBodyLen = maxMessageBodyLen
|
|
@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *PortalSuspended) Encode(dst []byte) []byte {
|
||||
return append(dst, 's', 0, 0, 0, 4)
|
||||
func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 's', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -3,8 +3,6 @@ package pgproto3
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Query struct {
|
||||
|
@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Query) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'Q')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
|
||||
|
||||
func (src *Query) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'Q')
|
||||
dst = append(dst, src.String...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestQueryBiggerThanMaxMessageBodyLen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Maximum allowed size. 4 bytes for size and 1 byte for 0 terminated string.
|
||||
_, err := (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-5))}).Encode(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1 byte too big
|
||||
_, err = (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-4))}).Encode(nil)
|
||||
require.Error(t, err)
|
||||
}
|
|
@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ReadyForQuery) Encode(dst []byte) []byte {
|
||||
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
|
||||
func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -99,10 +99,8 @@ func (dst *RowDescription) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *RowDescription) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'T')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'T')
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
|
||||
for _, fd := range src.Fields {
|
||||
|
@ -117,9 +115,7 @@ func (src *RowDescription) Encode(dst []byte) []byte {
|
|||
dst = pgio.AppendInt16(dst, fd.Format)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *SASLInitialResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'p')
|
||||
|
||||
dst = append(dst, []byte(src.AuthMechanism)...)
|
||||
dst = append(dst, 0)
|
||||
|
@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte {
|
|||
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -3,8 +3,6 @@ package pgproto3
|
|||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type SASLResponse struct {
|
||||
|
@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *SASLResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
|
||||
func (src *SASLResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'p')
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *SSLRequest) Encode(dst []byte) []byte {
|
||||
func (src *SSLRequest) Encode(dst []byte) ([]byte, error) {
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendInt32(dst, sslRequestNumber)
|
||||
return dst
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *StartupMessage) Encode(dst []byte) []byte {
|
||||
func (src *StartupMessage) Encode(dst []byte) ([]byte, error) {
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
|
@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte {
|
|||
}
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Sync) Encode(dst []byte) []byte {
|
||||
return append(dst, 'S', 0, 0, 0, 4)
|
||||
func (src *Sync) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'S', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
|
@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error {
|
|||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Terminate) Encode(dst []byte) []byte {
|
||||
return append(dst, 'X', 0, 0, 0, 4)
|
||||
func (src *Terminate) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'X', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
|
Loading…
Reference in New Issue