Add RowsFromResultReader

pull/483/head
Jack Christensen 2019-04-24 15:57:50 -05:00
parent a19ca0638f
commit c7d03eb555
3 changed files with 73 additions and 32 deletions

16
conn.go
View File

@ -242,7 +242,7 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
}
if name != "" {
@ -482,7 +482,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
}
arguments, err = convertDriverValuers(arguments)
@ -573,7 +573,7 @@ func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg
// pgproto3FieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
// FieldDescription.
func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) {
func pgproto3FieldDescriptionToPgxFieldDescription(connInfo *pgtype.ConnInfo, src *pgproto3.FieldDescription, dst *FieldDescription) {
dst.Name = string(src.Name)
dst.Table = pgtype.OID(src.TableOID)
dst.AttributeNumber = src.TableAttributeNumber
@ -582,7 +582,7 @@ func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.Field
dst.Modifier = src.TypeModifier
dst.FormatCode = src.Format
if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok {
if dt, ok := connInfo.DataTypeForOID(dst.DataType); ok {
dst.DataTypeName = dt.Name
}
}
@ -595,7 +595,8 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
r := &c.preallocatedRows[len(c.preallocatedRows)-1]
c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
r.conn = c
r.logger = c
r.connInfo = c.ConnInfo
r.startTime = time.Now()
r.sql = sql
r.args = args
@ -624,7 +625,8 @@ optionLoop:
}
rows := &connRows{
conn: c,
logger: c,
connInfo: c.ConnInfo,
startTime: time.Now(),
sql: sql,
args: args,
@ -654,7 +656,7 @@ optionLoop:
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
}
}
rows.sql = ps.SQL

View File

@ -1273,3 +1273,35 @@ func TestQueryCloseBefore(t *testing.T) {
t.Error("Expected bytes to be sent to server")
}
}
func TestRowsFromResultReader(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
resultReader := conn.PgConn().ExecParams(context.Background(), "select generate_series(1,$1)", [][]byte{[]byte("10")}, nil, nil, nil)
var sum, rowCount int32
rows := pgx.RowsFromResultReader(conn.ConnInfo, resultReader)
defer rows.Close()
for rows.Next() {
var n int32
rows.Scan(&n)
sum += n
rowCount++
}
if rows.Err() != nil {
t.Fatalf("conn.Query failed: %v", rows.Err())
}
if rowCount != 10 {
t.Error("wrong number of rows")
}
if sum != 55 {
t.Error("Wrong values returned")
}
}

57
rows.go
View File

@ -67,10 +67,15 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
return rows.Err()
}
type rowLog interface {
shouldLog(lvl LogLevel) bool
log(lvl LogLevel, msg string, data map[string]interface{})
}
// connRows implements the Rows interface for Conn.Query.
type connRows struct {
conn *Conn
batch *Batch
logger rowLog
connInfo *pgtype.ConnInfo
values [][]byte
fields []FieldDescription
rowCount int
@ -81,8 +86,7 @@ type connRows struct {
args []interface{}
closed bool
resultReader *pgconn.ResultReader
multiResultReader *pgconn.MultiResultReader
resultReader *pgconn.ResultReader
}
func (rows *connRows) FieldDescriptions() []FieldDescription {
@ -103,25 +107,16 @@ func (rows *connRows) Close() {
}
}
if rows.multiResultReader != nil {
closeErr := rows.multiResultReader.Close()
if rows.logger != nil {
if rows.err == nil {
rows.err = closeErr
if rows.logger.shouldLog(LogLevelInfo) {
endTime := time.Now()
rows.logger.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
}
} else if rows.logger.shouldLog(LogLevelError) {
rows.logger.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
}
}
if rows.err == nil {
if rows.conn.shouldLog(LogLevelInfo) {
endTime := time.Now()
rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
}
} else if rows.conn.shouldLog(LogLevelError) {
rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
}
if rows.batch != nil && rows.err != nil {
rows.batch.die(rows.err)
}
}
func (rows *connRows) Err() error {
@ -149,7 +144,7 @@ func (rows *connRows) Next() bool {
rrFieldDescriptions := rows.resultReader.FieldDescriptions()
rows.fields = make([]FieldDescription, len(rrFieldDescriptions))
for i := range rrFieldDescriptions {
rows.conn.pgproto3FieldDescriptionToPgxFieldDescription(&rrFieldDescriptions[i], &rows.fields[i])
pgproto3FieldDescriptionToPgxFieldDescription(rows.connInfo, &rrFieldDescriptions[i], &rows.fields[i])
}
}
rows.rowCount++
@ -191,7 +186,7 @@ func (rows *connRows) Scan(dest ...interface{}) error {
continue
}
err := rows.conn.ConnInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
err := rows.connInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
if err != nil {
rows.fatal(scanArgError{col: i, err: err})
return err
@ -216,7 +211,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
continue
}
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok {
if dt, ok := rows.connInfo.DataTypeForOID(fd.DataType); ok {
value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value)
switch fd.FormatCode {
@ -225,7 +220,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
if decoder == nil {
decoder = &pgtype.GenericText{}
}
err := decoder.DecodeText(rows.conn.ConnInfo, buf)
err := decoder.DecodeText(rows.connInfo, buf)
if err != nil {
rows.fatal(err)
}
@ -235,7 +230,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
if decoder == nil {
decoder = &pgtype.GenericBinary{}
}
err := decoder.DecodeBinary(rows.conn.ConnInfo, buf)
err := decoder.DecodeBinary(rows.connInfo, buf)
if err != nil {
rows.fatal(err)
}
@ -263,3 +258,15 @@ type scanArgError struct {
func (e scanArgError) Error() string {
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
}
// RowsFromResultReader wraps a *pgconn.ResultReader in a Rows wrapper so a more convenient scanning interface can be
// used.
//
// In most cases, the appropriate pgx query methods should be used instead of sending a query with pgconn and reading
// the results with pgx.
func RowsFromResultReader(connInfo *pgtype.ConnInfo, rr *pgconn.ResultReader) Rows {
return &connRows{
connInfo: connInfo,
resultReader: rr,
}
}