Introduce pgproto3 package

pgproto3 will wrap the message encoding and decoding for the PostgreSQL
frontend/backend protocol version 3.
batch-wip
Jack Christensen 2017-04-29 10:02:38 -05:00
parent e305ece410
commit f04c58338b
34 changed files with 1676 additions and 262 deletions

1
.gitignore vendored
View File

@ -22,3 +22,4 @@ _testmain.go
*.exe *.exe
conn_config_test.go conn_config_test.go
.envrc

View File

@ -52,7 +52,7 @@ install:
- go get -u github.com/shopspring/decimal - go get -u github.com/shopspring/decimal
- go get -u gopkg.in/inconshreveable/log15.v2 - go get -u gopkg.in/inconshreveable/log15.v2
- go get -u github.com/jackc/fake - go get -u github.com/jackc/fake
- go get -u github.com/jackc/pgmock/pgmsg - go get -u github.com/jackc/pgmock/pgproto3
- go get -u github.com/lib/pq - go get -u github.com/lib/pq
- go get -u github.com/hashicorp/go-version - go get -u github.com/hashicorp/go-version
- go get -u github.com/satori/go.uuid - go get -u github.com/satori/go.uuid

304
conn.go
View File

@ -20,7 +20,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/jackc/pgx/chunkreader" "github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@ -88,8 +88,8 @@ type Conn struct {
lastActivityTime time.Time // the last time the connection was used lastActivityTime time.Time // the last time the connection was used
wbuf [1024]byte wbuf [1024]byte
writeBuf WriteBuf writeBuf WriteBuf
pid int32 // backend pid pid uint32 // backend pid
secretKey int32 // key to use to send a cancel query message to the server secretKey uint32 // key to use to send a cancel query message to the server
RuntimeParams map[string]string // parameters that have been reported by the server RuntimeParams map[string]string // parameters that have been reported by the server
config ConnConfig // config used when establishing this connection config ConnConfig // config used when establishing this connection
txStatus byte txStatus byte
@ -98,7 +98,6 @@ type Conn struct {
notifications []*Notification notifications []*Notification
logger Logger logger Logger
logLevel int logLevel int
mr msgReader
fp *fastpath fp *fastpath
poolResetCount int poolResetCount int
preallocatedRows []Rows preallocatedRows []Rows
@ -116,6 +115,8 @@ type Conn struct {
closedChan chan error closedChan chan error
ConnInfo *pgtype.ConnInfo ConnInfo *pgtype.ConnInfo
frontend *pgproto3.Frontend
} }
// PreparedStatement is a description of a prepared statement // PreparedStatement is a description of a prepared statement
@ -133,7 +134,7 @@ type PrepareExOptions struct {
// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
type Notification struct { type Notification struct {
PID int32 // backend pid that sent the notification PID uint32 // backend pid that sent the notification
Channel string // channel from which notification was received Channel string // channel from which notification was received
Payload string Payload string
} }
@ -213,8 +214,6 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
c.logLevel = LogLevelDebug c.logLevel = LogLevelDebug
} }
c.logger = c.config.Logger c.logger = c.config.Logger
c.mr.log = c.log
c.mr.shouldLog = c.shouldLog
if c.config.User == "" { if c.config.User == "" {
user, err := user.Current() user, err := user.Current()
@ -290,7 +289,10 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
} }
} }
c.mr.cr = chunkreader.NewChunkReader(c.conn) c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn)
if err != nil {
return err
}
msg := newStartupMessage() msg := newStartupMessage()
@ -317,29 +319,27 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
} }
for { for {
var t byte msg, err := c.rxMsg()
var r *msgReader
t, r, err = c.rxMsg()
if err != nil { if err != nil {
return err return err
} }
switch t { switch msg := msg.(type) {
case backendKeyData: case *pgproto3.BackendKeyData:
c.rxBackendKeyData(r) c.rxBackendKeyData(msg)
case authenticationX: case *pgproto3.Authentication:
if err = c.rxAuthenticationX(r); err != nil { if err = c.rxAuthenticationX(msg); err != nil {
return err return err
} }
case readyForQuery: case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(r) c.rxReadyForQuery(msg)
if c.shouldLog(LogLevelInfo) { if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Connection established") c.log(LogLevelInfo, "Connection established")
} }
// Replication connections can't execute the queries to // Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet // populate the c.PgTypes and c.pgsqlAfInet
if _, ok := msg.options["replication"]; ok { if _, ok := config.RuntimeParams["replication"]; ok {
return nil return nil
} }
@ -352,7 +352,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
return nil return nil
default: default:
if err = c.processContextFreeMsg(t, r); err != nil { if err = c.processContextFreeMsg(msg); err != nil {
return err return err
} }
} }
@ -393,7 +393,7 @@ where (
} }
// PID returns the backend PID for this connection. // PID returns the backend PID for this connection.
func (c *Conn) PID() int32 { func (c *Conn) PID() uint32 {
return c.pid return c.pid
} }
@ -744,22 +744,20 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
var softErr error var softErr error
for { for {
var t byte msg, err := c.rxMsg()
var r *msgReader
t, r, err := c.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch t { switch msg := msg.(type) {
case parameterDescription: case *pgproto3.ParameterDescription:
ps.ParameterOids = c.rxParameterDescription(r) ps.ParameterOids = c.rxParameterDescription(msg)
if len(ps.ParameterOids) > 65535 && softErr == nil { if len(ps.ParameterOids) > 65535 && softErr == nil {
softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids))
} }
case rowDescription: case *pgproto3.RowDescription:
ps.FieldDescriptions = c.rxRowDescription(r) ps.FieldDescriptions = c.rxRowDescription(msg)
for i := range ps.FieldDescriptions { for i := range ps.FieldDescriptions {
if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok { if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok {
ps.FieldDescriptions[i].DataTypeName = dt.Name ps.FieldDescriptions[i].DataTypeName = dt.Name
@ -772,8 +770,8 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType)
} }
} }
case readyForQuery: case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(r) c.rxReadyForQuery(msg)
if softErr == nil { if softErr == nil {
c.preparedStatements[name] = ps c.preparedStatements[name] = ps
@ -781,7 +779,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
return ps, softErr return ps, softErr
default: default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
softErr = e softErr = e
} }
} }
@ -830,18 +828,16 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
} }
for { for {
var t byte msg, err := c.rxMsg()
var r *msgReader
t, r, err := c.rxMsg()
if err != nil { if err != nil {
return err return err
} }
switch t { switch msg.(type) {
case closeComplete: case *pgproto3.CloseComplete:
return nil return nil
default: default:
err = c.processContextFreeMsg(t, r) err = c.processContextFreeMsg(msg)
if err != nil { if err != nil {
return err return err
} }
@ -908,12 +904,12 @@ func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notificat
} }
for { for {
t, r, err := c.rxMsg() msg, err := c.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = c.processContextFreeMsg(t, r) err = c.processContextFreeMsg(msg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1030,62 +1026,48 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
// meaningful in a given context. These messages can occur due to a context // meaningful in a given context. These messages can occur due to a context
// deadline interrupting message processing. For example, an interrupted query // deadline interrupting message processing. For example, an interrupted query
// may have left DataRow messages on the wire. // may have left DataRow messages on the wire.
func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
switch t { switch msg := msg.(type) {
case bindComplete: case *pgproto3.ErrorResponse:
case commandComplete: return c.rxErrorResponse(msg)
case dataRow: case *pgproto3.NotificationResponse:
case emptyQueryResponse: c.rxNotificationResponse(msg)
case errorResponse: case *pgproto3.ReadyForQuery:
return c.rxErrorResponse(r) c.rxReadyForQuery(msg)
case noData: case *pgproto3.ParameterStatus:
case noticeResponse: c.rxParameterStatus(msg)
case notificationResponse:
c.rxNotificationResponse(r)
case parameterDescription:
case parseComplete:
case readyForQuery:
c.rxReadyForQuery(r)
case rowDescription:
case 'S':
c.rxParameterStatus(r)
default:
return fmt.Errorf("Received unknown message type: %c", t)
} }
return nil return nil
} }
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) {
if atomic.LoadInt32(&c.status) < connStatusIdle { if atomic.LoadInt32(&c.status) < connStatusIdle {
return 0, nil, ErrDeadConn return nil, ErrDeadConn
} }
t, err = c.mr.rxMsg() msg, err := c.frontend.Receive()
if err != nil { if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
c.die(err) c.die(err)
} }
return nil, err
} }
c.lastActivityTime = time.Now() c.lastActivityTime = time.Now()
if c.shouldLog(LogLevelTrace) { // fmt.Printf("rxMsg: %#v\n", msg)
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody))
}
return t, &c.mr, err return msg, nil
} }
func (c *Conn) rxAuthenticationX(r *msgReader) (err error) { func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
switch r.readInt32() { switch msg.Type {
case 0: // AuthenticationOk case pgproto3.AuthTypeOk:
case 3: // AuthenticationCleartextPassword case pgproto3.AuthTypeCleartextPassword:
err = c.txPasswordMessage(c.config.Password) err = c.txPasswordMessage(c.config.Password)
case 5: // AuthenticationMD5Password case pgproto3.AuthTypeMD5Password:
salt := r.readString(4) digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:]))
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt)
err = c.txPasswordMessage(digestedPassword) err = c.txPasswordMessage(digestedPassword)
default: default:
err = errors.New("Received unknown authentication message") err = errors.New("Received unknown authentication message")
@ -1100,115 +1082,75 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil)) return hex.EncodeToString(hash.Sum(nil))
} }
func (c *Conn) rxParameterStatus(r *msgReader) { func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) {
key := r.readCString() c.RuntimeParams[msg.Name] = msg.Value
value := r.readCString()
c.RuntimeParams[key] = value
} }
func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) { func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError {
for { err := PgError{
switch r.readByte() { Severity: msg.Severity,
case 'S': Code: msg.Code,
err.Severity = r.readCString() Message: msg.Message,
case 'C': Detail: msg.Detail,
err.Code = r.readCString() Hint: msg.Hint,
case 'M': Position: msg.Position,
err.Message = r.readCString() InternalPosition: msg.InternalPosition,
case 'D': InternalQuery: msg.InternalQuery,
err.Detail = r.readCString() Where: msg.Where,
case 'H': SchemaName: msg.SchemaName,
err.Hint = r.readCString() TableName: msg.TableName,
case 'P': ColumnName: msg.ColumnName,
s := r.readCString() DataTypeName: msg.DataTypeName,
n, _ := strconv.ParseInt(s, 10, 32) ConstraintName: msg.ConstraintName,
err.Position = int32(n) File: msg.File,
case 'p': Line: msg.Line,
s := r.readCString() Routine: msg.Routine,
n, _ := strconv.ParseInt(s, 10, 32)
err.InternalPosition = int32(n)
case 'q':
err.InternalQuery = r.readCString()
case 'W':
err.Where = r.readCString()
case 's':
err.SchemaName = r.readCString()
case 't':
err.TableName = r.readCString()
case 'c':
err.ColumnName = r.readCString()
case 'd':
err.DataTypeName = r.readCString()
case 'n':
err.ConstraintName = r.readCString()
case 'F':
err.File = r.readCString()
case 'L':
s := r.readCString()
n, _ := strconv.ParseInt(s, 10, 32)
err.Line = int32(n)
case 'R':
err.Routine = r.readCString()
case 0: // End of error message
if err.Severity == "FATAL" {
c.die(err)
}
return
default: // Ignore other error fields
r.readCString()
}
} }
if err.Severity == "FATAL" {
c.die(err)
}
return err
} }
func (c *Conn) rxBackendKeyData(r *msgReader) { func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) {
c.pid = r.readInt32() c.pid = msg.ProcessID
c.secretKey = r.readInt32() c.secretKey = msg.SecretKey
} }
func (c *Conn) rxReadyForQuery(r *msgReader) { func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) {
c.readyForQuery = true c.readyForQuery = true
c.txStatus = r.readByte() c.txStatus = msg.TxStatus
} }
func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription {
fieldCount := r.readInt16() fields := make([]FieldDescription, len(msg.Fields))
fields = make([]FieldDescription, fieldCount) for i := 0; i < len(fields); i++ {
for i := int16(0); i < fieldCount; i++ { fields[i].Name = msg.Fields[i].Name
f := &fields[i] fields[i].Table = pgtype.Oid(msg.Fields[i].TableOID)
f.Name = r.readCString() fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber
f.Table = pgtype.Oid(r.readUint32()) fields[i].DataType = pgtype.Oid(msg.Fields[i].DataTypeOID)
f.AttributeNumber = r.readInt16() fields[i].DataTypeSize = msg.Fields[i].DataTypeSize
f.DataType = pgtype.Oid(r.readUint32()) fields[i].Modifier = msg.Fields[i].TypeModifier
f.DataTypeSize = r.readInt16() fields[i].FormatCode = msg.Fields[i].Format
f.Modifier = r.readInt32()
f.FormatCode = r.readInt16()
} }
return return fields
} }
func (c *Conn) rxParameterDescription(r *msgReader) (parameters []pgtype.Oid) { func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.Oid {
// Internally, PostgreSQL supports greater than 64k parameters to a prepared parameters := make([]pgtype.Oid, len(msg.ParameterOIDs))
// statement. But the parameter description uses a 16-bit integer for the for i := 0; i < len(parameters); i++ {
// count of parameters. If there are more than 64K parameters, this count is parameters[i] = pgtype.Oid(msg.ParameterOIDs[i])
// wrong. So read the count, ignore it, and compute the proper value from
// the size of the message.
r.readInt16()
parameterCount := len(r.msgBody[r.rp:]) / 4
parameters = make([]pgtype.Oid, 0, parameterCount)
for i := 0; i < parameterCount; i++ {
parameters = append(parameters, pgtype.Oid(r.readUint32()))
} }
return return parameters
} }
func (c *Conn) rxNotificationResponse(r *msgReader) { func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) {
n := new(Notification) n := new(Notification)
n.PID = r.readInt32() n.PID = msg.PID
n.Channel = r.readCString() n.Channel = msg.Channel
n.Payload = r.readCString() n.Payload = msg.Payload
c.notifications = append(c.notifications, n) c.notifications = append(c.notifications, n)
} }
@ -1453,21 +1395,19 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions,
var softErr error var softErr error
for { for {
var t byte msg, err := c.rxMsg()
var r *msgReader
t, r, err = c.rxMsg()
if err != nil { if err != nil {
return commandTag, err return commandTag, err
} }
switch t { switch msg := msg.(type) {
case readyForQuery: case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(r) c.rxReadyForQuery(msg)
return commandTag, softErr return commandTag, softErr
case commandComplete: case *pgproto3.CommandComplete:
commandTag = CommandTag(r.readCString()) commandTag = CommandTag(msg.CommandTag)
default: default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
softErr = e softErr = e
} }
} }
@ -1545,19 +1485,19 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
func (c *Conn) ensureConnectionReadyForQuery() error { func (c *Conn) ensureConnectionReadyForQuery() error {
for !c.readyForQuery { for !c.readyForQuery {
t, r, err := c.rxMsg() msg, err := c.rxMsg()
if err != nil { if err != nil {
return err return err
} }
switch t { switch msg := msg.(type) {
case errorResponse: case *pgproto3.ErrorResponse:
pgErr := c.rxErrorResponse(r) pgErr := c.rxErrorResponse(msg)
if pgErr.Severity == "FATAL" { if pgErr.Severity == "FATAL" {
return pgErr return pgErr
} }
default: default:
err = c.processContextFreeMsg(t, r) err = c.processContextFreeMsg(msg)
if err != nil { if err != nil {
return err return err
} }

View File

@ -686,7 +686,7 @@ func TestConnPoolBeginRetry(t *testing.T) {
} }
defer tx.Rollback() defer tx.Rollback()
var txPID int32 var txPID uint32
err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID) err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID)
if err != nil { if err != nil {
t.Fatalf("tx.QueryRow Scan failed: %v", err) t.Fatalf("tx.QueryRow Scan failed: %v", err)

View File

@ -3,6 +3,8 @@ package pgx
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/jackc/pgx/pgproto3"
) )
// CopyFromRows returns a CopyFromSource interface over the provided rows slice // CopyFromRows returns a CopyFromSource interface over the provided rows slice
@ -54,25 +56,25 @@ type copyFrom struct {
func (ct *copyFrom) readUntilReadyForQuery() { func (ct *copyFrom) readUntilReadyForQuery() {
for { for {
t, r, err := ct.conn.rxMsg() msg, err := ct.conn.rxMsg()
if err != nil { if err != nil {
ct.readerErrChan <- err ct.readerErrChan <- err
close(ct.readerErrChan) close(ct.readerErrChan)
return return
} }
switch t { switch msg := msg.(type) {
case readyForQuery: case *pgproto3.ReadyForQuery:
ct.conn.rxReadyForQuery(r) ct.conn.rxReadyForQuery(msg)
close(ct.readerErrChan) close(ct.readerErrChan)
return return
case commandComplete: case *pgproto3.CommandComplete:
case errorResponse: case *pgproto3.ErrorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(r) ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
default: default:
err = ct.conn.processContextFreeMsg(t, r) err = ct.conn.processContextFreeMsg(msg)
if err != nil { if err != nil {
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
} }
} }
} }
@ -190,18 +192,16 @@ func (ct *copyFrom) run() (int, error) {
func (c *Conn) readUntilCopyInResponse() error { func (c *Conn) readUntilCopyInResponse() error {
for { for {
var t byte msg, err := c.rxMsg()
var r *msgReader
t, r, err := c.rxMsg()
if err != nil { if err != nil {
return err return err
} }
switch t { switch msg := msg.(type) {
case copyInResponse: case *pgproto3.CopyInResponse:
return nil return nil
default: default:
err = c.processContextFreeMsg(t, r) err = c.processContextFreeMsg(msg)
if err != nil { if err != nil {
return err return err
} }

View File

@ -3,6 +3,7 @@ package pgx
import ( import (
"encoding/binary" "encoding/binary"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@ -71,23 +72,20 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) {
} }
for { for {
var t byte msg, err := f.cn.rxMsg()
var r *msgReader
t, r, err = f.cn.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch t { switch msg := msg.(type) {
case 'V': // FunctionCallResponse case *pgproto3.FunctionCallResponse:
data := r.readBytes(r.readInt32()) res = make([]byte, len(msg.Result))
res = make([]byte, len(data)) copy(res, msg.Result)
copy(res, data) case *pgproto3.ReadyForQuery:
case 'Z': // Ready for query f.cn.rxReadyForQuery(msg)
f.cn.rxReadyForQuery(r)
// done // done
return return res, err
default: default:
if err := f.cn.processContextFreeMsg(t, r); err != nil { if err := f.cn.processContextFreeMsg(msg); err != nil {
return nil, err return nil, err
} }
} }

View File

@ -58,11 +58,11 @@ func (s *startupMessage) Bytes() (buf []byte) {
type FieldDescription struct { type FieldDescription struct {
Name string Name string
Table pgtype.Oid Table pgtype.Oid
AttributeNumber int16 AttributeNumber uint16
DataType pgtype.Oid DataType pgtype.Oid
DataTypeSize int16 DataTypeSize int16
DataTypeName string DataTypeName string
Modifier int32 Modifier uint32
FormatCode int16 FormatCode int16
} }

View File

@ -0,0 +1,54 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"fmt"
)
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
)
type Authentication struct {
Type uint32
// MD5Password fields
Salt [4]byte
}
func (*Authentication) Backend() {}
func (dst *Authentication) UnmarshalBinary(src []byte) error {
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
switch dst.Type {
case AuthTypeOk:
case AuthTypeCleartextPassword:
case AuthTypeMD5Password:
copy(dst.Salt[:], src[4:8])
default:
return fmt.Errorf("unknown authentication type: %d", dst.Type)
}
return nil
}
func (src *Authentication) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('R')
buf.Write(bigEndian.Uint32(0))
buf.Write(bigEndian.Uint32(src.Type))
switch src.Type {
case AuthTypeMD5Password:
buf.Write(src.Salt[:])
}
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}

View File

@ -0,0 +1,47 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type BackendKeyData struct {
ProcessID uint32
SecretKey uint32
}
func (*BackendKeyData) Backend() {}
func (dst *BackendKeyData) UnmarshalBinary(src []byte) error {
if len(src) != 8 {
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
}
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
return nil
}
func (src *BackendKeyData) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('K')
buf.Write(bigEndian.Uint32(12))
buf.Write(bigEndian.Uint32(src.ProcessID))
buf.Write(bigEndian.Uint32(src.SecretKey))
return buf.Bytes(), nil
}
func (src *BackendKeyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "BackendKeyData",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

37
pgproto3/big_endian.go Normal file
View File

@ -0,0 +1,37 @@
package pgproto3
import (
"encoding/binary"
)
type BigEndianBuf [8]byte
func (b BigEndianBuf) Int16(n int16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, uint16(n))
return buf
}
func (b BigEndianBuf) Uint16(n uint16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, n)
return buf
}
func (b BigEndianBuf) Int32(n int32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, uint32(n))
return buf
}
func (b BigEndianBuf) Uint32(n uint32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, n)
return buf
}
func (b BigEndianBuf) Int64(n int64) []byte {
buf := b[0:8]
binary.BigEndian.PutUint64(buf, uint64(n))
return buf
}

29
pgproto3/bind_complete.go Normal file
View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type BindComplete struct{}
func (*BindComplete) Backend() {}
func (dst *BindComplete) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *BindComplete) MarshalBinary() ([]byte, error) {
return []byte{'2', 0, 0, 0, 4}, nil
}
func (src *BindComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "BindComplete",
})
}

View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type CloseComplete struct{}
func (*CloseComplete) Backend() {}
func (dst *CloseComplete) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *CloseComplete) MarshalBinary() ([]byte, error) {
return []byte{'3', 0, 0, 0, 4}, nil
}
func (src *CloseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "CloseComplete",
})
}

View File

@ -0,0 +1,47 @@
package pgproto3
import (
"bytes"
"encoding/json"
)
type CommandComplete struct {
CommandTag string
}
func (*CommandComplete) Backend() {}
func (dst *CommandComplete) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.CommandTag = string(b[:len(b)-1])
return nil
}
func (src *CommandComplete) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('C')
buf.Write(bigEndian.Uint32(uint32(4 + len(src.CommandTag) + 1)))
buf.WriteString(src.CommandTag)
buf.WriteByte(0)
return buf.Bytes(), nil
}
func (src *CommandComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
CommandTag string
}{
Type: "CommandComplete",
CommandTag: src.CommandTag,
})
}

View File

@ -0,0 +1,64 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type CopyBothResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyBothResponse) Backend() {}
func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
func (src *CopyBothResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('W')
buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes))))
buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes))))
for _, fc := range src.ColumnFormatCodes {
buf.Write(bigEndian.Uint16(fc))
}
return buf.Bytes(), nil
}
func (src *CopyBothResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyBothResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}

41
pgproto3/copy_data.go Normal file
View File

@ -0,0 +1,41 @@
package pgproto3
import (
"bytes"
"encoding/hex"
"encoding/json"
)
type CopyData struct {
Data []byte
}
func (*CopyData) Backend() {}
func (*CopyData) Frontend() {}
func (dst *CopyData) UnmarshalBinary(src []byte) error {
dst.Data = make([]byte, len(src))
copy(dst.Data, src)
return nil
}
func (src *CopyData) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('d')
buf.Write(bigEndian.Uint32(uint32(4 + len(src.Data))))
buf.Write(src.Data)
return buf.Bytes(), nil
}
func (src *CopyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "CopyData",
Data: hex.EncodeToString(src.Data),
})
}

View File

@ -0,0 +1,64 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type CopyInResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyInResponse) Backend() {}
func (dst *CopyInResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
func (src *CopyInResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('G')
buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes))))
buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes))))
for _, fc := range src.ColumnFormatCodes {
buf.Write(bigEndian.Uint16(fc))
}
return buf.Bytes(), nil
}
func (src *CopyInResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyInResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}

View File

@ -0,0 +1,64 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type CopyOutResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyOutResponse) Backend() {}
func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
func (src *CopyOutResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('H')
buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes))))
buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes))))
for _, fc := range src.ColumnFormatCodes {
buf.Write(bigEndian.Uint16(fc))
}
return buf.Bytes(), nil
}
func (src *CopyOutResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyOutResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}

103
pgproto3/data_row.go Normal file
View File

@ -0,0 +1,103 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
)
type DataRow struct {
Values [][]byte
}
func (*DataRow) Backend() {}
func (dst *DataRow) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
dst.Values = make([][]byte, fieldCount)
for i := 0; i < fieldCount; i++ {
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4))))
// null
if msgSize == -1 {
continue
}
value := make([]byte, msgSize)
_, err := buf.Read(value)
if err != nil {
return err
}
dst.Values[i] = value
}
return nil
}
func (src *DataRow) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('D')
buf.Write(bigEndian.Uint32(0))
buf.Write(bigEndian.Uint16(uint16(len(src.Values))))
for _, v := range src.Values {
if v == nil {
buf.Write(bigEndian.Int32(-1))
continue
}
buf.Write(bigEndian.Int32(int32(len(v))))
buf.Write(v)
}
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}
func (src *DataRow) MarshalJSON() ([]byte, error) {
formattedValues := make([]map[string]string, len(src.Values))
for i, v := range src.Values {
if v == nil {
continue
}
var hasNonPrintable bool
for _, b := range v {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
} else {
formattedValues[i] = map[string]string{"text": string(v)}
}
}
return json.Marshal(struct {
Type string
Values []map[string]string
}{
Type: "DataRow",
Values: formattedValues,
})
}

View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type EmptyQueryResponse struct{}
func (*EmptyQueryResponse) Backend() {}
func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *EmptyQueryResponse) MarshalBinary() ([]byte, error) {
return []byte{'I', 0, 0, 0, 4}, nil
}
func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "EmptyQueryResponse",
})
}

197
pgproto3/error_response.go Normal file
View File

@ -0,0 +1,197 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"strconv"
)
type ErrorResponse struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
func (*ErrorResponse) Backend() {}
func (dst *ErrorResponse) UnmarshalBinary(src []byte) error {
*dst = ErrorResponse{}
buf := bytes.NewBuffer(src)
for {
k, err := buf.ReadByte()
if err != nil {
return err
}
if k == 0 {
break
}
vb, err := buf.ReadBytes(0)
if err != nil {
return err
}
v := string(vb[:len(vb)-1])
switch k {
case 'S':
dst.Severity = v
case 'C':
dst.Code = v
case 'M':
dst.Message = v
case 'D':
dst.Detail = v
case 'H':
dst.Hint = v
case 'P':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Position = int32(n)
case 'p':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.InternalPosition = int32(n)
case 'q':
dst.InternalQuery = v
case 'W':
dst.Where = v
case 's':
dst.SchemaName = v
case 't':
dst.TableName = v
case 'c':
dst.ColumnName = v
case 'd':
dst.DataTypeName = v
case 'n':
dst.ConstraintName = v
case 'F':
dst.File = v
case 'L':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Line = int32(n)
case 'R':
dst.Routine = v
default:
if dst.UnknownFields == nil {
dst.UnknownFields = make(map[byte]string)
}
dst.UnknownFields[k] = v
}
}
return nil
}
func (src *ErrorResponse) MarshalBinary() ([]byte, error) {
return src.marshalBinary('E')
}
func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
if src.Severity != "" {
buf.WriteString(src.Severity)
buf.WriteByte(0)
}
if src.Code != "" {
buf.WriteString(src.Code)
buf.WriteByte(0)
}
if src.Message != "" {
buf.WriteString(src.Message)
buf.WriteByte(0)
}
if src.Detail != "" {
buf.WriteString(src.Detail)
buf.WriteByte(0)
}
if src.Hint != "" {
buf.WriteString(src.Hint)
buf.WriteByte(0)
}
if src.Position != 0 {
buf.WriteString(strconv.Itoa(int(src.Position)))
buf.WriteByte(0)
}
if src.InternalPosition != 0 {
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
buf.WriteByte(0)
}
if src.InternalQuery != "" {
buf.WriteString(src.InternalQuery)
buf.WriteByte(0)
}
if src.Where != "" {
buf.WriteString(src.Where)
buf.WriteByte(0)
}
if src.SchemaName != "" {
buf.WriteString(src.SchemaName)
buf.WriteByte(0)
}
if src.TableName != "" {
buf.WriteString(src.TableName)
buf.WriteByte(0)
}
if src.ColumnName != "" {
buf.WriteString(src.ColumnName)
buf.WriteByte(0)
}
if src.DataTypeName != "" {
buf.WriteString(src.DataTypeName)
buf.WriteByte(0)
}
if src.ConstraintName != "" {
buf.WriteString(src.ConstraintName)
buf.WriteByte(0)
}
if src.File != "" {
buf.WriteString(src.File)
buf.WriteByte(0)
}
if src.Line != 0 {
buf.WriteString(strconv.Itoa(int(src.Line)))
buf.WriteByte(0)
}
if src.Routine != "" {
buf.WriteString(src.Routine)
buf.WriteByte(0)
}
for k, v := range src.UnknownFields {
buf.WriteByte(k)
buf.WriteByte(0)
buf.WriteString(v)
buf.WriteByte(0)
}
buf.WriteByte(0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}

70
pgproto3/frontend.go Normal file
View File

@ -0,0 +1,70 @@
package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/jackc/pgx/chunkreader"
)
type Frontend struct {
cr *chunkreader.ChunkReader
w io.Writer
}
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
cr := chunkreader.NewChunkReader(r)
return &Frontend{cr: cr, w: w}, nil
}
func (b *Frontend) Send(msg FrontendMessage) error {
return errors.New("not implemented")
}
func (b *Frontend) Receive() (BackendMessage, error) {
backendMessages := map[byte]BackendMessage{
'1': &ParseComplete{},
'2': &BindComplete{},
'3': &CloseComplete{},
'A': &NotificationResponse{},
'C': &CommandComplete{},
'd': &CopyData{},
'D': &DataRow{},
'E': &ErrorResponse{},
'G': &CopyInResponse{},
'H': &CopyOutResponse{},
'I': &EmptyQueryResponse{},
'K': &BackendKeyData{},
'n': &NoData{},
'N': &NoticeResponse{},
'R': &Authentication{},
'S': &ParameterStatus{},
't': &ParameterDescription{},
'T': &RowDescription{},
'V': &FunctionCallResponse{},
'W': &CopyBothResponse{},
'Z': &ReadyForQuery{},
}
header, err := b.cr.Next(5)
if err != nil {
return nil, err
}
msgType := header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
msgBody, err := b.cr.Next(bodyLen)
if err != nil {
return nil, err
}
if msg, ok := backendMessages[msgType]; ok {
err = msg.UnmarshalBinary(msgBody)
return msg, err
}
return nil, fmt.Errorf("unknown message type: %c", msgType)
}

