Add DataRowReader and change Connection.SelectFunc to use it

Preparatory step for structure binding. refs #11
pgx-vs-pq
Jack Christensen 2013-05-01 08:51:09 -05:00
parent 78590be058
commit 36e4d74d30
7 changed files with 308 additions and 28 deletions

View File

@ -93,7 +93,7 @@ func (c *Connection) Close() (err error) {
return return
} }
func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []FieldDescription) error) (err error) { func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error) (err error) {
if err = c.sendSimpleQuery(sql); err != nil { if err = c.sendSimpleQuery(sql); err != nil {
return return
} }
@ -115,7 +115,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []Fie
fields = c.rxRowDescription(r) fields = c.rxRowDescription(r)
case dataRow: case dataRow:
if callbackError == nil { if callbackError == nil {
callbackError = onDataRow(r, fields) callbackError = onDataRow(newDataRowReader(r, fields))
} }
case commandComplete: case commandComplete:
default: default:
@ -137,8 +137,8 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []Fie
// pattern when accessing the map // pattern when accessing the map
func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) { func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) {
rows = make([]map[string]string, 0, 8) rows = make([]map[string]string, 0, 8)
onDataRow := func(r *MessageReader, fields []FieldDescription) error { onDataRow := func(r *DataRowReader) error {
rows = append(rows, c.rxDataRow(r, fields)) rows = append(rows, c.rxDataRow(r))
return nil return nil
} }
err = c.SelectFunc(sql, onDataRow) err = c.SelectFunc(sql, onDataRow)
@ -312,22 +312,21 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti
return return
} }
func (c *Connection) rxDataRow(r *MessageReader, fields []FieldDescription) (row map[string]string) { func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]string) {
fieldCount := r.ReadInt16() fieldCount := len(r.fields)
mr := r.mr
row = make(map[string]string, fieldCount) row = make(map[string]string, fieldCount)
for i := int16(0); i < fieldCount; i++ { for i := 0; i < fieldCount; i++ {
size := r.ReadInt32() size := mr.ReadInt32()
if size > -1 { if size > -1 {
row[fields[i].Name] = r.ReadByteString(size) row[r.fields[i].Name] = mr.ReadByteString(size)
} }
} }
return return
} }
func (c *Connection) rxDataRowFirstValue(r *MessageReader) (s string, null bool) { func (c *Connection) rxDataRowFirstValue(r *MessageReader) (s string, null bool) {
r.ReadInt16() // ignore field count
size := r.ReadInt32() size := r.ReadInt32()
if size > -1 { if size > -1 {
s = r.ReadByteString(size) s = r.ReadByteString(size)

View File

@ -7,8 +7,8 @@ import (
func (c *Connection) SelectAllString(sql string) (strings []string, err error) { func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
strings = make([]string, 0, 8) strings = make([]string, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) error { onDataRow := func(r *DataRowReader) error {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r.mr)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
} }
@ -21,8 +21,8 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
ints = make([]int64, 0, 8) ints = make([]int64, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r.mr)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
} }
@ -37,8 +37,8 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
ints = make([]int32, 0, 8) ints = make([]int32, 0, 8)
onDataRow := func(r *MessageReader, fields []FieldDescription) (parseError error) { onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r.mr)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
} }
@ -53,8 +53,8 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
ints = make([]int16, 0, 8) ints = make([]int16, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r.mr)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
} }
@ -69,8 +69,8 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) { func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) {
floats = make([]float64, 0, 8) floats = make([]float64, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r.mr)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
} }
@ -85,8 +85,8 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error)
func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) { func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) {
floats = make([]float32, 0, 8) floats = make([]float32, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r) s, null := c.rxDataRowFirstValue(r.mr)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
} }

View File

@ -6,9 +6,9 @@ import (
) )
func (c *Connection) SelectString(sql string) (s string, err error) { func (c *Connection) SelectString(sql string) (s string, err error) {
onDataRow := func(r *MessageReader, _ []FieldDescription) error { onDataRow := func(r *DataRowReader) error {
var null bool var null bool
s, null = c.rxDataRowFirstValue(r) s, null = c.rxDataRowFirstValue(r.mr)
if null { if null {
return errors.New("Unexpected NULL") return errors.New("Unexpected NULL")
} }

View File

@ -129,9 +129,10 @@ func TestExecute(t *testing.T) {
func TestSelectFunc(t *testing.T) { func TestSelectFunc(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection()
rowCount := 0 var sum, rowCount int32
onDataRow := func(r *MessageReader, fields []FieldDescription) error { onDataRow := func(r *DataRowReader) error {
rowCount++ rowCount++
sum += r.ReadInt32()
return nil return nil
} }
@ -140,7 +141,10 @@ func TestSelectFunc(t *testing.T) {
t.Fatal("Select failed: " + err.Error()) t.Fatal("Select failed: " + err.Error())
} }
if rowCount != 10 { if rowCount != 10 {
t.Fatal("Select called onDataRow wrong number of times") t.Error("Select called onDataRow wrong number of times")
}
if sum != 55 {
t.Error("Wrong values returned")
} }
} }

97
data_row_reader.go Normal file
View File

@ -0,0 +1,97 @@
package pgx
import (
"strconv"
)
type DataRowReader struct {
mr *MessageReader
fields []FieldDescription
}
func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowReader) {
r = new(DataRowReader)
r.mr = mr
r.fields = fields
fieldCount := int(mr.ReadInt16())
if fieldCount != len(fields) {
panic("Row description field count and data row field count do not match")
}
return
}
func (r *DataRowReader) ReadString() string {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
return r.mr.ReadByteString(size)
}
func (r *DataRowReader) ReadInt64() int64 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 64)
if err != nil {
panic("Number too large")
}
return i64
}
func (r *DataRowReader) ReadInt32() int32 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 32)
if err != nil {
panic("Number too large")
}
return int32(i64)
}
func (r *DataRowReader) ReadInt16() int16 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 16)
if err != nil {
panic("Number too large")
}
return int16(i64)
}
func (r *DataRowReader) ReadFloat64() float64 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
f64, err := strconv.ParseFloat(r.mr.ReadByteString(size), 64)
if err != nil {
panic("Number too large")
}
return f64
}
func (r *DataRowReader) ReadFloat32() float32 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
f64, err := strconv.ParseFloat(r.mr.ReadByteString(size), 32)
if err != nil {
panic("Number too large")
}
return float32(f64)
}

117
data_row_reader_test.go Normal file
View File

@ -0,0 +1,117 @@
package pgx
import (
"testing"
)
func TestDataRowReaderReadString(t *testing.T) {
conn := getSharedConnection()
var s string
onDataRow := func(r *DataRowReader) error {
s = r.ReadString()
return nil
}
err := conn.SelectFunc("select 'Jack'", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if s != "Jack" {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadInt64(t *testing.T) {
conn := getSharedConnection()
var n int64
onDataRow := func(r *DataRowReader) error {
n = r.ReadInt64()
return nil
}
err := conn.SelectFunc("select 1", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1 {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadInt32(t *testing.T) {
conn := getSharedConnection()
var n int32
onDataRow := func(r *DataRowReader) error {
n = r.ReadInt32()
return nil
}
err := conn.SelectFunc("select 1", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1 {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadInt16(t *testing.T) {
conn := getSharedConnection()
var n int16
onDataRow := func(r *DataRowReader) error {
n = r.ReadInt16()
return nil
}
err := conn.SelectFunc("select 1", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1 {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadFloat64(t *testing.T) {
conn := getSharedConnection()
var n float64
onDataRow := func(r *DataRowReader) error {
n = r.ReadFloat64()
return nil
}
err := conn.SelectFunc("select 1.5", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1.5 {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadFloat32(t *testing.T) {
conn := getSharedConnection()
var n float32
onDataRow := func(r *DataRowReader) error {
n = r.ReadFloat32()
return nil
}
err := conn.SelectFunc("select 1.5", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1.5 {
t.Error("Wrong value returned")
}
}

View File

@ -0,0 +1,63 @@
package pgx
import (
"testing"
)
func TestDataRowReaderReadString(t *testing.T) {
conn := getSharedConnection()
var s string
onDataRow := func(r *DataRowReader) error {
s = r.ReadString()
return nil
}
err := conn.SelectFunc("select 'Jack'", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if s != "Jack" {
t.Error("Wrong value returned")
}
}
<% [64, 32, 16].each do |size| %>
func TestDataRowReaderReadInt<%= size %>(t *testing.T) {
conn := getSharedConnection()
var n int<%= size %>
onDataRow := func(r *DataRowReader) error {
n = r.ReadInt<%= size %>()
return nil
}
err := conn.SelectFunc("select 1", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1 {
t.Error("Wrong value returned")
}
}
<% end %>
<% [64, 32].each do |size| %>
func TestDataRowReaderReadFloat<%= size %>(t *testing.T) {
conn := getSharedConnection()
var n float<%= size %>
onDataRow := func(r *DataRowReader) error {
n = r.ReadFloat<%= size %>()
return nil
}
err := conn.SelectFunc("select 1.5", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1.5 {
t.Error("Wrong value returned")
}
}
<% end %>