mirror of https://github.com/jackc/pgx.git
Make MsgReader private
parent
f215c8bf5f
commit
4fbd76bee5
82
conn.go
82
conn.go
|
@ -52,7 +52,7 @@ type Conn struct {
|
|||
causeOfDeath error
|
||||
logger log.Logger
|
||||
rows Rows
|
||||
mr MsgReader
|
||||
mr msgReader
|
||||
}
|
||||
|
||||
type PreparedStatement struct {
|
||||
|
@ -172,7 +172,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
|
||||
for {
|
||||
var t byte
|
||||
var r *MsgReader
|
||||
var r *msgReader
|
||||
t, r, err = c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -278,7 +278,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||
|
||||
for {
|
||||
var t byte
|
||||
var r *MsgReader
|
||||
var r *msgReader
|
||||
t, r, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -364,7 +364,7 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
|
|||
}
|
||||
|
||||
var t byte
|
||||
var r *MsgReader
|
||||
var r *msgReader
|
||||
if t, r, err = c.rxMsg(); err == nil {
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
return nil, err
|
||||
|
@ -544,7 +544,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
|
|||
|
||||
for {
|
||||
var t byte
|
||||
var r *MsgReader
|
||||
var r *msgReader
|
||||
t, r, err = c.rxMsg()
|
||||
if err != nil {
|
||||
return commandTag, err
|
||||
|
@ -558,7 +558,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
|
|||
case dataRow:
|
||||
case bindComplete:
|
||||
case commandComplete:
|
||||
commandTag = CommandTag(r.ReadCString())
|
||||
commandTag = CommandTag(r.readCString())
|
||||
default:
|
||||
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
|
||||
softErr = e
|
||||
|
@ -570,7 +570,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
|
|||
// Processes messages that are not exclusive to one context such as
|
||||
// authentication or query response. The response to these messages
|
||||
// is the same regardless of when they occur.
|
||||
func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) {
|
||||
func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
|
||||
switch t {
|
||||
case 'S':
|
||||
c.rxParameterStatus(r)
|
||||
|
@ -587,7 +587,7 @@ func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) {
|
||||
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
|
||||
if !c.alive {
|
||||
return 0, nil, ErrDeadConn
|
||||
}
|
||||
|
@ -600,13 +600,13 @@ func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) {
|
|||
return t, &c.mr, err
|
||||
}
|
||||
|
||||
func (c *Conn) rxAuthenticationX(r *MsgReader) (err error) {
|
||||
switch r.ReadInt32() {
|
||||
func (c *Conn) rxAuthenticationX(r *msgReader) (err error) {
|
||||
switch r.readInt32() {
|
||||
case 0: // AuthenticationOk
|
||||
case 3: // AuthenticationCleartextPassword
|
||||
err = c.txPasswordMessage(c.config.Password)
|
||||
case 5: // AuthenticationMD5Password
|
||||
salt := r.ReadString(4)
|
||||
salt := r.readString(4)
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt)
|
||||
err = c.txPasswordMessage(digestedPassword)
|
||||
default:
|
||||
|
@ -622,72 +622,72 @@ func hexMD5(s string) string {
|
|||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
func (c *Conn) rxParameterStatus(r *MsgReader) {
|
||||
key := r.ReadCString()
|
||||
value := r.ReadCString()
|
||||
func (c *Conn) rxParameterStatus(r *msgReader) {
|
||||
key := r.readCString()
|
||||
value := r.readCString()
|
||||
c.RuntimeParams[key] = value
|
||||
}
|
||||
|
||||
func (c *Conn) rxErrorResponse(r *MsgReader) (err PgError) {
|
||||
func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) {
|
||||
for {
|
||||
switch r.ReadByte() {
|
||||
switch r.readByte() {
|
||||
case 'S':
|
||||
err.Severity = r.ReadCString()
|
||||
err.Severity = r.readCString()
|
||||
case 'C':
|
||||
err.Code = r.ReadCString()
|
||||
err.Code = r.readCString()
|
||||
case 'M':
|
||||
err.Message = r.ReadCString()
|
||||
err.Message = r.readCString()
|
||||
case 0: // End of error message
|
||||
if err.Severity == "FATAL" {
|
||||
c.die(err)
|
||||
}
|
||||
return
|
||||
default: // Ignore other error fields
|
||||
r.ReadCString()
|
||||
r.readCString()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) rxBackendKeyData(r *MsgReader) {
|
||||
c.Pid = r.ReadInt32()
|
||||
c.SecretKey = r.ReadInt32()
|
||||
func (c *Conn) rxBackendKeyData(r *msgReader) {
|
||||
c.Pid = r.readInt32()
|
||||
c.SecretKey = r.readInt32()
|
||||
}
|
||||
|
||||
func (c *Conn) rxReadyForQuery(r *MsgReader) {
|
||||
c.TxStatus = r.ReadByte()
|
||||
func (c *Conn) rxReadyForQuery(r *msgReader) {
|
||||
c.TxStatus = r.readByte()
|
||||
}
|
||||
|
||||
func (c *Conn) rxRowDescription(r *MsgReader) (fields []FieldDescription) {
|
||||
fieldCount := r.ReadInt16()
|
||||
func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) {
|
||||
fieldCount := r.readInt16()
|
||||
fields = make([]FieldDescription, fieldCount)
|
||||
for i := int16(0); i < fieldCount; i++ {
|
||||
f := &fields[i]
|
||||
f.Name = r.ReadCString()
|
||||
f.Table = r.ReadOid()
|
||||
f.AttributeNumber = r.ReadInt16()
|
||||
f.DataType = r.ReadOid()
|
||||
f.DataTypeSize = r.ReadInt16()
|
||||
f.Modifier = r.ReadInt32()
|
||||
f.FormatCode = r.ReadInt16()
|
||||
f.Name = r.readCString()
|
||||
f.Table = r.readOid()
|
||||
f.AttributeNumber = r.readInt16()
|
||||
f.DataType = r.readOid()
|
||||
f.DataTypeSize = r.readInt16()
|
||||
f.Modifier = r.readInt32()
|
||||
f.FormatCode = r.readInt16()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) rxParameterDescription(r *MsgReader) (parameters []Oid) {
|
||||
parameterCount := r.ReadInt16()
|
||||
func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) {
|
||||
parameterCount := r.readInt16()
|
||||
parameters = make([]Oid, 0, parameterCount)
|
||||
|
||||
for i := int16(0); i < parameterCount; i++ {
|
||||
parameters = append(parameters, r.ReadOid())
|
||||
parameters = append(parameters, r.readOid())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) rxNotificationResponse(r *MsgReader) {
|
||||
func (c *Conn) rxNotificationResponse(r *msgReader) {
|
||||
n := new(Notification)
|
||||
n.Pid = r.ReadInt32()
|
||||
n.Channel = r.ReadCString()
|
||||
n.Payload = r.ReadCString()
|
||||
n.Pid = r.readInt32()
|
||||
n.Channel = r.readCString()
|
||||
n.Payload = r.readCString()
|
||||
c.notifications = append(c.notifications, n)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,26 +8,26 @@ import (
|
|||
"io/ioutil"
|
||||
)
|
||||
|
||||
// MsgReader is a helper that reads values from a PostgreSQL message.
|
||||
type MsgReader struct {
|
||||
// msgReader is a helper that reads values from a PostgreSQL message.
|
||||
type msgReader struct {
|
||||
reader *bufio.Reader
|
||||
buf [128]byte
|
||||
msgBytesRemaining int32
|
||||
err error
|
||||
}
|
||||
|
||||
// Err returns any error that the MsgReader has experienced
|
||||
func (r *MsgReader) Err() error {
|
||||
// Err returns any error that the msgReader has experienced
|
||||
func (r *msgReader) Err() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
// Fatal tells r that a Fatal error has occurred
|
||||
func (r *MsgReader) Fatal(err error) {
|
||||
// fatal tells r that a Fatal error has occurred
|
||||
func (r *msgReader) fatal(err error) {
|
||||
r.err = err
|
||||
}
|
||||
|
||||
// rxMsg reads the type and size of the next message.
|
||||
func (r *MsgReader) rxMsg() (t byte, err error) {
|
||||
func (r *msgReader) rxMsg() (t byte, err error) {
|
||||
if r.err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -43,123 +43,123 @@ func (r *MsgReader) rxMsg() (t byte, err error) {
|
|||
return t, err
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadByte() byte {
|
||||
func (r *msgReader) readByte() byte {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 1
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadByte()
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadInt16() int16 {
|
||||
func (r *msgReader) readInt16() int16 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 2
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b := r.buf[0:2]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return int16(binary.BigEndian.Uint16(b))
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadInt32() int32 {
|
||||
func (r *msgReader) readInt32() int32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 4
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b := r.buf[0:4]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return int32(binary.BigEndian.Uint32(b))
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadInt64() int64 {
|
||||
func (r *msgReader) readInt64() int64 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 8
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b := r.buf[0:8]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return int64(binary.BigEndian.Uint64(b))
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadOid() Oid {
|
||||
return Oid(r.ReadInt32())
|
||||
func (r *msgReader) readOid() Oid {
|
||||
return Oid(r.readInt32())
|
||||
}
|
||||
|
||||
// ReadCString reads a null terminated string
|
||||
func (r *MsgReader) ReadCString() string {
|
||||
// readCString reads a null terminated string
|
||||
func (r *msgReader) readCString() string {
|
||||
if r.err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadBytes(0)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
r.fatal(err)
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= int32(len(b))
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(b[0 : len(b)-1])
|
||||
}
|
||||
|
||||
// ReadString reads count bytes and returns as string
|
||||
func (r *MsgReader) ReadString(count int32) string {
|
||||
// readString reads count bytes and returns as string
|
||||
func (r *msgReader) readString(count int32) string {
|
||||
if r.err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= count
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
|
@ -172,22 +172,22 @@ func (r *MsgReader) ReadString(count int32) string {
|
|||
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
r.fatal(err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// ReadBytes reads count bytes and returns as []byte
|
||||
func (r *MsgReader) ReadBytes(count int32) []byte {
|
||||
// readBytes reads count bytes and returns as []byte
|
||||
func (r *msgReader) readBytes(count int32) []byte {
|
||||
if r.err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= count
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -195,7 +195,7 @@ func (r *MsgReader) ReadBytes(count int32) []byte {
|
|||
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
r.fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
6
query.go
6
query.go
|
@ -31,7 +31,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
|
|||
type Rows struct {
|
||||
pool *ConnPool
|
||||
conn *Conn
|
||||
mr *MsgReader
|
||||
mr *msgReader
|
||||
fields []FieldDescription
|
||||
vr ValueReader
|
||||
rowCount int
|
||||
|
@ -134,7 +134,7 @@ func (rows *Rows) Next() bool {
|
|||
rows.close()
|
||||
return false
|
||||
case dataRow:
|
||||
fieldCount := r.ReadInt16()
|
||||
fieldCount := r.readInt16()
|
||||
if int(fieldCount) != len(rows.fields) {
|
||||
rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount)))
|
||||
return false
|
||||
|
@ -165,7 +165,7 @@ func (rows *Rows) nextColumn() (*ValueReader, bool) {
|
|||
|
||||
fd := &rows.fields[rows.columnIdx]
|
||||
rows.columnIdx++
|
||||
size := rows.mr.ReadInt32()
|
||||
size := rows.mr.readInt32()
|
||||
rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size}
|
||||
return &rows.vr, true
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
// ValueReader the mechanism for implementing the BinaryDecoder interface.
|
||||
type ValueReader struct {
|
||||
mr *MsgReader
|
||||
mr *msgReader
|
||||
fd *FieldDescription
|
||||
valueBytesRemaining int32
|
||||
err error
|
||||
|
@ -43,7 +43,7 @@ func (r *ValueReader) ReadByte() byte {
|
|||
return 0
|
||||
}
|
||||
|
||||
return r.mr.ReadByte()
|
||||
return r.mr.readByte()
|
||||
}
|
||||
|
||||
func (r *ValueReader) ReadInt16() int16 {
|
||||
|
@ -57,7 +57,7 @@ func (r *ValueReader) ReadInt16() int16 {
|
|||
return 0
|
||||
}
|
||||
|
||||
return r.mr.ReadInt16()
|
||||
return r.mr.readInt16()
|
||||
}
|
||||
|
||||
func (r *ValueReader) ReadInt32() int32 {
|
||||
|
@ -71,7 +71,7 @@ func (r *ValueReader) ReadInt32() int32 {
|
|||
return 0
|
||||
}
|
||||
|
||||
return r.mr.ReadInt32()
|
||||
return r.mr.readInt32()
|
||||
}
|
||||
|
||||
func (r *ValueReader) ReadInt64() int64 {
|
||||
|
@ -85,7 +85,7 @@ func (r *ValueReader) ReadInt64() int64 {
|
|||
return 0
|
||||
}
|
||||
|
||||
return r.mr.ReadInt64()
|
||||
return r.mr.readInt64()
|
||||
}
|
||||
|
||||
func (r *ValueReader) ReadOid() Oid {
|
||||
|
@ -104,7 +104,7 @@ func (r *ValueReader) ReadString(count int32) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
return r.mr.ReadString(count)
|
||||
return r.mr.readString(count)
|
||||
}
|
||||
|
||||
// ReadBytes reads count bytes and returns as []byte
|
||||
|
@ -119,5 +119,5 @@ func (r *ValueReader) ReadBytes(count int32) []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
return r.mr.ReadBytes(count)
|
||||
return r.mr.readBytes(count)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue