Make MsgReader private

scan-io
Jack Christensen 2014-07-12 20:08:17 -05:00
parent f215c8bf5f
commit 4fbd76bee5
4 changed files with 84 additions and 84 deletions

82
conn.go
View File

@ -52,7 +52,7 @@ type Conn struct {
causeOfDeath error causeOfDeath error
logger log.Logger logger log.Logger
rows Rows rows Rows
mr MsgReader mr msgReader
} }
type PreparedStatement struct { type PreparedStatement struct {
@ -172,7 +172,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
for { for {
var t byte var t byte
var r *MsgReader var r *msgReader
t, r, err = c.rxMsg() t, r, err = c.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
@ -278,7 +278,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
for { for {
var t byte var t byte
var r *MsgReader var r *msgReader
t, r, err := c.rxMsg() t, r, err := c.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
@ -364,7 +364,7 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
} }
var t byte var t byte
var r *MsgReader var r *msgReader
if t, r, err = c.rxMsg(); err == nil { if t, r, err = c.rxMsg(); err == nil {
if err = c.processContextFreeMsg(t, r); err != nil { if err = c.processContextFreeMsg(t, r); err != nil {
return nil, err return nil, err
@ -544,7 +544,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
for { for {
var t byte var t byte
var r *MsgReader var r *msgReader
t, r, err = c.rxMsg() t, r, err = c.rxMsg()
if err != nil { if err != nil {
return commandTag, err return commandTag, err
@ -558,7 +558,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
case dataRow: case dataRow:
case bindComplete: case bindComplete:
case commandComplete: case commandComplete:
commandTag = CommandTag(r.ReadCString()) commandTag = CommandTag(r.readCString())
default: default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
softErr = e softErr = e
@ -570,7 +570,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
// Processes messages that are not exclusive to one context such as // Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages // authentication or query response. The response to these messages
// is the same regardless of when they occur. // is the same regardless of when they occur.
func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) { func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
switch t { switch t {
case 'S': case 'S':
c.rxParameterStatus(r) c.rxParameterStatus(r)
@ -587,7 +587,7 @@ func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) {
} }
} }
func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) { func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
if !c.alive { if !c.alive {
return 0, nil, ErrDeadConn return 0, nil, ErrDeadConn
} }
@ -600,13 +600,13 @@ func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) {
return t, &c.mr, err return t, &c.mr, err
} }
func (c *Conn) rxAuthenticationX(r *MsgReader) (err error) { func (c *Conn) rxAuthenticationX(r *msgReader) (err error) {
switch r.ReadInt32() { switch r.readInt32() {
case 0: // AuthenticationOk case 0: // AuthenticationOk
case 3: // AuthenticationCleartextPassword case 3: // AuthenticationCleartextPassword
err = c.txPasswordMessage(c.config.Password) err = c.txPasswordMessage(c.config.Password)
case 5: // AuthenticationMD5Password case 5: // AuthenticationMD5Password
salt := r.ReadString(4) salt := r.readString(4)
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt) digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt)
err = c.txPasswordMessage(digestedPassword) err = c.txPasswordMessage(digestedPassword)
default: default:
@ -622,72 +622,72 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil)) return hex.EncodeToString(hash.Sum(nil))
} }
func (c *Conn) rxParameterStatus(r *MsgReader) { func (c *Conn) rxParameterStatus(r *msgReader) {
key := r.ReadCString() key := r.readCString()
value := r.ReadCString() value := r.readCString()
c.RuntimeParams[key] = value c.RuntimeParams[key] = value
} }
func (c *Conn) rxErrorResponse(r *MsgReader) (err PgError) { func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) {
for { for {
switch r.ReadByte() { switch r.readByte() {
case 'S': case 'S':
err.Severity = r.ReadCString() err.Severity = r.readCString()
case 'C': case 'C':
err.Code = r.ReadCString() err.Code = r.readCString()
case 'M': case 'M':
err.Message = r.ReadCString() err.Message = r.readCString()
case 0: // End of error message case 0: // End of error message
if err.Severity == "FATAL" { if err.Severity == "FATAL" {
c.die(err) c.die(err)
} }
return return
default: // Ignore other error fields default: // Ignore other error fields
r.ReadCString() r.readCString()
} }
} }
} }
func (c *Conn) rxBackendKeyData(r *MsgReader) { func (c *Conn) rxBackendKeyData(r *msgReader) {
c.Pid = r.ReadInt32() c.Pid = r.readInt32()
c.SecretKey = r.ReadInt32() c.SecretKey = r.readInt32()
} }
func (c *Conn) rxReadyForQuery(r *MsgReader) { func (c *Conn) rxReadyForQuery(r *msgReader) {
c.TxStatus = r.ReadByte() c.TxStatus = r.readByte()
} }
func (c *Conn) rxRowDescription(r *MsgReader) (fields []FieldDescription) { func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) {
fieldCount := r.ReadInt16() fieldCount := r.readInt16()
fields = make([]FieldDescription, fieldCount) fields = make([]FieldDescription, fieldCount)
for i := int16(0); i < fieldCount; i++ { for i := int16(0); i < fieldCount; i++ {
f := &fields[i] f := &fields[i]
f.Name = r.ReadCString() f.Name = r.readCString()
f.Table = r.ReadOid() f.Table = r.readOid()
f.AttributeNumber = r.ReadInt16() f.AttributeNumber = r.readInt16()
f.DataType = r.ReadOid() f.DataType = r.readOid()
f.DataTypeSize = r.ReadInt16() f.DataTypeSize = r.readInt16()
f.Modifier = r.ReadInt32() f.Modifier = r.readInt32()
f.FormatCode = r.ReadInt16() f.FormatCode = r.readInt16()
} }
return return
} }
func (c *Conn) rxParameterDescription(r *MsgReader) (parameters []Oid) { func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) {
parameterCount := r.ReadInt16() parameterCount := r.readInt16()
parameters = make([]Oid, 0, parameterCount) parameters = make([]Oid, 0, parameterCount)
for i := int16(0); i < parameterCount; i++ { for i := int16(0); i < parameterCount; i++ {
parameters = append(parameters, r.ReadOid()) parameters = append(parameters, r.readOid())
} }
return return
} }
func (c *Conn) rxNotificationResponse(r *MsgReader) { func (c *Conn) rxNotificationResponse(r *msgReader) {
n := new(Notification) n := new(Notification)
n.Pid = r.ReadInt32() n.Pid = r.readInt32()
n.Channel = r.ReadCString() n.Channel = r.readCString()
n.Payload = r.ReadCString() n.Payload = r.readCString()
c.notifications = append(c.notifications, n) c.notifications = append(c.notifications, n)
} }

View File

@ -8,26 +8,26 @@ import (
"io/ioutil" "io/ioutil"
) )
// MsgReader is a helper that reads values from a PostgreSQL message. // msgReader is a helper that reads values from a PostgreSQL message.
type MsgReader struct { type msgReader struct {
reader *bufio.Reader reader *bufio.Reader
buf [128]byte buf [128]byte
msgBytesRemaining int32 msgBytesRemaining int32
err error err error
} }
// Err returns any error that the MsgReader has experienced // Err returns any error that the msgReader has experienced
func (r *MsgReader) Err() error { func (r *msgReader) Err() error {
return r.err return r.err
} }
// Fatal tells r that a Fatal error has occurred // fatal tells r that a Fatal error has occurred
func (r *MsgReader) Fatal(err error) { func (r *msgReader) fatal(err error) {
r.err = err r.err = err
} }
// rxMsg reads the type and size of the next message. // rxMsg reads the type and size of the next message.
func (r *MsgReader) rxMsg() (t byte, err error) { func (r *msgReader) rxMsg() (t byte, err error) {
if r.err != nil { if r.err != nil {
return 0, err return 0, err
} }
@ -43,123 +43,123 @@ func (r *MsgReader) rxMsg() (t byte, err error) {
return t, err return t, err
} }
func (r *MsgReader) ReadByte() byte { func (r *msgReader) readByte() byte {
if r.err != nil { if r.err != nil {
return 0 return 0
} }
r.msgBytesRemaining -= 1 r.msgBytesRemaining -= 1
if r.msgBytesRemaining < 0 { if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b, err := r.reader.ReadByte() b, err := r.reader.ReadByte()
if err != nil { if err != nil {
r.Fatal(err) r.fatal(err)
return 0 return 0
} }
return b return b
} }
func (r *MsgReader) ReadInt16() int16 { func (r *msgReader) readInt16() int16 {
if r.err != nil { if r.err != nil {
return 0 return 0
} }
r.msgBytesRemaining -= 2 r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 { if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b := r.buf[0:2] b := r.buf[0:2]
_, err := io.ReadFull(r.reader, b) _, err := io.ReadFull(r.reader, b)
if err != nil { if err != nil {
r.Fatal(err) r.fatal(err)
return 0 return 0
} }
return int16(binary.BigEndian.Uint16(b)) return int16(binary.BigEndian.Uint16(b))
} }
func (r *MsgReader) ReadInt32() int32 { func (r *msgReader) readInt32() int32 {
if r.err != nil { if r.err != nil {
return 0 return 0
} }
r.msgBytesRemaining -= 4 r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 { if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b := r.buf[0:4] b := r.buf[0:4]
_, err := io.ReadFull(r.reader, b) _, err := io.ReadFull(r.reader, b)
if err != nil { if err != nil {
r.Fatal(err) r.fatal(err)
return 0 return 0
} }
return int32(binary.BigEndian.Uint32(b)) return int32(binary.BigEndian.Uint32(b))
} }
func (r *MsgReader) ReadInt64() int64 { func (r *msgReader) readInt64() int64 {
if r.err != nil { if r.err != nil {
return 0 return 0
} }
r.msgBytesRemaining -= 8 r.msgBytesRemaining -= 8
if r.msgBytesRemaining < 0 { if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b := r.buf[0:8] b := r.buf[0:8]
_, err := io.ReadFull(r.reader, b) _, err := io.ReadFull(r.reader, b)
if err != nil { if err != nil {
r.Fatal(err) r.fatal(err)
return 0 return 0
} }
return int64(binary.BigEndian.Uint64(b)) return int64(binary.BigEndian.Uint64(b))
} }
func (r *MsgReader) ReadOid() Oid { func (r *msgReader) readOid() Oid {
return Oid(r.ReadInt32()) return Oid(r.readInt32())
} }
// ReadCString reads a null terminated string // readCString reads a null terminated string
func (r *MsgReader) ReadCString() string { func (r *msgReader) readCString() string {
if r.err != nil { if r.err != nil {
return "" return ""
} }
b, err := r.reader.ReadBytes(0) b, err := r.reader.ReadBytes(0)
if err != nil { if err != nil {
r.Fatal(err) r.fatal(err)
return "" return ""
} }
r.msgBytesRemaining -= int32(len(b)) r.msgBytesRemaining -= int32(len(b))
if r.msgBytesRemaining < 0 { if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return "" return ""
} }
return string(b[0 : len(b)-1]) return string(b[0 : len(b)-1])
} }
// ReadString reads count bytes and returns as string // readString reads count bytes and returns as string
func (r *MsgReader) ReadString(count int32) string { func (r *msgReader) readString(count int32) string {
if r.err != nil { if r.err != nil {
return "" return ""
} }
r.msgBytesRemaining -= count r.msgBytesRemaining -= count
if r.msgBytesRemaining < 0 { if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return "" return ""
} }
@ -172,22 +172,22 @@ func (r *MsgReader) ReadString(count int32) string {
_, err := io.ReadFull(r.reader, b) _, err := io.ReadFull(r.reader, b)
if err != nil { if err != nil {
r.Fatal(err) r.fatal(err)
return "" return ""
} }
return string(b) return string(b)
} }
// ReadBytes reads count bytes and returns as []byte // readBytes reads count bytes and returns as []byte
func (r *MsgReader) ReadBytes(count int32) []byte { func (r *msgReader) readBytes(count int32) []byte {
if r.err != nil { if r.err != nil {
return nil return nil
} }
r.msgBytesRemaining -= count r.msgBytesRemaining -= count
if r.msgBytesRemaining < 0 { if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return nil return nil
} }
@ -195,7 +195,7 @@ func (r *MsgReader) ReadBytes(count int32) []byte {
_, err := io.ReadFull(r.reader, b) _, err := io.ReadFull(r.reader, b)
if err != nil { if err != nil {
r.Fatal(err) r.fatal(err)
return nil return nil
} }

View File

@ -31,7 +31,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
type Rows struct { type Rows struct {
pool *ConnPool pool *ConnPool
conn *Conn conn *Conn
mr *MsgReader mr *msgReader
fields []FieldDescription fields []FieldDescription
vr ValueReader vr ValueReader
rowCount int rowCount int
@ -134,7 +134,7 @@ func (rows *Rows) Next() bool {
rows.close() rows.close()
return false return false
case dataRow: case dataRow:
fieldCount := r.ReadInt16() fieldCount := r.readInt16()
if int(fieldCount) != len(rows.fields) { if int(fieldCount) != len(rows.fields) {
rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount))) rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount)))
return false return false
@ -165,7 +165,7 @@ func (rows *Rows) nextColumn() (*ValueReader, bool) {
fd := &rows.fields[rows.columnIdx] fd := &rows.fields[rows.columnIdx]
rows.columnIdx++ rows.columnIdx++
size := rows.mr.ReadInt32() size := rows.mr.readInt32()
rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size} rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size}
return &rows.vr, true return &rows.vr, true
} }

View File

@ -6,7 +6,7 @@ import (
// ValueReader the mechanism for implementing the BinaryDecoder interface. // ValueReader the mechanism for implementing the BinaryDecoder interface.
type ValueReader struct { type ValueReader struct {
mr *MsgReader mr *msgReader
fd *FieldDescription fd *FieldDescription
valueBytesRemaining int32 valueBytesRemaining int32
err error err error
@ -43,7 +43,7 @@ func (r *ValueReader) ReadByte() byte {
return 0 return 0
} }
return r.mr.ReadByte() return r.mr.readByte()
} }
func (r *ValueReader) ReadInt16() int16 { func (r *ValueReader) ReadInt16() int16 {
@ -57,7 +57,7 @@ func (r *ValueReader) ReadInt16() int16 {
return 0 return 0
} }
return r.mr.ReadInt16() return r.mr.readInt16()
} }
func (r *ValueReader) ReadInt32() int32 { func (r *ValueReader) ReadInt32() int32 {
@ -71,7 +71,7 @@ func (r *ValueReader) ReadInt32() int32 {
return 0 return 0
} }
return r.mr.ReadInt32() return r.mr.readInt32()
} }
func (r *ValueReader) ReadInt64() int64 { func (r *ValueReader) ReadInt64() int64 {
@ -85,7 +85,7 @@ func (r *ValueReader) ReadInt64() int64 {
return 0 return 0
} }
return r.mr.ReadInt64() return r.mr.readInt64()
} }
func (r *ValueReader) ReadOid() Oid { func (r *ValueReader) ReadOid() Oid {
@ -104,7 +104,7 @@ func (r *ValueReader) ReadString(count int32) string {
return "" return ""
} }
return r.mr.ReadString(count) return r.mr.readString(count)
} }
// ReadBytes reads count bytes and returns as []byte // ReadBytes reads count bytes and returns as []byte
@ -119,5 +119,5 @@ func (r *ValueReader) ReadBytes(count int32) []byte {
return nil return nil
} }
return r.mr.ReadBytes(count) return r.mr.readBytes(count)
} }