mirror of
https://github.com/jackc/pgx.git
synced 2025-04-27 13:14:32 +00:00
Safely handle bad reads of QueryResult
This commit is contained in:
parent
718c4dfdc8
commit
78b8e0b6f2
47
conn.go
47
conn.go
@ -593,7 +593,10 @@ type RowReader struct{}
|
|||||||
// TODO - Read*...
|
// TODO - Read*...
|
||||||
|
|
||||||
func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
|
func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
|
||||||
fd, size := qr.NextColumn()
|
fd, size, ok := qr.NextColumn()
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
// TODO - do something about nulls
|
// TODO - do something about nulls
|
||||||
if size == -1 {
|
if size == -1 {
|
||||||
@ -604,7 +607,10 @@ func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
|
func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
|
||||||
fd, size := qr.NextColumn()
|
fd, size, ok := qr.NextColumn()
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
// TODO - do something about nulls
|
// TODO - do something about nulls
|
||||||
if size == -1 {
|
if size == -1 {
|
||||||
@ -615,7 +621,12 @@ func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
|
func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
|
||||||
fd, size := qr.NextColumn()
|
var zeroTime time.Time
|
||||||
|
|
||||||
|
fd, size, ok := qr.NextColumn()
|
||||||
|
if !ok {
|
||||||
|
return zeroTime
|
||||||
|
}
|
||||||
|
|
||||||
// TODO - do something about nulls
|
// TODO - do something about nulls
|
||||||
if size == -1 {
|
if size == -1 {
|
||||||
@ -626,7 +637,12 @@ func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
|
func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
|
||||||
fd, size := qr.NextColumn()
|
var zeroTime time.Time
|
||||||
|
|
||||||
|
fd, size, ok := qr.NextColumn()
|
||||||
|
if !ok {
|
||||||
|
return zeroTime
|
||||||
|
}
|
||||||
|
|
||||||
// TODO - do something about nulls
|
// TODO - do something about nulls
|
||||||
if size == -1 {
|
if size == -1 {
|
||||||
@ -637,12 +653,19 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rr *RowReader) ReadString(qr *QueryResult) string {
|
func (rr *RowReader) ReadString(qr *QueryResult) string {
|
||||||
_, size := qr.NextColumn()
|
_, size, ok := qr.NextColumn()
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
return qr.mr.ReadString(size)
|
return qr.mr.ReadString(size)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
||||||
fd, size := qr.NextColumn()
|
fd, size, ok := qr.NextColumn()
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if size > -1 {
|
if size > -1 {
|
||||||
if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil {
|
if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil {
|
||||||
@ -768,12 +791,20 @@ func (qr *QueryResult) NextRow() bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qr *QueryResult) NextColumn() (*FieldDescription, int32) {
|
func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) {
|
||||||
|
if qr.closed {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
if len(qr.fields) <= qr.columnIdx {
|
||||||
|
qr.Fatal(ProtocolError("No next column available"))
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
fd := &qr.fields[qr.columnIdx]
|
fd := &qr.fields[qr.columnIdx]
|
||||||
qr.columnIdx++
|
qr.columnIdx++
|
||||||
size := qr.mr.ReadInt32()
|
size := qr.mr.ReadInt32()
|
||||||
|
|
||||||
return fd, size
|
return fd, size, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - document
|
// TODO - document
|
||||||
|
153
conn_test.go
153
conn_test.go
@ -309,6 +309,159 @@ func TestConnQuery(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do a simple query to ensure the connection is still usable
|
||||||
|
func ensureConnValid(t *testing.T, conn *pgx.Conn) {
|
||||||
|
var sum, rowCount int32
|
||||||
|
|
||||||
|
qr, err := conn.Query("select generate_series(1,$1)", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
defer qr.Close()
|
||||||
|
|
||||||
|
for qr.NextRow() {
|
||||||
|
var rr pgx.RowReader
|
||||||
|
sum += rr.ReadInt32(qr)
|
||||||
|
rowCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if qr.Err() != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowCount != 10 {
|
||||||
|
t.Error("Select called onDataRow wrong number of times")
|
||||||
|
}
|
||||||
|
if sum != 55 {
|
||||||
|
t.Error("Wrong values returned")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that a connection stays valid when query results are closed early
|
||||||
|
func TestConnQueryCloseEarly(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
// Immediately close query without reading any rows
|
||||||
|
qr, err := conn.Query("select generate_series(1,$1)", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
qr.Close()
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
|
||||||
|
// Read partial response then close
|
||||||
|
qr, err = conn.Query("select generate_series(1,$1)", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := qr.NextRow()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("qr.NextRow terminated early")
|
||||||
|
}
|
||||||
|
|
||||||
|
var rr pgx.RowReader
|
||||||
|
if n := rr.ReadInt32(qr); n != 1 {
|
||||||
|
t.Fatalf("Expected 1 from first row, but got %v", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
qr.Close()
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that a connection stays valid when query results read incorrectly
|
||||||
|
func TestConnQueryReadWrongTypeError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
// Read a single value incorrectly
|
||||||
|
qr, err := conn.Query("select generate_series(1,$1)", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsRead := 0
|
||||||
|
|
||||||
|
for qr.NextRow() {
|
||||||
|
var rr pgx.RowReader
|
||||||
|
rr.ReadDate(qr)
|
||||||
|
rowsRead++
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowsRead != 1 {
|
||||||
|
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if qr.Err() == nil {
|
||||||
|
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read too many values
|
||||||
|
qr, err = conn.Query("select generate_series(1,$1)", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsRead = 0
|
||||||
|
|
||||||
|
for qr.NextRow() {
|
||||||
|
var rr pgx.RowReader
|
||||||
|
rr.ReadInt32(qr)
|
||||||
|
rr.ReadInt32(qr)
|
||||||
|
rowsRead++
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowsRead != 1 {
|
||||||
|
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if qr.Err() == nil {
|
||||||
|
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that a connection stays valid when query results read incorrectly
|
||||||
|
func TestConnQueryReadTooManyValues(t *testing.T) {
|
||||||
|
// t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
// Read too many values
|
||||||
|
qr, err := conn.Query("select generate_series(1,$1)", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsRead := 0
|
||||||
|
|
||||||
|
for qr.NextRow() {
|
||||||
|
var rr pgx.RowReader
|
||||||
|
rr.ReadInt32(qr)
|
||||||
|
rr.ReadInt32(qr)
|
||||||
|
rowsRead++
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowsRead != 1 {
|
||||||
|
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if qr.Err() == nil {
|
||||||
|
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnectionSelectValue(t *testing.T) {
|
func TestConnectionSelectValue(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user