View File

@ -0,0 +1,73 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
)
type FunctionCallResponse struct {
Result []byte
}
func (*FunctionCallResponse) Backend() {}
func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
resultSize := int(binary.BigEndian.Uint32(buf.Next(4)))
if buf.Len() != resultSize {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
dst.Result = make([]byte, resultSize)
copy(dst.Result, buf.Bytes())
return nil
}
func (src *FunctionCallResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('V')
buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Result))))
if src.Result == nil {
buf.Write(bigEndian.Int32(-1))
} else {
buf.Write(bigEndian.Int32(int32(len(src.Result))))
buf.Write(src.Result)
}
return buf.Bytes(), nil
}
func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) {
var formattedValue map[string]string
var hasNonPrintable bool
for _, b := range src.Result {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
} else {
formattedValue = map[string]string{"text": string(src.Result)}
}
return json.Marshal(struct {
Type string
Result map[string]string
}{
Type: "FunctionCallResponse",
Result: formattedValue,
})
}

29
pgproto3/no_data.go Normal file
View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type NoData struct{}
func (*NoData) Backend() {}
func (dst *NoData) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *NoData) MarshalBinary() ([]byte, error) {
return []byte{'n', 0, 0, 0, 4}, nil
}
func (src *NoData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "NoData",
})
}

View File

@ -0,0 +1,13 @@
package pgproto3
type NoticeResponse ErrorResponse
func (*NoticeResponse) Backend() {}
func (dst *NoticeResponse) UnmarshalBinary(src []byte) error {
return (*ErrorResponse)(dst).UnmarshalBinary(src)
}
func (src *NoticeResponse) MarshalBinary() ([]byte, error) {
return (*ErrorResponse)(src).marshalBinary('N')
}

View File

@ -0,0 +1,65 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type NotificationResponse struct {
PID uint32
Channel string
Payload string
}
func (*NotificationResponse) Backend() {}
func (dst *NotificationResponse) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
pid := binary.BigEndian.Uint32(buf.Next(4))
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
channel := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
payload := string(b[:len(b)-1])
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
return nil
}
func (src *NotificationResponse) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('A')
buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Channel) + len(src.Payload))))
buf.WriteString(src.Channel)
buf.WriteByte(0)
buf.WriteString(src.Payload)
buf.WriteByte(0)
return buf.Bytes(), nil
}
func (src *NotificationResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
PID uint32
Channel string
Payload string
}{
Type: "NotificationResponse",
PID: src.PID,
Channel: src.Channel,
Payload: src.Payload,
})
}

View File

@ -0,0 +1,60 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type ParameterDescription struct {
ParameterOIDs []uint32
}
func (*ParameterDescription) Backend() {}
func (dst *ParameterDescription) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
}
// Reported parameter count will be incorrect when number of args is greater than uint16
buf.Next(2)
// Instead infer parameter count by remaining size of message
parameterCount := buf.Len() / 4
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
for i := 0; i < parameterCount; i++ {
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
}
return nil
}
func (src *ParameterDescription) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('t')
buf.Write(bigEndian.Uint32(uint32(4 + 2 + 4*len(src.ParameterOIDs))))
buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs))))
for _, oid := range src.ParameterOIDs {
buf.Write(bigEndian.Uint32(oid))
}
return buf.Bytes(), nil
}
func (src *ParameterDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ParameterOIDs []uint32
}{
Type: "ParameterDescription",
ParameterOIDs: src.ParameterOIDs,
})
}

View File

@ -0,0 +1,62 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
type ParameterStatus struct {
Name string
Value string
}
func (*ParameterStatus) Backend() {}
func (dst *ParameterStatus) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
name := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
value := string(b[:len(b)-1])
*dst = ParameterStatus{Name: name, Value: value}
return nil
}
func (src *ParameterStatus) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('S')
buf.Write(bigEndian.Uint32(0))
buf.WriteString(src.Name)
buf.WriteByte(0)
buf.WriteString(src.Value)
buf.WriteByte(0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}
func (ps *ParameterStatus) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Value string
}{
Type: "ParameterStatus",
Name: ps.Name,
Value: ps.Value,
})
}

