mirror of https://github.com/jackc/pgx.git
Add RowsFromResultReader
parent
a19ca0638f
commit
c7d03eb555
16
conn.go
16
conn.go
|
@ -242,7 +242,7 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
|
|||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||
}
|
||||
for i := range ps.FieldDescriptions {
|
||||
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
|
||||
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
|
@ -482,7 +482,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||
}
|
||||
for i := range ps.FieldDescriptions {
|
||||
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
|
||||
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
|
||||
}
|
||||
|
||||
arguments, err = convertDriverValuers(arguments)
|
||||
|
@ -573,7 +573,7 @@ func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg
|
|||
|
||||
// pgproto3FieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
|
||||
// FieldDescription.
|
||||
func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) {
|
||||
func pgproto3FieldDescriptionToPgxFieldDescription(connInfo *pgtype.ConnInfo, src *pgproto3.FieldDescription, dst *FieldDescription) {
|
||||
dst.Name = string(src.Name)
|
||||
dst.Table = pgtype.OID(src.TableOID)
|
||||
dst.AttributeNumber = src.TableAttributeNumber
|
||||
|
@ -582,7 +582,7 @@ func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.Field
|
|||
dst.Modifier = src.TypeModifier
|
||||
dst.FormatCode = src.Format
|
||||
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok {
|
||||
if dt, ok := connInfo.DataTypeForOID(dst.DataType); ok {
|
||||
dst.DataTypeName = dt.Name
|
||||
}
|
||||
}
|
||||
|
@ -595,7 +595,8 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
|||
r := &c.preallocatedRows[len(c.preallocatedRows)-1]
|
||||
c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
|
||||
|
||||
r.conn = c
|
||||
r.logger = c
|
||||
r.connInfo = c.ConnInfo
|
||||
r.startTime = time.Now()
|
||||
r.sql = sql
|
||||
r.args = args
|
||||
|
@ -624,7 +625,8 @@ optionLoop:
|
|||
}
|
||||
|
||||
rows := &connRows{
|
||||
conn: c,
|
||||
logger: c,
|
||||
connInfo: c.ConnInfo,
|
||||
startTime: time.Now(),
|
||||
sql: sql,
|
||||
args: args,
|
||||
|
@ -654,7 +656,7 @@ optionLoop:
|
|||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||
}
|
||||
for i := range ps.FieldDescriptions {
|
||||
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
|
||||
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
|
||||
}
|
||||
}
|
||||
rows.sql = ps.SQL
|
||||
|
|
|
@ -1273,3 +1273,35 @@ func TestQueryCloseBefore(t *testing.T) {
|
|||
t.Error("Expected bytes to be sent to server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowsFromResultReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
resultReader := conn.PgConn().ExecParams(context.Background(), "select generate_series(1,$1)", [][]byte{[]byte("10")}, nil, nil, nil)
|
||||
|
||||
var sum, rowCount int32
|
||||
|
||||
rows := pgx.RowsFromResultReader(conn.ConnInfo, resultReader)
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var n int32
|
||||
rows.Scan(&n)
|
||||
sum += n
|
||||
rowCount++
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("conn.Query failed: %v", rows.Err())
|
||||
}
|
||||
|
||||
if rowCount != 10 {
|
||||
t.Error("wrong number of rows")
|
||||
}
|
||||
if sum != 55 {
|
||||
t.Error("Wrong values returned")
|
||||
}
|
||||
}
|
||||
|
|
57
rows.go
57
rows.go
|
@ -67,10 +67,15 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
|
|||
return rows.Err()
|
||||
}
|
||||
|
||||
type rowLog interface {
|
||||
shouldLog(lvl LogLevel) bool
|
||||
log(lvl LogLevel, msg string, data map[string]interface{})
|
||||
}
|
||||
|
||||
// connRows implements the Rows interface for Conn.Query.
|
||||
type connRows struct {
|
||||
conn *Conn
|
||||
batch *Batch
|
||||
logger rowLog
|
||||
connInfo *pgtype.ConnInfo
|
||||
values [][]byte
|
||||
fields []FieldDescription
|
||||
rowCount int
|
||||
|
@ -81,8 +86,7 @@ type connRows struct {
|
|||
args []interface{}
|
||||
closed bool
|
||||
|
||||
resultReader *pgconn.ResultReader
|
||||
multiResultReader *pgconn.MultiResultReader
|
||||
resultReader *pgconn.ResultReader
|
||||
}
|
||||
|
||||
func (rows *connRows) FieldDescriptions() []FieldDescription {
|
||||
|
@ -103,25 +107,16 @@ func (rows *connRows) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
if rows.multiResultReader != nil {
|
||||
closeErr := rows.multiResultReader.Close()
|
||||
if rows.logger != nil {
|
||||
if rows.err == nil {
|
||||
rows.err = closeErr
|
||||
if rows.logger.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
rows.logger.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
|
||||
}
|
||||
} else if rows.logger.shouldLog(LogLevelError) {
|
||||
rows.logger.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
|
||||
}
|
||||
}
|
||||
|
||||
if rows.err == nil {
|
||||
if rows.conn.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
|
||||
}
|
||||
} else if rows.conn.shouldLog(LogLevelError) {
|
||||
rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
|
||||
}
|
||||
|
||||
if rows.batch != nil && rows.err != nil {
|
||||
rows.batch.die(rows.err)
|
||||
}
|
||||
}
|
||||
|
||||
func (rows *connRows) Err() error {
|
||||
|
@ -149,7 +144,7 @@ func (rows *connRows) Next() bool {
|
|||
rrFieldDescriptions := rows.resultReader.FieldDescriptions()
|
||||
rows.fields = make([]FieldDescription, len(rrFieldDescriptions))
|
||||
for i := range rrFieldDescriptions {
|
||||
rows.conn.pgproto3FieldDescriptionToPgxFieldDescription(&rrFieldDescriptions[i], &rows.fields[i])
|
||||
pgproto3FieldDescriptionToPgxFieldDescription(rows.connInfo, &rrFieldDescriptions[i], &rows.fields[i])
|
||||
}
|
||||
}
|
||||
rows.rowCount++
|
||||
|
@ -191,7 +186,7 @@ func (rows *connRows) Scan(dest ...interface{}) error {
|
|||
continue
|
||||
}
|
||||
|
||||
err := rows.conn.ConnInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
|
||||
err := rows.connInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
|
||||
if err != nil {
|
||||
rows.fatal(scanArgError{col: i, err: err})
|
||||
return err
|
||||
|
@ -216,7 +211,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
|||
continue
|
||||
}
|
||||
|
||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok {
|
||||
if dt, ok := rows.connInfo.DataTypeForOID(fd.DataType); ok {
|
||||
value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value)
|
||||
|
||||
switch fd.FormatCode {
|
||||
|
@ -225,7 +220,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
|||
if decoder == nil {
|
||||
decoder = &pgtype.GenericText{}
|
||||
}
|
||||
err := decoder.DecodeText(rows.conn.ConnInfo, buf)
|
||||
err := decoder.DecodeText(rows.connInfo, buf)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
}
|
||||
|
@ -235,7 +230,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
|||
if decoder == nil {
|
||||
decoder = &pgtype.GenericBinary{}
|
||||
}
|
||||
err := decoder.DecodeBinary(rows.conn.ConnInfo, buf)
|
||||
err := decoder.DecodeBinary(rows.connInfo, buf)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
}
|
||||
|
@ -263,3 +258,15 @@ type scanArgError struct {
|
|||
func (e scanArgError) Error() string {
|
||||
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
|
||||
}
|
||||
|
||||
// RowsFromResultReader wraps a *pgconn.ResultReader in a Rows wrapper so a more convenient scanning interface can be
|
||||
// used.
|
||||
//
|
||||
// In most cases, the appropriate pgx query methods should be used instead of sending a query with pgconn and reading
|
||||
// the results with pgx.
|
||||
func RowsFromResultReader(connInfo *pgtype.ConnInfo, rr *pgconn.ResultReader) Rows {
|
||||
return &connRows{
|
||||
connInfo: connInfo,
|
||||
resultReader: rr,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue