Introduce pgproto3 package

pgproto3 will wrap the message encoding and decoding for the PostgreSQL
frontend/backend protocol version 3.
batch-wip
Jack Christensen 2017-04-29 10:02:38 -05:00
parent e305ece410
commit f04c58338b
34 changed files with 1676 additions and 262 deletions

1
.gitignore vendored
View File

@ -22,3 +22,4 @@ _testmain.go
*.exe
conn_config_test.go
.envrc

View File

@ -52,7 +52,7 @@ install:
- go get -u github.com/shopspring/decimal
- go get -u gopkg.in/inconshreveable/log15.v2
- go get -u github.com/jackc/fake
- go get -u github.com/jackc/pgmock/pgmsg
- go get -u github.com/jackc/pgmock/pgproto3
- go get -u github.com/lib/pq
- go get -u github.com/hashicorp/go-version
- go get -u github.com/satori/go.uuid

304
conn.go
View File

@ -20,7 +20,7 @@ import (
"sync/atomic"
"time"
"github.com/jackc/pgx/chunkreader"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
@ -88,8 +88,8 @@ type Conn struct {
lastActivityTime time.Time // the last time the connection was used
wbuf [1024]byte
writeBuf WriteBuf
pid int32 // backend pid
secretKey int32 // key to use to send a cancel query message to the server
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
RuntimeParams map[string]string // parameters that have been reported by the server
config ConnConfig // config used when establishing this connection
txStatus byte
@ -98,7 +98,6 @@ type Conn struct {
notifications []*Notification
logger Logger
logLevel int
mr msgReader
fp *fastpath
poolResetCount int
preallocatedRows []Rows
@ -116,6 +115,8 @@ type Conn struct {
closedChan chan error
ConnInfo *pgtype.ConnInfo
frontend *pgproto3.Frontend
}
// PreparedStatement is a description of a prepared statement
@ -133,7 +134,7 @@ type PrepareExOptions struct {
// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
type Notification struct {
PID int32 // backend pid that sent the notification
PID uint32 // backend pid that sent the notification
Channel string // channel from which notification was received
Payload string
}
@ -213,8 +214,6 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
c.logLevel = LogLevelDebug
}
c.logger = c.config.Logger
c.mr.log = c.log
c.mr.shouldLog = c.shouldLog
if c.config.User == "" {
user, err := user.Current()
@ -290,7 +289,10 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
}
}
c.mr.cr = chunkreader.NewChunkReader(c.conn)
c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn)
if err != nil {
return err
}
msg := newStartupMessage()
@ -317,29 +319,27 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
}
for {
var t byte
var r *msgReader
t, r, err = c.rxMsg()
msg, err := c.rxMsg()
if err != nil {
return err
}
switch t {
case backendKeyData:
c.rxBackendKeyData(r)
case authenticationX:
if err = c.rxAuthenticationX(r); err != nil {
switch msg := msg.(type) {
case *pgproto3.BackendKeyData:
c.rxBackendKeyData(msg)
case *pgproto3.Authentication:
if err = c.rxAuthenticationX(msg); err != nil {
return err
}
case readyForQuery:
c.rxReadyForQuery(r)
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Connection established")
}
// Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet
if _, ok := msg.options["replication"]; ok {
if _, ok := config.RuntimeParams["replication"]; ok {
return nil
}
@ -352,7 +352,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
return nil
default:
if err = c.processContextFreeMsg(t, r); err != nil {
if err = c.processContextFreeMsg(msg); err != nil {
return err
}
}
@ -393,7 +393,7 @@ where (
}
// PID returns the backend PID for this connection.
func (c *Conn) PID() int32 {
func (c *Conn) PID() uint32 {
return c.pid
}
@ -744,22 +744,20 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
var softErr error
for {
var t byte
var r *msgReader
t, r, err := c.rxMsg()
msg, err := c.rxMsg()
if err != nil {
return nil, err
}
switch t {
case parameterDescription:
ps.ParameterOids = c.rxParameterDescription(r)
switch msg := msg.(type) {
case *pgproto3.ParameterDescription:
ps.ParameterOids = c.rxParameterDescription(msg)
if len(ps.ParameterOids) > 65535 && softErr == nil {
softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids))
}
case rowDescription:
ps.FieldDescriptions = c.rxRowDescription(r)
case *pgproto3.RowDescription:
ps.FieldDescriptions = c.rxRowDescription(msg)
for i := range ps.FieldDescriptions {
if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok {
ps.FieldDescriptions[i].DataTypeName = dt.Name
@ -772,8 +770,8 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType)
}
}
case readyForQuery:
c.rxReadyForQuery(r)
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
if softErr == nil {
c.preparedStatements[name] = ps
@ -781,7 +779,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
return ps, softErr
default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
softErr = e
}
}
@ -830,18 +828,16 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
}
for {
var t byte
var r *msgReader
t, r, err := c.rxMsg()
msg, err := c.rxMsg()
if err != nil {
return err
}
switch t {
case closeComplete:
switch msg.(type) {
case *pgproto3.CloseComplete:
return nil
default:
err = c.processContextFreeMsg(t, r)
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
@ -908,12 +904,12 @@ func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notificat
}
for {
t, r, err := c.rxMsg()
msg, err := c.rxMsg()
if err != nil {
return nil, err
}
err = c.processContextFreeMsg(t, r)
err = c.processContextFreeMsg(msg)
if err != nil {
return nil, err
}
@ -1030,62 +1026,48 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
// meaningful in a given context. These messages can occur due to a context
// deadline interrupting message processing. For example, an interrupted query
// may have left DataRow messages on the wire.
func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
switch t {
case bindComplete:
case commandComplete:
case dataRow:
case emptyQueryResponse:
case errorResponse:
return c.rxErrorResponse(r)
case noData:
case noticeResponse:
case notificationResponse:
c.rxNotificationResponse(r)
case parameterDescription:
case parseComplete:
case readyForQuery:
c.rxReadyForQuery(r)
case rowDescription:
case 'S':
c.rxParameterStatus(r)
default:
return fmt.Errorf("Received unknown message type: %c", t)
func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
return c.rxErrorResponse(msg)
case *pgproto3.NotificationResponse:
c.rxNotificationResponse(msg)
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
case *pgproto3.ParameterStatus:
c.rxParameterStatus(msg)
}
return nil
}
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) {
if atomic.LoadInt32(&c.status) < connStatusIdle {
return 0, nil, ErrDeadConn
return nil, ErrDeadConn
}
t, err = c.mr.rxMsg()
msg, err := c.frontend.Receive()
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
c.die(err)
}
return nil, err
}
c.lastActivityTime = time.Now()
if c.shouldLog(LogLevelTrace) {
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody))
}
// fmt.Printf("rxMsg: %#v\n", msg)
return t, &c.mr, err
return msg, nil
}
func (c *Conn) rxAuthenticationX(r *msgReader) (err error) {
switch r.readInt32() {
case 0: // AuthenticationOk
case 3: // AuthenticationCleartextPassword
func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
switch msg.Type {
case pgproto3.AuthTypeOk:
case pgproto3.AuthTypeCleartextPassword:
err = c.txPasswordMessage(c.config.Password)
case 5: // AuthenticationMD5Password
salt := r.readString(4)
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt)
case pgproto3.AuthTypeMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:]))
err = c.txPasswordMessage(digestedPassword)
default:
err = errors.New("Received unknown authentication message")
@ -1100,115 +1082,75 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil))
}
func (c *Conn) rxParameterStatus(r *msgReader) {
key := r.readCString()
value := r.readCString()
c.RuntimeParams[key] = value
func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) {
c.RuntimeParams[msg.Name] = msg.Value
}
func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) {
for {
switch r.readByte() {
case 'S':
err.Severity = r.readCString()
case 'C':
err.Code = r.readCString()
case 'M':
err.Message = r.readCString()
case 'D':
err.Detail = r.readCString()
case 'H':
err.Hint = r.readCString()
case 'P':
s := r.readCString()
n, _ := strconv.ParseInt(s, 10, 32)
err.Position = int32(n)
case 'p':
s := r.readCString()
n, _ := strconv.ParseInt(s, 10, 32)
err.InternalPosition = int32(n)
case 'q':
err.InternalQuery = r.readCString()
case 'W':
err.Where = r.readCString()
case 's':
err.SchemaName = r.readCString()
case 't':
err.TableName = r.readCString()
case 'c':
err.ColumnName = r.readCString()
case 'd':
err.DataTypeName = r.readCString()
case 'n':
err.ConstraintName = r.readCString()
case 'F':
err.File = r.readCString()
case 'L':
s := r.readCString()
n, _ := strconv.ParseInt(s, 10, 32)
err.Line = int32(n)
case 'R':
err.Routine = r.readCString()
case 0: // End of error message
if err.Severity == "FATAL" {
c.die(err)
}
return
default: // Ignore other error fields
r.readCString()
}
func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError {
err := PgError{
Severity: msg.Severity,
Code: msg.Code,
Message: msg.Message,
Detail: msg.Detail,
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: msg.InternalQuery,
Where: msg.Where,
SchemaName: msg.SchemaName,
TableName: msg.TableName,
ColumnName: msg.ColumnName,
DataTypeName: msg.DataTypeName,
ConstraintName: msg.ConstraintName,
File: msg.File,
Line: msg.Line,
Routine: msg.Routine,
}
if err.Severity == "FATAL" {
c.die(err)
}
return err
}
func (c *Conn) rxBackendKeyData(r *msgReader) {
c.pid = r.readInt32()
c.secretKey = r.readInt32()
func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) {
c.pid = msg.ProcessID
c.secretKey = msg.SecretKey
}
func (c *Conn) rxReadyForQuery(r *msgReader) {
func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) {
c.readyForQuery = true
c.txStatus = r.readByte()
c.txStatus = msg.TxStatus
}
func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) {
fieldCount := r.readInt16()
fields = make([]FieldDescription, fieldCount)
for i := int16(0); i < fieldCount; i++ {
f := &fields[i]
f.Name = r.readCString()
f.Table = pgtype.Oid(r.readUint32())
f.AttributeNumber = r.readInt16()
f.DataType = pgtype.Oid(r.readUint32())
f.DataTypeSize = r.readInt16()
f.Modifier = r.readInt32()
f.FormatCode = r.readInt16()
func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription {
fields := make([]FieldDescription, len(msg.Fields))
for i := 0; i < len(fields); i++ {
fields[i].Name = msg.Fields[i].Name
fields[i].Table = pgtype.Oid(msg.Fields[i].TableOID)
fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber
fields[i].DataType = pgtype.Oid(msg.Fields[i].DataTypeOID)
fields[i].DataTypeSize = msg.Fields[i].DataTypeSize
fields[i].Modifier = msg.Fields[i].TypeModifier
fields[i].FormatCode = msg.Fields[i].Format
}
return
return fields
}
func (c *Conn) rxParameterDescription(r *msgReader) (parameters []pgtype.Oid) {
// Internally, PostgreSQL supports greater than 64k parameters to a prepared
// statement. But the parameter description uses a 16-bit integer for the
// count of parameters. If there are more than 64K parameters, this count is
// wrong. So read the count, ignore it, and compute the proper value from
// the size of the message.
r.readInt16()
parameterCount := len(r.msgBody[r.rp:]) / 4
parameters = make([]pgtype.Oid, 0, parameterCount)
for i := 0; i < parameterCount; i++ {
parameters = append(parameters, pgtype.Oid(r.readUint32()))
func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.Oid {
parameters := make([]pgtype.Oid, len(msg.ParameterOIDs))
for i := 0; i < len(parameters); i++ {
parameters[i] = pgtype.Oid(msg.ParameterOIDs[i])
}
return
return parameters
}
func (c *Conn) rxNotificationResponse(r *msgReader) {
func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) {
n := new(Notification)
n.PID = r.readInt32()
n.Channel = r.readCString()
n.Payload = r.readCString()
n.PID = msg.PID
n.Channel = msg.Channel
n.Payload = msg.Payload
c.notifications = append(c.notifications, n)
}
@ -1453,21 +1395,19 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions,
var softErr error
for {
var t byte
var r *msgReader
t, r, err = c.rxMsg()
msg, err := c.rxMsg()
if err != nil {
return commandTag, err
}
switch t {
case readyForQuery:
c.rxReadyForQuery(r)
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return commandTag, softErr
case commandComplete:
commandTag = CommandTag(r.readCString())
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
softErr = e
}
}
@ -1545,19 +1485,19 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
func (c *Conn) ensureConnectionReadyForQuery() error {
for !c.readyForQuery {
t, r, err := c.rxMsg()
msg, err := c.rxMsg()
if err != nil {
return err
}
switch t {
case errorResponse:
pgErr := c.rxErrorResponse(r)
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr := c.rxErrorResponse(msg)
if pgErr.Severity == "FATAL" {
return pgErr
}
default:
err = c.processContextFreeMsg(t, r)
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}

View File

@ -686,7 +686,7 @@ func TestConnPoolBeginRetry(t *testing.T) {
}
defer tx.Rollback()
var txPID int32
var txPID uint32
err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID)
if err != nil {
t.Fatalf("tx.QueryRow Scan failed: %v", err)

View File

@ -3,6 +3,8 @@ package pgx
import (
"bytes"
"fmt"
"github.com/jackc/pgx/pgproto3"
)
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
@ -54,25 +56,25 @@ type copyFrom struct {
func (ct *copyFrom) readUntilReadyForQuery() {
for {
t, r, err := ct.conn.rxMsg()
msg, err := ct.conn.rxMsg()
if err != nil {
ct.readerErrChan <- err
close(ct.readerErrChan)
return
}
switch t {
case readyForQuery:
ct.conn.rxReadyForQuery(r)
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
ct.conn.rxReadyForQuery(msg)
close(ct.readerErrChan)
return
case commandComplete:
case errorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
case *pgproto3.CommandComplete:
case *pgproto3.ErrorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
default:
err = ct.conn.processContextFreeMsg(t, r)
err = ct.conn.processContextFreeMsg(msg)
if err != nil {
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
}
}
}
@ -190,18 +192,16 @@ func (ct *copyFrom) run() (int, error) {
func (c *Conn) readUntilCopyInResponse() error {
for {
var t byte
var r *msgReader
t, r, err := c.rxMsg()
msg, err := c.rxMsg()
if err != nil {
return err
}
switch t {
case copyInResponse:
switch msg := msg.(type) {
case *pgproto3.CopyInResponse:
return nil
default:
err = c.processContextFreeMsg(t, r)
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}

View File

@ -3,6 +3,7 @@ package pgx
import (
"encoding/binary"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
@ -71,23 +72,20 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) {
}
for {
var t byte
var r *msgReader
t, r, err = f.cn.rxMsg()
msg, err := f.cn.rxMsg()
if err != nil {
return nil, err
}
switch t {
case 'V': // FunctionCallResponse
data := r.readBytes(r.readInt32())
res = make([]byte, len(data))
copy(res, data)
case 'Z': // Ready for query
f.cn.rxReadyForQuery(r)
switch msg := msg.(type) {
case *pgproto3.FunctionCallResponse:
res = make([]byte, len(msg.Result))
copy(res, msg.Result)
case *pgproto3.ReadyForQuery:
f.cn.rxReadyForQuery(msg)
// done
return
return res, err
default:
if err := f.cn.processContextFreeMsg(t, r); err != nil {
if err := f.cn.processContextFreeMsg(msg); err != nil {
return nil, err
}
}

View File

@ -58,11 +58,11 @@ func (s *startupMessage) Bytes() (buf []byte) {
type FieldDescription struct {
Name string
Table pgtype.Oid
AttributeNumber int16
AttributeNumber uint16
DataType pgtype.Oid
DataTypeSize int16
DataTypeName string
Modifier int32
Modifier uint32
FormatCode int16
}

View File

@ -0,0 +1,54 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"fmt"
)
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
)
type Authentication struct {
Type uint32
// MD5Password fields
Salt [4]byte
}
func (*Authentication) Backend() {}
func (dst *Authentication) UnmarshalBinary(src []byte) error {
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
switch dst.Type {
case AuthTypeOk:
case AuthTypeCleartextPassword:
case AuthTypeMD5Password:
copy(dst.Salt[:], src[4:8])
default:
return fmt.Errorf("unknown authentication type: %d", dst.Type)
}
return nil
}
func (src *Authentication) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('R')
buf.Write(bigEndian.Uint32(0))
buf.Write(bigEndian.Uint32(src.Type))
switch src.Type {
case AuthTypeMD5Password:
buf.Write(src.Salt[:])
}
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}

View File

@ -0,0 +1,47 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type BackendKeyData struct {
ProcessID uint32
SecretKey uint32
}
func (*BackendKeyData) Backend() {}
func (dst *BackendKeyData) UnmarshalBinary(src []byte) error {
if len(src) != 8 {
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
}
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
return nil
}
func (src *BackendKeyData) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('K')
buf.Write(bigEndian.Uint32(12))
buf.Write(bigEndian.Uint32(src.ProcessID))
buf.Write(bigEndian.Uint32(src.SecretKey))
return buf.Bytes(), nil
}
func (src *BackendKeyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "BackendKeyData",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

37
pgproto3/big_endian.go Normal file
View File

@ -0,0 +1,37 @@
package pgproto3
import (
"encoding/binary"
)
type BigEndianBuf [8]byte
func (b BigEndianBuf) Int16(n int16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, uint16(n))
return buf
}
func (b BigEndianBuf) Uint16(n uint16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, n)
return buf
}
func (b BigEndianBuf) Int32(n int32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, uint32(n))
return buf
}
func (b BigEndianBuf) Uint32(n uint32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, n)
return buf
}
func (b BigEndianBuf) Int64(n int64) []byte {
buf := b[0:8]
binary.BigEndian.PutUint64(buf, uint64(n))
return buf
}

29
pgproto3/bind_complete.go Normal file
View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type BindComplete struct{}
func (*BindComplete) Backend() {}
func (dst *BindComplete) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *BindComplete) MarshalBinary() ([]byte, error) {
return []byte{'2', 0, 0, 0, 4}, nil
}
func (src *BindComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "BindComplete",
})
}

View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type CloseComplete struct{}
func (*CloseComplete) Backend() {}
func (dst *CloseComplete) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *CloseComplete) MarshalBinary() ([]byte, error) {
return []byte{'3', 0, 0, 0, 4}, nil
}
func (src *CloseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "CloseComplete",
})
}

View File

@ -0,0 +1,47 @@
package pgproto3
import (
"bytes"
"encoding/json"
)
type CommandComplete struct {
CommandTag string
}
func (*CommandComplete) Backend() {}
func (dst *CommandComplete) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.CommandTag = string(b[:len(b)-1])
return nil
}
func (src *CommandComplete) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('C')
buf.Write(bigEndian.Uint32(uint32(4 + len(src.CommandTag) + 1)))
buf.WriteString(src.CommandTag)
buf.WriteByte(0)
return buf.Bytes(), nil
}
func (src *CommandComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
CommandTag string
}{
Type: "CommandComplete",
CommandTag: src.CommandTag,
})
}

View File

@ -0,0 +1,64 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type CopyBothResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyBothResponse) Backend() {}
func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
func (src *CopyBothResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('W')
buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes))))
buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes))))
for _, fc := range src.ColumnFormatCodes {
buf.Write(bigEndian.Uint16(fc))
}
return buf.Bytes(), nil
}
func (src *CopyBothResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyBothResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}

41
pgproto3/copy_data.go Normal file
View File

@ -0,0 +1,41 @@
package pgproto3
import (
"bytes"
"encoding/hex"
"encoding/json"
)
type CopyData struct {
Data []byte
}
func (*CopyData) Backend() {}
func (*CopyData) Frontend() {}
func (dst *CopyData) UnmarshalBinary(src []byte) error {
dst.Data = make([]byte, len(src))
copy(dst.Data, src)
return nil
}
func (src *CopyData) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('d')
buf.Write(bigEndian.Uint32(uint32(4 + len(src.Data))))
buf.Write(src.Data)
return buf.Bytes(), nil
}
func (src *CopyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "CopyData",
Data: hex.EncodeToString(src.Data),
})
}

View File

@ -0,0 +1,64 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type CopyInResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyInResponse) Backend() {}
func (dst *CopyInResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
func (src *CopyInResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('G')
buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes))))
buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes))))
for _, fc := range src.ColumnFormatCodes {
buf.Write(bigEndian.Uint16(fc))
}
return buf.Bytes(), nil
}
func (src *CopyInResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyInResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}

