From 36e4d74d30342bc972872d6ac1234262c94ef333 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 1 May 2013 08:51:09 -0500 Subject: [PATCH] Add DataRowReader and change Connection.SelectFunc to use it Preparatory step for structure binding. refs #11 --- connection.go | 21 +++---- connection_select_column.go | 24 ++++---- connection_select_value.go | 4 +- connection_test.go | 10 ++- data_row_reader.go | 97 ++++++++++++++++++++++++++++++ data_row_reader_test.go | 117 ++++++++++++++++++++++++++++++++++++ data_row_reader_test.go.erb | 63 +++++++++++++++++++ 7 files changed, 308 insertions(+), 28 deletions(-) create mode 100644 data_row_reader.go create mode 100644 data_row_reader_test.go create mode 100644 data_row_reader_test.go.erb diff --git a/connection.go b/connection.go index b6e66558..93c35a53 100644 --- a/connection.go +++ b/connection.go @@ -93,7 +93,7 @@ func (c *Connection) Close() (err error) { return } -func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []FieldDescription) error) (err error) { +func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error) (err error) { if err = c.sendSimpleQuery(sql); err != nil { return } @@ -115,7 +115,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []Fie fields = c.rxRowDescription(r) case dataRow: if callbackError == nil { - callbackError = onDataRow(r, fields) + callbackError = onDataRow(newDataRowReader(r, fields)) } case commandComplete: default: @@ -137,8 +137,8 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []Fie // pattern when accessing the map func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) { rows = make([]map[string]string, 0, 8) - onDataRow := func(r *MessageReader, fields []FieldDescription) error { - rows = append(rows, c.rxDataRow(r, fields)) + onDataRow := func(r *DataRowReader) error { + rows = append(rows, c.rxDataRow(r)) return nil } err = c.SelectFunc(sql, onDataRow) @@ -312,22 +312,21 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti return } -func (c *Connection) rxDataRow(r *MessageReader, fields []FieldDescription) (row map[string]string) { - fieldCount := r.ReadInt16() +func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]string) { + fieldCount := len(r.fields) + mr := r.mr row = make(map[string]string, fieldCount) - for i := int16(0); i < fieldCount; i++ { - size := r.ReadInt32() + for i := 0; i < fieldCount; i++ { + size := mr.ReadInt32() if size > -1 { - row[fields[i].Name] = r.ReadByteString(size) + row[r.fields[i].Name] = mr.ReadByteString(size) } } return } func (c *Connection) rxDataRowFirstValue(r *MessageReader) (s string, null bool) { - r.ReadInt16() // ignore field count - size := r.ReadInt32() if size > -1 { s = r.ReadByteString(size) diff --git a/connection_select_column.go b/connection_select_column.go index 7d6442fb..442e028e 100644 --- a/connection_select_column.go +++ b/connection_select_column.go @@ -7,8 +7,8 @@ import ( func (c *Connection) SelectAllString(sql string) (strings []string, err error) { strings = make([]string, 0, 8) - onDataRow := func(r *MessageReader, _ []FieldDescription) error { - s, null := c.rxDataRowFirstValue(r) + onDataRow := func(r *DataRowReader) error { + s, null := c.rxDataRowFirstValue(r.mr) if null { return errors.New("Unexpected NULL") } @@ -21,8 +21,8 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) { func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { ints = make([]int64, 0, 8) - onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { - s, null := c.rxDataRowFirstValue(r) + onDataRow := func(r *DataRowReader) (parseError error) { + s, null := c.rxDataRowFirstValue(r.mr) if null { return errors.New("Unexpected NULL") } @@ -37,8 +37,8 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { ints = make([]int32, 0, 8) - onDataRow := func(r *MessageReader, fields []FieldDescription) (parseError error) { - s, null := c.rxDataRowFirstValue(r) + onDataRow := func(r *DataRowReader) (parseError error) { + s, null := c.rxDataRowFirstValue(r.mr) if null { return errors.New("Unexpected NULL") } @@ -53,8 +53,8 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { ints = make([]int16, 0, 8) - onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { - s, null := c.rxDataRowFirstValue(r) + onDataRow := func(r *DataRowReader) (parseError error) { + s, null := c.rxDataRowFirstValue(r.mr) if null { return errors.New("Unexpected NULL") } @@ -69,8 +69,8 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) { floats = make([]float64, 0, 8) - onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { - s, null := c.rxDataRowFirstValue(r) + onDataRow := func(r *DataRowReader) (parseError error) { + s, null := c.rxDataRowFirstValue(r.mr) if null { return errors.New("Unexpected NULL") } @@ -85,8 +85,8 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) { floats = make([]float32, 0, 8) - onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) { - s, null := c.rxDataRowFirstValue(r) + onDataRow := func(r *DataRowReader) (parseError error) { + s, null := c.rxDataRowFirstValue(r.mr) if null { return errors.New("Unexpected NULL") } diff --git a/connection_select_value.go b/connection_select_value.go index d6516715..ebe93099 100644 --- a/connection_select_value.go +++ b/connection_select_value.go @@ -6,9 +6,9 @@ import ( ) func (c *Connection) SelectString(sql string) (s string, err error) { - onDataRow := func(r *MessageReader, _ []FieldDescription) error { + onDataRow := func(r *DataRowReader) error { var null bool - s, null = c.rxDataRowFirstValue(r) + s, null = c.rxDataRowFirstValue(r.mr) if null { return errors.New("Unexpected NULL") } diff --git a/connection_test.go b/connection_test.go index 50415157..95e4db23 100644 --- a/connection_test.go +++ b/connection_test.go @@ -129,9 +129,10 @@ func TestExecute(t *testing.T) { func TestSelectFunc(t *testing.T) { conn := getSharedConnection() - rowCount := 0 - onDataRow := func(r *MessageReader, fields []FieldDescription) error { + var sum, rowCount int32 + onDataRow := func(r *DataRowReader) error { rowCount++ + sum += r.ReadInt32() return nil } @@ -140,7 +141,10 @@ func TestSelectFunc(t *testing.T) { t.Fatal("Select failed: " + err.Error()) } if rowCount != 10 { - t.Fatal("Select called onDataRow wrong number of times") + t.Error("Select called onDataRow wrong number of times") + } + if sum != 55 { + t.Error("Wrong values returned") } } diff --git a/data_row_reader.go b/data_row_reader.go new file mode 100644 index 00000000..c6151c53 --- /dev/null +++ b/data_row_reader.go @@ -0,0 +1,97 @@ +package pgx + +import ( + "strconv" +) + +type DataRowReader struct { + mr *MessageReader + fields []FieldDescription +} + +func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowReader) { + r = new(DataRowReader) + r.mr = mr + r.fields = fields + + fieldCount := int(mr.ReadInt16()) + if fieldCount != len(fields) { + panic("Row description field count and data row field count do not match") + } + + return +} + +func (r *DataRowReader) ReadString() string { + size := r.mr.ReadInt32() + if size == -1 { + panic("Unexpected null") + } + + return r.mr.ReadByteString(size) +} + +func (r *DataRowReader) ReadInt64() int64 { + size := r.mr.ReadInt32() + if size == -1 { + panic("Unexpected null") + } + + i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 64) + if err != nil { + panic("Number too large") + } + return i64 +} + +func (r *DataRowReader) ReadInt32() int32 { + size := r.mr.ReadInt32() + if size == -1 { + panic("Unexpected null") + } + + i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 32) + if err != nil { + panic("Number too large") + } + return int32(i64) +} + +func (r *DataRowReader) ReadInt16() int16 { + size := r.mr.ReadInt32() + if size == -1 { + panic("Unexpected null") + } + + i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 16) + if err != nil { + panic("Number too large") + } + return int16(i64) +} + +func (r *DataRowReader) ReadFloat64() float64 { + size := r.mr.ReadInt32() + if size == -1 { + panic("Unexpected null") + } + + f64, err := strconv.ParseFloat(r.mr.ReadByteString(size), 64) + if err != nil { + panic("Number too large") + } + return f64 +} + +func (r *DataRowReader) ReadFloat32() float32 { + size := r.mr.ReadInt32() + if size == -1 { + panic("Unexpected null") + } + + f64, err := strconv.ParseFloat(r.mr.ReadByteString(size), 32) + if err != nil { + panic("Number too large") + } + return float32(f64) +} diff --git a/data_row_reader_test.go b/data_row_reader_test.go new file mode 100644 index 00000000..c4ee6600 --- /dev/null +++ b/data_row_reader_test.go @@ -0,0 +1,117 @@ +package pgx + +import ( + "testing" +) + +func TestDataRowReaderReadString(t *testing.T) { + conn := getSharedConnection() + + var s string + onDataRow := func(r *DataRowReader) error { + s = r.ReadString() + return nil + } + + err := conn.SelectFunc("select 'Jack'", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if s != "Jack" { + t.Error("Wrong value returned") + } +} + + +func TestDataRowReaderReadInt64(t *testing.T) { + conn := getSharedConnection() + + var n int64 + onDataRow := func(r *DataRowReader) error { + n = r.ReadInt64() + return nil + } + + err := conn.SelectFunc("select 1", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if n != 1 { + t.Error("Wrong value returned") + } +} + +func TestDataRowReaderReadInt32(t *testing.T) { + conn := getSharedConnection() + + var n int32 + onDataRow := func(r *DataRowReader) error { + n = r.ReadInt32() + return nil + } + + err := conn.SelectFunc("select 1", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if n != 1 { + t.Error("Wrong value returned") + } +} + +func TestDataRowReaderReadInt16(t *testing.T) { + conn := getSharedConnection() + + var n int16 + onDataRow := func(r *DataRowReader) error { + n = r.ReadInt16() + return nil + } + + err := conn.SelectFunc("select 1", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if n != 1 { + t.Error("Wrong value returned") + } +} + + + +func TestDataRowReaderReadFloat64(t *testing.T) { + conn := getSharedConnection() + + var n float64 + onDataRow := func(r *DataRowReader) error { + n = r.ReadFloat64() + return nil + } + + err := conn.SelectFunc("select 1.5", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if n != 1.5 { + t.Error("Wrong value returned") + } +} + +func TestDataRowReaderReadFloat32(t *testing.T) { + conn := getSharedConnection() + + var n float32 + onDataRow := func(r *DataRowReader) error { + n = r.ReadFloat32() + return nil + } + + err := conn.SelectFunc("select 1.5", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if n != 1.5 { + t.Error("Wrong value returned") + } +} + diff --git a/data_row_reader_test.go.erb b/data_row_reader_test.go.erb new file mode 100644 index 00000000..f5d75a4b --- /dev/null +++ b/data_row_reader_test.go.erb @@ -0,0 +1,63 @@ +package pgx + +import ( + "testing" +) + +func TestDataRowReaderReadString(t *testing.T) { + conn := getSharedConnection() + + var s string + onDataRow := func(r *DataRowReader) error { + s = r.ReadString() + return nil + } + + err := conn.SelectFunc("select 'Jack'", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if s != "Jack" { + t.Error("Wrong value returned") + } +} + +<% [64, 32, 16].each do |size| %> +func TestDataRowReaderReadInt<%= size %>(t *testing.T) { + conn := getSharedConnection() + + var n int<%= size %> + onDataRow := func(r *DataRowReader) error { + n = r.ReadInt<%= size %>() + return nil + } + + err := conn.SelectFunc("select 1", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if n != 1 { + t.Error("Wrong value returned") + } +} +<% end %> + +<% [64, 32].each do |size| %> +func TestDataRowReaderReadFloat<%= size %>(t *testing.T) { + conn := getSharedConnection() + + var n float<%= size %> + onDataRow := func(r *DataRowReader) error { + n = r.ReadFloat<%= size %>() + return nil + } + + err := conn.SelectFunc("select 1.5", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if n != 1.5 { + t.Error("Wrong value returned") + } +} +<% end %>