View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type ParseComplete struct{}
func (*ParseComplete) Backend() {}
func (dst *ParseComplete) UnmarshalBinary(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *ParseComplete) MarshalBinary() ([]byte, error) {
return []byte{'1', 0, 0, 0, 4}, nil
}
func (src *ParseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "ParseComplete",
})
}

88
pgproto3/pgproto3.go Normal file
View File

@ -0,0 +1,88 @@
package pgproto3
import "fmt"
type Message interface {
UnmarshalBinary(data []byte) error
MarshalBinary() (data []byte, err error)
}
type FrontendMessage interface {
Message
Frontend() // no-op method to distinguish frontend from backend methods
}
type BackendMessage interface {
Message
Backend() // no-op method to distinguish frontend from backend methods
}
// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) {
// switch typeByte {
// case '1':
// return ParseParseComplete(body)
// case '2':
// return ParseBindComplete(body)
// case 'C':
// return ParseCommandComplete(body)
// case 'D':
// return ParseDataRow(body)
// case 'E':
// return ParseErrorResponse(body)
// case 'K':
// return ParseBackendKeyData(body)
// case 'R':
// return ParseAuthentication(body)
// case 'S':
// return ParseParameterStatus(body)
// case 'T':
// return ParseRowDescription(body)
// case 't':
// return ParseParameterDescription(body)
// case 'Z':
// return ParseReadyForQuery(body)
// default:
// return ParseUnknownMessage(typeByte, body)
// }
// }
// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) {
// switch typeByte {
// case 'B':
// return ParseBind(body)
// case 'D':
// return ParseDescribe(body)
// case 'E':
// return ParseExecute(body)
// case 'P':
// return ParseParse(body)
// case 'p':
// return ParsePasswordMessage(body)
// case 'Q':
// return ParseQuery(body)
// case 'S':
// return ParseSync(body)
// case 'X':
// return ParseTerminate(body)
// default:
// return ParseUnknownMessage(typeByte, body)
// }
// }
type invalidMessageLenErr struct {
messageType string
expectedLen int
actualLen int
}
func (e *invalidMessageLenErr) Error() string {
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
}
type invalidMessageFormatErr struct {
messageType string
}
func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType)
}

43
pgproto3/query.go Normal file
View File

@ -0,0 +1,43 @@
package pgproto3
import (
"bytes"
"encoding/json"
)
type Query struct {
String string
}
func (*Query) Frontend() {}
func (dst *Query) UnmarshalBinary(src []byte) error {
i := bytes.IndexByte(src, 0)
if i != len(src)-1 {
return &invalidMessageFormatErr{messageType: "Query"}
}
dst.String = string(src[:i])
return nil
}
func (src *Query) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('Q')
buf.Write(bigEndian.Uint32(uint32(4 + len(src.String) + 1)))
buf.WriteString(src.String)
buf.WriteByte(0)
return buf.Bytes(), nil
}
func (src *Query) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
String string
}{
Type: "Query",
String: src.String,
})
}

View File

@ -0,0 +1,35 @@
package pgproto3
import (
"encoding/json"
)
type ReadyForQuery struct {
TxStatus byte
}
func (*ReadyForQuery) Backend() {}
func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error {
if len(src) != 1 {
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
}
dst.TxStatus = src[0]
return nil
}
func (src *ReadyForQuery) MarshalBinary() ([]byte, error) {
return []byte{'Z', 0, 0, 0, 5, src.TxStatus}, nil
}
func (src *ReadyForQuery) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
TxStatus string
}{
Type: "ReadyForQuery",
TxStatus: string(src.TxStatus),
})
}

101
pgproto3/row_description.go Normal file
View File

@ -0,0 +1,101 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
)
const (
TextFormat = 0
BinaryFormat = 1
)
type FieldDescription struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier uint32
Format int16
}
type RowDescription struct {
Fields []FieldDescription
}
func (*RowDescription) Backend() {}
func (dst *RowDescription) UnmarshalBinary(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
*dst = RowDescription{Fields: make([]FieldDescription, fieldCount)}
for i := 0; i < fieldCount; i++ {
var fd FieldDescription
bName, err := buf.ReadBytes(0)
if err != nil {
return err
}
fd.Name = string(bName[:len(bName)-1])
// Since buf.Next() doesn't return an error if we hit the end of the buffer
// check Len ahead of time
if buf.Len() < 18 {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fd.TableOID = binary.BigEndian.Uint32(buf.Next(4))
fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2))
fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4))
fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2)))
fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4))
fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2)))
dst.Fields[i] = fd
}
return nil
}
func (src *RowDescription) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte('T')
buf.Write(bigEndian.Uint32(0))
buf.Write(bigEndian.Uint16(uint16(len(src.Fields))))
for _, fd := range src.Fields {
buf.WriteString(fd.Name)
buf.WriteByte(0)
buf.Write(bigEndian.Uint32(fd.TableOID))
buf.Write(bigEndian.Uint16(fd.TableAttributeNumber))
buf.Write(bigEndian.Uint32(fd.DataTypeOID))
buf.Write(bigEndian.Uint16(uint16(fd.DataTypeSize)))
buf.Write(bigEndian.Uint32(fd.TypeModifier))
buf.Write(bigEndian.Uint16(uint16(fd.Format)))
}
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes(), nil
}
func (src *RowDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Fields []FieldDescription
}{
Type: "RowDescription",
Fields: src.Fields,
})
}

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/jackc/pgx/internal/sanitize" "github.com/jackc/pgx/internal/sanitize"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@ -41,7 +42,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
// calling Next() until it returns false, or when a fatal error occurs. // calling Next() until it returns false, or when a fatal error occurs.
type Rows struct { type Rows struct {
conn *Conn conn *Conn
mr *msgReader values [][]byte
fields []FieldDescription fields []FieldDescription
rowCount int rowCount int
columnIdx int columnIdx int
@ -115,15 +116,15 @@ func (rows *Rows) Next() bool {
rows.columnIdx = 0 rows.columnIdx = 0
for { for {
t, r, err := rows.conn.rxMsg() msg, err := rows.conn.rxMsg()
if err != nil { if err != nil {
rows.Fatal(err) rows.Fatal(err)
return false return false
} }
switch t { switch msg := msg.(type) {
case rowDescription: case *pgproto3.RowDescription:
rows.fields = rows.conn.rxRowDescription(r) rows.fields = rows.conn.rxRowDescription(msg)
for i := range rows.fields { for i := range rows.fields {
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok { if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok {
rows.fields[i].DataTypeName = dt.Name rows.fields[i].DataTypeName = dt.Name
@ -133,21 +134,20 @@ func (rows *Rows) Next() bool {
return false return false
} }
} }
case dataRow: case *pgproto3.DataRow:
fieldCount := r.readInt16() if len(msg.Values) != len(rows.fields) {
if int(fieldCount) != len(rows.fields) { rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values))))
rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount)))
return false return false
} }
rows.mr = r rows.values = msg.Values
return true return true
case commandComplete: case *pgproto3.CommandComplete:
rows.Close() rows.Close()
return false return false
default: default:
err = rows.conn.processContextFreeMsg(t, r) err = rows.conn.processContextFreeMsg(msg)
if err != nil { if err != nil {
rows.Fatal(err) rows.Fatal(err)
return false return false
@ -170,13 +170,9 @@ func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) {
return nil, nil, false return nil, nil, false
} }
buf := rows.values[rows.columnIdx]
fd := &rows.fields[rows.columnIdx] fd := &rows.fields[rows.columnIdx]
rows.columnIdx++ rows.columnIdx++
size := rows.mr.readInt32()
var buf []byte
if size >= 0 {
buf = rows.mr.readBytes(size)
}
return buf, fd, true return buf, fd, true
} }

View File

@ -2,9 +2,12 @@ package pgx
import ( import (
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"time" "time"
"github.com/jackc/pgx/pgproto3"
) )
const ( const (
@ -203,59 +206,64 @@ func (rc *ReplicationConn) CauseOfDeath() error {
} }
func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) {
var t byte msg, err := rc.c.rxMsg()
var reader *msgReader
t, reader, err = rc.c.rxMsg()
if err != nil { if err != nil {
return return
} }
switch t { switch msg := msg.(type) {
case noticeResponse: case *pgproto3.NoticeResponse:
pgError := rc.c.rxErrorResponse(reader) pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
if rc.c.shouldLog(LogLevelInfo) { if rc.c.shouldLog(LogLevelInfo) {
rc.c.log(LogLevelInfo, pgError.Error()) rc.c.log(LogLevelInfo, pgError.Error())
} }
case errorResponse: case *pgproto3.ErrorResponse:
err = rc.c.rxErrorResponse(reader) err = rc.c.rxErrorResponse(msg)
if rc.c.shouldLog(LogLevelError) { if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, err.Error()) rc.c.log(LogLevelError, err.Error())
} }
return return
case copyBothResponse: case *pgproto3.CopyBothResponse:
// This is the tail end of the replication process start, // This is the tail end of the replication process start,
// and can be safely ignored // and can be safely ignored
return return
case copyData: case *pgproto3.CopyData:
var msgType byte msgType := msg.Data[0]
msgType = reader.readByte() rp := 1
switch msgType { switch msgType {
case walData: case walData:
walStart := reader.readInt64() walStart := binary.BigEndian.Uint64(msg.Data[rp:])
serverWalEnd := reader.readInt64() rp += 8
serverTime := reader.readInt64() serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:])
walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp)) rp += 8
walMessage := WalMessage{WalStart: uint64(walStart), serverTime := binary.BigEndian.Uint64(msg.Data[rp:])
ServerWalEnd: uint64(serverWalEnd), rp += 8
ServerTime: uint64(serverTime), walData := msg.Data[rp:]
walMessage := WalMessage{WalStart: walStart,
ServerWalEnd: serverWalEnd,
ServerTime: serverTime,
WalData: walData, WalData: walData,
} }
return &ReplicationMessage{WalMessage: &walMessage}, nil return &ReplicationMessage{WalMessage: &walMessage}, nil
case senderKeepalive: case senderKeepalive:
serverWalEnd := reader.readInt64() serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:])
serverTime := reader.readInt64() rp += 8
replyNow := reader.readByte() serverTime := binary.BigEndian.Uint64(msg.Data[rp:])
h := &ServerHeartbeat{ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), ReplyRequested: replyNow} rp += 8
replyNow := msg.Data[rp]
rp += 1
h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow}
return &ReplicationMessage{ServerHeartbeat: h}, nil return &ReplicationMessage{ServerHeartbeat: h}, nil
default: default:
if rc.c.shouldLog(LogLevelError) { if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected data playload message type %v", t) rc.c.log(LogLevelError, "Unexpected data playload message type %v", msgType)
} }
} }
default: default:
if rc.c.shouldLog(LogLevelError) { if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected replication message type %v", t) rc.c.log(LogLevelError, "Unexpected replication message type %T", msg)
} }
} }
return return
@ -325,21 +333,19 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
rows.Fatal(err) rows.Fatal(err)
} }
var t byte msg, err := rc.c.rxMsg()
var r *msgReader
t, r, err = rc.c.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch t { switch msg := msg.(type) {
case rowDescription: case *pgproto3.RowDescription:
rows.fields = rc.c.rxRowDescription(r) rows.fields = rc.c.rxRowDescription(msg)
// We don't have c.PgTypes here because we're a replication // We don't have c.PgTypes here because we're a replication
// connection. This means the field descriptions will have // connection. This means the field descriptions will have
// only Oids. Not much we can do about this. // only Oids. Not much we can do about this.
default: default:
if e := rc.c.processContextFreeMsg(t, r); e != nil { if e := rc.c.processContextFreeMsg(msg); e != nil {
rows.Fatal(e) rows.Fatal(e)
return rows, e return rows, e
} }