View File

@ -0,0 +1,64 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type CopyOutResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyOutResponse) Backend() {}
func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
func (src *CopyOutResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('H')
buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes))))
buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes))))
for _, fc := range src.ColumnFormatCodes {
buf.Write(bigEndian.Uint16(fc))
}
return buf.Bytes(), nil
}
func (src *CopyOutResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyOutResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}

103
pgproto3/data_row.go Normal file
View File

@ -0,0 +1,103 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
)
type DataRow struct {
Values [][]byte
}
func (*DataRow) Backend() {}
func (dst *DataRow) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
dst.Values = make([][]byte, fieldCount)
for i := 0; i < fieldCount; i++ {
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4))))
// null
if msgSize == -1 {
continue
}
value := make([]byte, msgSize)
_, err := buf.Read(value)
if err != nil {
return err
}
dst.Values[i] = value
}
return nil
}
func (src *DataRow) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('D')
buf.Write(bigEndian.Uint32(0))
buf.Write(bigEndian.Uint16(uint16(len(src.Values))))
for _, v := range src.Values {
if v == nil {
buf.Write(bigEndian.Int32(-1))
continue
}
buf.Write(bigEndian.Int32(int32(len(v))))
buf.Write(v)
}
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}
func (src *DataRow) MarshalJSON() ([]byte, error) {
formattedValues := make([]map[string]string, len(src.Values))
for i, v := range src.Values {
if v == nil {
continue
}
var hasNonPrintable bool
for _, b := range v {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
} else {
formattedValues[i] = map[string]string{"text": string(v)}
}
}
return json.Marshal(struct {
Type string
Values []map[string]string
}{
Type: "DataRow",
Values: formattedValues,
})
}

View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type EmptyQueryResponse struct{}
func (*EmptyQueryResponse) Backend() {}
func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *EmptyQueryResponse) MarshalBinary() ([]byte, error) {
return []byte{'I', 0, 0, 0, 4}, nil
}
func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "EmptyQueryResponse",
})
}

197
pgproto3/error_response.go Normal file
View File

@ -0,0 +1,197 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"strconv"
)
type ErrorResponse struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
func (*ErrorResponse) Backend() {}
func (dst *ErrorResponse) UnmarshalBinary(src []byte) error {
*dst = ErrorResponse{}
buf := bytes.NewBuffer(src)
for {
k, err := buf.ReadByte()
if err != nil {
return err
}
if k == 0 {
break
}
vb, err := buf.ReadBytes(0)
if err != nil {
return err
}
v := string(vb[:len(vb)-1])
switch k {
case 'S':
dst.Severity = v
case 'C':
dst.Code = v
case 'M':
dst.Message = v
case 'D':
dst.Detail = v
case 'H':
dst.Hint = v
case 'P':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Position = int32(n)
case 'p':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.InternalPosition = int32(n)
case 'q':
dst.InternalQuery = v
case 'W':
dst.Where = v
case 's':
dst.SchemaName = v
case 't':
dst.TableName = v
case 'c':
dst.ColumnName = v
case 'd':
dst.DataTypeName = v
case 'n':
dst.ConstraintName = v
case 'F':
dst.File = v
case 'L':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Line = int32(n)
case 'R':
dst.Routine = v
default:
if dst.UnknownFields == nil {
dst.UnknownFields = make(map[byte]string)
}
dst.UnknownFields[k] = v
}
}
return nil
}
func (src *ErrorResponse) MarshalBinary() ([]byte, error) {
return src.marshalBinary('E')
}
func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
if src.Severity != "" {
buf.WriteString(src.Severity)
buf.WriteByte(0)
}
if src.Code != "" {
buf.WriteString(src.Code)
buf.WriteByte(0)
}
if src.Message != "" {
buf.WriteString(src.Message)
buf.WriteByte(0)
}
if src.Detail != "" {
buf.WriteString(src.Detail)
buf.WriteByte(0)
}
if src.Hint != "" {
buf.WriteString(src.Hint)
buf.WriteByte(0)
}
if src.Position != 0 {
buf.WriteString(strconv.Itoa(int(src.Position)))
buf.WriteByte(0)
}
if src.InternalPosition != 0 {
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
buf.WriteByte(0)
}
if src.InternalQuery != "" {
buf.WriteString(src.InternalQuery)
buf.WriteByte(0)
}
if src.Where != "" {
buf.WriteString(src.Where)
buf.WriteByte(0)
}
if src.SchemaName != "" {
buf.WriteString(src.SchemaName)
buf.WriteByte(0)
}
if src.TableName != "" {
buf.WriteString(src.TableName)
buf.WriteByte(0)
}
if src.ColumnName != "" {
buf.WriteString(src.ColumnName)
buf.WriteByte(0)
}
if src.DataTypeName != "" {
buf.WriteString(src.DataTypeName)
buf.WriteByte(0)
}
if src.ConstraintName != "" {
buf.WriteString(src.ConstraintName)
buf.WriteByte(0)
}
if src.File != "" {
buf.WriteString(src.File)
buf.WriteByte(0)
}
if src.Line != 0 {
buf.WriteString(strconv.Itoa(int(src.Line)))
buf.WriteByte(0)
}
if src.Routine != "" {
buf.WriteString(src.Routine)
buf.WriteByte(0)
}
for k, v := range src.UnknownFields {
buf.WriteByte(k)
buf.WriteByte(0)
buf.WriteString(v)
buf.WriteByte(0)
}
buf.WriteByte(0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}

