mirror of https://github.com/jackc/pgx.git
Add DataRowReader and change Connection.SelectFunc to use it
Preparatory step for structure binding. refs #11pgx-vs-pq
parent
78590be058
commit
36e4d74d30
|
@ -93,7 +93,7 @@ func (c *Connection) Close() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []FieldDescription) error) (err error) {
|
||||
func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error) (err error) {
|
||||
if err = c.sendSimpleQuery(sql); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -115,7 +115,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []Fie
|
|||
fields = c.rxRowDescription(r)
|
||||
case dataRow:
|
||||
if callbackError == nil {
|
||||
callbackError = onDataRow(r, fields)
|
||||
callbackError = onDataRow(newDataRowReader(r, fields))
|
||||
}
|
||||
case commandComplete:
|
||||
default:
|
||||
|
@ -137,8 +137,8 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []Fie
|
|||
// pattern when accessing the map
|
||||
func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) {
|
||||
rows = make([]map[string]string, 0, 8)
|
||||
onDataRow := func(r *MessageReader, fields []FieldDescription) error {
|
||||
rows = append(rows, c.rxDataRow(r, fields))
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
rows = append(rows, c.rxDataRow(r))
|
||||
return nil
|
||||
}
|
||||
err = c.SelectFunc(sql, onDataRow)
|
||||
|
@ -312,22 +312,21 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Connection) rxDataRow(r *MessageReader, fields []FieldDescription) (row map[string]string) {
|
||||
fieldCount := r.ReadInt16()
|
||||
func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]string) {
|
||||
fieldCount := len(r.fields)
|
||||
mr := r.mr
|
||||
|
||||
row = make(map[string]string, fieldCount)
|
||||
for i := int16(0); i < fieldCount; i++ {
|
||||
size := r.ReadInt32()
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
size := mr.ReadInt32()
|
||||
if size > -1 {
|
||||
row[fields[i].Name] = r.ReadByteString(size)
|
||||
row[r.fields[i].Name] = mr.ReadByteString(size)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) rxDataRowFirstValue(r *MessageReader) (s string, null bool) {
|
||||
r.ReadInt16() // ignore field count
|
||||
|
||||
size := r.ReadInt32()
|
||||
if size > -1 {
|
||||
s = r.ReadByteString(size)
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
|
||||
func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
|
||||
strings = make([]string, 0, 8)
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) error {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
s, null := c.rxDataRowFirstValue(r.mr)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
}
|
||||
|
@ -21,8 +21,8 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
|
|||
|
||||
func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
|
||||
ints = make([]int64, 0, 8)
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
onDataRow := func(r *DataRowReader) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r.mr)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
}
|
||||
|
@ -37,8 +37,8 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
|
|||
|
||||
func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
|
||||
ints = make([]int32, 0, 8)
|
||||
onDataRow := func(r *MessageReader, fields []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
onDataRow := func(r *DataRowReader) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r.mr)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
}
|
||||
|
@ -53,8 +53,8 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
|
|||
|
||||
func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
|
||||
ints = make([]int16, 0, 8)
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
onDataRow := func(r *DataRowReader) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r.mr)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
}
|
||||
|
@ -69,8 +69,8 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
|
|||
|
||||
func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) {
|
||||
floats = make([]float64, 0, 8)
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
onDataRow := func(r *DataRowReader) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r.mr)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
}
|
||||
|
@ -85,8 +85,8 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error)
|
|||
|
||||
func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) {
|
||||
floats = make([]float32, 0, 8)
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
onDataRow := func(r *DataRowReader) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r.mr)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
}
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
)
|
||||
|
||||
func (c *Connection) SelectString(sql string) (s string, err error) {
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) error {
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
var null bool
|
||||
s, null = c.rxDataRowFirstValue(r)
|
||||
s, null = c.rxDataRowFirstValue(r.mr)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
}
|
||||
|
|
|
@ -129,9 +129,10 @@ func TestExecute(t *testing.T) {
|
|||
func TestSelectFunc(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
rowCount := 0
|
||||
onDataRow := func(r *MessageReader, fields []FieldDescription) error {
|
||||
var sum, rowCount int32
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
rowCount++
|
||||
sum += r.ReadInt32()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -140,7 +141,10 @@ func TestSelectFunc(t *testing.T) {
|
|||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if rowCount != 10 {
|
||||
t.Fatal("Select called onDataRow wrong number of times")
|
||||
t.Error("Select called onDataRow wrong number of times")
|
||||
}
|
||||
if sum != 55 {
|
||||
t.Error("Wrong values returned")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type DataRowReader struct {
|
||||
mr *MessageReader
|
||||
fields []FieldDescription
|
||||
}
|
||||
|
||||
func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowReader) {
|
||||
r = new(DataRowReader)
|
||||
r.mr = mr
|
||||
r.fields = fields
|
||||
|
||||
fieldCount := int(mr.ReadInt16())
|
||||
if fieldCount != len(fields) {
|
||||
panic("Row description field count and data row field count do not match")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *DataRowReader) ReadString() string {
|
||||
size := r.mr.ReadInt32()
|
||||
if size == -1 {
|
||||
panic("Unexpected null")
|
||||
}
|
||||
|
||||
return r.mr.ReadByteString(size)
|
||||
}
|
||||
|
||||
func (r *DataRowReader) ReadInt64() int64 {
|
||||
size := r.mr.ReadInt32()
|
||||
if size == -1 {
|
||||
panic("Unexpected null")
|
||||
}
|
||||
|
||||
i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 64)
|
||||
if err != nil {
|
||||
panic("Number too large")
|
||||
}
|
||||
return i64
|
||||
}
|
||||
|
||||
func (r *DataRowReader) ReadInt32() int32 {
|
||||
size := r.mr.ReadInt32()
|
||||
if size == -1 {
|
||||
panic("Unexpected null")
|
||||
}
|
||||
|
||||
i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 32)
|
||||
if err != nil {
|
||||
panic("Number too large")
|
||||
}
|
||||
return int32(i64)
|
||||
}
|
||||
|
||||
func (r *DataRowReader) ReadInt16() int16 {
|
||||
size := r.mr.ReadInt32()
|
||||
if size == -1 {
|
||||
panic("Unexpected null")
|
||||
}
|
||||
|
||||
i64, err := strconv.ParseInt(r.mr.ReadByteString(size), 10, 16)
|
||||
if err != nil {
|
||||
panic("Number too large")
|
||||
}
|
||||
return int16(i64)
|
||||
}
|
||||
|
||||
func (r *DataRowReader) ReadFloat64() float64 {
|
||||
size := r.mr.ReadInt32()
|
||||
if size == -1 {
|
||||
panic("Unexpected null")
|
||||
}
|
||||
|
||||
f64, err := strconv.ParseFloat(r.mr.ReadByteString(size), 64)
|
||||
if err != nil {
|
||||
panic("Number too large")
|
||||
}
|
||||
return f64
|
||||
}
|
||||
|
||||
func (r *DataRowReader) ReadFloat32() float32 {
|
||||
size := r.mr.ReadInt32()
|
||||
if size == -1 {
|
||||
panic("Unexpected null")
|
||||
}
|
||||
|
||||
f64, err := strconv.ParseFloat(r.mr.ReadByteString(size), 32)
|
||||
if err != nil {
|
||||
panic("Number too large")
|
||||
}
|
||||
return float32(f64)
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDataRowReaderReadString(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
var s string
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
s = r.ReadString()
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select 'Jack'", onDataRow)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if s != "Jack" {
|
||||
t.Error("Wrong value returned")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestDataRowReaderReadInt64(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
var n int64
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
n = r.ReadInt64()
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select 1", onDataRow)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if n != 1 {
|
||||
t.Error("Wrong value returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataRowReaderReadInt32(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
var n int32
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
n = r.ReadInt32()
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select 1", onDataRow)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if n != 1 {
|
||||
t.Error("Wrong value returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataRowReaderReadInt16(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
var n int16
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
n = r.ReadInt16()
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select 1", onDataRow)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if n != 1 {
|
||||
t.Error("Wrong value returned")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestDataRowReaderReadFloat64(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
var n float64
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
n = r.ReadFloat64()
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select 1.5", onDataRow)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if n != 1.5 {
|
||||
t.Error("Wrong value returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataRowReaderReadFloat32(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
var n float32
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
n = r.ReadFloat32()
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select 1.5", onDataRow)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if n != 1.5 {
|
||||
t.Error("Wrong value returned")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDataRowReaderReadString(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
var s string
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
s = r.ReadString()
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select 'Jack'", onDataRow)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if s != "Jack" {
|
||||
t.Error("Wrong value returned")
|
||||
}
|
||||
}
|
||||
|
||||
<% [64, 32, 16].each do |size| %>
|
||||
func TestDataRowReaderReadInt<%= size %>(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
var n int<%= size %>
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
n = r.ReadInt<%= size %>()
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select 1", onDataRow)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if n != 1 {
|
||||
t.Error("Wrong value returned")
|
||||
}
|
||||
}
|
||||
<% end %>
|
||||
|
||||
<% [64, 32].each do |size| %>
|
||||
func TestDataRowReaderReadFloat<%= size %>(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
var n float<%= size %>
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
n = r.ReadFloat<%= size %>()
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select 1.5", onDataRow)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
if n != 1.5 {
|
||||
t.Error("Wrong value returned")
|
||||
}
|
||||
}
|
||||
<% end %>
|
Loading…
Reference in New Issue