Add DataRowReader and change Connection.SelectFunc to use it

Preparatory step for structure binding. refs #11
pgx-vs-pq
Jack Christensen 2013-05-01 08:51:09 -05:00
parent 78590be058
commit 36e4d74d30
7 changed files with 308 additions and 28 deletions

View File

@ -93,7 +93,7 @@ func (c *Connection) Close() (err error) {
return
}
func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []FieldDescription) error) (err error) {
func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error) (err error) {
if err = c.sendSimpleQuery(sql); err != nil {
return
}
@ -115,7 +115,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []Fie
fields = c.rxRowDescription(r)
case dataRow:
if callbackError == nil {
callbackError = onDataRow(r, fields)
callbackError = onDataRow(newDataRowReader(r, fields))
}
case commandComplete:
default:
@ -137,8 +137,8 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []Fie
// pattern when accessing the map
func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) {
rows = make([]map[string]string, 0, 8)
onDataRow := func(r *MessageReader, fields []FieldDescription) error {
rows = append(rows, c.rxDataRow(r, fields))
onDataRow := func(r *DataRowReader) error {
rows = append(rows, c.rxDataRow(r))
return nil
}
err = c.SelectFunc(sql, onDataRow)
@ -312,22 +312,21 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti
return
}
func (c *Connection) rxDataRow(r *MessageReader, fields []FieldDescription) (row map[string]string) {
fieldCount := r.ReadInt16()
func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]string) {
fieldCount := len(r.fields)
mr := r.mr
row = make(map[string]string, fieldCount)
for i := int16(0); i < fieldCount; i++ {
size := r.ReadInt32()
for i := 0; i < fieldCount; i++ {
size := mr.ReadInt32()
if size > -1 {
row[fields[i].Name] = r.ReadByteString(size)
row[r.fields[i].Name] = mr.ReadByteString(size)
}
}
return
}
func (c *Connection) rxDataRowFirstValue(r *MessageReader) (s string, null bool) {
r.ReadInt16() // ignore field count
size := r.ReadInt32()
if size > -1 {
s = r.ReadByteString(size)

View File

@ -7,8 +7,8 @@ import (
func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
strings = make([]string, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) error {
s, null := c.rxDataRowFirstValue(r)
onDataRow := func(r *DataRowReader) error {
s, null := c.rxDataRowFirstValue(r.mr)
if null {
return errors.New("Unexpected NULL")
}
@ -21,8 +21,8 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
ints = make([]int64, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r)
onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r.mr)
if null {
return errors.New("Unexpected NULL")
}
@ -37,8 +37,8 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
ints = make([]int32, 0, 8)
onDataRow := func(r *MessageReader, fields []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r)
onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r.mr)
if null {
return errors.New("Unexpected NULL")
}
@ -53,8 +53,8 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
ints = make([]int16, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r)
onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r.mr)
if null {
return errors.New("Unexpected NULL")
}
@ -69,8 +69,8 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) {
floats = make([]float64, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r)
onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r.mr)
if null {
return errors.New("Unexpected NULL")
}
@ -85,8 +85,8 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error)
func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) {
floats = make([]float32, 0, 8)
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
s, null := c.rxDataRowFirstValue(r)
onDataRow := func(r *DataRowReader) (parseError error) {
s, null := c.rxDataRowFirstValue(r.mr)
if null {
return errors.New("Unexpected NULL")
}

View File

@ -6,9 +6,9 @@ import (
)
func (c *Connection) SelectString(sql string) (s string, err error) {
onDataRow := func(r *MessageReader, _ []FieldDescription) error {
onDataRow := func(r *DataRowReader) error {
var null bool
s, null = c.rxDataRowFirstValue(r)
s, null = c.rxDataRowFirstValue(r.mr)
if null {
return errors.New("Unexpected NULL")
}

View File

@ -129,9 +129,10 @@ func TestExecute(t *testing.T) {
func TestSelectFunc(t *testing.T) {
conn := getSharedConnection()
rowCount := 0
onDataRow := func(r *MessageReader, fields []FieldDescription) error {
var sum, rowCount int32
onDataRow := func(r *DataRowReader) error {
rowCount++
sum += r.ReadInt32()
return nil
}
@ -140,7 +141,10 @@ func TestSelectFunc(t *testing.T) {
t.Fatal("Select failed: " + err.Error())
}
if rowCount != 10 {
t.Fatal("Select called onDataRow wrong number of times")
t.Error("Select called onDataRow wrong number of times")
}
if sum != 55 {
t.Error("Wrong values returned")
}
}

97
data_row_reader.go Normal file
View File

@ -0,0 +1,97 @@
package pgx
import (
"strconv"
)
type DataRowReader struct {
mr *MessageReader
fields []FieldDescription
}
func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowReader) {
r = new(DataRowReader)
r.mr = mr
r.fields = fields
fieldCount := int(mr.ReadInt16())
if fieldCount != len(fields) {
panic("Row description field count and data row field count do not match")
}
return
}
func (r *DataRowReader) ReadString() string {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
return r.mr.ReadByteString(size)
}
func (r *DataRowReader) ReadInt64() int64 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 64)
if err != nil {
panic("Number too large")
}
return i64
}
func (r *DataRowReader) ReadInt32() int32 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 32)
if err != nil {
panic("Number too large")
}
return int32(i64)
}
func (r *DataRowReader) ReadInt16() int16 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 16)
if err != nil {
panic("Number too large")
}
return int16(i64)
}
func (r *DataRowReader) ReadFloat64() float64 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
f64, err := strconv.ParseFloat(r.mr.ReadByteString(size), 64)
if err != nil {
panic("Number too large")
}
return f64
}
func (r *DataRowReader) ReadFloat32() float32 {
size := r.mr.ReadInt32()
if size == -1 {
panic("Unexpected null")
}
f64, err := strconv.ParseFloat(r.mr.ReadByteString(size), 32)
if err != nil {
panic("Number too large")
}
return float32(f64)
}

117
data_row_reader_test.go Normal file
View File

@ -0,0 +1,117 @@
package pgx
import (
"testing"
)
func TestDataRowReaderReadString(t *testing.T) {
conn := getSharedConnection()
var s string
onDataRow := func(r *DataRowReader) error {
s = r.ReadString()
return nil
}
err := conn.SelectFunc("select 'Jack'", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if s != "Jack" {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadInt64(t *testing.T) {
conn := getSharedConnection()
var n int64
onDataRow := func(r *DataRowReader) error {
n = r.ReadInt64()
return nil
}
err := conn.SelectFunc("select 1", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1 {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadInt32(t *testing.T) {
conn := getSharedConnection()
var n int32
onDataRow := func(r *DataRowReader) error {
n = r.ReadInt32()
return nil
}
err := conn.SelectFunc("select 1", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1 {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadInt16(t *testing.T) {
conn := getSharedConnection()
var n int16
onDataRow := func(r *DataRowReader) error {
n = r.ReadInt16()
return nil
}
err := conn.SelectFunc("select 1", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1 {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadFloat64(t *testing.T) {
conn := getSharedConnection()
var n float64
onDataRow := func(r *DataRowReader) error {
n = r.ReadFloat64()
return nil
}
err := conn.SelectFunc("select 1.5", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1.5 {
t.Error("Wrong value returned")
}
}
func TestDataRowReaderReadFloat32(t *testing.T) {
conn := getSharedConnection()
var n float32
onDataRow := func(r *DataRowReader) error {
n = r.ReadFloat32()
return nil
}
err := conn.SelectFunc("select 1.5", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1.5 {
t.Error("Wrong value returned")
}
}

View File

@ -0,0 +1,63 @@
package pgx
import (
"testing"
)
func TestDataRowReaderReadString(t *testing.T) {
conn := getSharedConnection()
var s string
onDataRow := func(r *DataRowReader) error {
s = r.ReadString()
return nil
}
err := conn.SelectFunc("select 'Jack'", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if s != "Jack" {
t.Error("Wrong value returned")
}
}
<% [64, 32, 16].each do |size| %>
func TestDataRowReaderReadInt<%= size %>(t *testing.T) {
conn := getSharedConnection()
var n int<%= size %>
onDataRow := func(r *DataRowReader) error {
n = r.ReadInt<%= size %>()
return nil
}
err := conn.SelectFunc("select 1", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1 {
t.Error("Wrong value returned")
}
}
<% end %>
<% [64, 32].each do |size| %>
func TestDataRowReaderReadFloat<%= size %>(t *testing.T) {
conn := getSharedConnection()
var n float<%= size %>
onDataRow := func(r *DataRowReader) error {
n = r.ReadFloat<%= size %>()
return nil
}
err := conn.SelectFunc("select 1.5", onDataRow)
if err != nil {
t.Fatal("Select failed: " + err.Error())
}
if n != 1.5 {
t.Error("Wrong value returned")
}
}
<% end %>