pgconn.CommandTag is now an opaque type

It now makes a copy instead of retaining driver memory. This is in
preparation to reuse the driver read buffer.
query-exec-mode
Jack Christensen 2022-02-12 10:26:26 -06:00
parent e6680127e3
commit 9c5dfbdfb3
14 changed files with 246 additions and 225 deletions

View File

@ -64,10 +64,10 @@ type batchResults struct {
// Exec reads the results from the next query in the batch as if the query has been sent with Exec. // Exec reads the results from the next query in the batch as if the query has been sent with Exec.
func (br *batchResults) Exec() (pgconn.CommandTag, error) { func (br *batchResults) Exec() (pgconn.CommandTag, error) {
if br.err != nil { if br.err != nil {
return nil, br.err return pgconn.CommandTag{}, br.err
} }
if br.closed { if br.closed {
return nil, fmt.Errorf("batch already closed") return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
} }
query, arguments, _ := br.nextQueryAndArgs() query, arguments, _ := br.nextQueryAndArgs()
@ -84,7 +84,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
"err": err, "err": err,
}) })
} }
return nil, err return pgconn.CommandTag{}, err
} }
commandTag, err := br.mrr.ResultReader().Close() commandTag, err := br.mrr.ResultReader().Close()
@ -151,29 +151,29 @@ func (br *batchResults) Query() (Rows, error) {
// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc.
func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
if br.closed { if br.closed {
return nil, fmt.Errorf("batch already closed") return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
} }
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
err = rows.Scan(scans...) err = rows.Scan(scans...)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
err = f(rows) err = f(rows)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
return rows.CommandTag(), nil return rows.CommandTag(), nil

18
conn.go
View File

@ -432,7 +432,7 @@ optionLoop:
if c.stmtcache != nil { if c.stmtcache != nil {
sd, err := c.stmtcache.Get(ctx, sql) sd, err := c.stmtcache.Get(ctx, sql)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
if c.stmtcache.Mode() == stmtcache.ModeDescribe { if c.stmtcache.Mode() == stmtcache.ModeDescribe {
@ -443,7 +443,7 @@ optionLoop:
sd, err := c.Prepare(ctx, "", sql) sd, err := c.Prepare(ctx, "", sql)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
return c.execPrepared(ctx, sd, arguments) return c.execPrepared(ctx, sd, arguments)
} }
@ -452,7 +452,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i
if len(arguments) > 0 { if len(arguments) > 0 {
sql, err = c.sanitizeForSimpleQuery(sql, arguments...) sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
} }
@ -493,7 +493,7 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, argu
func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
err := c.execParamsAndPreparedPrefix(sd, arguments) err := c.execParamsAndPreparedPrefix(sd, arguments)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
@ -504,7 +504,7 @@ func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription,
func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
err := c.execParamsAndPreparedPrefix(sd, arguments) err := c.execParamsAndPreparedPrefix(sd, arguments)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read()
@ -688,24 +688,24 @@ type QueryFuncRow interface {
func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
rows, err := c.Query(ctx, sql, args...) rows, err := c.Query(ctx, sql, args...)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
err = rows.Scan(scans...) err = rows.Scan(scans...)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
err = f(rows) err = f(rows)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
return rows.CommandTag(), nil return rows.CommandTag(), nil

View File

@ -188,31 +188,31 @@ func TestExec(t *testing.T) {
t.Parallel() t.Parallel()
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" { if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Accept parameters // Accept parameters
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" { if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results.String() != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" { if results := mustExec(t, conn, "drop table foo;"); results.String() != "DROP TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Multiple statements can be executed -- last command tag is returned // Multiple statements can be executed -- last command tag is returned
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" { if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results.String() != "DROP TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Can execute longer SQL strings than sharedBufferSize // Can execute longer SQL strings than sharedBufferSize
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" { if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results.String() != "SELECT 1" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
// Exec no-op which does not return a command tag // Exec no-op which does not return a command tag
if results := mustExec(t, conn, "--;"); string(results) != "" { if results := mustExec(t, conn, "--;"); results.String() != "" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
}) })
@ -260,7 +260,7 @@ func TestExecContextWithoutCancelation(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(commandTag) != "CREATE TABLE" { if commandTag.String() != "CREATE TABLE" {
t.Fatalf("Unexpected results from Exec: %v", commandTag) t.Fatalf("Unexpected results from Exec: %v", commandTag)
} }
assert.False(t, pgconn.SafeToRetry(err)) assert.False(t, pgconn.SafeToRetry(err))
@ -350,15 +350,15 @@ func TestExecStatementCacheModes(t *testing.T) {
commandTag, err := conn.Exec(context.Background(), "select 1") commandTag, err := conn.Exec(context.Background(), "select 1")
assert.NoError(t, err, tt.name) assert.NoError(t, err, tt.name)
assert.Equal(t, "SELECT 1", string(commandTag), tt.name) assert.Equal(t, "SELECT 1", commandTag.String(), tt.name)
commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1") commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1")
assert.NoError(t, err, tt.name) assert.NoError(t, err, tt.name)
assert.Equal(t, "SELECT 2", string(commandTag), tt.name) assert.Equal(t, "SELECT 2", commandTag.String(), tt.name)
commandTag, err = conn.Exec(context.Background(), "select 1") commandTag, err = conn.Exec(context.Background(), "select 1")
assert.NoError(t, err, tt.name) assert.NoError(t, err, tt.name)
assert.Equal(t, "SELECT 1", string(commandTag), tt.name) assert.Equal(t, "SELECT 1", commandTag.String(), tt.name)
ensureConnValid(t, conn) ensureConnValid(t, conn)
}() }()
@ -378,7 +378,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(commandTag) != "CREATE TABLE" { if commandTag.String() != "CREATE TABLE" {
t.Fatalf("Unexpected results from Exec: %v", commandTag) t.Fatalf("Unexpected results from Exec: %v", commandTag)
} }
@ -390,7 +390,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(commandTag) != "INSERT 0 1" { if commandTag.String() != "INSERT 0 1" {
t.Fatalf("Unexpected results from Exec: %v", commandTag) t.Fatalf("Unexpected results from Exec: %v", commandTag)
} }
@ -720,12 +720,12 @@ func TestInsertBoolArray(t *testing.T) {
t.Parallel() t.Parallel()
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" { if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Accept parameters // Accept parameters
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" { if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results.String() != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
}) })
@ -735,12 +735,12 @@ func TestInsertTimestampArray(t *testing.T) {
t.Parallel() t.Parallel()
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" { if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
// Accept parameters // Accept parameters
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" { if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results.String() != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results) t.Errorf("Unexpected results from Exec: %v", results)
} }
}) })

View File

@ -0,0 +1,73 @@
package pgconn
import (
"strings"
"testing"
)
func BenchmarkCommandTagRowsAffected(b *testing.B) {
benchmarks := []struct {
commandTag string
rowsAffected int64
}{
{"UPDATE 1", 1},
{"UPDATE 123456789", 123456789},
{"INSERT 0 1", 1},
{"INSERT 0 123456789", 123456789},
}
for _, bm := range benchmarks {
ct := CommandTag{buf: []byte(bm.commandTag)}
b.Run(bm.commandTag, func(b *testing.B) {
var n int64
for i := 0; i < b.N; i++ {
n = ct.RowsAffected()
}
if n != bm.rowsAffected {
b.Errorf("expected %d got %d", bm.rowsAffected, n)
}
})
}
}
func BenchmarkCommandTagTypeFromString(b *testing.B) {
ct := CommandTag{buf: []byte("UPDATE 1")}
var update bool
for i := 0; i < b.N; i++ {
update = strings.HasPrefix(ct.String(), "UPDATE")
}
if !update {
b.Error("expected update")
}
}
func BenchmarkCommandTagInsert(b *testing.B) {
benchmarks := []struct {
commandTag string
is bool
}{
{"INSERT 1", true},
{"INSERT 1234567890", true},
{"UPDATE 1", false},
{"UPDATE 1234567890", false},
{"DELETE 1", false},
{"DELETE 1234567890", false},
{"SELECT 1", false},
{"SELECT 1234567890", false},
{"UNKNOWN 1234567890", false},
}
for _, bm := range benchmarks {
ct := CommandTag{buf: []byte(bm.commandTag)}
b.Run(bm.commandTag, func(b *testing.B) {
var is bool
for i := 0; i < b.N; i++ {
is = ct.Insert()
}
if is != bm.is {
b.Errorf("expected %v got %v", bm.is, is)
}
})
}
}

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"os" "os"
"strings"
"testing" "testing"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
@ -253,70 +252,3 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
// conn.ChanToSetDeadline().Ignore() // conn.ChanToSetDeadline().Ignore()
// } // }
// } // }
func BenchmarkCommandTagRowsAffected(b *testing.B) {
benchmarks := []struct {
commandTag string
rowsAffected int64
}{
{"UPDATE 1", 1},
{"UPDATE 123456789", 123456789},
{"INSERT 0 1", 1},
{"INSERT 0 123456789", 123456789},
}
for _, bm := range benchmarks {
ct := pgconn.CommandTag(bm.commandTag)
b.Run(bm.commandTag, func(b *testing.B) {
var n int64
for i := 0; i < b.N; i++ {
n = ct.RowsAffected()
}
if n != bm.rowsAffected {
b.Errorf("expected %d got %d", bm.rowsAffected, n)
}
})
}
}
func BenchmarkCommandTagTypeFromString(b *testing.B) {
ct := pgconn.CommandTag("UPDATE 1")
var update bool
for i := 0; i < b.N; i++ {
update = strings.HasPrefix(ct.String(), "UPDATE")
}
if !update {
b.Error("expected update")
}
}
func BenchmarkCommandTagInsert(b *testing.B) {
benchmarks := []struct {
commandTag string
is bool
}{
{"INSERT 1", true},
{"INSERT 1234567890", true},
{"UPDATE 1", false},
{"UPDATE 1234567890", false},
{"DELETE 1", false},
{"DELETE 1234567890", false},
{"SELECT 1", false},
{"SELECT 1234567890", false},
{"UNKNOWN 1234567890", false},
}
for _, bm := range benchmarks {
ct := pgconn.CommandTag(bm.commandTag)
b.Run(bm.commandTag, func(b *testing.B) {
var is bool
for i := 0; i < b.N; i++ {
is = ct.Insert()
}
if is != bm.is {
b.Errorf("expected %v got %v", bm.is, is)
}
})
}
}

View File

@ -685,15 +685,17 @@ func (pgConn *PgConn) ParameterStatus(key string) string {
} }
// CommandTag is the result of an Exec function // CommandTag is the result of an Exec function
type CommandTag []byte type CommandTag struct {
buf []byte
}
// RowsAffected returns the number of rows affected. If the CommandTag was not // RowsAffected returns the number of rows affected. If the CommandTag was not
// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. // for a row affecting command (e.g. "CREATE TABLE") then it returns 0.
func (ct CommandTag) RowsAffected() int64 { func (ct CommandTag) RowsAffected() int64 {
// Find last non-digit // Find last non-digit
idx := -1 idx := -1
for i := len(ct) - 1; i >= 0; i-- { for i := len(ct.buf) - 1; i >= 0; i-- {
if ct[i] >= '0' && ct[i] <= '9' { if ct.buf[i] >= '0' && ct.buf[i] <= '9' {
idx = i idx = i
} else { } else {
break break
@ -705,7 +707,7 @@ func (ct CommandTag) RowsAffected() int64 {
} }
var n int64 var n int64
for _, b := range ct[idx:] { for _, b := range ct.buf[idx:] {
n = n*10 + int64(b-'0') n = n*10 + int64(b-'0')
} }
@ -713,51 +715,51 @@ func (ct CommandTag) RowsAffected() int64 {
} }
func (ct CommandTag) String() string { func (ct CommandTag) String() string {
return string(ct) return string(ct.buf)
} }
// Insert is true if the command tag starts with "INSERT". // Insert is true if the command tag starts with "INSERT".
func (ct CommandTag) Insert() bool { func (ct CommandTag) Insert() bool {
return len(ct) >= 6 && return len(ct.buf) >= 6 &&
ct[0] == 'I' && ct.buf[0] == 'I' &&
ct[1] == 'N' && ct.buf[1] == 'N' &&
ct[2] == 'S' && ct.buf[2] == 'S' &&
ct[3] == 'E' && ct.buf[3] == 'E' &&
ct[4] == 'R' && ct.buf[4] == 'R' &&
ct[5] == 'T' ct.buf[5] == 'T'
} }
// Update is true if the command tag starts with "UPDATE". // Update is true if the command tag starts with "UPDATE".
func (ct CommandTag) Update() bool { func (ct CommandTag) Update() bool {
return len(ct) >= 6 && return len(ct.buf) >= 6 &&
ct[0] == 'U' && ct.buf[0] == 'U' &&
ct[1] == 'P' && ct.buf[1] == 'P' &&
ct[2] == 'D' && ct.buf[2] == 'D' &&
ct[3] == 'A' && ct.buf[3] == 'A' &&
ct[4] == 'T' && ct.buf[4] == 'T' &&
ct[5] == 'E' ct.buf[5] == 'E'
} }
// Delete is true if the command tag starts with "DELETE". // Delete is true if the command tag starts with "DELETE".
func (ct CommandTag) Delete() bool { func (ct CommandTag) Delete() bool {
return len(ct) >= 6 && return len(ct.buf) >= 6 &&
ct[0] == 'D' && ct.buf[0] == 'D' &&
ct[1] == 'E' && ct.buf[1] == 'E' &&
ct[2] == 'L' && ct.buf[2] == 'L' &&
ct[3] == 'E' && ct.buf[3] == 'E' &&
ct[4] == 'T' && ct.buf[4] == 'T' &&
ct[5] == 'E' ct.buf[5] == 'E'
} }
// Select is true if the command tag starts with "SELECT". // Select is true if the command tag starts with "SELECT".
func (ct CommandTag) Select() bool { func (ct CommandTag) Select() bool {
return len(ct) >= 6 && return len(ct.buf) >= 6 &&
ct[0] == 'S' && ct.buf[0] == 'S' &&
ct[1] == 'E' && ct.buf[1] == 'E' &&
ct[2] == 'L' && ct.buf[2] == 'L' &&
ct[3] == 'E' && ct.buf[3] == 'E' &&
ct[4] == 'C' && ct.buf[4] == 'C' &&
ct[5] == 'T' ct.buf[5] == 'T'
} }
type StatementDescription struct { type StatementDescription struct {
@ -1076,13 +1078,13 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
result := &pgConn.resultReader result := &pgConn.resultReader
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
result.concludeCommand(nil, err) result.concludeCommand(CommandTag{}, err)
result.closed = true result.closed = true
return result return result
} }
if len(paramValues) > math.MaxUint16 { if len(paramValues) > math.MaxUint16 {
result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.concludeCommand(CommandTag{}, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
return result return result
@ -1091,7 +1093,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
if ctx != context.Background() { if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) result.concludeCommand(CommandTag{}, newContextAlreadyDoneError(ctx))
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
return result return result
@ -1111,7 +1113,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) result.concludeCommand(CommandTag{}, &writeError{err: err, safeToRetry: n == 0})
pgConn.contextWatcher.Unwatch() pgConn.contextWatcher.Unwatch()
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
@ -1124,14 +1126,14 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
// CopyTo executes the copy command sql and copies the results to w. // CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return nil, err return CommandTag{}, err
} }
if ctx != context.Background() { if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
pgConn.unlock() pgConn.unlock()
return nil, newContextAlreadyDoneError(ctx) return CommandTag{}, newContextAlreadyDoneError(ctx)
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -1146,7 +1148,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
pgConn.unlock() pgConn.unlock()
return nil, &writeError{err: err, safeToRetry: n == 0} return CommandTag{}, &writeError{err: err, safeToRetry: n == 0}
} }
// Read results // Read results
@ -1156,7 +1158,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -1165,13 +1167,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
_, err := w.Write(msg.Data) _, err := w.Write(msg.Data)
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return nil, err return CommandTag{}, err
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.unlock() pgConn.unlock()
return commandTag, pgErr return commandTag, pgErr
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag) commandTag = pgConn.makeCommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
} }
@ -1184,14 +1186,14 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// could still block. // could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return nil, err return CommandTag{}, err
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != context.Background() { if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, newContextAlreadyDoneError(ctx) return CommandTag{}, newContextAlreadyDoneError(ctx)
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -1205,7 +1207,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return nil, &writeError{err: err, safeToRetry: n == 0} return CommandTag{}, &writeError{err: err, safeToRetry: n == 0}
} }
// Send copy data // Send copy data
@ -1255,7 +1257,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -1279,7 +1281,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return nil, err return CommandTag{}, err
} }
// Read results // Read results
@ -1288,14 +1290,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
return commandTag, pgErr return commandTag, pgErr
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag) commandTag = pgConn.makeCommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
} }
@ -1368,7 +1370,7 @@ func (mrr *MultiResultReader) NextResult() bool {
return true return true
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
mrr.pgConn.resultReader = ResultReader{ mrr.pgConn.resultReader = ResultReader{
commandTag: CommandTag(msg.CommandTag), commandTag: mrr.pgConn.makeCommandTag(msg.CommandTag),
commandConcluded: true, commandConcluded: true,
closed: true, closed: true,
} }
@ -1483,7 +1485,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
for !rr.commandConcluded { for !rr.commandConcluded {
_, err := rr.receiveMessage() _, err := rr.receiveMessage()
if err != nil { if err != nil {
return nil, rr.err return CommandTag{}, rr.err
} }
} }
@ -1491,7 +1493,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
for { for {
msg, err := rr.receiveMessage() msg, err := rr.receiveMessage()
if err != nil { if err != nil {
return nil, rr.err return CommandTag{}, rr.err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -1538,7 +1540,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
if err != nil { if err != nil {
err = preferContextOverNetTimeoutError(rr.ctx, err) err = preferContextOverNetTimeoutError(rr.ctx, err)
rr.concludeCommand(nil, err) rr.concludeCommand(CommandTag{}, err)
rr.pgConn.contextWatcher.Unwatch() rr.pgConn.contextWatcher.Unwatch()
rr.closed = true rr.closed = true
if rr.multiResultReader == nil { if rr.multiResultReader == nil {
@ -1552,11 +1554,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
case *pgproto3.RowDescription: case *pgproto3.RowDescription:
rr.fieldDescriptions = msg.Fields rr.fieldDescriptions = msg.Fields
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
rr.concludeCommand(CommandTag(msg.CommandTag), nil) rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil)
case *pgproto3.EmptyQueryResponse: case *pgproto3.EmptyQueryResponse:
rr.concludeCommand(nil, nil) rr.concludeCommand(CommandTag{}, nil)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
rr.concludeCommand(nil, ErrorResponseToPgError(msg)) rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
} }
return msg, nil return msg, nil
@ -1659,6 +1661,13 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) {
return strings.Replace(s, "'", "''", -1), nil return strings.Replace(s, "'", "''", -1), nil
} }
// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.
func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
ct := make([]byte, len(buf))
copy(ct, buf)
return CommandTag{buf: ct}
}
// HijackedConn is the result of hijacking a connection. // HijackedConn is the result of hijacking a connection.
// //
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning

View File

@ -0,0 +1,41 @@
package pgconn
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCommandTag(t *testing.T) {
t.Parallel()
var tests = []struct {
commandTag CommandTag
rowsAffected int64
isInsert bool
isUpdate bool
isDelete bool
isSelect bool
}{
{commandTag: CommandTag{buf: []byte("INSERT 0 5")}, rowsAffected: 5, isInsert: true},
{commandTag: CommandTag{buf: []byte("UPDATE 0")}, rowsAffected: 0, isUpdate: true},
{commandTag: CommandTag{buf: []byte("UPDATE 1")}, rowsAffected: 1, isUpdate: true},
{commandTag: CommandTag{buf: []byte("DELETE 0")}, rowsAffected: 0, isDelete: true},
{commandTag: CommandTag{buf: []byte("DELETE 1")}, rowsAffected: 1, isDelete: true},
{commandTag: CommandTag{buf: []byte("DELETE 1234567890")}, rowsAffected: 1234567890, isDelete: true},
{commandTag: CommandTag{buf: []byte("SELECT 1")}, rowsAffected: 1, isSelect: true},
{commandTag: CommandTag{buf: []byte("SELECT 99999999999")}, rowsAffected: 99999999999, isSelect: true},
{commandTag: CommandTag{buf: []byte("CREATE TABLE")}, rowsAffected: 0},
{commandTag: CommandTag{buf: []byte("ALTER TABLE")}, rowsAffected: 0},
{commandTag: CommandTag{buf: []byte("DROP TABLE")}, rowsAffected: 0},
}
for i, tt := range tests {
ct := tt.commandTag
assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
}
}

View File

@ -538,7 +538,7 @@ func TestConnExec(t *testing.T) {
assert.Len(t, results, 1) assert.Len(t, results, 1)
assert.Nil(t, results[0].Err) assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1) assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
@ -579,12 +579,12 @@ func TestConnExecMultipleQueries(t *testing.T) {
assert.Len(t, results, 2) assert.Len(t, results, 2)
assert.Nil(t, results[0].Err) assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1) assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
assert.Nil(t, results[1].Err) assert.Nil(t, results[1].Err)
assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
assert.Len(t, results[1].Rows, 1) assert.Len(t, results[1].Rows, 1)
assert.Equal(t, "1", string(results[1].Rows[0][0])) assert.Equal(t, "1", string(results[1].Rows[0][0]))
@ -741,7 +741,7 @@ func TestConnExecParams(t *testing.T) {
} }
assert.Equal(t, 1, rowCount) assert.Equal(t, 1, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", string(commandTag)) assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err) assert.NoError(t, err)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
@ -840,7 +840,7 @@ func TestConnExecParamsCanceled(t *testing.T) {
} }
assert.Equal(t, 0, rowCount) assert.Equal(t, 0, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, pgconn.CommandTag{}, commandTag)
assert.True(t, pgconn.Timeout(err)) assert.True(t, pgconn.Timeout(err))
assert.ErrorIs(t, err, context.DeadlineExceeded) assert.ErrorIs(t, err, context.DeadlineExceeded)
@ -880,7 +880,7 @@ func TestConnExecParamsEmptySQL(t *testing.T) {
defer closeConn(t, pgConn) defer closeConn(t, pgConn)
result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read() result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read()
assert.Nil(t, result.CommandTag) assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
assert.Len(t, result.Rows, 0) assert.Len(t, result.Rows, 0)
assert.NoError(t, result.Err) assert.NoError(t, result.Err)
@ -907,7 +907,7 @@ func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) {
} }
assert.Equal(t, 1, rowCount) assert.Equal(t, 1, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", string(commandTag)) assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err) assert.NoError(t, err)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
@ -937,7 +937,7 @@ func TestConnExecPrepared(t *testing.T) {
} }
assert.Equal(t, 1, rowCount) assert.Equal(t, 1, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", string(commandTag)) assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err) assert.NoError(t, err)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
@ -1025,7 +1025,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
} }
assert.Equal(t, 0, rowCount) assert.Equal(t, 0, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, pgconn.CommandTag{}, commandTag)
assert.True(t, pgconn.Timeout(err)) assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed()) assert.True(t, pgConn.IsClosed())
select { select {
@ -1069,7 +1069,7 @@ func TestConnExecPreparedEmptySQL(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
assert.Nil(t, result.CommandTag) assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
assert.Len(t, result.Rows, 0) assert.Len(t, result.Rows, 0)
assert.NoError(t, result.Err) assert.NoError(t, result.Err)
@ -1097,15 +1097,15 @@ func TestConnExecBatch(t *testing.T) {
require.Len(t, results[0].Rows, 1) require.Len(t, results[0].Rows, 1)
require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0]))
assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
require.Len(t, results[1].Rows, 1) require.Len(t, results[1].Rows, 1)
require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0]))
assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
require.Len(t, results[2].Rows, 1) require.Len(t, results[2].Rows, 1)
require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0]))
assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
} }
func TestConnExecBatchDeferredError(t *testing.T) { func TestConnExecBatchDeferredError(t *testing.T) {
@ -1199,7 +1199,7 @@ func TestConnExecBatchHuge(t *testing.T) {
for i := range args { for i := range args {
require.Len(t, results[i].Rows, 1) require.Len(t, results[i].Rows, 1)
require.Equal(t, args[i], string(results[i].Rows[0][0])) require.Equal(t, args[i], string(results[i].Rows[0][0]))
assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) assert.Equal(t, "SELECT 1", results[i].CommandTag.String())
} }
} }
@ -1247,47 +1247,13 @@ func TestConnLocking(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, results, 1) assert.Len(t, results, 1)
assert.Nil(t, results[0].Err) assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1) assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
func TestCommandTag(t *testing.T) {
t.Parallel()
var tests = []struct {
commandTag pgconn.CommandTag
rowsAffected int64
isInsert bool
isUpdate bool
isDelete bool
isSelect bool
}{
{commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true},
{commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true},
{commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true},
{commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true},
{commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true},
{commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true},
{commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true},
{commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true},
{commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0},
{commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0},
{commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0},
}
for i, tt := range tests {
ct := tt.commandTag
assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
}
}
func TestConnOnNotice(t *testing.T) { func TestConnOnNotice(t *testing.T) {
t.Parallel() t.Parallel()
@ -1546,7 +1512,7 @@ func TestConnCopyToCanceled(t *testing.T) {
defer cancel() defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, pgconn.CommandTag(nil), res) assert.Equal(t, pgconn.CommandTag{}, res)
assert.True(t, pgConn.IsClosed()) assert.True(t, pgConn.IsClosed())
select { select {
@ -1571,7 +1537,7 @@ func TestConnCopyToPrecanceled(t *testing.T) {
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(err)) assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), res) assert.Equal(t, pgconn.CommandTag{}, res)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@ -1692,7 +1658,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(err)) assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), ct) assert.Equal(t, pgconn.CommandTag{}, ct)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@ -2014,7 +1980,7 @@ func TestHijackAndConstruct(t *testing.T) {
assert.Len(t, results, 1) assert.Len(t, results, 1)
assert.Nil(t, results[0].Err) assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1) assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))

View File

@ -10,7 +10,7 @@ type errBatchResults struct {
} }
func (br errBatchResults) Exec() (pgconn.CommandTag, error) { func (br errBatchResults) Exec() (pgconn.CommandTag, error) {
return nil, br.err return pgconn.CommandTag{}, br.err
} }
func (br errBatchResults) Query() (pgx.Rows, error) { func (br errBatchResults) Query() (pgx.Rows, error) {
@ -18,7 +18,7 @@ func (br errBatchResults) Query() (pgx.Rows, error) {
} }
func (br errBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { func (br errBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) {
return nil, br.err return pgconn.CommandTag{}, br.err
} }
func (br errBatchResults) QueryRow() pgx.Row { func (br errBatchResults) QueryRow() pgx.Row {

View File

@ -27,7 +27,7 @@ type execer interface {
func testExec(t *testing.T, db execer) { func testExec(t *testing.T, db execer) {
results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'") results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'")
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, "SET", results) assert.EqualValues(t, "SET", results.String())
} }
type queryer interface { type queryer interface {

View File

@ -470,7 +470,7 @@ func (p *Pool) Stat() *Stat {
func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
c, err := p.Acquire(ctx) c, err := p.Acquire(ctx)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
defer c.Release() defer c.Release()
@ -527,7 +527,7 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg
func (p *Pool) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { func (p *Pool) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) {
c, err := p.Acquire(ctx) c, err := p.Acquire(ctx)
if err != nil { if err != nil {
return nil, err return pgconn.CommandTag{}, err
} }
defer c.Release() defer c.Release()

View File

@ -12,7 +12,7 @@ type errRows struct {
func (errRows) Close() {} func (errRows) Close() {}
func (e errRows) Err() error { return e.err } func (e errRows) Err() error { return e.err }
func (errRows) CommandTag() pgconn.CommandTag { return nil } func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} }
func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil }
func (errRows) Next() bool { return false } func (errRows) Next() bool { return false }
func (e errRows) Scan(dest ...interface{}) error { return e.err } func (e errRows) Scan(dest ...interface{}) error { return e.err }

View File

@ -45,7 +45,7 @@ func TestConnQueryScan(t *testing.T) {
t.Fatalf("conn.Query failed: %v", err) t.Fatalf("conn.Query failed: %v", err)
} }
assert.Equal(t, "SELECT 10", string(rows.CommandTag())) assert.Equal(t, "SELECT 10", rows.CommandTag().String())
if rowCount != 10 { if rowCount != 10 {
t.Error("Select called onDataRow wrong number of times") t.Error("Select called onDataRow wrong number of times")
@ -79,7 +79,7 @@ func TestConnQueryWithoutResultSetCommandTag(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
rows.Close() rows.Close()
assert.NoError(t, rows.Err()) assert.NoError(t, rows.Err())
assert.Equal(t, "CREATE TABLE", string(rows.CommandTag())) assert.Equal(t, "CREATE TABLE", rows.CommandTag().String())
} }
func TestConnQueryScanWithManyColumns(t *testing.T) { func TestConnQueryScanWithManyColumns(t *testing.T) {
@ -1139,7 +1139,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(commandTag) != "INSERT 0 1" { if commandTag.String() != "INSERT 0 1" {
t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag) t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag)
} }
@ -1976,7 +1976,7 @@ func TestConnQueryFuncAbort(t *testing.T) {
}, },
) )
require.EqualError(t, err, "abort") require.EqualError(t, err, "abort")
require.Nil(t, ct) require.Equal(t, pgconn.CommandTag{}, ct)
}) })
} }

8
tx.go
View File

@ -235,7 +235,7 @@ func (tx *dbTx) Commit(ctx context.Context) error {
} }
return err return err
} }
if string(commandTag) == "ROLLBACK" { if commandTag.String() == "ROLLBACK" {
return ErrTxCommitRollback return ErrTxCommitRollback
} }
@ -296,7 +296,7 @@ func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) R
// QueryFunc delegates to the underlying *Conn. // QueryFunc delegates to the underlying *Conn.
func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
if tx.closed { if tx.closed {
return nil, ErrTxClosed return pgconn.CommandTag{}, ErrTxClosed
} }
return tx.conn.QueryFunc(ctx, sql, args, scans, f) return tx.conn.QueryFunc(ctx, sql, args, scans, f)
@ -380,7 +380,7 @@ func (sp *dbSavepoint) Rollback(ctx context.Context) error {
// Exec delegates to the underlying Tx // Exec delegates to the underlying Tx
func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
if sp.closed { if sp.closed {
return nil, ErrTxClosed return pgconn.CommandTag{}, ErrTxClosed
} }
return sp.tx.Exec(ctx, sql, arguments...) return sp.tx.Exec(ctx, sql, arguments...)
@ -415,7 +415,7 @@ func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...interfa
// QueryFunc delegates to the underlying Tx. // QueryFunc delegates to the underlying Tx.
func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
if sp.closed { if sp.closed {
return nil, ErrTxClosed return pgconn.CommandTag{}, ErrTxClosed
} }
return sp.tx.QueryFunc(ctx, sql, args, scans, f) return sp.tx.QueryFunc(ctx, sql, args, scans, f)