mirror of https://github.com/jackc/pgx.git
Introduce pgproto3 package
pgproto3 will wrap the message encoding and decoding for the PostgreSQL frontend/backend protocol version 3.batch-wip
parent
e305ece410
commit
f04c58338b
|
@ -22,3 +22,4 @@ _testmain.go
|
|||
*.exe
|
||||
|
||||
conn_config_test.go
|
||||
.envrc
|
||||
|
|
|
@ -52,7 +52,7 @@ install:
|
|||
- go get -u github.com/shopspring/decimal
|
||||
- go get -u gopkg.in/inconshreveable/log15.v2
|
||||
- go get -u github.com/jackc/fake
|
||||
- go get -u github.com/jackc/pgmock/pgmsg
|
||||
- go get -u github.com/jackc/pgmock/pgproto3
|
||||
- go get -u github.com/lib/pq
|
||||
- go get -u github.com/hashicorp/go-version
|
||||
- go get -u github.com/satori/go.uuid
|
||||
|
|
304
conn.go
304
conn.go
|
@ -20,7 +20,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/chunkreader"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
@ -88,8 +88,8 @@ type Conn struct {
|
|||
lastActivityTime time.Time // the last time the connection was used
|
||||
wbuf [1024]byte
|
||||
writeBuf WriteBuf
|
||||
pid int32 // backend pid
|
||||
secretKey int32 // key to use to send a cancel query message to the server
|
||||
pid uint32 // backend pid
|
||||
secretKey uint32 // key to use to send a cancel query message to the server
|
||||
RuntimeParams map[string]string // parameters that have been reported by the server
|
||||
config ConnConfig // config used when establishing this connection
|
||||
txStatus byte
|
||||
|
@ -98,7 +98,6 @@ type Conn struct {
|
|||
notifications []*Notification
|
||||
logger Logger
|
||||
logLevel int
|
||||
mr msgReader
|
||||
fp *fastpath
|
||||
poolResetCount int
|
||||
preallocatedRows []Rows
|
||||
|
@ -116,6 +115,8 @@ type Conn struct {
|
|||
closedChan chan error
|
||||
|
||||
ConnInfo *pgtype.ConnInfo
|
||||
|
||||
frontend *pgproto3.Frontend
|
||||
}
|
||||
|
||||
// PreparedStatement is a description of a prepared statement
|
||||
|
@ -133,7 +134,7 @@ type PrepareExOptions struct {
|
|||
|
||||
// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
|
||||
type Notification struct {
|
||||
PID int32 // backend pid that sent the notification
|
||||
PID uint32 // backend pid that sent the notification
|
||||
Channel string // channel from which notification was received
|
||||
Payload string
|
||||
}
|
||||
|
@ -213,8 +214,6 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
|||
c.logLevel = LogLevelDebug
|
||||
}
|
||||
c.logger = c.config.Logger
|
||||
c.mr.log = c.log
|
||||
c.mr.shouldLog = c.shouldLog
|
||||
|
||||
if c.config.User == "" {
|
||||
user, err := user.Current()
|
||||
|
@ -290,7 +289,10 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
}
|
||||
}
|
||||
|
||||
c.mr.cr = chunkreader.NewChunkReader(c.conn)
|
||||
c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg := newStartupMessage()
|
||||
|
||||
|
@ -317,29 +319,27 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
}
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err = c.rxMsg()
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case backendKeyData:
|
||||
c.rxBackendKeyData(r)
|
||||
case authenticationX:
|
||||
if err = c.rxAuthenticationX(r); err != nil {
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.BackendKeyData:
|
||||
c.rxBackendKeyData(msg)
|
||||
case *pgproto3.Authentication:
|
||||
if err = c.rxAuthenticationX(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "Connection established")
|
||||
}
|
||||
|
||||
// Replication connections can't execute the queries to
|
||||
// populate the c.PgTypes and c.pgsqlAfInet
|
||||
if _, ok := msg.options["replication"]; ok {
|
||||
if _, ok := config.RuntimeParams["replication"]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -352,7 +352,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
|
||||
return nil
|
||||
default:
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
if err = c.processContextFreeMsg(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -393,7 +393,7 @@ where (
|
|||
}
|
||||
|
||||
// PID returns the backend PID for this connection.
|
||||
func (c *Conn) PID() int32 {
|
||||
func (c *Conn) PID() uint32 {
|
||||
return c.pid
|
||||
}
|
||||
|
||||
|
@ -744,22 +744,20 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
var softErr error
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err := c.rxMsg()
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case parameterDescription:
|
||||
ps.ParameterOids = c.rxParameterDescription(r)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ParameterDescription:
|
||||
ps.ParameterOids = c.rxParameterDescription(msg)
|
||||
|
||||
if len(ps.ParameterOids) > 65535 && softErr == nil {
|
||||
softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids))
|
||||
}
|
||||
case rowDescription:
|
||||
ps.FieldDescriptions = c.rxRowDescription(r)
|
||||
case *pgproto3.RowDescription:
|
||||
ps.FieldDescriptions = c.rxRowDescription(msg)
|
||||
for i := range ps.FieldDescriptions {
|
||||
if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok {
|
||||
ps.FieldDescriptions[i].DataTypeName = dt.Name
|
||||
|
@ -772,8 +770,8 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType)
|
||||
}
|
||||
}
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
|
||||
if softErr == nil {
|
||||
c.preparedStatements[name] = ps
|
||||
|
@ -781,7 +779,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
|
||||
return ps, softErr
|
||||
default:
|
||||
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
|
||||
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
|
||||
softErr = e
|
||||
}
|
||||
}
|
||||
|
@ -830,18 +828,16 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
|
|||
}
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err := c.rxMsg()
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case closeComplete:
|
||||
switch msg.(type) {
|
||||
case *pgproto3.CloseComplete:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(t, r)
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -908,12 +904,12 @@ func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notificat
|
|||
}
|
||||
|
||||
for {
|
||||
t, r, err := c.rxMsg()
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.processContextFreeMsg(t, r)
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1030,62 +1026,48 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
|
|||
// meaningful in a given context. These messages can occur due to a context
|
||||
// deadline interrupting message processing. For example, an interrupted query
|
||||
// may have left DataRow messages on the wire.
|
||||
func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
|
||||
switch t {
|
||||
case bindComplete:
|
||||
case commandComplete:
|
||||
case dataRow:
|
||||
case emptyQueryResponse:
|
||||
case errorResponse:
|
||||
return c.rxErrorResponse(r)
|
||||
case noData:
|
||||
case noticeResponse:
|
||||
case notificationResponse:
|
||||
c.rxNotificationResponse(r)
|
||||
case parameterDescription:
|
||||
case parseComplete:
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
case rowDescription:
|
||||
case 'S':
|
||||
c.rxParameterStatus(r)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("Received unknown message type: %c", t)
|
||||
func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ErrorResponse:
|
||||
return c.rxErrorResponse(msg)
|
||||
case *pgproto3.NotificationResponse:
|
||||
c.rxNotificationResponse(msg)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
case *pgproto3.ParameterStatus:
|
||||
c.rxParameterStatus(msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
|
||||
func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) {
|
||||
if atomic.LoadInt32(&c.status) < connStatusIdle {
|
||||
return 0, nil, ErrDeadConn
|
||||
return nil, ErrDeadConn
|
||||
}
|
||||
|
||||
t, err = c.mr.rxMsg()
|
||||
msg, err := c.frontend.Receive()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
c.die(err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.lastActivityTime = time.Now()
|
||||
|
||||
if c.shouldLog(LogLevelTrace) {
|
||||
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody))
|
||||
}
|
||||
// fmt.Printf("rxMsg: %#v\n", msg)
|
||||
|
||||
return t, &c.mr, err
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (c *Conn) rxAuthenticationX(r *msgReader) (err error) {
|
||||
switch r.readInt32() {
|
||||
case 0: // AuthenticationOk
|
||||
case 3: // AuthenticationCleartextPassword
|
||||
func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
|
||||
switch msg.Type {
|
||||
case pgproto3.AuthTypeOk:
|
||||
case pgproto3.AuthTypeCleartextPassword:
|
||||
err = c.txPasswordMessage(c.config.Password)
|
||||
case 5: // AuthenticationMD5Password
|
||||
salt := r.readString(4)
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt)
|
||||
case pgproto3.AuthTypeMD5Password:
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:]))
|
||||
err = c.txPasswordMessage(digestedPassword)
|
||||
default:
|
||||
err = errors.New("Received unknown authentication message")
|
||||
|
@ -1100,115 +1082,75 @@ func hexMD5(s string) string {
|
|||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
func (c *Conn) rxParameterStatus(r *msgReader) {
|
||||
key := r.readCString()
|
||||
value := r.readCString()
|
||||
c.RuntimeParams[key] = value
|
||||
func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) {
|
||||
c.RuntimeParams[msg.Name] = msg.Value
|
||||
}
|
||||
|
||||
func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) {
|
||||
for {
|
||||
switch r.readByte() {
|
||||
case 'S':
|
||||
err.Severity = r.readCString()
|
||||
case 'C':
|
||||
err.Code = r.readCString()
|
||||
case 'M':
|
||||
err.Message = r.readCString()
|
||||
case 'D':
|
||||
err.Detail = r.readCString()
|
||||
case 'H':
|
||||
err.Hint = r.readCString()
|
||||
case 'P':
|
||||
s := r.readCString()
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
err.Position = int32(n)
|
||||
case 'p':
|
||||
s := r.readCString()
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
err.InternalPosition = int32(n)
|
||||
case 'q':
|
||||
err.InternalQuery = r.readCString()
|
||||
case 'W':
|
||||
err.Where = r.readCString()
|
||||
case 's':
|
||||
err.SchemaName = r.readCString()
|
||||
case 't':
|
||||
err.TableName = r.readCString()
|
||||
case 'c':
|
||||
err.ColumnName = r.readCString()
|
||||
case 'd':
|
||||
err.DataTypeName = r.readCString()
|
||||
case 'n':
|
||||
err.ConstraintName = r.readCString()
|
||||
case 'F':
|
||||
err.File = r.readCString()
|
||||
case 'L':
|
||||
s := r.readCString()
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
err.Line = int32(n)
|
||||
case 'R':
|
||||
err.Routine = r.readCString()
|
||||
|
||||
case 0: // End of error message
|
||||
if err.Severity == "FATAL" {
|
||||
c.die(err)
|
||||
}
|
||||
return
|
||||
default: // Ignore other error fields
|
||||
r.readCString()
|
||||
}
|
||||
func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError {
|
||||
err := PgError{
|
||||
Severity: msg.Severity,
|
||||
Code: msg.Code,
|
||||
Message: msg.Message,
|
||||
Detail: msg.Detail,
|
||||
Hint: msg.Hint,
|
||||
Position: msg.Position,
|
||||
InternalPosition: msg.InternalPosition,
|
||||
InternalQuery: msg.InternalQuery,
|
||||
Where: msg.Where,
|
||||
SchemaName: msg.SchemaName,
|
||||
TableName: msg.TableName,
|
||||
ColumnName: msg.ColumnName,
|
||||
DataTypeName: msg.DataTypeName,
|
||||
ConstraintName: msg.ConstraintName,
|
||||
File: msg.File,
|
||||
Line: msg.Line,
|
||||
Routine: msg.Routine,
|
||||
}
|
||||
|
||||
if err.Severity == "FATAL" {
|
||||
c.die(err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) rxBackendKeyData(r *msgReader) {
|
||||
c.pid = r.readInt32()
|
||||
c.secretKey = r.readInt32()
|
||||
func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) {
|
||||
c.pid = msg.ProcessID
|
||||
c.secretKey = msg.SecretKey
|
||||
}
|
||||
|
||||
func (c *Conn) rxReadyForQuery(r *msgReader) {
|
||||
func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) {
|
||||
c.readyForQuery = true
|
||||
c.txStatus = r.readByte()
|
||||
c.txStatus = msg.TxStatus
|
||||
}
|
||||
|
||||
func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) {
|
||||
fieldCount := r.readInt16()
|
||||
fields = make([]FieldDescription, fieldCount)
|
||||
for i := int16(0); i < fieldCount; i++ {
|
||||
f := &fields[i]
|
||||
f.Name = r.readCString()
|
||||
f.Table = pgtype.Oid(r.readUint32())
|
||||
f.AttributeNumber = r.readInt16()
|
||||
f.DataType = pgtype.Oid(r.readUint32())
|
||||
f.DataTypeSize = r.readInt16()
|
||||
f.Modifier = r.readInt32()
|
||||
f.FormatCode = r.readInt16()
|
||||
func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription {
|
||||
fields := make([]FieldDescription, len(msg.Fields))
|
||||
for i := 0; i < len(fields); i++ {
|
||||
fields[i].Name = msg.Fields[i].Name
|
||||
fields[i].Table = pgtype.Oid(msg.Fields[i].TableOID)
|
||||
fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber
|
||||
fields[i].DataType = pgtype.Oid(msg.Fields[i].DataTypeOID)
|
||||
fields[i].DataTypeSize = msg.Fields[i].DataTypeSize
|
||||
fields[i].Modifier = msg.Fields[i].TypeModifier
|
||||
fields[i].FormatCode = msg.Fields[i].Format
|
||||
}
|
||||
return
|
||||
return fields
|
||||
}
|
||||
|
||||
func (c *Conn) rxParameterDescription(r *msgReader) (parameters []pgtype.Oid) {
|
||||
// Internally, PostgreSQL supports greater than 64k parameters to a prepared
|
||||
// statement. But the parameter description uses a 16-bit integer for the
|
||||
// count of parameters. If there are more than 64K parameters, this count is
|
||||
// wrong. So read the count, ignore it, and compute the proper value from
|
||||
// the size of the message.
|
||||
r.readInt16()
|
||||
parameterCount := len(r.msgBody[r.rp:]) / 4
|
||||
|
||||
parameters = make([]pgtype.Oid, 0, parameterCount)
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
parameters = append(parameters, pgtype.Oid(r.readUint32()))
|
||||
func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.Oid {
|
||||
parameters := make([]pgtype.Oid, len(msg.ParameterOIDs))
|
||||
for i := 0; i < len(parameters); i++ {
|
||||
parameters[i] = pgtype.Oid(msg.ParameterOIDs[i])
|
||||
}
|
||||
return
|
||||
return parameters
|
||||
}
|
||||
|
||||
func (c *Conn) rxNotificationResponse(r *msgReader) {
|
||||
func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) {
|
||||
n := new(Notification)
|
||||
n.PID = r.readInt32()
|
||||
n.Channel = r.readCString()
|
||||
n.Payload = r.readCString()
|
||||
n.PID = msg.PID
|
||||
n.Channel = msg.Channel
|
||||
n.Payload = msg.Payload
|
||||
c.notifications = append(c.notifications, n)
|
||||
}
|
||||
|
||||
|
@ -1453,21 +1395,19 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
var softErr error
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err = c.rxMsg()
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return commandTag, softErr
|
||||
case commandComplete:
|
||||
commandTag = CommandTag(r.readCString())
|
||||
case *pgproto3.CommandComplete:
|
||||
commandTag = CommandTag(msg.CommandTag)
|
||||
default:
|
||||
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
|
||||
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
|
||||
softErr = e
|
||||
}
|
||||
}
|
||||
|
@ -1545,19 +1485,19 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
|
|||
|
||||
func (c *Conn) ensureConnectionReadyForQuery() error {
|
||||
for !c.readyForQuery {
|
||||
t, r, err := c.rxMsg()
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case errorResponse:
|
||||
pgErr := c.rxErrorResponse(r)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr := c.rxErrorResponse(msg)
|
||||
if pgErr.Severity == "FATAL" {
|
||||
return pgErr
|
||||
}
|
||||
default:
|
||||
err = c.processContextFreeMsg(t, r)
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -686,7 +686,7 @@ func TestConnPoolBeginRetry(t *testing.T) {
|
|||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
var txPID int32
|
||||
var txPID uint32
|
||||
err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID)
|
||||
if err != nil {
|
||||
t.Fatalf("tx.QueryRow Scan failed: %v", err)
|
||||
|
|
30
copy_from.go
30
copy_from.go
|
@ -3,6 +3,8 @@ package pgx
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
||||
|
@ -54,25 +56,25 @@ type copyFrom struct {
|
|||
|
||||
func (ct *copyFrom) readUntilReadyForQuery() {
|
||||
for {
|
||||
t, r, err := ct.conn.rxMsg()
|
||||
msg, err := ct.conn.rxMsg()
|
||||
if err != nil {
|
||||
ct.readerErrChan <- err
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
}
|
||||
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
ct.conn.rxReadyForQuery(r)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
ct.conn.rxReadyForQuery(msg)
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
case commandComplete:
|
||||
case errorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
|
||||
case *pgproto3.CommandComplete:
|
||||
case *pgproto3.ErrorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
|
||||
default:
|
||||
err = ct.conn.processContextFreeMsg(t, r)
|
||||
err = ct.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
|
||||
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -190,18 +192,16 @@ func (ct *copyFrom) run() (int, error) {
|
|||
|
||||
func (c *Conn) readUntilCopyInResponse() error {
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err := c.rxMsg()
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case copyInResponse:
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyInResponse:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(t, r)
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
22
fastpath.go
22
fastpath.go
|
@ -3,6 +3,7 @@ package pgx
|
|||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
@ -71,23 +72,20 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) {
|
|||
}
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err = f.cn.rxMsg()
|
||||
msg, err := f.cn.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch t {
|
||||
case 'V': // FunctionCallResponse
|
||||
data := r.readBytes(r.readInt32())
|
||||
res = make([]byte, len(data))
|
||||
copy(res, data)
|
||||
case 'Z': // Ready for query
|
||||
f.cn.rxReadyForQuery(r)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.FunctionCallResponse:
|
||||
res = make([]byte, len(msg.Result))
|
||||
copy(res, msg.Result)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
f.cn.rxReadyForQuery(msg)
|
||||
// done
|
||||
return
|
||||
return res, err
|
||||
default:
|
||||
if err := f.cn.processContextFreeMsg(t, r); err != nil {
|
||||
if err := f.cn.processContextFreeMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,11 +58,11 @@ func (s *startupMessage) Bytes() (buf []byte) {
|
|||
type FieldDescription struct {
|
||||
Name string
|
||||
Table pgtype.Oid
|
||||
AttributeNumber int16
|
||||
AttributeNumber uint16
|
||||
DataType pgtype.Oid
|
||||
DataTypeSize int16
|
||||
DataTypeName string
|
||||
Modifier int32
|
||||
Modifier uint32
|
||||
FormatCode int16
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthTypeOk = 0
|
||||
AuthTypeCleartextPassword = 3
|
||||
AuthTypeMD5Password = 5
|
||||
)
|
||||
|
||||
type Authentication struct {
|
||||
Type uint32
|
||||
|
||||
// MD5Password fields
|
||||
Salt [4]byte
|
||||
}
|
||||
|
||||
func (*Authentication) Backend() {}
|
||||
|
||||
func (dst *Authentication) UnmarshalBinary(src []byte) error {
|
||||
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
|
||||
|
||||
switch dst.Type {
|
||||
case AuthTypeOk:
|
||||
case AuthTypeCleartextPassword:
|
||||
case AuthTypeMD5Password:
|
||||
copy(dst.Salt[:], src[4:8])
|
||||
default:
|
||||
return fmt.Errorf("unknown authentication type: %d", dst.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Authentication) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteByte('R')
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
buf.Write(bigEndian.Uint32(src.Type))
|
||||
|
||||
switch src.Type {
|
||||
case AuthTypeMD5Password:
|
||||
buf.Write(src.Salt[:])
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type BackendKeyData struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}
|
||||
|
||||
func (*BackendKeyData) Backend() {}
|
||||
|
||||
func (dst *BackendKeyData) UnmarshalBinary(src []byte) error {
|
||||
if len(src) != 8 {
|
||||
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *BackendKeyData) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteByte('K')
|
||||
buf.Write(bigEndian.Uint32(12))
|
||||
buf.Write(bigEndian.Uint32(src.ProcessID))
|
||||
buf.Write(bigEndian.Uint32(src.SecretKey))
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *BackendKeyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}{
|
||||
Type: "BackendKeyData",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type BigEndianBuf [8]byte
|
||||
|
||||
func (b BigEndianBuf) Int16(n int16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, uint16(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint16(n uint16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int32(n int32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint32(n uint32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int64(n int64) []byte {
|
||||
buf := b[0:8]
|
||||
binary.BigEndian.PutUint64(buf, uint64(n))
|
||||
return buf
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type BindComplete struct{}
|
||||
|
||||
func (*BindComplete) Backend() {}
|
||||
|
||||
func (dst *BindComplete) UnmarshalBinary(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *BindComplete) MarshalBinary() ([]byte, error) {
|
||||
return []byte{'2', 0, 0, 0, 4}, nil
|
||||
}
|
||||
|
||||
func (src *BindComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "BindComplete",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CloseComplete struct{}
|
||||
|
||||
func (*CloseComplete) Backend() {}
|
||||
|
||||
func (dst *CloseComplete) UnmarshalBinary(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CloseComplete) MarshalBinary() ([]byte, error) {
|
||||
return []byte{'3', 0, 0, 0, 4}, nil
|
||||
}
|
||||
|
||||
func (src *CloseComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "CloseComplete",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CommandComplete struct {
|
||||
CommandTag string
|
||||
}
|
||||
|
||||
func (*CommandComplete) Backend() {}
|
||||
|
||||
func (dst *CommandComplete) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.CommandTag = string(b[:len(b)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CommandComplete) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('C')
|
||||
buf.Write(bigEndian.Uint32(uint32(4 + len(src.CommandTag) + 1)))
|
||||
|
||||
buf.WriteString(src.CommandTag)
|
||||
buf.WriteByte(0)
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *CommandComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
CommandTag string
|
||||
}{
|
||||
Type: "CommandComplete",
|
||||
CommandTag: src.CommandTag,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CopyBothResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyBothResponse) Backend() {}
|
||||
|
||||
func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyBothResponse) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('W')
|
||||
buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes))))
|
||||
|
||||
buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes))))
|
||||
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
buf.Write(bigEndian.Uint16(fc))
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *CopyBothResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyBothResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CopyData struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (*CopyData) Backend() {}
|
||||
func (*CopyData) Frontend() {}
|
||||
|
||||
func (dst *CopyData) UnmarshalBinary(src []byte) error {
|
||||
dst.Data = make([]byte, len(src))
|
||||
copy(dst.Data, src)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyData) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('d')
|
||||
buf.Write(bigEndian.Uint32(uint32(4 + len(src.Data))))
|
||||
buf.Write(src.Data)
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *CopyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "CopyData",
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CopyInResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyInResponse) Backend() {}
|
||||
|
||||
func (dst *CopyInResponse) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyInResponse) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('G')
|
||||
buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes))))
|
||||
|
||||
buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes))))
|
||||
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
buf.Write(bigEndian.Uint16(fc))
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *CopyInResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyInResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CopyOutResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyOutResponse) Backend() {}
|
||||
|
||||
func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyOutResponse) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('H')
|
||||
buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes))))
|
||||
|
||||
buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes))))
|
||||
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
buf.Write(bigEndian.Uint16(fc))
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *CopyOutResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyOutResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type DataRow struct {
|
||||
Values [][]byte
|
||||
}
|
||||
|
||||
func (*DataRow) Backend() {}
|
||||
|
||||
func (dst *DataRow) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
dst.Values = make([][]byte, fieldCount)
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4))))
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
value := make([]byte, msgSize)
|
||||
_, err := buf.Read(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Values[i] = value
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *DataRow) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('D')
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
buf.Write(bigEndian.Uint16(uint16(len(src.Values))))
|
||||
|
||||
for _, v := range src.Values {
|
||||
if v == nil {
|
||||
buf.Write(bigEndian.Int32(-1))
|
||||
continue
|
||||
}
|
||||
|
||||
buf.Write(bigEndian.Int32(int32(len(v))))
|
||||
buf.Write(v)
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *DataRow) MarshalJSON() ([]byte, error) {
|
||||
formattedValues := make([]map[string]string, len(src.Values))
|
||||
for i, v := range src.Values {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var hasNonPrintable bool
|
||||
for _, b := range v {
|
||||
if b < 32 {
|
||||
hasNonPrintable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasNonPrintable {
|
||||
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
|
||||
} else {
|
||||
formattedValues[i] = map[string]string{"text": string(v)}
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Values []map[string]string
|
||||
}{
|
||||
Type: "DataRow",
|
||||
Values: formattedValues,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type EmptyQueryResponse struct{}
|
||||
|
||||
func (*EmptyQueryResponse) Backend() {}
|
||||
|
||||
func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *EmptyQueryResponse) MarshalBinary() ([]byte, error) {
|
||||
return []byte{'I', 0, 0, 0, 4}, nil
|
||||
}
|
||||
|
||||
func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "EmptyQueryResponse",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
Severity string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
|
||||
UnknownFields map[byte]string
|
||||
}
|
||||
|
||||
func (*ErrorResponse) Backend() {}
|
||||
|
||||
func (dst *ErrorResponse) UnmarshalBinary(src []byte) error {
|
||||
*dst = ErrorResponse{}
|
||||
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
for {
|
||||
k, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if k == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
vb, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := string(vb[:len(vb)-1])
|
||||
|
||||
switch k {
|
||||
case 'S':
|
||||
dst.Severity = v
|
||||
case 'C':
|
||||
dst.Code = v
|
||||
case 'M':
|
||||
dst.Message = v
|
||||
case 'D':
|
||||
dst.Detail = v
|
||||
case 'H':
|
||||
dst.Hint = v
|
||||
case 'P':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.Position = int32(n)
|
||||
case 'p':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.InternalPosition = int32(n)
|
||||
case 'q':
|
||||
dst.InternalQuery = v
|
||||
case 'W':
|
||||
dst.Where = v
|
||||
case 's':
|
||||
dst.SchemaName = v
|
||||
case 't':
|
||||
dst.TableName = v
|
||||
case 'c':
|
||||
dst.ColumnName = v
|
||||
case 'd':
|
||||
dst.DataTypeName = v
|
||||
case 'n':
|
||||
dst.ConstraintName = v
|
||||
case 'F':
|
||||
dst.File = v
|
||||
case 'L':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.Line = int32(n)
|
||||
case 'R':
|
||||
dst.Routine = v
|
||||
|
||||
default:
|
||||
if dst.UnknownFields == nil {
|
||||
dst.UnknownFields = make(map[byte]string)
|
||||
}
|
||||
dst.UnknownFields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) MarshalBinary() ([]byte, error) {
|
||||
return src.marshalBinary('E')
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte(typeByte)
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
if src.Severity != "" {
|
||||
buf.WriteString(src.Severity)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Code != "" {
|
||||
buf.WriteString(src.Code)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Message != "" {
|
||||
buf.WriteString(src.Message)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Detail != "" {
|
||||
buf.WriteString(src.Detail)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Hint != "" {
|
||||
buf.WriteString(src.Hint)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Position != 0 {
|
||||
buf.WriteString(strconv.Itoa(int(src.Position)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.InternalPosition != 0 {
|
||||
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.InternalQuery != "" {
|
||||
buf.WriteString(src.InternalQuery)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Where != "" {
|
||||
buf.WriteString(src.Where)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.SchemaName != "" {
|
||||
buf.WriteString(src.SchemaName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.TableName != "" {
|
||||
buf.WriteString(src.TableName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.ColumnName != "" {
|
||||
buf.WriteString(src.ColumnName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.DataTypeName != "" {
|
||||
buf.WriteString(src.DataTypeName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.ConstraintName != "" {
|
||||
buf.WriteString(src.ConstraintName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.File != "" {
|
||||
buf.WriteString(src.File)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Line != 0 {
|
||||
buf.WriteString(strconv.Itoa(int(src.Line)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Routine != "" {
|
||||
buf.WriteString(src.Routine)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
|
||||
for k, v := range src.UnknownFields {
|
||||
buf.WriteByte(k)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(v)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
buf.WriteByte(0)
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/chunkreader"
|
||||
)
|
||||
|
||||
type Frontend struct {
|
||||
cr *chunkreader.ChunkReader
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
|
||||
cr := chunkreader.NewChunkReader(r)
|
||||
return &Frontend{cr: cr, w: w}, nil
|
||||
}
|
||||
|
||||
func (b *Frontend) Send(msg FrontendMessage) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *Frontend) Receive() (BackendMessage, error) {
|
||||
backendMessages := map[byte]BackendMessage{
|
||||
'1': &ParseComplete{},
|
||||
'2': &BindComplete{},
|
||||
'3': &CloseComplete{},
|
||||
'A': &NotificationResponse{},
|
||||
'C': &CommandComplete{},
|
||||
'd': &CopyData{},
|
||||
'D': &DataRow{},
|
||||
'E': &ErrorResponse{},
|
||||
'G': &CopyInResponse{},
|
||||
'H': &CopyOutResponse{},
|
||||
'I': &EmptyQueryResponse{},
|
||||
'K': &BackendKeyData{},
|
||||
'n': &NoData{},
|
||||
'N': &NoticeResponse{},
|
||||
'R': &Authentication{},
|
||||
'S': &ParameterStatus{},
|
||||
't': &ParameterDescription{},
|
||||
'T': &RowDescription{},
|
||||
'V': &FunctionCallResponse{},
|
||||
'W': &CopyBothResponse{},
|
||||
'Z': &ReadyForQuery{},
|
||||
}
|
||||
|
||||
header, err := b.cr.Next(5)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgType := header[0]
|
||||
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
|
||||
msgBody, err := b.cr.Next(bodyLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msg, ok := backendMessages[msgType]; ok {
|
||||
err = msg.UnmarshalBinary(msgBody)
|
||||
return msg, err
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown message type: %c", msgType)
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type FunctionCallResponse struct {
|
||||
Result []byte
|
||||
}
|
||||
|
||||
func (*FunctionCallResponse) Backend() {}
|
||||
|
||||
func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
resultSize := int(binary.BigEndian.Uint32(buf.Next(4)))
|
||||
if buf.Len() != resultSize {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
|
||||
dst.Result = make([]byte, resultSize)
|
||||
copy(dst.Result, buf.Bytes())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *FunctionCallResponse) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('V')
|
||||
buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Result))))
|
||||
|
||||
if src.Result == nil {
|
||||
buf.Write(bigEndian.Int32(-1))
|
||||
} else {
|
||||
buf.Write(bigEndian.Int32(int32(len(src.Result))))
|
||||
buf.Write(src.Result)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) {
|
||||
var formattedValue map[string]string
|
||||
var hasNonPrintable bool
|
||||
for _, b := range src.Result {
|
||||
if b < 32 {
|
||||
hasNonPrintable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasNonPrintable {
|
||||
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
|
||||
} else {
|
||||
formattedValue = map[string]string{"text": string(src.Result)}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Result map[string]string
|
||||
}{
|
||||
Type: "FunctionCallResponse",
|
||||
Result: formattedValue,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type NoData struct{}
|
||||
|
||||
func (*NoData) Backend() {}
|
||||
|
||||
func (dst *NoData) UnmarshalBinary(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *NoData) MarshalBinary() ([]byte, error) {
|
||||
return []byte{'n', 0, 0, 0, 4}, nil
|
||||
}
|
||||
|
||||
func (src *NoData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "NoData",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package pgproto3
|
||||
|
||||
type NoticeResponse ErrorResponse
|
||||
|
||||
func (*NoticeResponse) Backend() {}
|
||||
|
||||
func (dst *NoticeResponse) UnmarshalBinary(src []byte) error {
|
||||
return (*ErrorResponse)(dst).UnmarshalBinary(src)
|
||||
}
|
||||
|
||||
func (src *NoticeResponse) MarshalBinary() ([]byte, error) {
|
||||
return (*ErrorResponse)(src).marshalBinary('N')
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type NotificationResponse struct {
|
||||
PID uint32
|
||||
Channel string
|
||||
Payload string
|
||||
}
|
||||
|
||||
func (*NotificationResponse) Backend() {}
|
||||
|
||||
func (dst *NotificationResponse) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
pid := binary.BigEndian.Uint32(buf.Next(4))
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
channel := string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := string(b[:len(b)-1])
|
||||
|
||||
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *NotificationResponse) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('A')
|
||||
buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Channel) + len(src.Payload))))
|
||||
|
||||
buf.WriteString(src.Channel)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(src.Payload)
|
||||
buf.WriteByte(0)
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *NotificationResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
PID uint32
|
||||
Channel string
|
||||
Payload string
|
||||
}{
|
||||
Type: "NotificationResponse",
|
||||
PID: src.PID,
|
||||
Channel: src.Channel,
|
||||
Payload: src.Payload,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ParameterDescription struct {
|
||||
ParameterOIDs []uint32
|
||||
}
|
||||
|
||||
func (*ParameterDescription) Backend() {}
|
||||
|
||||
func (dst *ParameterDescription) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
|
||||
}
|
||||
|
||||
// Reported parameter count will be incorrect when number of args is greater than uint16
|
||||
buf.Next(2)
|
||||
// Instead infer parameter count by remaining size of message
|
||||
parameterCount := buf.Len() / 4
|
||||
|
||||
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ParameterDescription) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('t')
|
||||
buf.Write(bigEndian.Uint32(uint32(4 + 2 + 4*len(src.ParameterOIDs))))
|
||||
|
||||
buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs))))
|
||||
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
buf.Write(bigEndian.Uint32(oid))
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *ParameterDescription) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ParameterOIDs []uint32
|
||||
}{
|
||||
Type: "ParameterDescription",
|
||||
ParameterOIDs: src.ParameterOIDs,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ParameterStatus struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (*ParameterStatus) Backend() {}
|
||||
|
||||
func (dst *ParameterStatus) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value := string(b[:len(b)-1])
|
||||
|
||||
*dst = ParameterStatus{Name: name, Value: value}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ParameterStatus) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('S')
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
buf.WriteString(src.Name)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(src.Value)
|
||||
buf.WriteByte(0)
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (ps *ParameterStatus) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Name string
|
||||
Value string
|
||||
}{
|
||||
Type: "ParameterStatus",
|
||||
Name: ps.Name,
|
||||
Value: ps.Value,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ParseComplete struct{}
|
||||
|
||||
func (*ParseComplete) Backend() {}
|
||||
|
||||
func (dst *ParseComplete) UnmarshalBinary(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ParseComplete) MarshalBinary() ([]byte, error) {
|
||||
return []byte{'1', 0, 0, 0, 4}, nil
|
||||
}
|
||||
|
||||
func (src *ParseComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "ParseComplete",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
package pgproto3
|
||||
|
||||
import "fmt"
|
||||
|
||||
type Message interface {
|
||||
UnmarshalBinary(data []byte) error
|
||||
MarshalBinary() (data []byte, err error)
|
||||
}
|
||||
|
||||
type FrontendMessage interface {
|
||||
Message
|
||||
Frontend() // no-op method to distinguish frontend from backend methods
|
||||
}
|
||||
|
||||
type BackendMessage interface {
|
||||
Message
|
||||
Backend() // no-op method to distinguish frontend from backend methods
|
||||
}
|
||||
|
||||
// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) {
|
||||
// switch typeByte {
|
||||
// case '1':
|
||||
// return ParseParseComplete(body)
|
||||
// case '2':
|
||||
// return ParseBindComplete(body)
|
||||
// case 'C':
|
||||
// return ParseCommandComplete(body)
|
||||
// case 'D':
|
||||
// return ParseDataRow(body)
|
||||
// case 'E':
|
||||
// return ParseErrorResponse(body)
|
||||
// case 'K':
|
||||
// return ParseBackendKeyData(body)
|
||||
// case 'R':
|
||||
// return ParseAuthentication(body)
|
||||
// case 'S':
|
||||
// return ParseParameterStatus(body)
|
||||
// case 'T':
|
||||
// return ParseRowDescription(body)
|
||||
// case 't':
|
||||
// return ParseParameterDescription(body)
|
||||
// case 'Z':
|
||||
// return ParseReadyForQuery(body)
|
||||
// default:
|
||||
// return ParseUnknownMessage(typeByte, body)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) {
|
||||
// switch typeByte {
|
||||
// case 'B':
|
||||
// return ParseBind(body)
|
||||
// case 'D':
|
||||
// return ParseDescribe(body)
|
||||
// case 'E':
|
||||
// return ParseExecute(body)
|
||||
// case 'P':
|
||||
// return ParseParse(body)
|
||||
// case 'p':
|
||||
// return ParsePasswordMessage(body)
|
||||
// case 'Q':
|
||||
// return ParseQuery(body)
|
||||
// case 'S':
|
||||
// return ParseSync(body)
|
||||
// case 'X':
|
||||
// return ParseTerminate(body)
|
||||
// default:
|
||||
// return ParseUnknownMessage(typeByte, body)
|
||||
// }
|
||||
// }
|
||||
|
||||
type invalidMessageLenErr struct {
|
||||
messageType string
|
||||
expectedLen int
|
||||
actualLen int
|
||||
}
|
||||
|
||||
func (e *invalidMessageLenErr) Error() string {
|
||||
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
|
||||
}
|
||||
|
||||
type invalidMessageFormatErr struct {
|
||||
messageType string
|
||||
}
|
||||
|
||||
func (e *invalidMessageFormatErr) Error() string {
|
||||
return fmt.Sprintf("%s body is invalid", e.messageType)
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Query struct {
|
||||
String string
|
||||
}
|
||||
|
||||
func (*Query) Frontend() {}
|
||||
|
||||
func (dst *Query) UnmarshalBinary(src []byte) error {
|
||||
i := bytes.IndexByte(src, 0)
|
||||
if i != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Query"}
|
||||
}
|
||||
|
||||
dst.String = string(src[:i])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Query) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteByte('Q')
|
||||
buf.Write(bigEndian.Uint32(uint32(4 + len(src.String) + 1)))
|
||||
buf.WriteString(src.String)
|
||||
buf.WriteByte(0)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *Query) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
String string
|
||||
}{
|
||||
Type: "Query",
|
||||
String: src.String,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ReadyForQuery struct {
|
||||
TxStatus byte
|
||||
}
|
||||
|
||||
func (*ReadyForQuery) Backend() {}
|
||||
|
||||
func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error {
|
||||
if len(src) != 1 {
|
||||
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.TxStatus = src[0]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ReadyForQuery) MarshalBinary() ([]byte, error) {
|
||||
return []byte{'Z', 0, 0, 0, 5, src.TxStatus}, nil
|
||||
}
|
||||
|
||||
func (src *ReadyForQuery) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
TxStatus string
|
||||
}{
|
||||
Type: "ReadyForQuery",
|
||||
TxStatus: string(src.TxStatus),
|
||||
})
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
const (
|
||||
TextFormat = 0
|
||||
BinaryFormat = 1
|
||||
)
|
||||
|
||||
type FieldDescription struct {
|
||||
Name string
|
||||
TableOID uint32
|
||||
TableAttributeNumber uint16
|
||||
DataTypeOID uint32
|
||||
DataTypeSize int16
|
||||
TypeModifier uint32
|
||||
Format int16
|
||||
}
|
||||
|
||||
type RowDescription struct {
|
||||
Fields []FieldDescription
|
||||
}
|
||||
|
||||
func (*RowDescription) Backend() {}
|
||||
|
||||
func (dst *RowDescription) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||
}
|
||||
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
*dst = RowDescription{Fields: make([]FieldDescription, fieldCount)}
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
var fd FieldDescription
|
||||
bName, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fd.Name = string(bName[:len(bName)-1])
|
||||
|
||||
// Since buf.Next() doesn't return an error if we hit the end of the buffer
|
||||
// check Len ahead of time
|
||||
if buf.Len() < 18 {
|
||||
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||
}
|
||||
|
||||
fd.TableOID = binary.BigEndian.Uint32(buf.Next(4))
|
||||
fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2))
|
||||
fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4))
|
||||
fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4))
|
||||
fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
dst.Fields[i] = fd
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *RowDescription) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte('T')
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
buf.Write(bigEndian.Uint16(uint16(len(src.Fields))))
|
||||
|
||||
for _, fd := range src.Fields {
|
||||
buf.WriteString(fd.Name)
|
||||
buf.WriteByte(0)
|
||||
|
||||
buf.Write(bigEndian.Uint32(fd.TableOID))
|
||||
buf.Write(bigEndian.Uint16(fd.TableAttributeNumber))
|
||||
buf.Write(bigEndian.Uint32(fd.DataTypeOID))
|
||||
buf.Write(bigEndian.Uint16(uint16(fd.DataTypeSize)))
|
||||
buf.Write(bigEndian.Uint32(fd.TypeModifier))
|
||||
buf.Write(bigEndian.Uint16(uint16(fd.Format)))
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *RowDescription) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Fields []FieldDescription
|
||||
}{
|
||||
Type: "RowDescription",
|
||||
Fields: src.Fields,
|
||||
})
|
||||
}
|
30
query.go
30
query.go
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/jackc/pgx/internal/sanitize"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
@ -41,7 +42,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
|
|||
// calling Next() until it returns false, or when a fatal error occurs.
|
||||
type Rows struct {
|
||||
conn *Conn
|
||||
mr *msgReader
|
||||
values [][]byte
|
||||
fields []FieldDescription
|
||||
rowCount int
|
||||
columnIdx int
|
||||
|
@ -115,15 +116,15 @@ func (rows *Rows) Next() bool {
|
|||
rows.columnIdx = 0
|
||||
|
||||
for {
|
||||
t, r, err := rows.conn.rxMsg()
|
||||
msg, err := rows.conn.rxMsg()
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
return false
|
||||
}
|
||||
|
||||
switch t {
|
||||
case rowDescription:
|
||||
rows.fields = rows.conn.rxRowDescription(r)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.RowDescription:
|
||||
rows.fields = rows.conn.rxRowDescription(msg)
|
||||
for i := range rows.fields {
|
||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok {
|
||||
rows.fields[i].DataTypeName = dt.Name
|
||||
|
@ -133,21 +134,20 @@ func (rows *Rows) Next() bool {
|
|||
return false
|
||||
}
|
||||
}
|
||||
case dataRow:
|
||||
fieldCount := r.readInt16()
|
||||
if int(fieldCount) != len(rows.fields) {
|
||||
rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount)))
|
||||
case *pgproto3.DataRow:
|
||||
if len(msg.Values) != len(rows.fields) {
|
||||
rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values))))
|
||||
return false
|
||||
}
|
||||
|
||||
rows.mr = r
|
||||
rows.values = msg.Values
|
||||
return true
|
||||
case commandComplete:
|
||||
case *pgproto3.CommandComplete:
|
||||
rows.Close()
|
||||
return false
|
||||
|
||||
default:
|
||||
err = rows.conn.processContextFreeMsg(t, r)
|
||||
err = rows.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
return false
|
||||
|
@ -170,13 +170,9 @@ func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) {
|
|||
return nil, nil, false
|
||||
}
|
||||
|
||||
buf := rows.values[rows.columnIdx]
|
||||
fd := &rows.fields[rows.columnIdx]
|
||||
rows.columnIdx++
|
||||
size := rows.mr.readInt32()
|
||||
var buf []byte
|
||||
if size >= 0 {
|
||||
buf = rows.mr.readBytes(size)
|
||||
}
|
||||
return buf, fd, true
|
||||
}
|
||||
|
||||
|
|
|
@ -2,9 +2,12 @@ package pgx
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -203,59 +206,64 @@ func (rc *ReplicationConn) CauseOfDeath() error {
|
|||
}
|
||||
|
||||
func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) {
|
||||
var t byte
|
||||
var reader *msgReader
|
||||
t, reader, err = rc.c.rxMsg()
|
||||
msg, err := rc.c.rxMsg()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch t {
|
||||
case noticeResponse:
|
||||
pgError := rc.c.rxErrorResponse(reader)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.NoticeResponse:
|
||||
pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
|
||||
if rc.c.shouldLog(LogLevelInfo) {
|
||||
rc.c.log(LogLevelInfo, pgError.Error())
|
||||
}
|
||||
case errorResponse:
|
||||
err = rc.c.rxErrorResponse(reader)
|
||||
case *pgproto3.ErrorResponse:
|
||||
err = rc.c.rxErrorResponse(msg)
|
||||
if rc.c.shouldLog(LogLevelError) {
|
||||
rc.c.log(LogLevelError, err.Error())
|
||||
}
|
||||
return
|
||||
case copyBothResponse:
|
||||
case *pgproto3.CopyBothResponse:
|
||||
// This is the tail end of the replication process start,
|
||||
// and can be safely ignored
|
||||
return
|
||||
case copyData:
|
||||
var msgType byte
|
||||
msgType = reader.readByte()
|
||||
case *pgproto3.CopyData:
|
||||
msgType := msg.Data[0]
|
||||
rp := 1
|
||||
|
||||
switch msgType {
|
||||
case walData:
|
||||
walStart := reader.readInt64()
|
||||
serverWalEnd := reader.readInt64()
|
||||
serverTime := reader.readInt64()
|
||||
walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp))
|
||||
walMessage := WalMessage{WalStart: uint64(walStart),
|
||||
ServerWalEnd: uint64(serverWalEnd),
|
||||
ServerTime: uint64(serverTime),
|
||||
walStart := binary.BigEndian.Uint64(msg.Data[rp:])
|
||||
rp += 8
|
||||
serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:])
|
||||
rp += 8
|
||||
serverTime := binary.BigEndian.Uint64(msg.Data[rp:])
|
||||
rp += 8
|
||||
walData := msg.Data[rp:]
|
||||
walMessage := WalMessage{WalStart: walStart,
|
||||
ServerWalEnd: serverWalEnd,
|
||||
ServerTime: serverTime,
|
||||
WalData: walData,
|
||||
}
|
||||
|
||||
return &ReplicationMessage{WalMessage: &walMessage}, nil
|
||||
case senderKeepalive:
|
||||
serverWalEnd := reader.readInt64()
|
||||
serverTime := reader.readInt64()
|
||||
replyNow := reader.readByte()
|
||||
h := &ServerHeartbeat{ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), ReplyRequested: replyNow}
|
||||
serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:])
|
||||
rp += 8
|
||||
serverTime := binary.BigEndian.Uint64(msg.Data[rp:])
|
||||
rp += 8
|
||||
replyNow := msg.Data[rp]
|
||||
rp += 1
|
||||
h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow}
|
||||
return &ReplicationMessage{ServerHeartbeat: h}, nil
|
||||
default:
|
||||
if rc.c.shouldLog(LogLevelError) {
|
||||
rc.c.log(LogLevelError, "Unexpected data playload message type %v", t)
|
||||
rc.c.log(LogLevelError, "Unexpected data playload message type %v", msgType)
|
||||
}
|
||||
}
|
||||
default:
|
||||
if rc.c.shouldLog(LogLevelError) {
|
||||
rc.c.log(LogLevelError, "Unexpected replication message type %v", t)
|
||||
rc.c.log(LogLevelError, "Unexpected replication message type %T", msg)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -325,21 +333,19 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
|
|||
rows.Fatal(err)
|
||||
}
|
||||
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err = rc.c.rxMsg()
|
||||
msg, err := rc.c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case rowDescription:
|
||||
rows.fields = rc.c.rxRowDescription(r)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.RowDescription:
|
||||
rows.fields = rc.c.rxRowDescription(msg)
|
||||
// We don't have c.PgTypes here because we're a replication
|
||||
// connection. This means the field descriptions will have
|
||||
// only Oids. Not much we can do about this.
|
||||
default:
|
||||
if e := rc.c.processContextFreeMsg(t, r); e != nil {
|
||||
if e := rc.c.processContextFreeMsg(msg); e != nil {
|
||||
rows.Fatal(e)
|
||||
return rows, e
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue