From c7d03eb5554c78db59563b04630e989b07b3fe45 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 24 Apr 2019 15:57:50 -0500 Subject: [PATCH] Add RowsFromResultReader --- conn.go | 16 ++++++++------- query_test.go | 32 +++++++++++++++++++++++++++++ rows.go | 57 +++++++++++++++++++++++++++++---------------------- 3 files changed, 73 insertions(+), 32 deletions(-) diff --git a/conn.go b/conn.go index dbfc4017..28b5546f 100644 --- a/conn.go +++ b/conn.go @@ -242,7 +242,7 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) } for i := range ps.FieldDescriptions { - c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i]) + pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i]) } if name != "" { @@ -482,7 +482,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) ( ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) } for i := range ps.FieldDescriptions { - c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i]) + pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i]) } arguments, err = convertDriverValuers(arguments) @@ -573,7 +573,7 @@ func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg // pgproto3FieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a // FieldDescription. -func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) { +func pgproto3FieldDescriptionToPgxFieldDescription(connInfo *pgtype.ConnInfo, src *pgproto3.FieldDescription, dst *FieldDescription) { dst.Name = string(src.Name) dst.Table = pgtype.OID(src.TableOID) dst.AttributeNumber = src.TableAttributeNumber @@ -582,7 +582,7 @@ func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.Field dst.Modifier = src.TypeModifier dst.FormatCode = src.Format - if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok { + if dt, ok := connInfo.DataTypeForOID(dst.DataType); ok { dst.DataTypeName = dt.Name } } @@ -595,7 +595,8 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows { r := &c.preallocatedRows[len(c.preallocatedRows)-1] c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] - r.conn = c + r.logger = c + r.connInfo = c.ConnInfo r.startTime = time.Now() r.sql = sql r.args = args @@ -624,7 +625,8 @@ optionLoop: } rows := &connRows{ - conn: c, + logger: c, + connInfo: c.ConnInfo, startTime: time.Now(), sql: sql, args: args, @@ -654,7 +656,7 @@ optionLoop: ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) } for i := range ps.FieldDescriptions { - c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i]) + pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i]) } } rows.sql = ps.SQL diff --git a/query_test.go b/query_test.go index 7d799343..a0c807c8 100644 --- a/query_test.go +++ b/query_test.go @@ -1273,3 +1273,35 @@ func TestQueryCloseBefore(t *testing.T) { t.Error("Expected bytes to be sent to server") } } + +func TestRowsFromResultReader(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + resultReader := conn.PgConn().ExecParams(context.Background(), "select generate_series(1,$1)", [][]byte{[]byte("10")}, nil, nil, nil) + + var sum, rowCount int32 + + rows := pgx.RowsFromResultReader(conn.ConnInfo, resultReader) + defer rows.Close() + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", rows.Err()) + } + + if rowCount != 10 { + t.Error("wrong number of rows") + } + if sum != 55 { + t.Error("Wrong values returned") + } +} diff --git a/rows.go b/rows.go index dece31ca..b57f16fa 100644 --- a/rows.go +++ b/rows.go @@ -67,10 +67,15 @@ func (r *connRow) Scan(dest ...interface{}) (err error) { return rows.Err() } +type rowLog interface { + shouldLog(lvl LogLevel) bool + log(lvl LogLevel, msg string, data map[string]interface{}) +} + // connRows implements the Rows interface for Conn.Query. type connRows struct { - conn *Conn - batch *Batch + logger rowLog + connInfo *pgtype.ConnInfo values [][]byte fields []FieldDescription rowCount int @@ -81,8 +86,7 @@ type connRows struct { args []interface{} closed bool - resultReader *pgconn.ResultReader - multiResultReader *pgconn.MultiResultReader + resultReader *pgconn.ResultReader } func (rows *connRows) FieldDescriptions() []FieldDescription { @@ -103,25 +107,16 @@ func (rows *connRows) Close() { } } - if rows.multiResultReader != nil { - closeErr := rows.multiResultReader.Close() + if rows.logger != nil { if rows.err == nil { - rows.err = closeErr + if rows.logger.shouldLog(LogLevelInfo) { + endTime := time.Now() + rows.logger.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) + } + } else if rows.logger.shouldLog(LogLevelError) { + rows.logger.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) } } - - if rows.err == nil { - if rows.conn.shouldLog(LogLevelInfo) { - endTime := time.Now() - rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) - } - } else if rows.conn.shouldLog(LogLevelError) { - rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) - } - - if rows.batch != nil && rows.err != nil { - rows.batch.die(rows.err) - } } func (rows *connRows) Err() error { @@ -149,7 +144,7 @@ func (rows *connRows) Next() bool { rrFieldDescriptions := rows.resultReader.FieldDescriptions() rows.fields = make([]FieldDescription, len(rrFieldDescriptions)) for i := range rrFieldDescriptions { - rows.conn.pgproto3FieldDescriptionToPgxFieldDescription(&rrFieldDescriptions[i], &rows.fields[i]) + pgproto3FieldDescriptionToPgxFieldDescription(rows.connInfo, &rrFieldDescriptions[i], &rows.fields[i]) } } rows.rowCount++ @@ -191,7 +186,7 @@ func (rows *connRows) Scan(dest ...interface{}) error { continue } - err := rows.conn.ConnInfo.Scan(fd.DataType, fd.FormatCode, buf, d) + err := rows.connInfo.Scan(fd.DataType, fd.FormatCode, buf, d) if err != nil { rows.fatal(scanArgError{col: i, err: err}) return err @@ -216,7 +211,7 @@ func (rows *connRows) Values() ([]interface{}, error) { continue } - if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { + if dt, ok := rows.connInfo.DataTypeForOID(fd.DataType); ok { value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value) switch fd.FormatCode { @@ -225,7 +220,7 @@ func (rows *connRows) Values() ([]interface{}, error) { if decoder == nil { decoder = &pgtype.GenericText{} } - err := decoder.DecodeText(rows.conn.ConnInfo, buf) + err := decoder.DecodeText(rows.connInfo, buf) if err != nil { rows.fatal(err) } @@ -235,7 +230,7 @@ func (rows *connRows) Values() ([]interface{}, error) { if decoder == nil { decoder = &pgtype.GenericBinary{} } - err := decoder.DecodeBinary(rows.conn.ConnInfo, buf) + err := decoder.DecodeBinary(rows.connInfo, buf) if err != nil { rows.fatal(err) } @@ -263,3 +258,15 @@ type scanArgError struct { func (e scanArgError) Error() string { return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err) } + +// RowsFromResultReader wraps a *pgconn.ResultReader in a Rows wrapper so a more convenient scanning interface can be +// used. +// +// In most cases, the appropriate pgx query methods should be used instead of sending a query with pgconn and reading +// the results with pgx. +func RowsFromResultReader(connInfo *pgtype.ConnInfo, rr *pgconn.ResultReader) Rows { + return &connRows{ + connInfo: connInfo, + resultReader: rr, + } +}