From f119d5221cb600413eb1abe174c04238eeac8fd9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 19 May 2014 09:32:31 -0500 Subject: [PATCH] Add CommandTag --- conn.go | 18 ++++++++++++++++-- conn_pool.go | 2 +- conn_test.go | 22 ++++++++++++++++++++++ helper_test.go | 2 +- 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 6e18cdbe..72ea7791 100644 --- a/conn.go +++ b/conn.go @@ -77,6 +77,20 @@ type Notification struct { Payload string } +type CommandTag string + +// RowsAffected returns the number of rows affected. If the CommandTag was not +// for a row affecting command (such as "CREATE TABLE") then it returns 0 +func (ct CommandTag) RowsAffected() int64 { + words := strings.SplitN(string(ct), " ", 2) + if len(words) != 2 { + return 0 + } + + n, _ := strconv.ParseInt(words[1], 10, 64) + return n +} + // NotSingleRowError is returned when exactly 1 row is expected, but 0 or more than // 1 row is returned type NotSingleRowError struct { @@ -760,7 +774,7 @@ func (c *Conn) sendPreparedQuery(ps *preparedStatement, arguments ...interface{} // Execute executes sql. sql can be either a prepared statement name or an SQL string. // arguments will be sanitized before being interpolated into sql strings. arguments // should be referenced positionally from the sql string as $1, $2, etc. -func (c *Conn) Execute(sql string, arguments ...interface{}) (commandTag string, err error) { +func (c *Conn) Execute(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { defer func() { if err != nil { c.logger.Error(fmt.Sprintf("Execute `%s` with %v failed: %v", sql, arguments, err)) @@ -781,7 +795,7 @@ func (c *Conn) Execute(sql string, arguments ...interface{}) (commandTag string, case dataRow: case bindComplete: case commandComplete: - commandTag = r.ReadCString() + commandTag = CommandTag(r.ReadCString()) default: if e := c.processContextFreeMsg(t, r); e != nil && err == nil { err = e diff --git a/conn_pool.go b/conn_pool.go index 3d600c97..24b2b9af 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -223,7 +223,7 @@ func (p *ConnPool) SelectValues(sql string, arguments ...interface{}) (values [] } // Execute acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) Execute(sql string, arguments ...interface{}) (commandTag string, err error) { +func (p *ConnPool) Execute(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { var c *Conn if c, err = p.Acquire(); err != nil { return diff --git a/conn_test.go b/conn_test.go index d672843a..b55e51a5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -854,3 +854,25 @@ func TestFatalTxError(t *testing.T) { t.Fatal("Connection should not be live but was") } } + +func TestCommandTag(t *testing.T) { + var tests = []struct { + commandTag pgx.CommandTag + rowsAffected int64 + }{ + {commandTag: "UPDATE 0", rowsAffected: 0}, + {commandTag: "UPDATE 1", rowsAffected: 1}, + {commandTag: "DELETE 0", rowsAffected: 0}, + {commandTag: "DELETE 1", rowsAffected: 1}, + {commandTag: "CREATE TABLE", rowsAffected: 0}, + {commandTag: "ALTER TABLE", rowsAffected: 0}, + {commandTag: "DROP TABLE", rowsAffected: 0}, + } + + for i, tt := range tests { + actual := tt.commandTag.RowsAffected() + if tt.rowsAffected != actual { + t.Errorf(`%d. "%s" should have affected %d rows but it was %d`, i, tt.commandTag, tt.rowsAffected, actual) + } + } +} diff --git a/helper_test.go b/helper_test.go index e69f2105..1b43eaae 100644 --- a/helper_test.go +++ b/helper_test.go @@ -26,7 +26,7 @@ func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) { } } -func mustExecute(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag string) { +func mustExecute(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) { var err error if commandTag, err = conn.Execute(sql, arguments...); err != nil { t.Fatalf("Execute unexpectedly failed with %v: %v", sql, err)