70
pgproto3/frontend.go Normal file
View File

@ -0,0 +1,70 @@
package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/jackc/pgx/chunkreader"
)
type Frontend struct {
cr *chunkreader.ChunkReader
w io.Writer
}
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
cr := chunkreader.NewChunkReader(r)
return &Frontend{cr: cr, w: w}, nil
}
func (b *Frontend) Send(msg FrontendMessage) error {
return errors.New("not implemented")
}
func (b *Frontend) Receive() (BackendMessage, error) {
backendMessages := map[byte]BackendMessage{
'1': &ParseComplete{},
'2': &BindComplete{},
'3': &CloseComplete{},
'A': &NotificationResponse{},
'C': &CommandComplete{},
'd': &CopyData{},
'D': &DataRow{},
'E': &ErrorResponse{},
'G': &CopyInResponse{},
'H': &CopyOutResponse{},
'I': &EmptyQueryResponse{},
'K': &BackendKeyData{},
'n': &NoData{},
'N': &NoticeResponse{},
'R': &Authentication{},
'S': &ParameterStatus{},
't': &ParameterDescription{},
'T': &RowDescription{},
'V': &FunctionCallResponse{},
'W': &CopyBothResponse{},
'Z': &ReadyForQuery{},
}
header, err := b.cr.Next(5)
if err != nil {
return nil, err
}
msgType := header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
msgBody, err := b.cr.Next(bodyLen)
if err != nil {
return nil, err
}
if msg, ok := backendMessages[msgType]; ok {
err = msg.UnmarshalBinary(msgBody)
return msg, err
}
return nil, fmt.Errorf("unknown message type: %c", msgType)
}

