From b5ce0220f85d032432263b9ac366fcbb7aa07106 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 2 Sep 2019 09:53:26 -0500 Subject: [PATCH] Add CommandTag to Rows interface This allows handling queries where it is unknown if there will be a result set or not. If it is not a result set returning query the command tag will still be available. --- pgxpool/rows.go | 6 ++++++ query_test.go | 15 +++++++++++++++ rows.go | 33 ++++++++++++++++++++++----------- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/pgxpool/rows.go b/pgxpool/rows.go index be29cbf4..59279bc2 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -1,6 +1,7 @@ package pgxpool import ( + "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v4" ) @@ -11,6 +12,7 @@ type errRows struct { func (errRows) Close() {} func (e errRows) Err() error { return e.err } +func (errRows) CommandTag() pgconn.CommandTag { return nil } func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } func (errRows) Next() bool { return false } func (e errRows) Scan(dest ...interface{}) error { return e.err } @@ -43,6 +45,10 @@ func (rows *poolRows) Err() error { return rows.r.Err() } +func (rows *poolRows) CommandTag() pgconn.CommandTag { + return rows.CommandTag() +} + func (rows *poolRows) FieldDescriptions() []pgproto3.FieldDescription { return rows.r.FieldDescriptions() } diff --git a/query_test.go b/query_test.go index 21d366fe..8093d017 100644 --- a/query_test.go +++ b/query_test.go @@ -48,6 +48,8 @@ func TestConnQueryScan(t *testing.T) { t.Fatalf("conn.Query failed: %v", err) } + assert.Equal(t, "SELECT 10", string(rows.CommandTag())) + if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") } @@ -56,6 +58,19 @@ func TestConnQueryScan(t *testing.T) { } } +func TestConnQueryWithoutResultSetCommandTag(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + rows, err := conn.Query(context.Background(), "create temporary table t (id serial);") + assert.NoError(t, err) + rows.Close() + assert.NoError(t, rows.Err()) + assert.Equal(t, "CREATE TABLE", string(rows.CommandTag())) +} + func TestConnQueryScanWithManyColumns(t *testing.T) { t.Parallel() diff --git a/rows.go b/rows.go index 5ae1427e..7b3a5f17 100644 --- a/rows.go +++ b/rows.go @@ -21,7 +21,12 @@ type Rows interface { // to call Close after rows is already closed. Close() + // Err returns any error that occurred while reading. Err() error + + // CommandTag returns the command tag from this query. It is only available after Rows is closed. + CommandTag() pgconn.CommandTag + FieldDescriptions() []pgproto3.FieldDescription // Next prepares the next row for reading. It returns true if there is another @@ -76,16 +81,17 @@ type rowLog interface { // connRows implements the Rows interface for Conn.Query. type connRows struct { - ctx context.Context - logger rowLog - connInfo *pgtype.ConnInfo - values [][]byte - rowCount int - err error - startTime time.Time - sql string - args []interface{} - closed bool + ctx context.Context + logger rowLog + connInfo *pgtype.ConnInfo + values [][]byte + rowCount int + err error + commandTag pgconn.CommandTag + startTime time.Time + sql string + args []interface{} + closed bool resultReader *pgconn.ResultReader multiResultReader *pgconn.MultiResultReader @@ -103,7 +109,8 @@ func (rows *connRows) Close() { rows.closed = true if rows.resultReader != nil { - _, closeErr := rows.resultReader.Close() + var closeErr error + rows.commandTag, closeErr = rows.resultReader.Close() if rows.err == nil { rows.err = closeErr } @@ -128,6 +135,10 @@ func (rows *connRows) Close() { } } +func (rows *connRows) CommandTag() pgconn.CommandTag { + return rows.commandTag +} + func (rows *connRows) Err() error { return rows.err }