Made many things public so SelectFunc is actually usable by others

Definitely, need to add higher level methods for other packages to
use. May rehide some of these interfaces at that point.
pgx-vs-pq
Jack Christensen 2013-04-26 17:06:49 -05:00
parent 19d4a4d577
commit 78590be058
6 changed files with 71 additions and 71 deletions

View File

@ -61,7 +61,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) {
for { for {
var t byte var t byte
var r *messageReader var r *MessageReader
if t, r, err = c.rxMsg(); err == nil { if t, r, err = c.rxMsg(); err == nil {
switch t { switch t {
case backendKeyData: case backendKeyData:
@ -93,17 +93,17 @@ func (c *Connection) Close() (err error) {
return return
} }
func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) { func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []FieldDescription) error) (err error) {
if err = c.sendSimpleQuery(sql); err != nil { if err = c.sendSimpleQuery(sql); err != nil {
return return
} }
var callbackError error var callbackError error
var fields []fieldDescription var fields []FieldDescription
for { for {
var t byte var t byte
var r *messageReader var r *MessageReader
if t, r, err = c.rxMsg(); err == nil { if t, r, err = c.rxMsg(); err == nil {
switch t { switch t {
case readyForQuery: case readyForQuery:
@ -137,7 +137,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fie
// pattern when accessing the map // pattern when accessing the map
func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) { func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) {
rows = make([]map[string]string, 0, 8) rows = make([]map[string]string, 0, 8)
onDataRow := func(r *messageReader, fields []fieldDescription) error { onDataRow := func(r *MessageReader, fields []FieldDescription) error {
rows = append(rows, c.rxDataRow(r, fields)) rows = append(rows, c.rxDataRow(r, fields))
return nil return nil
} }
@ -164,7 +164,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) {
for { for {
var t byte var t byte
var r *messageReader var r *MessageReader
if t, r, err = c.rxMsg(); err == nil { if t, r, err = c.rxMsg(); err == nil {
switch t { switch t {
case readyForQuery: case readyForQuery:
@ -172,7 +172,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) {
case rowDescription: case rowDescription:
case dataRow: case dataRow:
case commandComplete: case commandComplete:
commandTag = r.readString() commandTag = r.ReadString()
default: default:
if err = c.processContextFreeMsg(t, r); err != nil { if err = c.processContextFreeMsg(t, r); err != nil {
return return
@ -189,7 +189,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) {
// Processes messages that are not exclusive to one context such as // Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages // authentication or query response. The response to these messages
// is the same regardless of when they occur. // is the same regardless of when they occur.
func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error) { func (c *Connection) processContextFreeMsg(t byte, r *MessageReader) (err error) {
switch t { switch t {
case 'S': case 'S':
c.rxParameterStatus(r) c.rxParameterStatus(r)
@ -206,7 +206,7 @@ func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error)
} }
func (c *Connection) rxMsg() (t byte, r *messageReader, err error) { func (c *Connection) rxMsg() (t byte, r *MessageReader, err error) {
var bodySize int32 var bodySize int32
t, bodySize, err = c.rxMsgHeader() t, bodySize, err = c.rxMsgHeader()
if err != nil { if err != nil {
@ -239,14 +239,14 @@ func (c *Connection) rxMsgBody(bodySize int32) (buf []byte, err error) {
return return
} }
func (c *Connection) rxAuthenticationX(r *messageReader) (err error) { func (c *Connection) rxAuthenticationX(r *MessageReader) (err error) {
code := r.readInt32() code := r.ReadInt32()
switch code { switch code {
case 0: // AuthenticationOk case 0: // AuthenticationOk
case 3: // AuthenticationCleartextPassword case 3: // AuthenticationCleartextPassword
c.txPasswordMessage(c.parameters.Password) c.txPasswordMessage(c.parameters.Password)
case 5: // AuthenticationMD5Password case 5: // AuthenticationMD5Password
salt := r.readByteString(4) salt := r.ReadByteString(4)
digestedPassword := "md5" + hexMD5(hexMD5(c.parameters.Password+c.parameters.User)+salt) digestedPassword := "md5" + hexMD5(hexMD5(c.parameters.Password+c.parameters.User)+salt)
c.txPasswordMessage(digestedPassword) c.txPasswordMessage(digestedPassword)
default: default:
@ -262,75 +262,75 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil)) return hex.EncodeToString(hash.Sum(nil))
} }
func (c *Connection) rxParameterStatus(r *messageReader) { func (c *Connection) rxParameterStatus(r *MessageReader) {
key := r.readString() key := r.ReadString()
value := r.readString() value := r.ReadString()
c.runtimeParams[key] = value c.runtimeParams[key] = value
} }
func (c *Connection) rxErrorResponse(r *messageReader) (err PgError) { func (c *Connection) rxErrorResponse(r *MessageReader) (err PgError) {
for { for {
switch r.readByte() { switch r.ReadByte() {
case 'S': case 'S':
err.Severity = r.readString() err.Severity = r.ReadString()
case 'C': case 'C':
err.Code = r.readString() err.Code = r.ReadString()
case 'M': case 'M':
err.Message = r.readString() err.Message = r.ReadString()
case 0: // End of error message case 0: // End of error message
return return
default: // Ignore other error fields default: // Ignore other error fields
r.readString() r.ReadString()
} }
} }
panic("Unreachable") panic("Unreachable")
} }
func (c *Connection) rxBackendKeyData(r *messageReader) { func (c *Connection) rxBackendKeyData(r *MessageReader) {
c.pid = r.readInt32() c.pid = r.ReadInt32()
c.secretKey = r.readInt32() c.secretKey = r.ReadInt32()
} }
func (c *Connection) rxReadyForQuery(r *messageReader) { func (c *Connection) rxReadyForQuery(r *MessageReader) {
c.txStatus = r.readByte() c.txStatus = r.ReadByte()
} }
func (c *Connection) rxRowDescription(r *messageReader) (fields []fieldDescription) { func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescription) {
fieldCount := r.readInt16() fieldCount := r.ReadInt16()
fields = make([]fieldDescription, fieldCount) fields = make([]FieldDescription, fieldCount)
for i := int16(0); i < fieldCount; i++ { for i := int16(0); i < fieldCount; i++ {
f := &fields[i] f := &fields[i]
f.name = r.readString() f.Name = r.ReadString()
f.table = r.readOid() f.Table = r.ReadOid()
f.attributeNumber = r.readInt16() f.AttributeNumber = r.ReadInt16()
f.dataType = r.readOid() f.DataType = r.ReadOid()
f.dataTypeSize = r.readInt16() f.DataTypeSize = r.ReadInt16()
f.modifier = r.readInt32() f.Modifier = r.ReadInt32()
f.formatCode = r.readInt16() f.FormatCode = r.ReadInt16()
} }
return return
} }
func (c *Connection) rxDataRow(r *messageReader, fields []fieldDescription) (row map[string]string) { func (c *Connection) rxDataRow(r *MessageReader, fields []FieldDescription) (row map[string]string) {
fieldCount := r.readInt16() fieldCount := r.ReadInt16()
row = make(map[string]string, fieldCount) row = make(map[string]string, fieldCount)
for i := int16(0); i < fieldCount; i++ { for i := int16(0); i < fieldCount; i++ {
size := r.readInt32() size := r.ReadInt32()
if size > -1 { if size > -1 {
row[fields[i].name] = r.readByteString(size) row[fields[i].Name] = r.ReadByteString(size)
} }
} }
return return
} }
func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool) { func (c *Connection) rxDataRowFirstValue(r *MessageReader) (s string, null bool) {
r.readInt16() // ignore field count r.ReadInt16() // ignore field count
size := r.readInt32() size := r.ReadInt32()
if size > -1 { if size > -1 {
s = r.readByteString(size) s = r.ReadByteString(size)
} else { } else {
null = true null = true
} }
@ -338,8 +338,8 @@ func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool)
return return
} }
func (c *Connection) rxCommandComplete(r *messageReader) string { func (c *Connection) rxCommandComplete(r *MessageReader) string {
return r.readString() return r.ReadString()
} }
func (c *Connection) txStartupMessage(msg *startupMessage) (err error) { func (c *Connection) txStartupMessage(msg *startupMessage) (err error) {

View File

@ -7,7 +7,7 @@ import (
func (c *Connection) SelectAllString(sql string) (strings []string, err error) { func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
strings = make([]string, 0, 8) strings = make([]string, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) error { onDataRow := func(r *MessageReader, _ []FieldDescription) error {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
@ -21,7 +21,7 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
ints = make([]int64, 0, 8) ints = make([]int64, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
@ -37,7 +37,7 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
ints = make([]int32, 0, 8) ints = make([]int32, 0, 8)
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) { onDataRow := func(r *MessageReader, fields []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
@ -53,7 +53,7 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
ints = make([]int16, 0, 8) ints = make([]int16, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
@ -69,7 +69,7 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) { func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) {
floats = make([]float64, 0, 8) floats = make([]float64, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
@ -85,7 +85,7 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error)
func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) { func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) {
floats = make([]float32, 0, 8) floats = make([]float32, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")

View File

@ -6,7 +6,7 @@ import (
) )
func (c *Connection) SelectString(sql string) (s string, err error) { func (c *Connection) SelectString(sql string) (s string, err error) {
onDataRow := func(r *messageReader, _ []fieldDescription) error { onDataRow := func(r *MessageReader, _ []FieldDescription) error {
var null bool var null bool
s, null = c.rxDataRowFirstValue(r) s, null = c.rxDataRowFirstValue(r)
if null { if null {

View File

@ -130,7 +130,7 @@ func TestSelectFunc(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection()
rowCount := 0 rowCount := 0
onDataRow := func(r *messageReader, fields []fieldDescription) error { onDataRow := func(r *MessageReader, fields []FieldDescription) error {
rowCount++ rowCount++
return nil return nil
} }

View File

@ -5,38 +5,38 @@ import (
"encoding/binary" "encoding/binary"
) )
type messageReader []byte type MessageReader []byte
func newMessageReader(buf []byte) *messageReader { func newMessageReader(buf []byte) *MessageReader {
r := messageReader(buf) r := MessageReader(buf)
return &r return &r
} }
func (r *messageReader) readByte() byte { func (r *MessageReader) ReadByte() byte {
b := (*r)[0] b := (*r)[0]
*r = (*r)[1:] *r = (*r)[1:]
return b return b
} }
func (r *messageReader) readInt16() int16 { func (r *MessageReader) ReadInt16() int16 {
n := int16(binary.BigEndian.Uint16((*r)[:2])) n := int16(binary.BigEndian.Uint16((*r)[:2]))
*r = (*r)[2:] *r = (*r)[2:]
return n return n
} }
func (r *messageReader) readInt32() int32 { func (r *MessageReader) ReadInt32() int32 {
n := int32(binary.BigEndian.Uint32((*r)[:4])) n := int32(binary.BigEndian.Uint32((*r)[:4]))
*r = (*r)[4:] *r = (*r)[4:]
return n return n
} }
func (r *messageReader) readOid() oid { func (r *MessageReader) ReadOid() oid {
n := oid(binary.BigEndian.Uint32((*r)[:4])) n := oid(binary.BigEndian.Uint32((*r)[:4]))
*r = (*r)[4:] *r = (*r)[4:]
return n return n
} }
func (r *messageReader) readString() string { func (r *MessageReader) ReadString() string {
n := bytes.IndexByte(*r, 0) n := bytes.IndexByte(*r, 0)
s := (*r)[:n] s := (*r)[:n]
*r = (*r)[n+1:] *r = (*r)[n+1:]
@ -44,7 +44,7 @@ func (r *messageReader) readString() string {
} }
// Read count bytes and return as string // Read count bytes and return as string
func (r *messageReader) readByteString(count int32) string { func (r *MessageReader) ReadByteString(count int32) string {
s := (*r)[:count] s := (*r)[:count]
*r = (*r)[count:] *r = (*r)[count:]
return string(s) return string(s)

View File

@ -43,14 +43,14 @@ func (self *startupMessage) Bytes() (buf []byte) {
type oid int32 type oid int32
type fieldDescription struct { type FieldDescription struct {
name string Name string
table oid Table oid
attributeNumber int16 AttributeNumber int16
dataType oid DataType oid
dataTypeSize int16 DataTypeSize int16
modifier int32 Modifier int32
formatCode int16 FormatCode int16
} }
type PgError struct { type PgError struct {