From 78590be058f9a690bc12e66c19a152ddbe34c3f2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 26 Apr 2013 17:06:49 -0500 Subject: [PATCH] Made many things public so SelectFunc is actually usable by others Definitely, need to add higher level methods for other packages to use. May rehide some of these interfaces at that point. --- connection.go | 92 ++++++++++++++++++------------------- connection_select_column.go | 12 ++--- connection_select_value.go | 2 +- connection_test.go | 2 +- message_reader.go | 18 ++++---- messages.go | 16 +++---- 6 files changed, 71 insertions(+), 71 deletions(-) diff --git a/connection.go b/connection.go index 8712d968..b6e66558 100644 --- a/connection.go +++ b/connection.go @@ -61,7 +61,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { for { var t byte - var r *messageReader + var r *MessageReader if t, r, err = c.rxMsg(); err == nil { switch t { case backendKeyData: @@ -93,17 +93,17 @@ func (c *Connection) Close() (err error) { return } -func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) { +func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []FieldDescription) error) (err error) { if err = c.sendSimpleQuery(sql); err != nil { return } var callbackError error - var fields []fieldDescription + var fields []FieldDescription for { var t byte - var r *messageReader + var r *MessageReader if t, r, err = c.rxMsg(); err == nil { switch t { case readyForQuery: @@ -137,7 +137,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fie // pattern when accessing the map func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) { rows = make([]map[string]string, 0, 8) - onDataRow := func(r *messageReader, fields []fieldDescription) error { + onDataRow := func(r *MessageReader, fields []FieldDescription) error { rows = append(rows, c.rxDataRow(r, fields)) return nil } @@ -164,7 +164,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) { for { var t byte - var r *messageReader + var r *MessageReader if t, r, err = c.rxMsg(); err == nil { switch t { case readyForQuery: @@ -172,7 +172,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) { case rowDescription: case dataRow: case commandComplete: - commandTag = r.readString() + commandTag = r.ReadString() default: if err = c.processContextFreeMsg(t, r); err != nil { return @@ -189,7 +189,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) { // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages // is the same regardless of when they occur. -func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error) { +func (c *Connection) processContextFreeMsg(t byte, r *MessageReader) (err error) { switch t { case 'S': c.rxParameterStatus(r) @@ -206,7 +206,7 @@ func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error) } -func (c *Connection) rxMsg() (t byte, r *messageReader, err error) { +func (c *Connection) rxMsg() (t byte, r *MessageReader, err error) { var bodySize int32 t, bodySize, err = c.rxMsgHeader() if err != nil { @@ -239,14 +239,14 @@ func (c *Connection) rxMsgBody(bodySize int32) (buf []byte, err error) { return } -func (c *Connection) rxAuthenticationX(r *messageReader) (err error) { - code := r.readInt32() +func (c *Connection) rxAuthenticationX(r *MessageReader) (err error) { + code := r.ReadInt32() switch code { case 0: // AuthenticationOk case 3: // AuthenticationCleartextPassword c.txPasswordMessage(c.parameters.Password) case 5: // AuthenticationMD5Password - salt := r.readByteString(4) + salt := r.ReadByteString(4) digestedPassword := "md5" + hexMD5(hexMD5(c.parameters.Password+c.parameters.User)+salt) c.txPasswordMessage(digestedPassword) default: @@ -262,75 +262,75 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (c *Connection) rxParameterStatus(r *messageReader) { - key := r.readString() - value := r.readString() +func (c *Connection) rxParameterStatus(r *MessageReader) { + key := r.ReadString() + value := r.ReadString() c.runtimeParams[key] = value } -func (c *Connection) rxErrorResponse(r *messageReader) (err PgError) { +func (c *Connection) rxErrorResponse(r *MessageReader) (err PgError) { for { - switch r.readByte() { + switch r.ReadByte() { case 'S': - err.Severity = r.readString() + err.Severity = r.ReadString() case 'C': - err.Code = r.readString() + err.Code = r.ReadString() case 'M': - err.Message = r.readString() + err.Message = r.ReadString() case 0: // End of error message return default: // Ignore other error fields - r.readString() + r.ReadString() } } panic("Unreachable") } -func (c *Connection) rxBackendKeyData(r *messageReader) { - c.pid = r.readInt32() - c.secretKey = r.readInt32() +func (c *Connection) rxBackendKeyData(r *MessageReader) { + c.pid = r.ReadInt32() + c.secretKey = r.ReadInt32() } -func (c *Connection) rxReadyForQuery(r *messageReader) { - c.txStatus = r.readByte() +func (c *Connection) rxReadyForQuery(r *MessageReader) { + c.txStatus = r.ReadByte() } -func (c *Connection) rxRowDescription(r *messageReader) (fields []fieldDescription) { - fieldCount := r.readInt16() - fields = make([]fieldDescription, fieldCount) +func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescription) { + fieldCount := r.ReadInt16() + fields = make([]FieldDescription, fieldCount) for i := int16(0); i < fieldCount; i++ { f := &fields[i] - f.name = r.readString() - f.table = r.readOid() - f.attributeNumber = r.readInt16() - f.dataType = r.readOid() - f.dataTypeSize = r.readInt16() - f.modifier = r.readInt32() - f.formatCode = r.readInt16() + f.Name = r.ReadString() + f.Table = r.ReadOid() + f.AttributeNumber = r.ReadInt16() + f.DataType = r.ReadOid() + f.DataTypeSize = r.ReadInt16() + f.Modifier = r.ReadInt32() + f.FormatCode = r.ReadInt16() } return } -func (c *Connection) rxDataRow(r *messageReader, fields []fieldDescription) (row map[string]string) { - fieldCount := r.readInt16() +func (c *Connection) rxDataRow(r *MessageReader, fields []FieldDescription) (row map[string]string) { + fieldCount := r.ReadInt16() row = make(map[string]string, fieldCount) for i := int16(0); i < fieldCount; i++ { - size := r.readInt32() + size := r.ReadInt32() if size > -1 { - row[fields[i].name] = r.readByteString(size) + row[fields[i].Name] = r.ReadByteString(size) } } return } -func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool) { - r.readInt16() // ignore field count +func (c *Connection) rxDataRowFirstValue(r *MessageReader) (s string, null bool) { + r.ReadInt16() // ignore field count - size := r.readInt32() + size := r.ReadInt32() if size > -1 { - s = r.readByteString(size) + s = r.ReadByteString(size) } else { null = true } @@ -338,8 +338,8 @@ func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool) return } -func (c *Connection) rxCommandComplete(r *messageReader) string { - return r.readString() +func (c *Connection) rxCommandComplete(r *MessageReader) string { + return r.ReadString() } func (c *Connection) txStartupMessage(msg *startupMessage) (err error) { diff --git a/connection_select_column.go b/connection_select_column.go index 1f516ef0..7d6442fb 100644 --- a/connection_select_column.go +++ b/connection_select_column.go @@ -7,7 +7,7 @@ import ( func (c *Connection) SelectAllString(sql string) (strings []string, err error) { strings = make([]string, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) error { + onDataRow := func(r *MessageReader, _ []FieldDescription) error { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -21,7 +21,7 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) { func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { ints = make([]int64, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -37,7 +37,7 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { ints = make([]int32, 0, 8) - onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, fields []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -53,7 +53,7 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { ints = make([]int16, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -69,7 +69,7 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) { floats = make([]float64, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") @@ -85,7 +85,7 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) { floats = make([]float32, 0, 8) - onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { s, null := c.rxDataRowFirstValue(r) if null { return errors.New("Unexpected NULL") diff --git a/connection_select_value.go b/connection_select_value.go index 84a667d0..d6516715 100644 --- a/connection_select_value.go +++ b/connection_select_value.go @@ -6,7 +6,7 @@ import ( ) func (c *Connection) SelectString(sql string) (s string, err error) { - onDataRow := func(r *messageReader, _ []fieldDescription) error { + onDataRow := func(r *MessageReader, _ []FieldDescription) error { var null bool s, null = c.rxDataRowFirstValue(r) if null { diff --git a/connection_test.go b/connection_test.go index 3af6a6ed..50415157 100644 --- a/connection_test.go +++ b/connection_test.go @@ -130,7 +130,7 @@ func TestSelectFunc(t *testing.T) { conn := getSharedConnection() rowCount := 0 - onDataRow := func(r *messageReader, fields []fieldDescription) error { + onDataRow := func(r *MessageReader, fields []FieldDescription) error { rowCount++ return nil } diff --git a/message_reader.go b/message_reader.go index 0da93e1f..932e7ed1 100644 --- a/message_reader.go +++ b/message_reader.go @@ -5,38 +5,38 @@ import ( "encoding/binary" ) -type messageReader []byte +type MessageReader []byte -func newMessageReader(buf []byte) *messageReader { - r := messageReader(buf) +func newMessageReader(buf []byte) *MessageReader { + r := MessageReader(buf) return &r } -func (r *messageReader) readByte() byte { +func (r *MessageReader) ReadByte() byte { b := (*r)[0] *r = (*r)[1:] return b } -func (r *messageReader) readInt16() int16 { +func (r *MessageReader) ReadInt16() int16 { n := int16(binary.BigEndian.Uint16((*r)[:2])) *r = (*r)[2:] return n } -func (r *messageReader) readInt32() int32 { +func (r *MessageReader) ReadInt32() int32 { n := int32(binary.BigEndian.Uint32((*r)[:4])) *r = (*r)[4:] return n } -func (r *messageReader) readOid() oid { +func (r *MessageReader) ReadOid() oid { n := oid(binary.BigEndian.Uint32((*r)[:4])) *r = (*r)[4:] return n } -func (r *messageReader) readString() string { +func (r *MessageReader) ReadString() string { n := bytes.IndexByte(*r, 0) s := (*r)[:n] *r = (*r)[n+1:] @@ -44,7 +44,7 @@ func (r *messageReader) readString() string { } // Read count bytes and return as string -func (r *messageReader) readByteString(count int32) string { +func (r *MessageReader) ReadByteString(count int32) string { s := (*r)[:count] *r = (*r)[count:] return string(s) diff --git a/messages.go b/messages.go index 28b4b975..24ba294f 100644 --- a/messages.go +++ b/messages.go @@ -43,14 +43,14 @@ func (self *startupMessage) Bytes() (buf []byte) { type oid int32 -type fieldDescription struct { - name string - table oid - attributeNumber int16 - dataType oid - dataTypeSize int16 - modifier int32 - formatCode int16 +type FieldDescription struct { + Name string + Table oid + AttributeNumber int16 + DataType oid + DataTypeSize int16 + Modifier int32 + FormatCode int16 } type PgError struct {