diff --git a/conn.go b/conn.go index e41d4869..523a0fef 100644 --- a/conn.go +++ b/conn.go @@ -52,7 +52,7 @@ type Conn struct { causeOfDeath error logger log.Logger rows Rows - mr MsgReader + mr msgReader } type PreparedStatement struct { @@ -172,7 +172,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { for { var t byte - var r *MsgReader + var r *msgReader t, r, err = c.rxMsg() if err != nil { return nil, err @@ -278,7 +278,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { for { var t byte - var r *MsgReader + var r *msgReader t, r, err := c.rxMsg() if err != nil { return nil, err @@ -364,7 +364,7 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) } var t byte - var r *MsgReader + var r *msgReader if t, r, err = c.rxMsg(); err == nil { if err = c.processContextFreeMsg(t, r); err != nil { return nil, err @@ -544,7 +544,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag for { var t byte - var r *MsgReader + var r *msgReader t, r, err = c.rxMsg() if err != nil { return commandTag, err @@ -558,7 +558,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag case dataRow: case bindComplete: case commandComplete: - commandTag = CommandTag(r.ReadCString()) + commandTag = CommandTag(r.readCString()) default: if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { softErr = e @@ -570,7 +570,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages // is the same regardless of when they occur. -func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) { +func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { switch t { case 'S': c.rxParameterStatus(r) @@ -587,7 +587,7 @@ func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) { } } -func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) { +func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { if !c.alive { return 0, nil, ErrDeadConn } @@ -600,13 +600,13 @@ func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) { return t, &c.mr, err } -func (c *Conn) rxAuthenticationX(r *MsgReader) (err error) { - switch r.ReadInt32() { +func (c *Conn) rxAuthenticationX(r *msgReader) (err error) { + switch r.readInt32() { case 0: // AuthenticationOk case 3: // AuthenticationCleartextPassword err = c.txPasswordMessage(c.config.Password) case 5: // AuthenticationMD5Password - salt := r.ReadString(4) + salt := r.readString(4) digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt) err = c.txPasswordMessage(digestedPassword) default: @@ -622,72 +622,72 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (c *Conn) rxParameterStatus(r *MsgReader) { - key := r.ReadCString() - value := r.ReadCString() +func (c *Conn) rxParameterStatus(r *msgReader) { + key := r.readCString() + value := r.readCString() c.RuntimeParams[key] = value } -func (c *Conn) rxErrorResponse(r *MsgReader) (err PgError) { +func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) { for { - switch r.ReadByte() { + switch r.readByte() { case 'S': - err.Severity = r.ReadCString() + err.Severity = r.readCString() case 'C': - err.Code = r.ReadCString() + err.Code = r.readCString() case 'M': - err.Message = r.ReadCString() + err.Message = r.readCString() case 0: // End of error message if err.Severity == "FATAL" { c.die(err) } return default: // Ignore other error fields - r.ReadCString() + r.readCString() } } } -func (c *Conn) rxBackendKeyData(r *MsgReader) { - c.Pid = r.ReadInt32() - c.SecretKey = r.ReadInt32() +func (c *Conn) rxBackendKeyData(r *msgReader) { + c.Pid = r.readInt32() + c.SecretKey = r.readInt32() } -func (c *Conn) rxReadyForQuery(r *MsgReader) { - c.TxStatus = r.ReadByte() +func (c *Conn) rxReadyForQuery(r *msgReader) { + c.TxStatus = r.readByte() } -func (c *Conn) rxRowDescription(r *MsgReader) (fields []FieldDescription) { - fieldCount := r.ReadInt16() +func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { + fieldCount := r.readInt16() fields = make([]FieldDescription, fieldCount) for i := int16(0); i < fieldCount; i++ { f := &fields[i] - f.Name = r.ReadCString() - f.Table = r.ReadOid() - f.AttributeNumber = r.ReadInt16() - f.DataType = r.ReadOid() - f.DataTypeSize = r.ReadInt16() - f.Modifier = r.ReadInt32() - f.FormatCode = r.ReadInt16() + f.Name = r.readCString() + f.Table = r.readOid() + f.AttributeNumber = r.readInt16() + f.DataType = r.readOid() + f.DataTypeSize = r.readInt16() + f.Modifier = r.readInt32() + f.FormatCode = r.readInt16() } return } -func (c *Conn) rxParameterDescription(r *MsgReader) (parameters []Oid) { - parameterCount := r.ReadInt16() +func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) { + parameterCount := r.readInt16() parameters = make([]Oid, 0, parameterCount) for i := int16(0); i < parameterCount; i++ { - parameters = append(parameters, r.ReadOid()) + parameters = append(parameters, r.readOid()) } return } -func (c *Conn) rxNotificationResponse(r *MsgReader) { +func (c *Conn) rxNotificationResponse(r *msgReader) { n := new(Notification) - n.Pid = r.ReadInt32() - n.Channel = r.ReadCString() - n.Payload = r.ReadCString() + n.Pid = r.readInt32() + n.Channel = r.readCString() + n.Payload = r.readCString() c.notifications = append(c.notifications, n) } diff --git a/msg_reader.go b/msg_reader.go index 32baac75..c958253c 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -8,26 +8,26 @@ import ( "io/ioutil" ) -// MsgReader is a helper that reads values from a PostgreSQL message. -type MsgReader struct { +// msgReader is a helper that reads values from a PostgreSQL message. +type msgReader struct { reader *bufio.Reader buf [128]byte msgBytesRemaining int32 err error } -// Err returns any error that the MsgReader has experienced -func (r *MsgReader) Err() error { +// Err returns any error that the msgReader has experienced +func (r *msgReader) Err() error { return r.err } -// Fatal tells r that a Fatal error has occurred -func (r *MsgReader) Fatal(err error) { +// fatal tells r that a Fatal error has occurred +func (r *msgReader) fatal(err error) { r.err = err } // rxMsg reads the type and size of the next message. -func (r *MsgReader) rxMsg() (t byte, err error) { +func (r *msgReader) rxMsg() (t byte, err error) { if r.err != nil { return 0, err } @@ -43,123 +43,123 @@ func (r *MsgReader) rxMsg() (t byte, err error) { return t, err } -func (r *MsgReader) ReadByte() byte { +func (r *msgReader) readByte() byte { if r.err != nil { return 0 } r.msgBytesRemaining -= 1 if r.msgBytesRemaining < 0 { - r.Fatal(errors.New("read past end of message")) + r.fatal(errors.New("read past end of message")) return 0 } b, err := r.reader.ReadByte() if err != nil { - r.Fatal(err) + r.fatal(err) return 0 } return b } -func (r *MsgReader) ReadInt16() int16 { +func (r *msgReader) readInt16() int16 { if r.err != nil { return 0 } r.msgBytesRemaining -= 2 if r.msgBytesRemaining < 0 { - r.Fatal(errors.New("read past end of message")) + r.fatal(errors.New("read past end of message")) return 0 } b := r.buf[0:2] _, err := io.ReadFull(r.reader, b) if err != nil { - r.Fatal(err) + r.fatal(err) return 0 } return int16(binary.BigEndian.Uint16(b)) } -func (r *MsgReader) ReadInt32() int32 { +func (r *msgReader) readInt32() int32 { if r.err != nil { return 0 } r.msgBytesRemaining -= 4 if r.msgBytesRemaining < 0 { - r.Fatal(errors.New("read past end of message")) + r.fatal(errors.New("read past end of message")) return 0 } b := r.buf[0:4] _, err := io.ReadFull(r.reader, b) if err != nil { - r.Fatal(err) + r.fatal(err) return 0 } return int32(binary.BigEndian.Uint32(b)) } -func (r *MsgReader) ReadInt64() int64 { +func (r *msgReader) readInt64() int64 { if r.err != nil { return 0 } r.msgBytesRemaining -= 8 if r.msgBytesRemaining < 0 { - r.Fatal(errors.New("read past end of message")) + r.fatal(errors.New("read past end of message")) return 0 } b := r.buf[0:8] _, err := io.ReadFull(r.reader, b) if err != nil { - r.Fatal(err) + r.fatal(err) return 0 } return int64(binary.BigEndian.Uint64(b)) } -func (r *MsgReader) ReadOid() Oid { - return Oid(r.ReadInt32()) +func (r *msgReader) readOid() Oid { + return Oid(r.readInt32()) } -// ReadCString reads a null terminated string -func (r *MsgReader) ReadCString() string { +// readCString reads a null terminated string +func (r *msgReader) readCString() string { if r.err != nil { return "" } b, err := r.reader.ReadBytes(0) if err != nil { - r.Fatal(err) + r.fatal(err) return "" } r.msgBytesRemaining -= int32(len(b)) if r.msgBytesRemaining < 0 { - r.Fatal(errors.New("read past end of message")) + r.fatal(errors.New("read past end of message")) return "" } return string(b[0 : len(b)-1]) } -// ReadString reads count bytes and returns as string -func (r *MsgReader) ReadString(count int32) string { +// readString reads count bytes and returns as string +func (r *msgReader) readString(count int32) string { if r.err != nil { return "" } r.msgBytesRemaining -= count if r.msgBytesRemaining < 0 { - r.Fatal(errors.New("read past end of message")) + r.fatal(errors.New("read past end of message")) return "" } @@ -172,22 +172,22 @@ func (r *MsgReader) ReadString(count int32) string { _, err := io.ReadFull(r.reader, b) if err != nil { - r.Fatal(err) + r.fatal(err) return "" } return string(b) } -// ReadBytes reads count bytes and returns as []byte -func (r *MsgReader) ReadBytes(count int32) []byte { +// readBytes reads count bytes and returns as []byte +func (r *msgReader) readBytes(count int32) []byte { if r.err != nil { return nil } r.msgBytesRemaining -= count if r.msgBytesRemaining < 0 { - r.Fatal(errors.New("read past end of message")) + r.fatal(errors.New("read past end of message")) return nil } @@ -195,7 +195,7 @@ func (r *MsgReader) ReadBytes(count int32) []byte { _, err := io.ReadFull(r.reader, b) if err != nil { - r.Fatal(err) + r.fatal(err) return nil } diff --git a/query.go b/query.go index dfdb3b2d..79f9bb8a 100644 --- a/query.go +++ b/query.go @@ -31,7 +31,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) { type Rows struct { pool *ConnPool conn *Conn - mr *MsgReader + mr *msgReader fields []FieldDescription vr ValueReader rowCount int @@ -134,7 +134,7 @@ func (rows *Rows) Next() bool { rows.close() return false case dataRow: - fieldCount := r.ReadInt16() + fieldCount := r.readInt16() if int(fieldCount) != len(rows.fields) { rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount))) return false @@ -165,7 +165,7 @@ func (rows *Rows) nextColumn() (*ValueReader, bool) { fd := &rows.fields[rows.columnIdx] rows.columnIdx++ - size := rows.mr.ReadInt32() + size := rows.mr.readInt32() rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size} return &rows.vr, true } diff --git a/value_reader.go b/value_reader.go index 6e8f65e0..750caa7a 100644 --- a/value_reader.go +++ b/value_reader.go @@ -6,7 +6,7 @@ import ( // ValueReader the mechanism for implementing the BinaryDecoder interface. type ValueReader struct { - mr *MsgReader + mr *msgReader fd *FieldDescription valueBytesRemaining int32 err error @@ -43,7 +43,7 @@ func (r *ValueReader) ReadByte() byte { return 0 } - return r.mr.ReadByte() + return r.mr.readByte() } func (r *ValueReader) ReadInt16() int16 { @@ -57,7 +57,7 @@ func (r *ValueReader) ReadInt16() int16 { return 0 } - return r.mr.ReadInt16() + return r.mr.readInt16() } func (r *ValueReader) ReadInt32() int32 { @@ -71,7 +71,7 @@ func (r *ValueReader) ReadInt32() int32 { return 0 } - return r.mr.ReadInt32() + return r.mr.readInt32() } func (r *ValueReader) ReadInt64() int64 { @@ -85,7 +85,7 @@ func (r *ValueReader) ReadInt64() int64 { return 0 } - return r.mr.ReadInt64() + return r.mr.readInt64() } func (r *ValueReader) ReadOid() Oid { @@ -104,7 +104,7 @@ func (r *ValueReader) ReadString(count int32) string { return "" } - return r.mr.ReadString(count) + return r.mr.readString(count) } // ReadBytes reads count bytes and returns as []byte @@ -119,5 +119,5 @@ func (r *ValueReader) ReadBytes(count int32) []byte { return nil } - return r.mr.ReadBytes(count) + return r.mr.readBytes(count) }