mirror of https://github.com/jackc/pgx.git
Add RowReader.CopyBytes
Implement SelectValueTo in terms of RowReader.CopyBytesscan-io
parent
a1fc6f513a
commit
b27d828311
102
conn.go
102
conn.go
|
@ -333,67 +333,26 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
|
|||
}
|
||||
}()
|
||||
|
||||
err = c.sendQuery(sql, arguments...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var numRowsFound int64
|
||||
var softErr error
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *MsgReader
|
||||
qr, _ := c.Query(sql, arguments...)
|
||||
|
||||
t, r, err = c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
for qr.NextRow() {
|
||||
if len(qr.fields) != 1 {
|
||||
qr.Close()
|
||||
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(qr.fields))}
|
||||
}
|
||||
|
||||
if t == dataRow {
|
||||
numRowsFound++
|
||||
|
||||
if numRowsFound > 1 {
|
||||
softErr = NotSingleRowError{RowCount: numRowsFound}
|
||||
}
|
||||
|
||||
if softErr != nil {
|
||||
// Read and discard rest of message
|
||||
continue
|
||||
}
|
||||
|
||||
softErr = c.rxDataRowValueTo(w, r)
|
||||
} else {
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
return softErr
|
||||
case rowDescription:
|
||||
case commandComplete:
|
||||
case bindComplete:
|
||||
default:
|
||||
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
|
||||
softErr = e
|
||||
}
|
||||
}
|
||||
numRowsFound++
|
||||
if numRowsFound != 1 {
|
||||
qr.Close()
|
||||
return NotSingleRowError{RowCount: numRowsFound}
|
||||
}
|
||||
|
||||
var rr RowReader
|
||||
rr.CopyBytes(qr, w)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error {
|
||||
columnCount := r.ReadInt16()
|
||||
if columnCount != 1 {
|
||||
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
|
||||
}
|
||||
|
||||
valueSize := r.ReadInt32()
|
||||
if valueSize == -1 {
|
||||
return errors.New("SelectValueTo cannot handle null")
|
||||
}
|
||||
|
||||
r.CopyN(w, valueSize)
|
||||
|
||||
return r.Err()
|
||||
return qr.Err()
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
|
||||
|
@ -554,9 +513,9 @@ func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
|
|||
return 0
|
||||
}
|
||||
|
||||
// TODO - do something about nulls
|
||||
if size == -1 {
|
||||
panic("Can't handle nulls")
|
||||
qr.Fatal(errors.New("Unexpected null"))
|
||||
return 0
|
||||
}
|
||||
|
||||
return decodeInt4(qr, fd, size)
|
||||
|
@ -568,9 +527,9 @@ func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
|
|||
return 0
|
||||
}
|
||||
|
||||
// TODO - do something about nulls
|
||||
if size == -1 {
|
||||
panic("Can't handle nulls")
|
||||
qr.Fatal(errors.New("Unexpected null"))
|
||||
return 0
|
||||
}
|
||||
|
||||
return decodeInt8(qr, fd, size)
|
||||
|
@ -584,9 +543,9 @@ func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
|
|||
return zeroTime
|
||||
}
|
||||
|
||||
// TODO - do something about nulls
|
||||
if size == -1 {
|
||||
panic("Can't handle nulls")
|
||||
qr.Fatal(errors.New("Unexpected null"))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
return decodeTimestampTz(qr, fd, size)
|
||||
|
@ -600,9 +559,9 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
|
|||
return zeroTime
|
||||
}
|
||||
|
||||
// TODO - do something about nulls
|
||||
if size == -1 {
|
||||
panic("Can't handle nulls")
|
||||
qr.Fatal(errors.New("Unexpected null"))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
return decodeDate(qr, fd, size)
|
||||
|
@ -614,6 +573,11 @@ func (rr *RowReader) ReadString(qr *QueryResult) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
qr.Fatal(errors.New("Unexpected null"))
|
||||
return ""
|
||||
}
|
||||
|
||||
return decodeText(qr, fd, size)
|
||||
}
|
||||
|
||||
|
@ -634,6 +598,20 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
|||
}
|
||||
}
|
||||
|
||||
func (rr *RowReader) CopyBytes(qr *QueryResult, w io.Writer) {
|
||||
_, size, ok := qr.NextColumn()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
qr.Fatal(errors.New("Unexpected null"))
|
||||
return
|
||||
}
|
||||
|
||||
qr.MsgReader().CopyN(w, size)
|
||||
}
|
||||
|
||||
type QueryResult struct {
|
||||
pool *ConnPool
|
||||
conn *Conn
|
||||
|
|
45
conn_test.go
45
conn_test.go
|
@ -462,6 +462,42 @@ func TestConnQueryReadTooManyValues(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryResultCopyBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var mimeType string
|
||||
var buf bytes.Buffer
|
||||
|
||||
qr, err := conn.Query("select 'application/json', '[1,2,3,4,5]'::json")
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
for qr.NextRow() {
|
||||
var rr pgx.RowReader
|
||||
mimeType = rr.ReadString(qr)
|
||||
rr.CopyBytes(qr, &buf)
|
||||
}
|
||||
qr.Close()
|
||||
|
||||
if qr.Err() != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
if mimeType != "application/json" {
|
||||
t.Errorf(`Expected mimeType to be "application/json", but it was "%v"`, mimeType)
|
||||
}
|
||||
|
||||
if bytes.Compare(buf.Bytes(), []byte("[1,2,3,4,5]")) != 0 {
|
||||
t.Fatalf("CopyBytes did not write expected data: %v", string(buf.Bytes()))
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnectionSelectValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -546,14 +582,11 @@ func TestConnectionSelectValueTo(t *testing.T) {
|
|||
|
||||
// Null
|
||||
err = conn.SelectValueTo(&buf, "select null")
|
||||
if err == nil || err.Error() != "SelectValueTo cannot handle null" {
|
||||
if err == nil || err.Error() != "Unexpected null" {
|
||||
t.Fatalf("Expected null error: %#v", err)
|
||||
}
|
||||
if conn.IsAlive() {
|
||||
mustSelectValue(t, conn, "select 1") // ensure it really is alive and usable
|
||||
} else {
|
||||
t.Fatal("SelectValueTo null error should not have killed connection")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue