pgx uses pgconn.CommandTag instead of own definition

pull/483/head
Jack Christensen 2019-01-01 16:55:48 -06:00
parent ddd37cf557
commit 7f9540438c
13 changed files with 106 additions and 118 deletions

View File

@ -3,6 +3,7 @@ package pgx
import ( import (
"context" "context"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@ -162,21 +163,21 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
// ExecResults reads the results from the next query in the batch as if the // ExecResults reads the results from the next query in the batch as if the
// query has been sent with Exec. // query has been sent with Exec.
func (b *Batch) ExecResults() (CommandTag, error) { func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
if b.err != nil { if b.err != nil {
return "", b.err return nil, b.err
} }
select { select {
case <-b.ctx.Done(): case <-b.ctx.Done():
b.die(b.ctx.Err()) b.die(b.ctx.Err())
return "", b.ctx.Err() return nil, b.ctx.Err()
default: default:
} }
if err := b.ensureCommandComplete(); err != nil { if err := b.ensureCommandComplete(); err != nil {
b.die(err) b.die(err)
return "", err return nil, err
} }
b.resultsRead++ b.resultsRead++
@ -186,16 +187,16 @@ func (b *Batch) ExecResults() (CommandTag, error) {
for { for {
msg, err := b.conn.rxMsg() msg, err := b.conn.rxMsg()
if err != nil { if err != nil {
return "", err return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
b.pendingCommandComplete = false b.pendingCommandComplete = false
return CommandTag(msg.CommandTag), nil return pgconn.CommandTag(msg.CommandTag), nil
default: default:
if err := b.conn.processContextFreeMsg(msg); err != nil { if err := b.conn.processContextFreeMsg(msg); err != nil {
return "", err return nil, err
} }
} }
} }

42
conn.go
View File

@ -4,7 +4,6 @@ import (
"context" "context"
"net" "net"
"reflect" "reflect"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -120,21 +119,6 @@ type Notification struct {
Payload string Payload string
} }
// CommandTag is the result of an Exec function
type CommandTag string
// RowsAffected returns the number of rows affected. If the CommandTag was not
// for a row affecting command (such as "CREATE TABLE") then it returns 0
func (ct CommandTag) RowsAffected() int64 {
s := string(ct)
index := strings.LastIndex(s, " ")
if index == -1 {
return 0
}
n, _ := strconv.ParseInt(s[index+1:], 10, 64)
return n
}
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of // Identifier a PostgreSQL identifier or name. Identifiers can be composed of
// multiple parts such as ["schema", "table"] or ["table", "column"]. // multiple parts such as ["schema", "table"] or ["table", "column"].
type Identifier []string type Identifier []string
@ -855,7 +839,7 @@ func fatalWriteErr(bytesWritten int, err error) bool {
// Exec executes sql. sql can be either a prepared statement name or an SQL string. // Exec executes sql. sql can be either a prepared statement name or an SQL string.
// arguments should be referenced positionally from the sql string as $1, $2, etc. // arguments should be referenced positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
return c.ExecEx(context.Background(), sql, nil, arguments...) return c.ExecEx(context.Background(), sql, nil, arguments...)
} }
@ -1104,15 +1088,15 @@ func (c *Conn) Ping(ctx context.Context) error {
return err return err
} }
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) { func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (pgconn.CommandTag, error) {
c.lastStmtSent = false c.lastStmtSent = false
err := c.waitForPreviousCancelQuery(ctx) err := c.waitForPreviousCancelQuery(ctx)
if err != nil { if err != nil {
return "", err return nil, err
} }
if err := c.lock(); err != nil { if err := c.lock(); err != nil {
return "", err return nil, err
} }
defer c.unlock() defer c.unlock()
@ -1134,10 +1118,10 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions,
return commandTag, err return commandTag, err
} }
func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
err = c.initContext(ctx) err = c.initContext(ctx)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer func() { defer func() {
err = c.termContext(err) err = c.termContext(err)
@ -1147,16 +1131,16 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
c.lastStmtSent = true c.lastStmtSent = true
err = c.sanitizeAndSendSimpleQuery(sql, arguments...) err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
if err != nil { if err != nil {
return "", err return nil, err
} }
} else if options != nil && len(options.ParameterOIDs) > 0 { } else if options != nil && len(options.ParameterOIDs) > 0 {
if err := c.ensureConnectionReadyForQuery(); err != nil { if err := c.ensureConnectionReadyForQuery(); err != nil {
return "", err return nil, err
} }
buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments) buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments)
if err != nil { if err != nil {
return "", err return nil, err
} }
buf = appendSync(buf) buf = appendSync(buf)
@ -1165,7 +1149,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
c.lastStmtSent = true c.lastStmtSent = true
if err != nil && fatalWriteErr(n, err) { if err != nil && fatalWriteErr(n, err) {
c.die(err) c.die(err)
return "", err return nil, err
} }
c.pendingReadyForQueryCount++ c.pendingReadyForQueryCount++
} else { } else {
@ -1175,14 +1159,14 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
var err error var err error
ps, err = c.prepareEx("", sql, nil) ps, err = c.prepareEx("", sql, nil)
if err != nil { if err != nil {
return "", err return nil, err
} }
} }
c.lastStmtSent = true c.lastStmtSent = true
err = c.sendPreparedQuery(ps, arguments...) err = c.sendPreparedQuery(ps, arguments...)
if err != nil { if err != nil {
return "", err return nil, err
} }
} else { } else {
c.lastStmtSent = true c.lastStmtSent = true
@ -1205,7 +1189,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
c.rxReadyForQuery(msg) c.rxReadyForQuery(msg)
return commandTag, softErr return commandTag, softErr
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag) commandTag = pgconn.CommandTag(msg.CommandTag)
default: default:
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
softErr = e softErr = e

View File

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@ -352,7 +353,7 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
} }
// Exec acquires a connection, delegates the call to that connection, and releases the connection // Exec acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
var c *Conn var c *Conn
if c, err = p.Acquire(); err != nil { if c, err = p.Acquire(); err != nil {
return return
@ -362,7 +363,7 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
return c.Exec(sql, arguments...) return c.Exec(sql, arguments...)
} }
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
var c *Conn var c *Conn
if c, err = p.Acquire(); err != nil { if c, err = p.Acquire(); err != nil {
return return
@ -554,10 +555,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
} }
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection // CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) { func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
c, err := p.Acquire() c, err := p.Acquire()
if err != nil { if err != nil {
return "", err return nil, err
} }
defer p.Release(c) defer p.Release(c)
@ -565,10 +566,10 @@ func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
} }
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection // CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) { func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (pgconn.CommandTag, error) {
c, err := p.Acquire() c, err := p.Acquire()
if err != nil { if err != nil {
return "", err return nil, err
} }
defer p.Release(c) defer p.Release(c)

View File

@ -845,7 +845,7 @@ func TestConnPoolExec(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Unexpected error from pool.Exec: %v", err) t.Fatalf("Unexpected error from pool.Exec: %v", err)
} }
if results != "CREATE TABLE" { if string(results) != "CREATE TABLE" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
@ -853,7 +853,7 @@ func TestConnPoolExec(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Unexpected error from pool.Exec: %v", err) t.Fatalf("Unexpected error from pool.Exec: %v", err)
} }
if results != "INSERT 0 1" { if string(results) != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
@ -861,7 +861,7 @@ func TestConnPoolExec(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Unexpected error from pool.Exec: %v", err) t.Fatalf("Unexpected error from pool.Exec: %v", err)
} }
if results != "DROP TABLE" { if string(results) != "DROP TABLE" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
} }

View File

@ -111,31 +111,31 @@ func TestExec(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results != "CREATE TABLE" { if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Accept parameters // Accept parameters
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results != "INSERT 0 1" { if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
if results := mustExec(t, conn, "drop table foo;"); results != "DROP TABLE" { if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Multiple statements can be executed -- last command tag is returned // Multiple statements can be executed -- last command tag is returned
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results != "DROP TABLE" { if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Can execute longer SQL strings than sharedBufferSize // Can execute longer SQL strings than sharedBufferSize
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results != "SELECT 1" { if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
// Exec no-op which does not return a command tag // Exec no-op which does not return a command tag
if results := mustExec(t, conn, "--;"); results != "" { if results := mustExec(t, conn, "--;"); string(results) != "" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
} }
@ -190,7 +190,7 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if commandTag != "CREATE TABLE" { if string(commandTag) != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
if !conn.LastStmtSent() { if !conn.LastStmtSent() {
@ -291,7 +291,7 @@ func TestExecExExtendedProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if commandTag != "CREATE TABLE" { if string(commandTag) != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
@ -304,7 +304,7 @@ func TestExecExExtendedProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if commandTag != "INSERT 0 1" { if string(commandTag) != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
@ -324,7 +324,7 @@ func TestExecExSimpleProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if commandTag != "CREATE TABLE" { if string(commandTag) != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
if !conn.LastStmtSent() { if !conn.LastStmtSent() {
@ -340,7 +340,7 @@ func TestExecExSimpleProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if commandTag != "INSERT 0 1" { if string(commandTag) != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
if !conn.LastStmtSent() { if !conn.LastStmtSent() {
@ -365,7 +365,7 @@ func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if commandTag != "INSERT 0 1" { if string(commandTag) != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
if !conn.LastStmtSent() { if !conn.LastStmtSent() {
@ -924,43 +924,18 @@ func TestFatalTxError(t *testing.T) {
} }
} }
func TestCommandTag(t *testing.T) {
t.Parallel()
var tests = []struct {
commandTag pgx.CommandTag
rowsAffected int64
}{
{commandTag: "INSERT 0 5", rowsAffected: 5},
{commandTag: "UPDATE 0", rowsAffected: 0},
{commandTag: "UPDATE 1", rowsAffected: 1},
{commandTag: "DELETE 0", rowsAffected: 0},
{commandTag: "DELETE 1", rowsAffected: 1},
{commandTag: "CREATE TABLE", rowsAffected: 0},
{commandTag: "ALTER TABLE", rowsAffected: 0},
{commandTag: "DROP TABLE", rowsAffected: 0},
}
for i, tt := range tests {
actual := tt.commandTag.RowsAffected()
if tt.rowsAffected != actual {
t.Errorf(`%d. "%s" should have affected %d rows but it was %d`, i, tt.commandTag, tt.rowsAffected, actual)
}
}
}
func TestInsertBoolArray(t *testing.T) { func TestInsertBoolArray(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results != "CREATE TABLE" { if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Accept parameters // Accept parameters
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results != "INSERT 0 1" { if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
} }
@ -971,12 +946,12 @@ func TestInsertTimestampArray(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results != "CREATE TABLE" { if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Accept parameters // Accept parameters
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results != "INSERT 0 1" { if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgproto3"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -284,13 +285,13 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
} }
// CopyFromReader uses the PostgreSQL textual format of the copy protocol // CopyFromReader uses the PostgreSQL textual format of the copy protocol
func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) { func (c *Conn) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
if err := c.sendSimpleQuery(sql); err != nil { if err := c.sendSimpleQuery(sql); err != nil {
return "", err return nil, err
} }
if err := c.readUntilCopyInResponse(); err != nil { if err := c.readUntilCopyInResponse(); err != nil {
return "", err return nil, err
} }
buf := c.wbuf buf := c.wbuf
@ -305,7 +306,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
pgio.SetInt32(buf[sp:], int32(n+4)) pgio.SetInt32(buf[sp:], int32(n+4))
if _, err := c.pgConn.Conn().Write(buf); err != nil { if _, err := c.pgConn.Conn().Write(buf); err != nil {
return "", err return nil, err
} }
} }
@ -314,25 +315,25 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
buf = pgio.AppendInt32(buf, 4) buf = pgio.AppendInt32(buf, 4)
if _, err := c.pgConn.Conn().Write(buf); err != nil { if _, err := c.pgConn.Conn().Write(buf); err != nil {
return "", err return nil, err
} }
for { for {
msg, err := c.rxMsg() msg, err := c.rxMsg()
if err != nil { if err != nil {
return "", err return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg) c.rxReadyForQuery(msg)
return "", err return nil, err
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil return pgconn.CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
return "", c.rxErrorResponse(msg) return nil, c.rxErrorResponse(msg)
default: default:
return "", c.processContextFreeMsg(msg) return nil, c.processContextFreeMsg(msg)
} }
} }
} }

View File

@ -3,6 +3,7 @@ package pgx
import ( import (
"io" "io"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgproto3"
) )
@ -25,19 +26,19 @@ func (c *Conn) readUntilCopyOutResponse() error {
} }
} }
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) { func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (pgconn.CommandTag, error) {
if err := c.sendSimpleQuery(sql, args...); err != nil { if err := c.sendSimpleQuery(sql, args...); err != nil {
return "", err return nil, err
} }
if err := c.readUntilCopyOutResponse(); err != nil { if err := c.readUntilCopyOutResponse(); err != nil {
return "", err return nil, err
} }
for { for {
msg, err := c.rxMsg() msg, err := c.rxMsg()
if err != nil { if err != nil {
return "", err return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -47,17 +48,17 @@ func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (Comma
_, err := w.Write(msg.Data) _, err := w.Write(msg.Data)
if err != nil { if err != nil {
c.die(err) c.die(err)
return "", err return nil, err
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg) c.rxReadyForQuery(msg)
return "", nil return nil, nil
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil return pgconn.CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
return "", c.rxErrorResponse(msg) return nil, c.rxErrorResponse(msg)
default: default:
return "", c.processContextFreeMsg(msg) return nil, c.processContextFreeMsg(msg)
} }
} }
} }

View File

@ -53,7 +53,7 @@ func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) {
} }
} }
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) { func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) {
var err error var err error
if commandTag, err = conn.Exec(sql, arguments...); err != nil { if commandTag, err = conn.Exec(sql, arguments...); err != nil {
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err) t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)

View File

@ -512,3 +512,26 @@ func TestConnCancelQuery(t *testing.T) {
t.Errorf("expected pgconn.PgError got %v", err) t.Errorf("expected pgconn.PgError got %v", err)
} }
} }
func TestCommandTag(t *testing.T) {
t.Parallel()
var tests = []struct {
commandTag pgconn.CommandTag
rowsAffected int64
}{
{commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5},
{commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0},
{commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1},
{commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0},
{commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1},
{commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0},
{commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0},
{commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0},
}
for i, tt := range tests {
actual := tt.commandTag.RowsAffected()
assert.Equalf(t, tt.rowsAffected, actual, "%d. %v", i, tt.commandTag)
}
}

View File

@ -1087,7 +1087,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if commandTag != "INSERT 0 1" { if string(commandTag) != "INSERT 0 1" {
t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag) t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag)
} }

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"github.com/jackc/pgx/pgconn"
) )
// This function uses a postgresql 9.6 specific column // This function uses a postgresql 9.6 specific column
@ -87,7 +88,7 @@ func TestSimpleReplicationConnection(t *testing.T) {
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
var ct pgx.CommandTag var ct pgconn.CommandTag
insertedTimes = append(insertedTimes, currentTime) insertedTimes = append(insertedTimes, currentTime)
ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime) ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime)
if err != nil { if err != nil {

View File

@ -17,7 +17,7 @@ import (
) )
type execer interface { type execer interface {
Exec(sql string, arguments ...interface{}) (commandTag pgx.CommandTag, err error) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
} }
type queryer interface { type queryer interface {
Query(sql string, args ...interface{}) (*pgx.Rows, error) Query(sql string, args ...interface{}) (*pgx.Rows, error)

19
tx.go
View File

@ -7,6 +7,7 @@ import (
"io" "io"
"time" "time"
"github.com/jackc/pgx/pgconn"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -123,9 +124,9 @@ func (tx *Tx) CommitEx(ctx context.Context) error {
} }
commandTag, err := tx.conn.ExecEx(ctx, "commit", nil) commandTag, err := tx.conn.ExecEx(ctx, "commit", nil)
if err == nil && commandTag == "COMMIT" { if err == nil && string(commandTag) == "COMMIT" {
tx.status = TxStatusCommitSuccess tx.status = TxStatusCommitSuccess
} else if err == nil && commandTag == "ROLLBACK" { } else if err == nil && string(commandTag) == "ROLLBACK" {
tx.status = TxStatusCommitFailure tx.status = TxStatusCommitFailure
tx.err = ErrTxCommitRollback tx.err = ErrTxCommitRollback
} else { } else {
@ -175,14 +176,14 @@ func (tx *Tx) RollbackEx(ctx context.Context) error {
} }
// Exec delegates to the underlying *Conn // Exec delegates to the underlying *Conn
func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
return tx.ExecEx(context.Background(), sql, nil, arguments...) return tx.ExecEx(context.Background(), sql, nil, arguments...)
} }
// ExecEx delegates to the underlying *Conn // ExecEx delegates to the underlying *Conn
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
if tx.status != TxStatusInProgress { if tx.status != TxStatusInProgress {
return CommandTag(""), ErrTxClosed return nil, ErrTxClosed
} }
return tx.conn.ExecEx(ctx, sql, options, arguments...) return tx.conn.ExecEx(ctx, sql, options, arguments...)
@ -240,18 +241,18 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
} }
// CopyFromReader delegates to the underlying *Conn // CopyFromReader delegates to the underlying *Conn
func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag CommandTag, err error) { func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag pgconn.CommandTag, err error) {
if tx.status != TxStatusInProgress { if tx.status != TxStatusInProgress {
return CommandTag(""), ErrTxClosed return nil, ErrTxClosed
} }
return tx.conn.CopyFromReader(r, sql) return tx.conn.CopyFromReader(r, sql)
} }
// CopyToWriter delegates to the underlying *Conn // CopyToWriter delegates to the underlying *Conn
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag CommandTag, err error) { func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag pgconn.CommandTag, err error) {
if tx.status != TxStatusInProgress { if tx.status != TxStatusInProgress {
return CommandTag(""), ErrTxClosed return nil, ErrTxClosed
} }
return tx.conn.CopyToWriter(w, sql, args...) return tx.conn.CopyToWriter(w, sql, args...)