View File

@ -0,0 +1,73 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
)
type FunctionCallResponse struct {
Result []byte
}
func (*FunctionCallResponse) Backend() {}
func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
resultSize := int(binary.BigEndian.Uint32(buf.Next(4)))
if buf.Len() != resultSize {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
dst.Result = make([]byte, resultSize)
copy(dst.Result, buf.Bytes())
return nil
}
func (src *FunctionCallResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('V')
buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Result))))
if src.Result == nil {
buf.Write(bigEndian.Int32(-1))
} else {
buf.Write(bigEndian.Int32(int32(len(src.Result))))
buf.Write(src.Result)
}
return buf.Bytes(), nil
}
func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) {
var formattedValue map[string]string
var hasNonPrintable bool
for _, b := range src.Result {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
} else {
formattedValue = map[string]string{"text": string(src.Result)}
}
return json.Marshal(struct {
Type string
Result map[string]string
}{
Type: "FunctionCallResponse",
Result: formattedValue,
})
}

29
pgproto3/no_data.go Normal file
View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type NoData struct{}
func (*NoData) Backend() {}
func (dst *NoData) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *NoData) MarshalBinary() ([]byte, error) {
return []byte{'n', 0, 0, 0, 4}, nil
}
func (src *NoData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "NoData",
})
}

View File

@ -0,0 +1,13 @@
package pgproto3
type NoticeResponse ErrorResponse
func (*NoticeResponse) Backend() {}
func (dst *NoticeResponse) UnmarshalBinary(src []byte) error {
return (*ErrorResponse)(dst).UnmarshalBinary(src)
}
func (src *NoticeResponse) MarshalBinary() ([]byte, error) {
return (*ErrorResponse)(src).marshalBinary('N')
}

View File

@ -0,0 +1,65 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type NotificationResponse struct {
PID uint32
Channel string
Payload string
}
func (*NotificationResponse) Backend() {}
func (dst *NotificationResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
pid := binary.BigEndian.Uint32(buf.Next(4))
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
channel := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
payload := string(b[:len(b)-1])
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
return nil
}
func (src *NotificationResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('A')
buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Channel) + len(src.Payload))))
buf.WriteString(src.Channel)
buf.WriteByte(0)
buf.WriteString(src.Payload)
buf.WriteByte(0)
return buf.Bytes(), nil
}
func (src *NotificationResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
PID uint32
Channel string
Payload string
}{
Type: "NotificationResponse",
PID: src.PID,
Channel: src.Channel,
Payload: src.Payload,
})
}

View File

@ -0,0 +1,60 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type ParameterDescription struct {
ParameterOIDs []uint32
}
func (*ParameterDescription) Backend() {}
func (dst *ParameterDescription) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
}
// Reported parameter count will be incorrect when number of args is greater than uint16
buf.Next(2)
// Instead infer parameter count by remaining size of message
parameterCount := buf.Len() / 4
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
for i := 0; i < parameterCount; i++ {
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
}
return nil
}
func (src *ParameterDescription) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('t')
buf.Write(bigEndian.Uint32(uint32(4 + 2 + 4*len(src.ParameterOIDs))))
buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs))))
for _, oid := range src.ParameterOIDs {
buf.Write(bigEndian.Uint32(oid))
}
return buf.Bytes(), nil
}
func (src *ParameterDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ParameterOIDs []uint32
}{
Type: "ParameterDescription",
ParameterOIDs: src.ParameterOIDs,
})
}

View File

@ -0,0 +1,62 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type ParameterStatus struct {
Name string
Value string
}
func (*ParameterStatus) Backend() {}
func (dst *ParameterStatus) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
name := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
value := string(b[:len(b)-1])
*dst = ParameterStatus{Name: name, Value: value}
return nil
}
func (src *ParameterStatus) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('S')
buf.Write(bigEndian.Uint32(0))
buf.WriteString(src.Name)
buf.WriteByte(0)
buf.WriteString(src.Value)
buf.WriteByte(0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}
func (ps *ParameterStatus) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Value string
}{
Type: "ParameterStatus",
Name: ps.Name,
Value: ps.Value,
})
}

