diff --git a/connection.go b/connection.go index 8712d968..b6e66558 100644 --- a/connection.go +++ b/connection.go @@ -61,7 +61,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { for { var t byte - var r *messageReader + var r *MessageReader if t, r, err = c.rxMsg(); err == nil { switch t { case backendKeyData: @@ -93,17 +93,17 @@ func (c *Connection) Close() (err error) { return } -func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) { +func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []FieldDescription) error) (err error) { if err = c.sendSimpleQuery(sql); err != nil { return } var callbackError error - var fields []fieldDescription + var fields []FieldDescription for { var t byte - var r *messageReader + var r *MessageReader if t, r, err = c.rxMsg(); err == nil { switch t { case readyForQuery: @@ -137,7 +137,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fie // pattern when accessing the map func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) { rows = make([]map[string]string, 0, 8) - onDataRow := func(r *messageReader, fields []fieldDescription) error { + onDataRow := func(r *MessageReader, fields []FieldDescription) error { rows = append(rows, c.rxDataRow(r, fields)) return nil } @@ -164,7 +164,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) { for { var t byte - var r *messageReader + var r *MessageReader if t, r, err = c.rxMsg(); err == nil { switch t { case readyForQuery: @@ -172,7 +172,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) { case rowDescription: case dataRow: case commandComplete: - commandTag = r.readString() + commandTag = r.ReadString() default: if err = c.processContextFreeMsg(t, r); err != nil { return @@ -189,7 +189,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) { // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages // is the same regardless of when they occur. -func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error) { +func (c *Connection) processContextFreeMsg(t byte, r *MessageReader) (err error) { switch t { case 'S': c.rxParameterStatus(r) @@ -206,7 +206,7 @@ func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error) } -func (c *Connection) rxMsg() (t byte, r *messageReader, err error) { +func (c *Connection) rxMsg() (t byte, r *MessageReader, err error) { var bodySize int32 t, bodySize, err = c.rxMsgHeader() if err != nil { @@ -239,14 +239,14 @@ func (c *Connection) rxMsgBody(bodySize int32) (buf []byte, err error) { return } -func (c *Connection) rxAuthenticationX(r *messageReader) (err error) { - code := r.readInt32() +func (c *Connection) rxAuthenticationX(r *MessageReader) (err error) { + code := r.ReadInt32() switch code { case 0: // AuthenticationOk case 3: // AuthenticationCleartextPassword c.txPasswordMessage(c.parameters.Password) case 5: // AuthenticationMD5Password - salt := r.readByteString(4) + salt := r.ReadByteString(4) digestedPassword := "md5" + hexMD5(hexMD5(c.parameters.Password+c.parameters.User)+salt) c.txPasswordMessage(digestedPassword) default: @@ -262,75 +262,75 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (c *Connection) rxParameterStatus(r *messageReader) { - key := r.readString() - value := r.readString() +func (c *Connection) rxParameterStatus(r *MessageReader) { + key := r.ReadString() + value := r.ReadString() c.runtimeParams[key] = value } -func (c *Connection) rxErrorResponse(r *messageReader) (err PgError) { +func (c *Connection) rxErrorResponse(r *MessageReader) (err PgError) { for { - switch r.readByte() { + switch r.ReadByte() { case 'S': - err.Severity = r.readString() + err.Severity = r.ReadString() case 'C': - err.Code = r.readString() + err.Code = r.ReadString() case 'M': - err.Message = r.readString() + err.Message = r.ReadString() case 0: // End of error message return default: // Ignore other error fields - r.readString() + r.ReadString() } } panic("Unreachable") } -func (c *Connection) rxBackendKeyData(r *messageReader) { - c.pid = r.readInt32() - c.secretKey = r.readInt32() +func (c *Connection) rxBackendKeyData(r *MessageReader) { + c.pid = r.ReadInt32() + c.secretKey = r.ReadInt32() } -func (c *Connection) rxReadyForQuery(r *messageReader) { - c.txStatus = r.readByte() +func (c *Connection) rxReadyForQuery(r *MessageReader) { + c.txStatus = r.ReadByte() } -func (c *Connection) rxRowDescription(r *messageReader) (fields []fieldDescription) { - fieldCount := r.readInt16() - fields = make([]fieldDescription, fieldCount) +func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescription) { + fieldCount := r.ReadInt16() + fields = make([]FieldDescription, fieldCount) for i := int16(0); i < fieldCount; i++ { f := &fields[i] - f.name = r.readString() - f.table = r.readOid() - f.attributeNumber = r.readInt16() - f.dataType = r.readOid() - f.dataTypeSize = r.readInt16() - f.modifier = r.readInt32() - f.formatCode = r.readInt16() + f.Name = r.ReadString() + f.Table = r.ReadOid() + f.AttributeNumber = r.ReadInt16() + f.DataType = r.ReadOid() + f.DataTypeSize = r.ReadInt16() + f.Modifier = r.ReadInt32() + f.FormatCode = r.ReadInt16() } return } -func (c *Connection) rxDataRow(r *messageReader, fields []fieldDescription) (row map[string]string) { - fieldCount := r.readInt16() +func (c *Connection) rxDataRow(r *MessageReader, fields []FieldDescription) (row map[string]string) { + fieldCount := r.ReadInt16() row = make(map[string]string, fieldCount) for i := int16(0); i < fieldCount; i++ { - size := r.readInt32() + size := r.ReadInt32() if size > -1 { - row[fields[i].name] = r.readByteString(size) + row[fields[i].Name] = r.ReadByteString(size) } } return } -func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool) { - r.readInt16() // ignore field count +func (c *Connection) rxDataRowFirstValue(r *MessageReader) (s string, null bool) { + r.ReadInt16() // ignore field count - size := r.readInt32() + size := r.ReadInt32() if size > -1 { - s = r.readByteString(size) + s = r.ReadByteString(size) } else { null = true } @@ -338,8 +338,8 @@ func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool) return } -func (c *Connection) rxCommandComplete(r *messageReader) string { - return r.readString() +func (c *Connection) rxCommandComplete(r *MessageReader) string { + return r.ReadString() } func (c *Connection) txStartupMessage(msg *startupMessage) (err error) { diff --git a/connection_select_column.go b/connection_select_column.go index 1f516ef0..7d6442fb 100644 --- a/connection_select_column.go +++ b/connection_select_column.go @@ -7,7 +7,7 @@ import ( func (c *Connection) SelectAllString(sql string) (strings []string, err error) { strings = make([]string, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) error { + onDataRow := func(r *MessageReader, _ []FieldDescription) error { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -21,7 +21,7 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) { func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { ints = make([]int64, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -37,7 +37,7 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { ints = make([]int32, 0, 8) - onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, fields []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -53,7 +53,7 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { ints = make([]int16, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -69,7 +69,7 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) { floats = make([]float64, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -85,7 +85,7 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) { floats = make([]float32, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") diff --git a/connection_select_value.go b/connection_select_value.go index 84a667d0..d6516715 100644 --- a/connection_select_value.go +++ b/connection_select_value.go @@ -6,7 +6,7 @@ import ( ) func (c *Connection) SelectString(sql string) (s string, err error) { - onDataRow := func(r *messageReader, _ []fieldDescription) error { + onDataRow := func(r *MessageReader, _ []FieldDescription) error { var null bool s, null = c.rxDataRowFirstValue(r) if null { diff --git a/connection_test.go b/connection_test.go index 3af6a6ed..50415157 100644 --- a/connection_test.go +++ b/connection_test.go @@ -130,7 +130,7 @@ func TestSelectFunc(t *testing.T) { conn := getSharedConnection() rowCount := 0 - onDataRow := func(r *messageReader, fields []fieldDescription) error { + onDataRow := func(r *MessageReader, fields []FieldDescription) error { rowCount++ return nil } diff --git a/message_reader.go b/message_reader.go index 0da93e1f..932e7ed1 100644 --- a/message_reader.go +++ b/message_reader.go @@ -5,38 +5,38 @@ import ( "encoding/binary" ) -type messageReader []byte +type MessageReader []byte -func newMessageReader(buf []byte) *messageReader { - r := messageReader(buf) +func newMessageReader(buf []byte) *MessageReader { + r := MessageReader(buf) return &r } -func (r *messageReader) readByte() byte { +func (r *MessageReader) ReadByte() byte { b := (*r)[0] *r = (*r)[1:] return b } -func (r *messageReader) readInt16() int16 { +func (r *MessageReader) ReadInt16() int16 { n := int16(binary.BigEndian.Uint16((*r)[:2])) *r = (*r)[2:] return n } -func (r *messageReader) readInt32() int32 { +func (r *MessageReader) ReadInt32() int32 { n := int32(binary.BigEndian.Uint32((*r)[:4])) *r = (*r)[4:] return n } -func (r *messageReader) readOid() oid { +func (r *MessageReader) ReadOid() oid { n := oid(binary.BigEndian.Uint32((*r)[:4])) *r = (*r)[4:] return n } -func (r *messageReader) readString() string { +func (r *MessageReader) ReadString() string { n := bytes.IndexByte(*r, 0) s := (*r)[:n] *r = (*r)[n+1:] @@ -44,7 +44,7 @@ func (r *messageReader) readString() string { } // Read count bytes and return as string -func (r *messageReader) readByteString(count int32) string { +func (r *MessageReader) ReadByteString(count int32) string { s := (*r)[:count] *r = (*r)[count:] return string(s) diff --git a/messages.go b/messages.go index 28b4b975..24ba294f 100644 --- a/messages.go +++ b/messages.go @@ -43,14 +43,14 @@ func (self *startupMessage) Bytes() (buf []byte) { type oid int32 -type fieldDescription struct { - name string - table oid - attributeNumber int16 - dataType oid - dataTypeSize int16 - modifier int32 - formatCode int16 +type FieldDescription struct { + Name string + Table oid + AttributeNumber int16 + DataType oid + DataTypeSize int16 + Modifier int32 + FormatCode int16 } type PgError struct {