mirror of https://github.com/jackc/pgx.git
Add RowsFromResultReader
parent
a19ca0638f
commit
c7d03eb555
16
conn.go
16
conn.go
|
@ -242,7 +242,7 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
|
||||||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||||
}
|
}
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.FieldDescriptions {
|
||||||
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
|
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
if name != "" {
|
if name != "" {
|
||||||
|
@ -482,7 +482,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
|
||||||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||||
}
|
}
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.FieldDescriptions {
|
||||||
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
|
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
arguments, err = convertDriverValuers(arguments)
|
arguments, err = convertDriverValuers(arguments)
|
||||||
|
@ -573,7 +573,7 @@ func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg
|
||||||
|
|
||||||
// pgproto3FieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
|
// pgproto3FieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
|
||||||
// FieldDescription.
|
// FieldDescription.
|
||||||
func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) {
|
func pgproto3FieldDescriptionToPgxFieldDescription(connInfo *pgtype.ConnInfo, src *pgproto3.FieldDescription, dst *FieldDescription) {
|
||||||
dst.Name = string(src.Name)
|
dst.Name = string(src.Name)
|
||||||
dst.Table = pgtype.OID(src.TableOID)
|
dst.Table = pgtype.OID(src.TableOID)
|
||||||
dst.AttributeNumber = src.TableAttributeNumber
|
dst.AttributeNumber = src.TableAttributeNumber
|
||||||
|
@ -582,7 +582,7 @@ func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.Field
|
||||||
dst.Modifier = src.TypeModifier
|
dst.Modifier = src.TypeModifier
|
||||||
dst.FormatCode = src.Format
|
dst.FormatCode = src.Format
|
||||||
|
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok {
|
if dt, ok := connInfo.DataTypeForOID(dst.DataType); ok {
|
||||||
dst.DataTypeName = dt.Name
|
dst.DataTypeName = dt.Name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -595,7 +595,8 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
||||||
r := &c.preallocatedRows[len(c.preallocatedRows)-1]
|
r := &c.preallocatedRows[len(c.preallocatedRows)-1]
|
||||||
c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
|
c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
|
||||||
|
|
||||||
r.conn = c
|
r.logger = c
|
||||||
|
r.connInfo = c.ConnInfo
|
||||||
r.startTime = time.Now()
|
r.startTime = time.Now()
|
||||||
r.sql = sql
|
r.sql = sql
|
||||||
r.args = args
|
r.args = args
|
||||||
|
@ -624,7 +625,8 @@ optionLoop:
|
||||||
}
|
}
|
||||||
|
|
||||||
rows := &connRows{
|
rows := &connRows{
|
||||||
conn: c,
|
logger: c,
|
||||||
|
connInfo: c.ConnInfo,
|
||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
sql: sql,
|
sql: sql,
|
||||||
args: args,
|
args: args,
|
||||||
|
@ -654,7 +656,7 @@ optionLoop:
|
||||||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||||
}
|
}
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.FieldDescriptions {
|
||||||
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
|
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rows.sql = ps.SQL
|
rows.sql = ps.SQL
|
||||||
|
|
|
@ -1273,3 +1273,35 @@ func TestQueryCloseBefore(t *testing.T) {
|
||||||
t.Error("Expected bytes to be sent to server")
|
t.Error("Expected bytes to be sent to server")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRowsFromResultReader(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
resultReader := conn.PgConn().ExecParams(context.Background(), "select generate_series(1,$1)", [][]byte{[]byte("10")}, nil, nil, nil)
|
||||||
|
|
||||||
|
var sum, rowCount int32
|
||||||
|
|
||||||
|
rows := pgx.RowsFromResultReader(conn.ConnInfo, resultReader)
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var n int32
|
||||||
|
rows.Scan(&n)
|
||||||
|
sum += n
|
||||||
|
rowCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Fatalf("conn.Query failed: %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowCount != 10 {
|
||||||
|
t.Error("wrong number of rows")
|
||||||
|
}
|
||||||
|
if sum != 55 {
|
||||||
|
t.Error("Wrong values returned")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
57
rows.go
57
rows.go
|
@ -67,10 +67,15 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
|
||||||
return rows.Err()
|
return rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type rowLog interface {
|
||||||
|
shouldLog(lvl LogLevel) bool
|
||||||
|
log(lvl LogLevel, msg string, data map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
// connRows implements the Rows interface for Conn.Query.
|
// connRows implements the Rows interface for Conn.Query.
|
||||||
type connRows struct {
|
type connRows struct {
|
||||||
conn *Conn
|
logger rowLog
|
||||||
batch *Batch
|
connInfo *pgtype.ConnInfo
|
||||||
values [][]byte
|
values [][]byte
|
||||||
fields []FieldDescription
|
fields []FieldDescription
|
||||||
rowCount int
|
rowCount int
|
||||||
|
@ -81,8 +86,7 @@ type connRows struct {
|
||||||
args []interface{}
|
args []interface{}
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
resultReader *pgconn.ResultReader
|
resultReader *pgconn.ResultReader
|
||||||
multiResultReader *pgconn.MultiResultReader
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *connRows) FieldDescriptions() []FieldDescription {
|
func (rows *connRows) FieldDescriptions() []FieldDescription {
|
||||||
|
@ -103,25 +107,16 @@ func (rows *connRows) Close() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.multiResultReader != nil {
|
if rows.logger != nil {
|
||||||
closeErr := rows.multiResultReader.Close()
|
|
||||||
if rows.err == nil {
|
if rows.err == nil {
|
||||||
rows.err = closeErr
|
if rows.logger.shouldLog(LogLevelInfo) {
|
||||||
|
endTime := time.Now()
|
||||||
|
rows.logger.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
|
||||||
|
}
|
||||||
|
} else if rows.logger.shouldLog(LogLevelError) {
|
||||||
|
rows.logger.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.err == nil {
|
|
||||||
if rows.conn.shouldLog(LogLevelInfo) {
|
|
||||||
endTime := time.Now()
|
|
||||||
rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
|
|
||||||
}
|
|
||||||
} else if rows.conn.shouldLog(LogLevelError) {
|
|
||||||
rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.batch != nil && rows.err != nil {
|
|
||||||
rows.batch.die(rows.err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *connRows) Err() error {
|
func (rows *connRows) Err() error {
|
||||||
|
@ -149,7 +144,7 @@ func (rows *connRows) Next() bool {
|
||||||
rrFieldDescriptions := rows.resultReader.FieldDescriptions()
|
rrFieldDescriptions := rows.resultReader.FieldDescriptions()
|
||||||
rows.fields = make([]FieldDescription, len(rrFieldDescriptions))
|
rows.fields = make([]FieldDescription, len(rrFieldDescriptions))
|
||||||
for i := range rrFieldDescriptions {
|
for i := range rrFieldDescriptions {
|
||||||
rows.conn.pgproto3FieldDescriptionToPgxFieldDescription(&rrFieldDescriptions[i], &rows.fields[i])
|
pgproto3FieldDescriptionToPgxFieldDescription(rows.connInfo, &rrFieldDescriptions[i], &rows.fields[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rows.rowCount++
|
rows.rowCount++
|
||||||
|
@ -191,7 +186,7 @@ func (rows *connRows) Scan(dest ...interface{}) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := rows.conn.ConnInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
|
err := rows.connInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(scanArgError{col: i, err: err})
|
rows.fatal(scanArgError{col: i, err: err})
|
||||||
return err
|
return err
|
||||||
|
@ -216,7 +211,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok {
|
if dt, ok := rows.connInfo.DataTypeForOID(fd.DataType); ok {
|
||||||
value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value)
|
value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value)
|
||||||
|
|
||||||
switch fd.FormatCode {
|
switch fd.FormatCode {
|
||||||
|
@ -225,7 +220,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
||||||
if decoder == nil {
|
if decoder == nil {
|
||||||
decoder = &pgtype.GenericText{}
|
decoder = &pgtype.GenericText{}
|
||||||
}
|
}
|
||||||
err := decoder.DecodeText(rows.conn.ConnInfo, buf)
|
err := decoder.DecodeText(rows.connInfo, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -235,7 +230,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
||||||
if decoder == nil {
|
if decoder == nil {
|
||||||
decoder = &pgtype.GenericBinary{}
|
decoder = &pgtype.GenericBinary{}
|
||||||
}
|
}
|
||||||
err := decoder.DecodeBinary(rows.conn.ConnInfo, buf)
|
err := decoder.DecodeBinary(rows.connInfo, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -263,3 +258,15 @@ type scanArgError struct {
|
||||||
func (e scanArgError) Error() string {
|
func (e scanArgError) Error() string {
|
||||||
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
|
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RowsFromResultReader wraps a *pgconn.ResultReader in a Rows wrapper so a more convenient scanning interface can be
|
||||||
|
// used.
|
||||||
|
//
|
||||||
|
// In most cases, the appropriate pgx query methods should be used instead of sending a query with pgconn and reading
|
||||||
|
// the results with pgx.
|
||||||
|
func RowsFromResultReader(connInfo *pgtype.ConnInfo, rr *pgconn.ResultReader) Rows {
|
||||||
|
return &connRows{
|
||||||
|
connInfo: connInfo,
|
||||||
|
resultReader: rr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue