Safely handle bad reads of QueryResult

scan-io
Jack Christensen 2014-06-30 19:01:36 -05:00
parent 718c4dfdc8
commit 78b8e0b6f2
2 changed files with 192 additions and 8 deletions

47
conn.go
View File

@ -593,7 +593,10 @@ type RowReader struct{}
// TODO - Read*...
func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
fd, size := qr.NextColumn()
fd, size, ok := qr.NextColumn()
if !ok {
return 0
}
// TODO - do something about nulls
if size == -1 {
@ -604,7 +607,10 @@ func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
}
func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
fd, size := qr.NextColumn()
fd, size, ok := qr.NextColumn()
if !ok {
return 0
}
// TODO - do something about nulls
if size == -1 {
@ -615,7 +621,12 @@ func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
}
func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
fd, size := qr.NextColumn()
var zeroTime time.Time
fd, size, ok := qr.NextColumn()
if !ok {
return zeroTime
}
// TODO - do something about nulls
if size == -1 {
@ -626,7 +637,12 @@ func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
}
func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
fd, size := qr.NextColumn()
var zeroTime time.Time
fd, size, ok := qr.NextColumn()
if !ok {
return zeroTime
}
// TODO - do something about nulls
if size == -1 {
@ -637,12 +653,19 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
}
func (rr *RowReader) ReadString(qr *QueryResult) string {
_, size := qr.NextColumn()
_, size, ok := qr.NextColumn()
if !ok {
return ""
}
return qr.mr.ReadString(size)
}
func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
fd, size := qr.NextColumn()
fd, size, ok := qr.NextColumn()
if !ok {
return nil
}
if size > -1 {
if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil {
@ -768,12 +791,20 @@ func (qr *QueryResult) NextRow() bool {
}
}
func (qr *QueryResult) NextColumn() (*FieldDescription, int32) {
func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) {
if qr.closed {
return nil, 0, false
}
if len(qr.fields) <= qr.columnIdx {
qr.Fatal(ProtocolError("No next column available"))
return nil, 0, false
}
fd := &qr.fields[qr.columnIdx]
qr.columnIdx++
size := qr.mr.ReadInt32()
return fd, size
return fd, size, true
}
// TODO - document

View File

@ -309,6 +309,159 @@ func TestConnQuery(t *testing.T) {
}
}
// Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, conn *pgx.Conn) {
var sum, rowCount int32
qr, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
defer qr.Close()
for qr.NextRow() {
var rr pgx.RowReader
sum += rr.ReadInt32(qr)
rowCount++
}
if qr.Err() != nil {
t.Fatalf("conn.Query failed: ", err)
}
if rowCount != 10 {
t.Error("Select called onDataRow wrong number of times")
}
if sum != 55 {
t.Error("Wrong values returned")
}
}
// Test that a connection stays valid when query results are closed early
func TestConnQueryCloseEarly(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
// Immediately close query without reading any rows
qr, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
qr.Close()
ensureConnValid(t, conn)
// Read partial response then close
qr, err = conn.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
ok := qr.NextRow()
if !ok {
t.Fatal("qr.NextRow terminated early")
}
var rr pgx.RowReader
if n := rr.ReadInt32(qr); n != 1 {
t.Fatalf("Expected 1 from first row, but got %v", n)
}
qr.Close()
ensureConnValid(t, conn)
}
// Test that a connection stays valid when query results read incorrectly
func TestConnQueryReadWrongTypeError(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
// Read a single value incorrectly
qr, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
rowsRead := 0
for qr.NextRow() {
var rr pgx.RowReader
rr.ReadDate(qr)
rowsRead++
}
if rowsRead != 1 {
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
}
if qr.Err() == nil {
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't")
}
// Read too many values
qr, err = conn.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
rowsRead = 0
for qr.NextRow() {
var rr pgx.RowReader
rr.ReadInt32(qr)
rr.ReadInt32(qr)
rowsRead++
}
if rowsRead != 1 {
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
}
if qr.Err() == nil {
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't")
}
ensureConnValid(t, conn)
}
// Test that a connection stays valid when query results read incorrectly
func TestConnQueryReadTooManyValues(t *testing.T) {
// t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
// Read too many values
qr, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
rowsRead := 0
for qr.NextRow() {
var rr pgx.RowReader
rr.ReadInt32(qr)
rr.ReadInt32(qr)
rowsRead++
}
if rowsRead != 1 {
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
}
if qr.Err() == nil {
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't")
}
ensureConnValid(t, conn)
}
func TestConnectionSelectValue(t *testing.T) {
t.Parallel()