Add RowReader.CopyBytes

Implement SelectValueTo in terms of RowReader.CopyBytes
scan-io
Jack Christensen 2014-07-05 07:50:46 -05:00
parent a1fc6f513a
commit b27d828311
2 changed files with 79 additions and 68 deletions

102
conn.go
View File

@ -333,67 +333,26 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
}
}()
err = c.sendQuery(sql, arguments...)
if err != nil {
return err
}
var numRowsFound int64
var softErr error
for {
var t byte
var r *MsgReader
qr, _ := c.Query(sql, arguments...)
t, r, err = c.rxMsg()
if err != nil {
return err
for qr.NextRow() {
if len(qr.fields) != 1 {
qr.Close()
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(qr.fields))}
}
if t == dataRow {
numRowsFound++
if numRowsFound > 1 {
softErr = NotSingleRowError{RowCount: numRowsFound}
}
if softErr != nil {
// Read and discard rest of message
continue
}
softErr = c.rxDataRowValueTo(w, r)
} else {
switch t {
case readyForQuery:
c.rxReadyForQuery(r)
return softErr
case rowDescription:
case commandComplete:
case bindComplete:
default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
softErr = e
}
}
numRowsFound++
if numRowsFound != 1 {
qr.Close()
return NotSingleRowError{RowCount: numRowsFound}
}
var rr RowReader
rr.CopyBytes(qr, w)
}
}
func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error {
columnCount := r.ReadInt16()
if columnCount != 1 {
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
}
valueSize := r.ReadInt32()
if valueSize == -1 {
return errors.New("SelectValueTo cannot handle null")
}
r.CopyN(w, valueSize)
return r.Err()
return qr.Err()
}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
@ -554,9 +513,9 @@ func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
return 0
}
// TODO - do something about nulls
if size == -1 {
panic("Can't handle nulls")
qr.Fatal(errors.New("Unexpected null"))
return 0
}
return decodeInt4(qr, fd, size)
@ -568,9 +527,9 @@ func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
return 0
}
// TODO - do something about nulls
if size == -1 {
panic("Can't handle nulls")
qr.Fatal(errors.New("Unexpected null"))
return 0
}
return decodeInt8(qr, fd, size)
@ -584,9 +543,9 @@ func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
return zeroTime
}
// TODO - do something about nulls
if size == -1 {
panic("Can't handle nulls")
qr.Fatal(errors.New("Unexpected null"))
return zeroTime
}
return decodeTimestampTz(qr, fd, size)
@ -600,9 +559,9 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
return zeroTime
}
// TODO - do something about nulls
if size == -1 {
panic("Can't handle nulls")
qr.Fatal(errors.New("Unexpected null"))
return zeroTime
}
return decodeDate(qr, fd, size)
@ -614,6 +573,11 @@ func (rr *RowReader) ReadString(qr *QueryResult) string {
return ""
}
if size == -1 {
qr.Fatal(errors.New("Unexpected null"))
return ""
}
return decodeText(qr, fd, size)
}
@ -634,6 +598,20 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
}
}
func (rr *RowReader) CopyBytes(qr *QueryResult, w io.Writer) {
_, size, ok := qr.NextColumn()
if !ok {
return
}
if size == -1 {
qr.Fatal(errors.New("Unexpected null"))
return
}
qr.MsgReader().CopyN(w, size)
}
type QueryResult struct {
pool *ConnPool
conn *Conn

View File

@ -462,6 +462,42 @@ func TestConnQueryReadTooManyValues(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryResultCopyBytes(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var mimeType string
var buf bytes.Buffer
qr, err := conn.Query("select 'application/json', '[1,2,3,4,5]'::json")
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
for qr.NextRow() {
var rr pgx.RowReader
mimeType = rr.ReadString(qr)
rr.CopyBytes(qr, &buf)
}
qr.Close()
if qr.Err() != nil {
t.Fatalf("conn.Query failed: ", err)
}
if mimeType != "application/json" {
t.Errorf(`Expected mimeType to be "application/json", but it was "%v"`, mimeType)
}
if bytes.Compare(buf.Bytes(), []byte("[1,2,3,4,5]")) != 0 {
t.Fatalf("CopyBytes did not write expected data: %v", string(buf.Bytes()))
}
ensureConnValid(t, conn)
}
func TestConnectionSelectValue(t *testing.T) {
t.Parallel()
@ -546,14 +582,11 @@ func TestConnectionSelectValueTo(t *testing.T) {
// Null
err = conn.SelectValueTo(&buf, "select null")
if err == nil || err.Error() != "SelectValueTo cannot handle null" {
if err == nil || err.Error() != "Unexpected null" {
t.Fatalf("Expected null error: %#v", err)
}
if conn.IsAlive() {
mustSelectValue(t, conn, "select 1") // ensure it really is alive and usable
} else {
t.Fatal("SelectValueTo null error should not have killed connection")
}
ensureConnValid(t, conn)
}
func TestPrepare(t *testing.T) {