diff --git a/conn.go b/conn.go index b5d37fe4..129cbef1 100644 --- a/conn.go +++ b/conn.go @@ -333,67 +333,26 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{}) } }() - err = c.sendQuery(sql, arguments...) - if err != nil { - return err - } - var numRowsFound int64 - var softErr error - for { - var t byte - var r *MsgReader + qr, _ := c.Query(sql, arguments...) - t, r, err = c.rxMsg() - if err != nil { - return err + for qr.NextRow() { + if len(qr.fields) != 1 { + qr.Close() + return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(qr.fields))} } - if t == dataRow { - numRowsFound++ - - if numRowsFound > 1 { - softErr = NotSingleRowError{RowCount: numRowsFound} - } - - if softErr != nil { - // Read and discard rest of message - continue - } - - softErr = c.rxDataRowValueTo(w, r) - } else { - switch t { - case readyForQuery: - c.rxReadyForQuery(r) - return softErr - case rowDescription: - case commandComplete: - case bindComplete: - default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { - softErr = e - } - } + numRowsFound++ + if numRowsFound != 1 { + qr.Close() + return NotSingleRowError{RowCount: numRowsFound} } + + var rr RowReader + rr.CopyBytes(qr, w) } -} - -func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error { - columnCount := r.ReadInt16() - if columnCount != 1 { - return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount} - } - - valueSize := r.ReadInt32() - if valueSize == -1 { - return errors.New("SelectValueTo cannot handle null") - } - - r.CopyN(w, valueSize) - - return r.Err() + return qr.Err() } // Prepare creates a prepared statement with name and sql. sql can contain placeholders @@ -554,9 +513,9 @@ func (rr *RowReader) ReadInt32(qr *QueryResult) int32 { return 0 } - // TODO - do something about nulls if size == -1 { - panic("Can't handle nulls") + qr.Fatal(errors.New("Unexpected null")) + return 0 } return decodeInt4(qr, fd, size) @@ -568,9 +527,9 @@ func (rr *RowReader) ReadInt64(qr *QueryResult) int64 { return 0 } - // TODO - do something about nulls if size == -1 { - panic("Can't handle nulls") + qr.Fatal(errors.New("Unexpected null")) + return 0 } return decodeInt8(qr, fd, size) @@ -584,9 +543,9 @@ func (rr *RowReader) ReadTime(qr *QueryResult) time.Time { return zeroTime } - // TODO - do something about nulls if size == -1 { - panic("Can't handle nulls") + qr.Fatal(errors.New("Unexpected null")) + return zeroTime } return decodeTimestampTz(qr, fd, size) @@ -600,9 +559,9 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time { return zeroTime } - // TODO - do something about nulls if size == -1 { - panic("Can't handle nulls") + qr.Fatal(errors.New("Unexpected null")) + return zeroTime } return decodeDate(qr, fd, size) @@ -614,6 +573,11 @@ func (rr *RowReader) ReadString(qr *QueryResult) string { return "" } + if size == -1 { + qr.Fatal(errors.New("Unexpected null")) + return "" + } + return decodeText(qr, fd, size) } @@ -634,6 +598,20 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} { } } +func (rr *RowReader) CopyBytes(qr *QueryResult, w io.Writer) { + _, size, ok := qr.NextColumn() + if !ok { + return + } + + if size == -1 { + qr.Fatal(errors.New("Unexpected null")) + return + } + + qr.MsgReader().CopyN(w, size) +} + type QueryResult struct { pool *ConnPool conn *Conn diff --git a/conn_test.go b/conn_test.go index 86c0554f..60896848 100644 --- a/conn_test.go +++ b/conn_test.go @@ -462,6 +462,42 @@ func TestConnQueryReadTooManyValues(t *testing.T) { ensureConnValid(t, conn) } +func TestQueryResultCopyBytes(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var mimeType string + var buf bytes.Buffer + + qr, err := conn.Query("select 'application/json', '[1,2,3,4,5]'::json") + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + for qr.NextRow() { + var rr pgx.RowReader + mimeType = rr.ReadString(qr) + rr.CopyBytes(qr, &buf) + } + qr.Close() + + if qr.Err() != nil { + t.Fatalf("conn.Query failed: ", err) + } + + if mimeType != "application/json" { + t.Errorf(`Expected mimeType to be "application/json", but it was "%v"`, mimeType) + } + + if bytes.Compare(buf.Bytes(), []byte("[1,2,3,4,5]")) != 0 { + t.Fatalf("CopyBytes did not write expected data: %v", string(buf.Bytes())) + } + + ensureConnValid(t, conn) +} + func TestConnectionSelectValue(t *testing.T) { t.Parallel() @@ -546,14 +582,11 @@ func TestConnectionSelectValueTo(t *testing.T) { // Null err = conn.SelectValueTo(&buf, "select null") - if err == nil || err.Error() != "SelectValueTo cannot handle null" { + if err == nil || err.Error() != "Unexpected null" { t.Fatalf("Expected null error: %#v", err) } - if conn.IsAlive() { - mustSelectValue(t, conn, "select 1") // ensure it really is alive and usable - } else { - t.Fatal("SelectValueTo null error should not have killed connection") - } + + ensureConnValid(t, conn) } func TestPrepare(t *testing.T) {