View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type ParseComplete struct{}
func (*ParseComplete) Backend() {}
func (dst *ParseComplete) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *ParseComplete) MarshalBinary() ([]byte, error) {
return []byte{'1', 0, 0, 0, 4}, nil
}
func (src *ParseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "ParseComplete",
})
}

88
pgproto3/pgproto3.go Normal file
View File

@ -0,0 +1,88 @@
package pgproto3
import "fmt"
type Message interface {
UnmarshalBinary(data []byte) error
MarshalBinary() (data []byte, err error)
}
type FrontendMessage interface {
Message
Frontend() // no-op method to distinguish frontend from backend methods
}
type BackendMessage interface {
Message
Backend() // no-op method to distinguish frontend from backend methods
}
// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) {
// switch typeByte {
// case '1':
// return ParseParseComplete(body)
// case '2':
// return ParseBindComplete(body)
// case 'C':
// return ParseCommandComplete(body)
// case 'D':
// return ParseDataRow(body)
// case 'E':
// return ParseErrorResponse(body)
// case 'K':
// return ParseBackendKeyData(body)
// case 'R':
// return ParseAuthentication(body)
// case 'S':
// return ParseParameterStatus(body)
// case 'T':
// return ParseRowDescription(body)
// case 't':
// return ParseParameterDescription(body)
// case 'Z':
// return ParseReadyForQuery(body)
// default:
// return ParseUnknownMessage(typeByte, body)
// }
// }
// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) {
// switch typeByte {
// case 'B':
// return ParseBind(body)
// case 'D':
// return ParseDescribe(body)
// case 'E':
// return ParseExecute(body)
// case 'P':
// return ParseParse(body)
// case 'p':
// return ParsePasswordMessage(body)
// case 'Q':
// return ParseQuery(body)
// case 'S':
// return ParseSync(body)
// case 'X':
// return ParseTerminate(body)
// default:
// return ParseUnknownMessage(typeByte, body)
// }
// }
type invalidMessageLenErr struct {
messageType string
expectedLen int
actualLen int
}
func (e *invalidMessageLenErr) Error() string {
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
}
type invalidMessageFormatErr struct {
messageType string
}
func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType)
}

43
pgproto3/query.go Normal file
View File

@ -0,0 +1,43 @@
package pgproto3
import (
"bytes"
"encoding/json"
)
type Query struct {
String string
}
func (*Query) Frontend() {}
func (dst *Query) UnmarshalBinary(src []byte) error {
i := bytes.IndexByte(src, 0)
if i != len(src)-1 {
return &invalidMessageFormatErr{messageType: "Query"}
}
dst.String = string(src[:i])
return nil
}
func (src *Query) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('Q')
buf.Write(bigEndian.Uint32(uint32(4 + len(src.String) + 1)))
buf.WriteString(src.String)
buf.WriteByte(0)
return buf.Bytes(), nil
}
func (src *Query) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
String string
}{
Type: "Query",
String: src.String,
})
}

View File

@ -0,0 +1,35 @@
package pgproto3
import (
"encoding/json"
)
type ReadyForQuery struct {
TxStatus byte
}
func (*ReadyForQuery) Backend() {}
func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error {
if len(src) != 1 {
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
}
dst.TxStatus = src[0]
return nil
}
func (src *ReadyForQuery) MarshalBinary() ([]byte, error) {
return []byte{'Z', 0, 0, 0, 5, src.TxStatus}, nil
}
func (src *ReadyForQuery) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
TxStatus string
}{
Type: "ReadyForQuery",
TxStatus: string(src.TxStatus),
})
}

101
pgproto3/row_description.go Normal file
View File

