pgconn.CommandTag is now an opaque type

It now makes a copy instead of retaining driver memory. This is in
preparation to reuse the driver read buffer.
query-exec-mode
Jack Christensen 2022-02-12 10:26:26 -06:00
parent e6680127e3
commit 9c5dfbdfb3
14 changed files with 246 additions and 225 deletions

View File

@ -64,10 +64,10 @@ type batchResults struct {
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
func (br *batchResults) Exec() (pgconn.CommandTag, error) {
if br.err != nil {
return nil, br.err
return pgconn.CommandTag{}, br.err
}
if br.closed {
return nil, fmt.Errorf("batch already closed")
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
query, arguments, _ := br.nextQueryAndArgs()
@ -84,7 +84,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
"err": err,
})
}
return nil, err
return pgconn.CommandTag{}, err
}
commandTag, err := br.mrr.ResultReader().Close()
@ -151,29 +151,29 @@ func (br *batchResults) Query() (Rows, error) {
// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc.
func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
if br.closed {
return nil, fmt.Errorf("batch already closed")
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
rows, err := br.Query()
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(scans...)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
err = f(rows)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
}
if err := rows.Err(); err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
return rows.CommandTag(), nil

18
conn.go
View File

@ -432,7 +432,7 @@ optionLoop:
if c.stmtcache != nil {
sd, err := c.stmtcache.Get(ctx, sql)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
if c.stmtcache.Mode() == stmtcache.ModeDescribe {
@ -443,7 +443,7 @@ optionLoop:
sd, err := c.Prepare(ctx, "", sql)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
return c.execPrepared(ctx, sd, arguments)
}
@ -452,7 +452,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i
if len(arguments) > 0 {
sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
}
@ -493,7 +493,7 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, argu
func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
err := c.execParamsAndPreparedPrefix(sd, arguments)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
@ -504,7 +504,7 @@ func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription,
func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
err := c.execParamsAndPreparedPrefix(sd, arguments)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read()
@ -688,24 +688,24 @@ type QueryFuncRow interface {
func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
rows, err := c.Query(ctx, sql, args...)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(scans...)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
err = f(rows)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
}
if err := rows.Err(); err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
return rows.CommandTag(), nil

View File

@ -188,31 +188,31 @@ func TestExec(t *testing.T) {
t.Parallel()
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" {
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec")
}
// Accept parameters
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" {
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results.String() != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results)
}
if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" {
if results := mustExec(t, conn, "drop table foo;"); results.String() != "DROP TABLE" {
t.Error("Unexpected results from Exec")
}
// Multiple statements can be executed -- last command tag is returned
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" {
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results.String() != "DROP TABLE" {
t.Error("Unexpected results from Exec")
}
// Can execute longer SQL strings than sharedBufferSize
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" {
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results.String() != "SELECT 1" {
t.Errorf("Unexpected results from Exec: %v", results)
}
// Exec no-op which does not return a command tag
if results := mustExec(t, conn, "--;"); string(results) != "" {
if results := mustExec(t, conn, "--;"); results.String() != "" {
t.Errorf("Unexpected results from Exec: %v", results)
}
})
@ -260,7 +260,7 @@ func TestExecContextWithoutCancelation(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if string(commandTag) != "CREATE TABLE" {
if commandTag.String() != "CREATE TABLE" {
t.Fatalf("Unexpected results from Exec: %v", commandTag)
}
assert.False(t, pgconn.SafeToRetry(err))
@ -350,15 +350,15 @@ func TestExecStatementCacheModes(t *testing.T) {
commandTag, err := conn.Exec(context.Background(), "select 1")
assert.NoError(t, err, tt.name)
assert.Equal(t, "SELECT 1", string(commandTag), tt.name)
assert.Equal(t, "SELECT 1", commandTag.String(), tt.name)
commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1")
assert.NoError(t, err, tt.name)
assert.Equal(t, "SELECT 2", string(commandTag), tt.name)
assert.Equal(t, "SELECT 2", commandTag.String(), tt.name)
commandTag, err = conn.Exec(context.Background(), "select 1")
assert.NoError(t, err, tt.name)
assert.Equal(t, "SELECT 1", string(commandTag), tt.name)
assert.Equal(t, "SELECT 1", commandTag.String(), tt.name)
ensureConnValid(t, conn)
}()
@ -378,7 +378,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if string(commandTag) != "CREATE TABLE" {
if commandTag.String() != "CREATE TABLE" {
t.Fatalf("Unexpected results from Exec: %v", commandTag)
}
@ -390,7 +390,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if string(commandTag) != "INSERT 0 1" {
if commandTag.String() != "INSERT 0 1" {
t.Fatalf("Unexpected results from Exec: %v", commandTag)
}
@ -720,12 +720,12 @@ func TestInsertBoolArray(t *testing.T) {
t.Parallel()
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" {
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec")
}
// Accept parameters
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" {
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results.String() != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results)
}
})
@ -735,12 +735,12 @@ func TestInsertTimestampArray(t *testing.T) {
t.Parallel()
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" {
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec")
}
// Accept parameters
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" {
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results.String() != "INSERT 0 1" {
t.Errorf("Unexpected results from Exec: %v", results)
}
})

View File

@ -0,0 +1,73 @@
package pgconn
import (
"strings"
"testing"
)
func BenchmarkCommandTagRowsAffected(b *testing.B) {
benchmarks := []struct {
commandTag string
rowsAffected int64
}{
{"UPDATE 1", 1},
{"UPDATE 123456789", 123456789},
{"INSERT 0 1", 1},
{"INSERT 0 123456789", 123456789},
}
for _, bm := range benchmarks {
ct := CommandTag{buf: []byte(bm.commandTag)}
b.Run(bm.commandTag, func(b *testing.B) {
var n int64
for i := 0; i < b.N; i++ {
n = ct.RowsAffected()
}
if n != bm.rowsAffected {
b.Errorf("expected %d got %d", bm.rowsAffected, n)
}
})
}
}
func BenchmarkCommandTagTypeFromString(b *testing.B) {
ct := CommandTag{buf: []byte("UPDATE 1")}
var update bool
for i := 0; i < b.N; i++ {
update = strings.HasPrefix(ct.String(), "UPDATE")
}
if !update {
b.Error("expected update")
}
}
func BenchmarkCommandTagInsert(b *testing.B) {
benchmarks := []struct {
commandTag string
is bool
}{
{"INSERT 1", true},
{"INSERT 1234567890", true},
{"UPDATE 1", false},
{"UPDATE 1234567890", false},
{"DELETE 1", false},
{"DELETE 1234567890", false},
{"SELECT 1", false},
{"SELECT 1234567890", false},
{"UNKNOWN 1234567890", false},
}
for _, bm := range benchmarks {
ct := CommandTag{buf: []byte(bm.commandTag)}
b.Run(bm.commandTag, func(b *testing.B) {
var is bool
for i := 0; i < b.N; i++ {
is = ct.Insert()
}
if is != bm.is {
b.Errorf("expected %v got %v", bm.is, is)
}
})
}
}

