mirror of https://github.com/jackc/pgx.git
pgx uses pgconn.CommandTag instead of own definition
parent
ddd37cf557
commit
7f9540438c
15
batch.go
15
batch.go
|
@ -3,6 +3,7 @@ package pgx
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
@ -162,21 +163,21 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
|||
|
||||
// ExecResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Exec.
|
||||
func (b *Batch) ExecResults() (CommandTag, error) {
|
||||
func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
|
||||
if b.err != nil {
|
||||
return "", b.err
|
||||
return nil, b.err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.die(b.ctx.Err())
|
||||
return "", b.ctx.Err()
|
||||
return nil, b.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
@ -186,16 +187,16 @@ func (b *Batch) ExecResults() (CommandTag, error) {
|
|||
for {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
b.pendingCommandComplete = false
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
return pgconn.CommandTag(msg.CommandTag), nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
42
conn.go
42
conn.go
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -120,21 +119,6 @@ type Notification struct {
|
|||
Payload string
|
||||
}
|
||||
|
||||
// CommandTag is the result of an Exec function
|
||||
type CommandTag string
|
||||
|
||||
// RowsAffected returns the number of rows affected. If the CommandTag was not
|
||||
// for a row affecting command (such as "CREATE TABLE") then it returns 0
|
||||
func (ct CommandTag) RowsAffected() int64 {
|
||||
s := string(ct)
|
||||
index := strings.LastIndex(s, " ")
|
||||
if index == -1 {
|
||||
return 0
|
||||
}
|
||||
n, _ := strconv.ParseInt(s[index+1:], 10, 64)
|
||||
return n
|
||||
}
|
||||
|
||||
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
|
||||
// multiple parts such as ["schema", "table"] or ["table", "column"].
|
||||
type Identifier []string
|
||||
|
@ -855,7 +839,7 @@ func fatalWriteErr(bytesWritten int, err error) bool {
|
|||
|
||||
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
|
||||
// arguments should be referenced positionally from the sql string as $1, $2, etc.
|
||||
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
return c.ExecEx(context.Background(), sql, nil, arguments...)
|
||||
}
|
||||
|
||||
|
@ -1104,15 +1088,15 @@ func (c *Conn) Ping(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) {
|
||||
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (pgconn.CommandTag, error) {
|
||||
c.lastStmtSent = false
|
||||
err := c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.lock(); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer c.unlock()
|
||||
|
||||
|
@ -1134,10 +1118,10 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
return commandTag, err
|
||||
}
|
||||
|
||||
func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
|
@ -1147,16 +1131,16 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
c.lastStmtSent = true
|
||||
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
} else if options != nil && len(options.ParameterOIDs) > 0 {
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = appendSync(buf)
|
||||
|
@ -1165,7 +1149,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
c.lastStmtSent = true
|
||||
if err != nil && fatalWriteErr(n, err) {
|
||||
c.die(err)
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
c.pendingReadyForQueryCount++
|
||||
} else {
|
||||
|
@ -1175,14 +1159,14 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
var err error
|
||||
ps, err = c.prepareEx("", sql, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
c.lastStmtSent = true
|
||||
err = c.sendPreparedQuery(ps, arguments...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
c.lastStmtSent = true
|
||||
|
@ -1205,7 +1189,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
c.rxReadyForQuery(msg)
|
||||
return commandTag, softErr
|
||||
case *pgproto3.CommandComplete:
|
||||
commandTag = CommandTag(msg.CommandTag)
|
||||
commandTag = pgconn.CommandTag(msg.CommandTag)
|
||||
default:
|
||||
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
|
||||
softErr = e
|
||||
|
|
13
conn_pool.go
13
conn_pool.go
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
@ -352,7 +353,7 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
|||
}
|
||||
|
||||
// Exec acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
var c *Conn
|
||||
if c, err = p.Acquire(); err != nil {
|
||||
return
|
||||
|
@ -362,7 +363,7 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
|
|||
return c.Exec(sql, arguments...)
|
||||
}
|
||||
|
||||
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
var c *Conn
|
||||
if c, err = p.Acquire(); err != nil {
|
||||
return
|
||||
|
@ -554,10 +555,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
|
|||
}
|
||||
|
||||
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
|
||||
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
|
@ -565,10 +566,10 @@ func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
|
|||
}
|
||||
|
||||
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
|
||||
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (pgconn.CommandTag, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
|
|
|
@ -845,7 +845,7 @@ func TestConnPoolExec(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
||||
}
|
||||
if results != "CREATE TABLE" {
|
||||
if string(results) != "CREATE TABLE" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
|
||||
|
@ -853,7 +853,7 @@ func TestConnPoolExec(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
||||
}
|
||||
if results != "INSERT 0 1" {
|
||||
if string(results) != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
|
||||
|
@ -861,7 +861,7 @@ func TestConnPoolExec(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Unexpected error from pool.Exec: %v", err)
|
||||
}
|
||||
if results != "DROP TABLE" {
|
||||
if string(results) != "DROP TABLE" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
}
|
||||
|
|
57
conn_test.go
57
conn_test.go
|
@ -111,31 +111,31 @@ func TestExec(t *testing.T) {
|
|||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results != "CREATE TABLE" {
|
||||
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Accept parameters
|
||||
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results != "INSERT 0 1" {
|
||||
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
|
||||
if results := mustExec(t, conn, "drop table foo;"); results != "DROP TABLE" {
|
||||
if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Multiple statements can be executed -- last command tag is returned
|
||||
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results != "DROP TABLE" {
|
||||
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Can execute longer SQL strings than sharedBufferSize
|
||||
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results != "SELECT 1" {
|
||||
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
|
||||
// Exec no-op which does not return a command tag
|
||||
if results := mustExec(t, conn, "--;"); results != "" {
|
||||
if results := mustExec(t, conn, "--;"); string(results) != "" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
}
|
||||
|
@ -190,7 +190,7 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
if string(commandTag) != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
|
@ -291,7 +291,7 @@ func TestExecExExtendedProtocol(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
if string(commandTag) != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
|
||||
|
@ -304,7 +304,7 @@ func TestExecExExtendedProtocol(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
if string(commandTag) != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
|
||||
|
@ -324,7 +324,7 @@ func TestExecExSimpleProtocol(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
if string(commandTag) != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
|
@ -340,7 +340,7 @@ func TestExecExSimpleProtocol(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
if string(commandTag) != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
|
@ -365,7 +365,7 @@ func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
if string(commandTag) != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
|
@ -924,43 +924,18 @@ func TestFatalTxError(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCommandTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var tests = []struct {
|
||||
commandTag pgx.CommandTag
|
||||
rowsAffected int64
|
||||
}{
|
||||
{commandTag: "INSERT 0 5", rowsAffected: 5},
|
||||
{commandTag: "UPDATE 0", rowsAffected: 0},
|
||||
{commandTag: "UPDATE 1", rowsAffected: 1},
|
||||
{commandTag: "DELETE 0", rowsAffected: 0},
|
||||
{commandTag: "DELETE 1", rowsAffected: 1},
|
||||
{commandTag: "CREATE TABLE", rowsAffected: 0},
|
||||
{commandTag: "ALTER TABLE", rowsAffected: 0},
|
||||
{commandTag: "DROP TABLE", rowsAffected: 0},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual := tt.commandTag.RowsAffected()
|
||||
if tt.rowsAffected != actual {
|
||||
t.Errorf(`%d. "%s" should have affected %d rows but it was %d`, i, tt.commandTag, tt.rowsAffected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertBoolArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results != "CREATE TABLE" {
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Accept parameters
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results != "INSERT 0 1" {
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
}
|
||||
|
@ -971,12 +946,12 @@ func TestInsertTimestampArray(t *testing.T) {
|
|||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results != "CREATE TABLE" {
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Accept parameters
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results != "INSERT 0 1" {
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
}
|
||||
|
|
21
copy_from.go
21
copy_from.go
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/pkg/errors"
|
||||
|
@ -284,13 +285,13 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
|
|||
}
|
||||
|
||||
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
|
||||
func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
|
||||
func (c *Conn) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
|
||||
if err := c.sendSimpleQuery(sql); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.readUntilCopyInResponse(); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
buf := c.wbuf
|
||||
|
||||
|
@ -305,7 +306,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
|
|||
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||
|
||||
if _, err := c.pgConn.Conn().Write(buf); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -314,25 +315,25 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
|
|||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
if _, err := c.pgConn.Conn().Write(buf); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return "", err
|
||||
return nil, err
|
||||
case *pgproto3.CommandComplete:
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
return pgconn.CommandTag(msg.CommandTag), nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return "", c.rxErrorResponse(msg)
|
||||
return nil, c.rxErrorResponse(msg)
|
||||
default:
|
||||
return "", c.processContextFreeMsg(msg)
|
||||
return nil, c.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
19
copy_to.go
19
copy_to.go
|
@ -3,6 +3,7 @@ package pgx
|
|||
import (
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
|
@ -25,19 +26,19 @@ func (c *Conn) readUntilCopyOutResponse() error {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
|
||||
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (pgconn.CommandTag, error) {
|
||||
if err := c.sendSimpleQuery(sql, args...); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.readUntilCopyOutResponse(); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -47,17 +48,17 @@ func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (Comma
|
|||
_, err := w.Write(msg.Data)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return "", nil
|
||||
return nil, nil
|
||||
case *pgproto3.CommandComplete:
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
return pgconn.CommandTag(msg.CommandTag), nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return "", c.rxErrorResponse(msg)
|
||||
return nil, c.rxErrorResponse(msg)
|
||||
default:
|
||||
return "", c.processContextFreeMsg(msg)
|
||||
return nil, c.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) {
|
|||
}
|
||||
}
|
||||
|
||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) {
|
||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) {
|
||||
var err error
|
||||
if commandTag, err = conn.Exec(sql, arguments...); err != nil {
|
||||
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
|
||||
|
|
|
@ -512,3 +512,26 @@ func TestConnCancelQuery(t *testing.T) {
|
|||
t.Errorf("expected pgconn.PgError got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var tests = []struct {
|
||||
commandTag pgconn.CommandTag
|
||||
rowsAffected int64
|
||||
}{
|
||||
{commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5},
|
||||
{commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0},
|
||||
{commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1},
|
||||
{commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0},
|
||||
{commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1},
|
||||
{commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0},
|
||||
{commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0},
|
||||
{commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual := tt.commandTag.RowsAffected()
|
||||
assert.Equalf(t, tt.rowsAffected, actual, "%d. %v", i, tt.commandTag)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1087,7 +1087,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
if string(commandTag) != "INSERT 0 1" {
|
||||
t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag)
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
)
|
||||
|
||||
// This function uses a postgresql 9.6 specific column
|
||||
|
@ -87,7 +88,7 @@ func TestSimpleReplicationConnection(t *testing.T) {
|
|||
currentTime := time.Now().Unix()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
var ct pgx.CommandTag
|
||||
var ct pgconn.CommandTag
|
||||
insertedTimes = append(insertedTimes, currentTime)
|
||||
ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime)
|
||||
if err != nil {
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
)
|
||||
|
||||
type execer interface {
|
||||
Exec(sql string, arguments ...interface{}) (commandTag pgx.CommandTag, err error)
|
||||
Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
|
||||
}
|
||||
type queryer interface {
|
||||
Query(sql string, args ...interface{}) (*pgx.Rows, error)
|
||||
|
|
19
tx.go
19
tx.go
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
@ -123,9 +124,9 @@ func (tx *Tx) CommitEx(ctx context.Context) error {
|
|||
}
|
||||
|
||||
commandTag, err := tx.conn.ExecEx(ctx, "commit", nil)
|
||||
if err == nil && commandTag == "COMMIT" {
|
||||
if err == nil && string(commandTag) == "COMMIT" {
|
||||
tx.status = TxStatusCommitSuccess
|
||||
} else if err == nil && commandTag == "ROLLBACK" {
|
||||
} else if err == nil && string(commandTag) == "ROLLBACK" {
|
||||
tx.status = TxStatusCommitFailure
|
||||
tx.err = ErrTxCommitRollback
|
||||
} else {
|
||||
|
@ -175,14 +176,14 @@ func (tx *Tx) RollbackEx(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Exec delegates to the underlying *Conn
|
||||
func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
return tx.ExecEx(context.Background(), sql, nil, arguments...)
|
||||
}
|
||||
|
||||
// ExecEx delegates to the underlying *Conn
|
||||
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return CommandTag(""), ErrTxClosed
|
||||
return nil, ErrTxClosed
|
||||
}
|
||||
|
||||
return tx.conn.ExecEx(ctx, sql, options, arguments...)
|
||||
|
@ -240,18 +241,18 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
|
|||
}
|
||||
|
||||
// CopyFromReader delegates to the underlying *Conn
|
||||
func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag CommandTag, err error) {
|
||||
func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag pgconn.CommandTag, err error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return CommandTag(""), ErrTxClosed
|
||||
return nil, ErrTxClosed
|
||||
}
|
||||
|
||||
return tx.conn.CopyFromReader(r, sql)
|
||||
}
|
||||
|
||||
// CopyToWriter delegates to the underlying *Conn
|
||||
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag CommandTag, err error) {
|
||||
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return CommandTag(""), ErrTxClosed
|
||||
return nil, ErrTxClosed
|
||||
}
|
||||
|
||||
return tx.conn.CopyToWriter(w, sql, args...)
|
||||
|
|
Loading…
Reference in New Issue