@ -0,0 +1,101 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
const (
TextFormat = 0
BinaryFormat = 1
)
type FieldDescription struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier uint32
Format int16
}
type RowDescription struct {
Fields []FieldDescription
}
func (*RowDescription) Backend() {}
func (dst *RowDescription) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
*dst = RowDescription{Fields: make([]FieldDescription, fieldCount)}
for i := 0; i < fieldCount; i++ {
var fd FieldDescription
bName, err := buf.ReadBytes(0)
if err != nil {
return err
}
fd.Name = string(bName[:len(bName)-1])
// Since buf.Next() doesn't return an error if we hit the end of the buffer
// check Len ahead of time
if buf.Len() < 18 {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fd.TableOID = binary.BigEndian.Uint32(buf.Next(4))
fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2))
fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4))
fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2)))
fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4))
fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2)))
dst.Fields[i] = fd
}
return nil
}
func (src *RowDescription) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('T')
buf.Write(bigEndian.Uint32(0))
buf.Write(bigEndian.Uint16(uint16(len(src.Fields))))
for _, fd := range src.Fields {
buf.WriteString(fd.Name)
buf.WriteByte(0)
buf.Write(bigEndian.Uint32(fd.TableOID))
buf.Write(bigEndian.Uint16(fd.TableAttributeNumber))
buf.Write(bigEndian.Uint32(fd.DataTypeOID))
buf.Write(bigEndian.Uint16(uint16(fd.DataTypeSize)))
buf.Write(bigEndian.Uint32(fd.TypeModifier))
buf.Write(bigEndian.Uint16(uint16(fd.Format)))
}
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}
func (src *RowDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Fields []FieldDescription
}{
Type: "RowDescription",
Fields: src.Fields,
})
}

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/jackc/pgx/internal/sanitize"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
@ -41,7 +42,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
// calling Next() until it returns false, or when a fatal error occurs.
type Rows struct {
conn *Conn
mr *msgReader
values [][]byte
fields []FieldDescription
rowCount int
columnIdx int
@ -115,15 +116,15 @@ func (rows *Rows) Next() bool {
rows.columnIdx = 0
for {
t, r, err := rows.conn.rxMsg()
msg, err := rows.conn.rxMsg()
if err != nil {
rows.Fatal(err)
return false
}
switch t {
case rowDescription:
rows.fields = rows.conn.rxRowDescription(r)
switch msg := msg.(type) {
case *pgproto3.RowDescription:
rows.fields = rows.conn.rxRowDescription(msg)
for i := range rows.fields {
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok {
rows.fields[i].DataTypeName = dt.Name
@ -133,21 +134,20 @@ func (rows *Rows) Next() bool {
return false
}
}
case dataRow:
fieldCount := r.readInt16()
if int(fieldCount) != len(rows.fields) {
rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount)))
case *pgproto3.DataRow:
if len(msg.Values) != len(rows.fields) {
rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values))))
return false
}
rows.mr = r
rows.values = msg.Values
return true
case commandComplete:
case *pgproto3.CommandComplete:
rows.Close()
return false
default:
err = rows.conn.processContextFreeMsg(t, r)
err = rows.conn.processContextFreeMsg(msg)
if err != nil {
rows.Fatal(err)
return false
@ -170,13 +170,9 @@ func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) {
return nil, nil, false
}
buf := rows.values[rows.columnIdx]
fd := &rows.fields[rows.columnIdx]
rows.columnIdx++
size := rows.mr.readInt32()
var buf []byte
if size >= 0 {
buf = rows.mr.readBytes(size)
}
return buf, fd, true
}

View File

@ -2,9 +2,12 @@ package pgx
import (
"context"
"encoding/binary"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/pgproto3"
)
const (
@ -203,59 +206,64 @@ func (rc *ReplicationConn) CauseOfDeath() error {
}
func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) {
var t byte
var reader *msgReader
t, reader, err = rc.c.rxMsg()
msg, err := rc.c.rxMsg()
if err != nil {
return
}
switch t {
case noticeResponse:
pgError := rc.c.rxErrorResponse(reader)
switch msg := msg.(type) {
case *pgproto3.NoticeResponse:
pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
if rc.c.shouldLog(LogLevelInfo) {
rc.c.log(LogLevelInfo, pgError.Error())
}
case errorResponse:
err = rc.c.rxErrorResponse(reader)
case *pgproto3.ErrorResponse:
err = rc.c.rxErrorResponse(msg)
if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, err.Error())
}
return
case copyBothResponse:
case *pgproto3.CopyBothResponse:
// This is the tail end of the replication process start,
// and can be safely ignored
return
case copyData:
var msgType byte
msgType = reader.readByte()
case *pgproto3.CopyData:
msgType := msg.Data[0]
rp := 1
switch msgType {
case walData:
walStart := reader.readInt64()
serverWalEnd := reader.readInt64()
serverTime := reader.readInt64()
walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp))
walMessage := WalMessage{WalStart: uint64(walStart),
ServerWalEnd: uint64(serverWalEnd),
ServerTime: uint64(serverTime),
walStart := binary.BigEndian.Uint64(msg.Data[rp:])
rp += 8
serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:])
rp += 8
serverTime := binary.BigEndian.Uint64(msg.Data[rp:])
rp += 8
walData := msg.Data[rp:]
walMessage := WalMessage{WalStart: walStart,
ServerWalEnd: serverWalEnd,
ServerTime: serverTime,
WalData: walData,
}
return &ReplicationMessage{WalMessage: &walMessage}, nil
case senderKeepalive:
serverWalEnd := reader.readInt64()
serverTime := reader.readInt64()
replyNow := reader.readByte()
h := &ServerHeartbeat{ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), ReplyRequested: replyNow}
serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:])
rp += 8
serverTime := binary.BigEndian.Uint64(msg.Data[rp:])
rp += 8
replyNow := msg.Data[rp]
rp += 1
h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow}
return &ReplicationMessage{ServerHeartbeat: h}, nil
default:
if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected data playload message type %v", t)
rc.c.log(LogLevelError, "Unexpected data playload message type %v", msgType)
}
}
default:
if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected replication message type %v", t)
rc.c.log(LogLevelError, "Unexpected replication message type %T", msg)
}
}
return
@ -325,21 +333,19 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
rows.Fatal(err)
}
var t byte
var r *msgReader
t, r, err = rc.c.rxMsg()
msg, err := rc.c.rxMsg()
if err != nil {
return nil, err
}
switch t {
case rowDescription:
rows.fields = rc.c.rxRowDescription(r)
switch msg := msg.(type) {
case *pgproto3.RowDescription:
rows.fields = rc.c.rxRowDescription(msg)
// We don't have c.PgTypes here because we're a replication
// connection. This means the field descriptions will have
// only Oids. Not much we can do about this.
default:
if e := rc.c.processContextFreeMsg(t, r); e != nil {
if e := rc.c.processContextFreeMsg(msg); e != nil {
rows.Fatal(e)
return rows, e
}