View File

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"os"
"strings"
"testing"
"github.com/jackc/pgx/v5/pgconn"
@ -253,70 +252,3 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
// conn.ChanToSetDeadline().Ignore()
// }
// }
func BenchmarkCommandTagRowsAffected(b *testing.B) {
benchmarks := []struct {
commandTag string
rowsAffected int64
}{
{"UPDATE 1", 1},
{"UPDATE 123456789", 123456789},
{"INSERT 0 1", 1},
{"INSERT 0 123456789", 123456789},
}
for _, bm := range benchmarks {
ct := pgconn.CommandTag(bm.commandTag)
b.Run(bm.commandTag, func(b *testing.B) {
var n int64
for i := 0; i < b.N; i++ {
n = ct.RowsAffected()
}
if n != bm.rowsAffected {
b.Errorf("expected %d got %d", bm.rowsAffected, n)
}
})
}
}
func BenchmarkCommandTagTypeFromString(b *testing.B) {
ct := pgconn.CommandTag("UPDATE 1")
var update bool
for i := 0; i < b.N; i++ {
update = strings.HasPrefix(ct.String(), "UPDATE")
}
if !update {
b.Error("expected update")
}
}
func BenchmarkCommandTagInsert(b *testing.B) {
benchmarks := []struct {
commandTag string
is bool
}{
{"INSERT 1", true},
{"INSERT 1234567890", true},
{"UPDATE 1", false},
{"UPDATE 1234567890", false},
{"DELETE 1", false},
{"DELETE 1234567890", false},
{"SELECT 1", false},
{"SELECT 1234567890", false},
{"UNKNOWN 1234567890", false},
}
for _, bm := range benchmarks {
ct := pgconn.CommandTag(bm.commandTag)
b.Run(bm.commandTag, func(b *testing.B) {
var is bool
for i := 0; i < b.N; i++ {
is = ct.Insert()
}
if is != bm.is {
b.Errorf("expected %v got %v", bm.is, is)
}
})
}
}

View File

@ -685,15 +685,17 @@ func (pgConn *PgConn) ParameterStatus(key string) string {
}
// CommandTag is the result of an Exec function
type CommandTag []byte
type CommandTag struct {
buf []byte
}
// RowsAffected returns the number of rows affected. If the CommandTag was not
// for a row affecting command (e.g. "CREATE TABLE") then it returns 0.
func (ct CommandTag) RowsAffected() int64 {
// Find last non-digit
idx := -1
for i := len(ct) - 1; i >= 0; i-- {
if ct[i] >= '0' && ct[i] <= '9' {
for i := len(ct.buf) - 1; i >= 0; i-- {
if ct.buf[i] >= '0' && ct.buf[i] <= '9' {
idx = i
} else {
break
@ -705,7 +707,7 @@ func (ct CommandTag) RowsAffected() int64 {
}
var n int64
for _, b := range ct[idx:] {
for _, b := range ct.buf[idx:] {
n = n*10 + int64(b-'0')
}
@ -713,51 +715,51 @@ func (ct CommandTag) RowsAffected() int64 {
}
func (ct CommandTag) String() string {
return string(ct)
return string(ct.buf)
}
// Insert is true if the command tag starts with "INSERT".
func (ct CommandTag) Insert() bool {
return len(ct) >= 6 &&
ct[0] == 'I' &&
ct[1] == 'N' &&
ct[2] == 'S' &&
ct[3] == 'E' &&
ct[4] == 'R' &&
ct[5] == 'T'
return len(ct.buf) >= 6 &&
ct.buf[0] == 'I' &&
ct.buf[1] == 'N' &&
ct.buf[2] == 'S' &&
ct.buf[3] == 'E' &&
ct.buf[4] == 'R' &&
ct.buf[5] == 'T'
}
// Update is true if the command tag starts with "UPDATE".
func (ct CommandTag) Update() bool {
return len(ct) >= 6 &&
ct[0] == 'U' &&
ct[1] == 'P' &&
ct[2] == 'D' &&
ct[3] == 'A' &&
ct[4] == 'T' &&
ct[5] == 'E'
return len(ct.buf) >= 6 &&
ct.buf[0] == 'U' &&
ct.buf[1] == 'P' &&
ct.buf[2] == 'D' &&
ct.buf[3] == 'A' &&
ct.buf[4] == 'T' &&
ct.buf[5] == 'E'
}
// Delete is true if the command tag starts with "DELETE".
func (ct CommandTag) Delete() bool {
return len(ct) >= 6 &&
ct[0] == 'D' &&
ct[1] == 'E' &&
ct[2] == 'L' &&
ct[3] == 'E' &&
ct[4] == 'T' &&
ct[5] == 'E'
return len(ct.buf) >= 6 &&
ct.buf[0] == 'D' &&
ct.buf[1] == 'E' &&
ct.buf[2] == 'L' &&
ct.buf[3] == 'E' &&
ct.buf[4] == 'T' &&
ct.buf[5] == 'E'
}
// Select is true if the command tag starts with "SELECT".
func (ct CommandTag) Select() bool {
return len(ct) >= 6 &&
ct[0] == 'S' &&
ct[1] == 'E' &&
ct[2] == 'L' &&
ct[3] == 'E' &&
ct[4] == 'C' &&
ct[5] == 'T'
return len(ct.buf) >= 6 &&
ct.buf[0] == 'S' &&
ct.buf[1] == 'E' &&
ct.buf[2] == 'L' &&
ct.buf[3] == 'E' &&
ct.buf[4] == 'C' &&
ct.buf[5] == 'T'
}
type StatementDescription struct {
@ -1076,13 +1078,13 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
result := &pgConn.resultReader
if err := pgConn.lock(); err != nil {
result.concludeCommand(nil, err)
result.concludeCommand(CommandTag{}, err)
result.closed = true
return result
}
if len(paramValues) > math.MaxUint16 {
result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
result.concludeCommand(CommandTag{}, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
result.closed = true
pgConn.unlock()
return result
@ -1091,7 +1093,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
if ctx != context.Background() {
select {
case <-ctx.Done():
result.concludeCommand(nil, newContextAlreadyDoneError(ctx))
result.concludeCommand(CommandTag{}, newContextAlreadyDoneError(ctx))
result.closed = true
pgConn.unlock()
return result
@ -1111,7 +1113,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.asyncClose()
result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0})
result.concludeCommand(CommandTag{}, &writeError{err: err, safeToRetry: n == 0})
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
@ -1124,14 +1126,14 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
// CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return nil, err
return CommandTag{}, err
}
if ctx != context.Background() {
select {
case <-ctx.Done():
pgConn.unlock()
return nil, newContextAlreadyDoneError(ctx)
return CommandTag{}, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -1146,7 +1148,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
if err != nil {
pgConn.asyncClose()
pgConn.unlock()
return nil, &writeError{err: err, safeToRetry: n == 0}
return CommandTag{}, &writeError{err: err, safeToRetry: n == 0}
}
// Read results
@ -1156,7 +1158,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
@ -1165,13 +1167,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
_, err := w.Write(msg.Data)
if err != nil {
pgConn.asyncClose()
return nil, err
return CommandTag{}, err
}
case *pgproto3.ReadyForQuery:
pgConn.unlock()
return commandTag, pgErr
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
commandTag = pgConn.makeCommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
}
@ -1184,14 +1186,14 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return nil, err
return CommandTag{}, err
}
defer pgConn.unlock()
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, newContextAlreadyDoneError(ctx)
return CommandTag{}, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -1205,7 +1207,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.asyncClose()
return nil, &writeError{err: err, safeToRetry: n == 0}
return CommandTag{}, &writeError{err: err, safeToRetry: n == 0}
}
// Send copy data
@ -1255,7 +1257,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
@ -1279,7 +1281,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.asyncClose()
return nil, err
return CommandTag{}, err
}
// Read results
@ -1288,14 +1290,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
return commandTag, pgErr
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
commandTag = pgConn.makeCommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
}
@ -1368,7 +1370,7 @@ func (mrr *MultiResultReader) NextResult() bool {
return true
case *pgproto3.CommandComplete:
mrr.pgConn.resultReader = ResultReader{
commandTag: CommandTag(msg.CommandTag),
commandTag: mrr.pgConn.makeCommandTag(msg.CommandTag),
commandConcluded: true,
closed: true,
}
@ -1483,7 +1485,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
for !rr.commandConcluded {
_, err := rr.receiveMessage()
if err != nil {
return nil, rr.err
return CommandTag{}, rr.err
}
}
@ -1491,7 +1493,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
for {
msg, err := rr.receiveMessage()
if err != nil {
return nil, rr.err
return CommandTag{}, rr.err
}
switch msg := msg.(type) {
@ -1538,7 +1540,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
if err != nil {
err = preferContextOverNetTimeoutError(rr.ctx, err)
rr.concludeCommand(nil, err)
rr.concludeCommand(CommandTag{}, err)
rr.pgConn.contextWatcher.Unwatch()
rr.closed = true
if rr.multiResultReader == nil {
@ -1552,11 +1554,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
case *pgproto3.RowDescription:
rr.fieldDescriptions = msg.Fields
case *pgproto3.CommandComplete:
rr.concludeCommand(CommandTag(msg.CommandTag), nil)
rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil)
case *pgproto3.EmptyQueryResponse:
rr.concludeCommand(nil, nil)
rr.concludeCommand(CommandTag{}, nil)
case *pgproto3.ErrorResponse:
rr.concludeCommand(nil, ErrorResponseToPgError(msg))
rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
}
return msg, nil
@ -1659,6 +1661,13 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) {
return strings.Replace(s, "'", "''", -1), nil
}
// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.
func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
ct := make([]byte, len(buf))
copy(ct, buf)
return CommandTag{buf: ct}
}
// HijackedConn is the result of hijacking a connection.
//
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning

View File

@ -0,0 +1,41 @@
package pgconn
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCommandTag(t *testing.T) {
t.Parallel()
var tests = []struct {
commandTag CommandTag
rowsAffected int64
isInsert bool
isUpdate bool
isDelete bool
isSelect bool
}{
{commandTag: CommandTag{buf: []byte("INSERT 0 5")}, rowsAffected: 5, isInsert: true},
{commandTag: CommandTag{buf: []byte("UPDATE 0")}, rowsAffected: 0, isUpdate: true},
{commandTag: CommandTag{buf: []byte("UPDATE 1")}, rowsAffected: 1, isUpdate: true},
{commandTag: CommandTag{buf: []byte("DELETE 0")}, rowsAffected: 0, isDelete: true},
{commandTag: CommandTag{buf: []byte("DELETE 1")}, rowsAffected: 1, isDelete: true},
{commandTag: CommandTag{buf: []byte("DELETE 1234567890")}, rowsAffected: 1234567890, isDelete: true},
{commandTag: CommandTag{buf: []byte("SELECT 1")}, rowsAffected: 1, isSelect: true},
{commandTag: CommandTag{buf: []byte("SELECT 99999999999")}, rowsAffected: 99999999999, isSelect: true},
{commandTag: CommandTag{buf: []byte("CREATE TABLE")}, rowsAffected: 0},
{commandTag: CommandTag{buf: []byte("ALTER TABLE")}, rowsAffected: 0},
{commandTag: CommandTag{buf: []byte("DROP TABLE")}, rowsAffected: 0},
}
for i, tt := range tests {
ct := tt.commandTag
assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
}
}

View File

@ -538,7 +538,7 @@ func TestConnExec(t *testing.T) {
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
@ -579,12 +579,12 @@ func TestConnExecMultipleQueries(t *testing.T) {
assert.Len(t, results, 2)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
assert.Nil(t, results[1].Err)
assert.Equal(t, "SELECT 1", string(results[1].CommandTag))
assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
assert.Len(t, results[1].Rows, 1)
assert.Equal(t, "1", string(results[1].Rows[0][0]))
@ -741,7 +741,7 @@ func TestConnExecParams(t *testing.T) {
}
assert.Equal(t, 1, rowCount)
commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", string(commandTag))
assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err)
ensureConnValid(t, pgConn)
@ -840,7 +840,7 @@ func TestConnExecParamsCanceled(t *testing.T) {
}
assert.Equal(t, 0, rowCount)
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, pgconn.CommandTag{}, commandTag)
assert.True(t, pgconn.Timeout(err))
assert.ErrorIs(t, err, context.DeadlineExceeded)
@ -880,7 +880,7 @@ func TestConnExecParamsEmptySQL(t *testing.T) {
defer closeConn(t, pgConn)
result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read()
assert.Nil(t, result.CommandTag)
assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
assert.Len(t, result.Rows, 0)
assert.NoError(t, result.Err)
@ -907,7 +907,7 @@ func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) {
}
assert.Equal(t, 1, rowCount)
commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", string(commandTag))
assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err)
ensureConnValid(t, pgConn)
@ -937,7 +937,7 @@ func TestConnExecPrepared(t *testing.T) {
}
assert.Equal(t, 1, rowCount)
commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", string(commandTag))
assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err)
ensureConnValid(t, pgConn)
@ -1025,7 +1025,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
}
assert.Equal(t, 0, rowCount)
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, pgconn.CommandTag{}, commandTag)
assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed())
select {
@ -1069,7 +1069,7 @@ func TestConnExecPreparedEmptySQL(t *testing.T) {
require.NoError(t, err)
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
assert.Nil(t, result.CommandTag)
assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
assert.Len(t, result.Rows, 0)
assert.NoError(t, result.Err)
@ -1097,15 +1097,15 @@ func TestConnExecBatch(t *testing.T) {
require.Len(t, results[0].Rows, 1)
require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0]))
assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
require.Len(t, results[1].Rows, 1)
require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0]))
assert.Equal(t, "SELECT 1", string(results[1].CommandTag))
assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
require.Len(t, results[2].Rows, 1)
require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0]))
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
}
func TestConnExecBatchDeferredError(t *testing.T) {
@ -1199,7 +1199,7 @@ func TestConnExecBatchHuge(t *testing.T) {
for i := range args {
require.Len(t, results[i].Rows, 1)
require.Equal(t, args[i], string(results[i].Rows[0][0]))
assert.Equal(t, "SELECT 1", string(results[i].CommandTag))
assert.Equal(t, "SELECT 1", results[i].CommandTag.String())
}
}
@ -1247,47 +1247,13 @@ func TestConnLocking(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
ensureConnValid(t, pgConn)
}
func TestCommandTag(t *testing.T) {
t.Parallel()
var tests = []struct {
commandTag pgconn.CommandTag
rowsAffected int64
isInsert bool
isUpdate bool
isDelete bool
isSelect bool
}{
{commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true},
{commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true},
{commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true},
{commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true},
{commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true},
{commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true},
{commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true},
{commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true},
{commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0},
{commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0},
{commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0},
}
for i, tt := range tests {
ct := tt.commandTag
assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
}
}
func TestConnOnNotice(t *testing.T) {
t.Parallel()
@ -1546,7 +1512,7 @@ func TestConnCopyToCanceled(t *testing.T) {
defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.Error(t, err)
assert.Equal(t, pgconn.CommandTag(nil), res)
assert.Equal(t, pgconn.CommandTag{}, res)
assert.True(t, pgConn.IsClosed())
select {
@ -1571,7 +1537,7 @@ func TestConnCopyToPrecanceled(t *testing.T) {
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), res)
assert.Equal(t, pgconn.CommandTag{}, res)
ensureConnValid(t, pgConn)
}
@ -1692,7 +1658,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), ct)
assert.Equal(t, pgconn.CommandTag{}, ct)
ensureConnValid(t, pgConn)
}
@ -2014,7 +1980,7 @@ func TestHijackAndConstruct(t *testing.T) {
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))

View File

@ -10,7 +10,7 @@ type errBatchResults struct {
}
func (br errBatchResults) Exec() (pgconn.CommandTag, error) {
return nil, br.err
return pgconn.CommandTag{}, br.err
}
func (br errBatchResults) Query() (pgx.Rows, error) {
@ -18,7 +18,7 @@ func (br errBatchResults) Query() (pgx.Rows, error) {
}
func (br errBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) {
return nil, br.err
return pgconn.CommandTag{}, br.err
}
func (br errBatchResults) QueryRow() pgx.Row {

View File

@ -27,7 +27,7 @@ type execer interface {
func testExec(t *testing.T, db execer) {
results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'")
require.NoError(t, err)
assert.EqualValues(t, "SET", results)
assert.EqualValues(t, "SET", results.String())
}
type queryer interface {

View File

@ -470,7 +470,7 @@ func (p *Pool) Stat() *Stat {
func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
c, err := p.Acquire(ctx)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
defer c.Release()
@ -527,7 +527,7 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg
func (p *Pool) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) {
c, err := p.Acquire(ctx)
if err != nil {
return nil, err
return pgconn.CommandTag{}, err
}
defer c.Release()

View File

@ -12,7 +12,7 @@ type errRows struct {
func (errRows) Close() {}
func (e errRows) Err() error { return e.err }
func (errRows) CommandTag() pgconn.CommandTag { return nil }
func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} }
func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil }
func (errRows) Next() bool { return false }
func (e errRows) Scan(dest ...interface{}) error { return e.err }

View File

@ -45,7 +45,7 @@ func TestConnQueryScan(t *testing.T) {
t.Fatalf("conn.Query failed: %v", err)
}
assert.Equal(t, "SELECT 10", string(rows.CommandTag()))
assert.Equal(t, "SELECT 10", rows.CommandTag().String())
if rowCount != 10 {
t.Error("Select called onDataRow wrong number of times")
@ -79,7 +79,7 @@ func TestConnQueryWithoutResultSetCommandTag(t *testing.T) {
assert.NoError(t, err)
rows.Close()
assert.NoError(t, rows.Err())
assert.Equal(t, "CREATE TABLE", string(rows.CommandTag()))
assert.Equal(t, "CREATE TABLE", rows.CommandTag().String())
}
func TestConnQueryScanWithManyColumns(t *testing.T) {
@ -1139,7 +1139,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
if err != nil {
t.Fatal(err)
}
if string(commandTag) != "INSERT 0 1" {
if commandTag.String() != "INSERT 0 1" {
t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag)
}
@ -1976,7 +1976,7 @@ func TestConnQueryFuncAbort(t *testing.T) {
},
)
require.EqualError(t, err, "abort")
require.Nil(t, ct)
require.Equal(t, pgconn.CommandTag{}, ct)
})
}

8
tx.go
View File

@ -235,7 +235,7 @@ func (tx *dbTx) Commit(ctx context.Context) error {
}
return err
}
if string(commandTag) == "ROLLBACK" {
if commandTag.String() == "ROLLBACK" {
return ErrTxCommitRollback
}
@ -296,7 +296,7 @@ func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) R
// QueryFunc delegates to the underlying *Conn.
func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
if tx.closed {
return nil, ErrTxClosed
return pgconn.CommandTag{}, ErrTxClosed
}
return tx.conn.QueryFunc(ctx, sql, args, scans, f)
@ -380,7 +380,7 @@ func (sp *dbSavepoint) Rollback(ctx context.Context) error {
// Exec delegates to the underlying Tx
func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
if sp.closed {
return nil, ErrTxClosed
return pgconn.CommandTag{}, ErrTxClosed
}
return sp.tx.Exec(ctx, sql, arguments...)
@ -415,7 +415,7 @@ func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...interfa
// QueryFunc delegates to the underlying Tx.
func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
if sp.closed {
return nil, ErrTxClosed
return pgconn.CommandTag{}, ErrTxClosed
}
return sp.tx.QueryFunc(ctx, sql, args, scans, f)