mirror of https://github.com/jackc/pgx.git
Safely handle bad reads of QueryResult
parent
718c4dfdc8
commit
78b8e0b6f2
47
conn.go
47
conn.go
|
@ -593,7 +593,10 @@ type RowReader struct{}
|
|||
// TODO - Read*...
|
||||
|
||||
func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
|
||||
fd, size := qr.NextColumn()
|
||||
fd, size, ok := qr.NextColumn()
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
// TODO - do something about nulls
|
||||
if size == -1 {
|
||||
|
@ -604,7 +607,10 @@ func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
|
|||
}
|
||||
|
||||
func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
|
||||
fd, size := qr.NextColumn()
|
||||
fd, size, ok := qr.NextColumn()
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
// TODO - do something about nulls
|
||||
if size == -1 {
|
||||
|
@ -615,7 +621,12 @@ func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
|
|||
}
|
||||
|
||||
func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
|
||||
fd, size := qr.NextColumn()
|
||||
var zeroTime time.Time
|
||||
|
||||
fd, size, ok := qr.NextColumn()
|
||||
if !ok {
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
// TODO - do something about nulls
|
||||
if size == -1 {
|
||||
|
@ -626,7 +637,12 @@ func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
|
|||
}
|
||||
|
||||
func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
|
||||
fd, size := qr.NextColumn()
|
||||
var zeroTime time.Time
|
||||
|
||||
fd, size, ok := qr.NextColumn()
|
||||
if !ok {
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
// TODO - do something about nulls
|
||||
if size == -1 {
|
||||
|
@ -637,12 +653,19 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
|
|||
}
|
||||
|
||||
func (rr *RowReader) ReadString(qr *QueryResult) string {
|
||||
_, size := qr.NextColumn()
|
||||
_, size, ok := qr.NextColumn()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return qr.mr.ReadString(size)
|
||||
}
|
||||
|
||||
func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
||||
fd, size := qr.NextColumn()
|
||||
fd, size, ok := qr.NextColumn()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if size > -1 {
|
||||
if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil {
|
||||
|
@ -768,12 +791,20 @@ func (qr *QueryResult) NextRow() bool {
|
|||
}
|
||||
}
|
||||
|
||||
func (qr *QueryResult) NextColumn() (*FieldDescription, int32) {
|
||||
func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) {
|
||||
if qr.closed {
|
||||
return nil, 0, false
|
||||
}
|
||||
if len(qr.fields) <= qr.columnIdx {
|
||||
qr.Fatal(ProtocolError("No next column available"))
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
fd := &qr.fields[qr.columnIdx]
|
||||
qr.columnIdx++
|
||||
size := qr.mr.ReadInt32()
|
||||
|
||||
return fd, size
|
||||
return fd, size, true
|
||||
}
|
||||
|
||||
// TODO - document
|
||||
|
|
153
conn_test.go
153
conn_test.go
|
@ -309,6 +309,159 @@ func TestConnQuery(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Do a simple query to ensure the connection is still usable
|
||||
func ensureConnValid(t *testing.T, conn *pgx.Conn) {
|
||||
var sum, rowCount int32
|
||||
|
||||
qr, err := conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
defer qr.Close()
|
||||
|
||||
for qr.NextRow() {
|
||||
var rr pgx.RowReader
|
||||
sum += rr.ReadInt32(qr)
|
||||
rowCount++
|
||||
}
|
||||
|
||||
if qr.Err() != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
if rowCount != 10 {
|
||||
t.Error("Select called onDataRow wrong number of times")
|
||||
}
|
||||
if sum != 55 {
|
||||
t.Error("Wrong values returned")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that a connection stays valid when query results are closed early
|
||||
func TestConnQueryCloseEarly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// Immediately close query without reading any rows
|
||||
qr, err := conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
qr.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
|
||||
// Read partial response then close
|
||||
qr, err = conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
ok := qr.NextRow()
|
||||
if !ok {
|
||||
t.Fatal("qr.NextRow terminated early")
|
||||
}
|
||||
|
||||
var rr pgx.RowReader
|
||||
if n := rr.ReadInt32(qr); n != 1 {
|
||||
t.Fatalf("Expected 1 from first row, but got %v", n)
|
||||
}
|
||||
|
||||
qr.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Test that a connection stays valid when query results read incorrectly
|
||||
func TestConnQueryReadWrongTypeError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// Read a single value incorrectly
|
||||
qr, err := conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
rowsRead := 0
|
||||
|
||||
for qr.NextRow() {
|
||||
var rr pgx.RowReader
|
||||
rr.ReadDate(qr)
|
||||
rowsRead++
|
||||
}
|
||||
|
||||
if rowsRead != 1 {
|
||||
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
|
||||
}
|
||||
|
||||
if qr.Err() == nil {
|
||||
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't")
|
||||
}
|
||||
|
||||
// Read too many values
|
||||
qr, err = conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
rowsRead = 0
|
||||
|
||||
for qr.NextRow() {
|
||||
var rr pgx.RowReader
|
||||
rr.ReadInt32(qr)
|
||||
rr.ReadInt32(qr)
|
||||
rowsRead++
|
||||
}
|
||||
|
||||
if rowsRead != 1 {
|
||||
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
|
||||
}
|
||||
|
||||
if qr.Err() == nil {
|
||||
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Test that a connection stays valid when query results read incorrectly
|
||||
func TestConnQueryReadTooManyValues(t *testing.T) {
|
||||
// t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// Read too many values
|
||||
qr, err := conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
rowsRead := 0
|
||||
|
||||
for qr.NextRow() {
|
||||
var rr pgx.RowReader
|
||||
rr.ReadInt32(qr)
|
||||
rr.ReadInt32(qr)
|
||||
rowsRead++
|
||||
}
|
||||
|
||||
if rowsRead != 1 {
|
||||
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
|
||||
}
|
||||
|
||||
if qr.Err() == nil {
|
||||
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnectionSelectValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue