diff --git a/batch.go b/batch.go index caa5a02f..689877a9 100644 --- a/batch.go +++ b/batch.go @@ -64,10 +64,10 @@ type batchResults struct { // Exec reads the results from the next query in the batch as if the query has been sent with Exec. func (br *batchResults) Exec() (pgconn.CommandTag, error) { if br.err != nil { - return nil, br.err + return pgconn.CommandTag{}, br.err } if br.closed { - return nil, fmt.Errorf("batch already closed") + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } query, arguments, _ := br.nextQueryAndArgs() @@ -84,7 +84,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { "err": err, }) } - return nil, err + return pgconn.CommandTag{}, err } commandTag, err := br.mrr.ResultReader().Close() @@ -151,29 +151,29 @@ func (br *batchResults) Query() (Rows, error) { // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if br.closed { - return nil, fmt.Errorf("batch already closed") + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } rows, err := br.Query() if err != nil { - return nil, err + return pgconn.CommandTag{}, err } defer rows.Close() for rows.Next() { err = rows.Scan(scans...) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } err = f(rows) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } } if err := rows.Err(); err != nil { - return nil, err + return pgconn.CommandTag{}, err } return rows.CommandTag(), nil diff --git a/conn.go b/conn.go index 8e0707c4..a03871ad 100644 --- a/conn.go +++ b/conn.go @@ -432,7 +432,7 @@ optionLoop: if c.stmtcache != nil { sd, err := c.stmtcache.Get(ctx, sql) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } if c.stmtcache.Mode() == stmtcache.ModeDescribe { @@ -443,7 +443,7 @@ optionLoop: sd, err := c.Prepare(ctx, "", sql) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } return c.execPrepared(ctx, sd, arguments) } @@ -452,7 +452,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i if len(arguments) > 0 { sql, err = c.sanitizeForSimpleQuery(sql, arguments...) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } } @@ -493,7 +493,7 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, argu func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { err := c.execParamsAndPreparedPrefix(sd, arguments) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() @@ -504,7 +504,7 @@ func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { err := c.execParamsAndPreparedPrefix(sd, arguments) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() @@ -688,24 +688,24 @@ type QueryFuncRow interface { func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { rows, err := c.Query(ctx, sql, args...) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } defer rows.Close() for rows.Next() { err = rows.Scan(scans...) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } err = f(rows) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } } if err := rows.Err(); err != nil { - return nil, err + return pgconn.CommandTag{}, err } return rows.CommandTag(), nil diff --git a/conn_test.go b/conn_test.go index e35def64..0cbc0040 100644 --- a/conn_test.go +++ b/conn_test.go @@ -188,31 +188,31 @@ func TestExec(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" { + if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters - if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" { + if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results.String() != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } - if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" { + if results := mustExec(t, conn, "drop table foo;"); results.String() != "DROP TABLE" { t.Error("Unexpected results from Exec") } // Multiple statements can be executed -- last command tag is returned - if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" { + if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results.String() != "DROP TABLE" { t.Error("Unexpected results from Exec") } // Can execute longer SQL strings than sharedBufferSize - if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" { + if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results.String() != "SELECT 1" { t.Errorf("Unexpected results from Exec: %v", results) } // Exec no-op which does not return a command tag - if results := mustExec(t, conn, "--;"); string(results) != "" { + if results := mustExec(t, conn, "--;"); results.String() != "" { t.Errorf("Unexpected results from Exec: %v", results) } }) @@ -260,7 +260,7 @@ func TestExecContextWithoutCancelation(t *testing.T) { if err != nil { t.Fatal(err) } - if string(commandTag) != "CREATE TABLE" { + if commandTag.String() != "CREATE TABLE" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } assert.False(t, pgconn.SafeToRetry(err)) @@ -350,15 +350,15 @@ func TestExecStatementCacheModes(t *testing.T) { commandTag, err := conn.Exec(context.Background(), "select 1") assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 1", string(commandTag), tt.name) + assert.Equal(t, "SELECT 1", commandTag.String(), tt.name) commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1") assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 2", string(commandTag), tt.name) + assert.Equal(t, "SELECT 2", commandTag.String(), tt.name) commandTag, err = conn.Exec(context.Background(), "select 1") assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 1", string(commandTag), tt.name) + assert.Equal(t, "SELECT 1", commandTag.String(), tt.name) ensureConnValid(t, conn) }() @@ -378,7 +378,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { if err != nil { t.Fatal(err) } - if string(commandTag) != "CREATE TABLE" { + if commandTag.String() != "CREATE TABLE" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } @@ -390,7 +390,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { if err != nil { t.Fatal(err) } - if string(commandTag) != "INSERT 0 1" { + if commandTag.String() != "INSERT 0 1" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } @@ -720,12 +720,12 @@ func TestInsertBoolArray(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" { + if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" { + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results.String() != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } }) @@ -735,12 +735,12 @@ func TestInsertTimestampArray(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" { + if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" { + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results.String() != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } }) diff --git a/pgconn/benchmark_private_test.go b/pgconn/benchmark_private_test.go new file mode 100644 index 00000000..e074c75c --- /dev/null +++ b/pgconn/benchmark_private_test.go @@ -0,0 +1,73 @@ +package pgconn + +import ( + "strings" + "testing" +) + +func BenchmarkCommandTagRowsAffected(b *testing.B) { + benchmarks := []struct { + commandTag string + rowsAffected int64 + }{ + {"UPDATE 1", 1}, + {"UPDATE 123456789", 123456789}, + {"INSERT 0 1", 1}, + {"INSERT 0 123456789", 123456789}, + } + + for _, bm := range benchmarks { + ct := CommandTag{buf: []byte(bm.commandTag)} + b.Run(bm.commandTag, func(b *testing.B) { + var n int64 + for i := 0; i < b.N; i++ { + n = ct.RowsAffected() + } + if n != bm.rowsAffected { + b.Errorf("expected %d got %d", bm.rowsAffected, n) + } + }) + } +} + +func BenchmarkCommandTagTypeFromString(b *testing.B) { + ct := CommandTag{buf: []byte("UPDATE 1")} + + var update bool + for i := 0; i < b.N; i++ { + update = strings.HasPrefix(ct.String(), "UPDATE") + } + if !update { + b.Error("expected update") + } +} + +func BenchmarkCommandTagInsert(b *testing.B) { + benchmarks := []struct { + commandTag string + is bool + }{ + {"INSERT 1", true}, + {"INSERT 1234567890", true}, + {"UPDATE 1", false}, + {"UPDATE 1234567890", false}, + {"DELETE 1", false}, + {"DELETE 1234567890", false}, + {"SELECT 1", false}, + {"SELECT 1234567890", false}, + {"UNKNOWN 1234567890", false}, + } + + for _, bm := range benchmarks { + ct := CommandTag{buf: []byte(bm.commandTag)} + b.Run(bm.commandTag, func(b *testing.B) { + var is bool + for i := 0; i < b.N; i++ { + is = ct.Insert() + } + if is != bm.is { + b.Errorf("expected %v got %v", bm.is, is) + } + }) + } +} diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go index 088a9bd9..ffa42243 100644 --- a/pgconn/benchmark_test.go +++ b/pgconn/benchmark_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "os" - "strings" "testing" "github.com/jackc/pgx/v5/pgconn" @@ -253,70 +252,3 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { // conn.ChanToSetDeadline().Ignore() // } // } - -func BenchmarkCommandTagRowsAffected(b *testing.B) { - benchmarks := []struct { - commandTag string - rowsAffected int64 - }{ - {"UPDATE 1", 1}, - {"UPDATE 123456789", 123456789}, - {"INSERT 0 1", 1}, - {"INSERT 0 123456789", 123456789}, - } - - for _, bm := range benchmarks { - ct := pgconn.CommandTag(bm.commandTag) - b.Run(bm.commandTag, func(b *testing.B) { - var n int64 - for i := 0; i < b.N; i++ { - n = ct.RowsAffected() - } - if n != bm.rowsAffected { - b.Errorf("expected %d got %d", bm.rowsAffected, n) - } - }) - } -} - -func BenchmarkCommandTagTypeFromString(b *testing.B) { - ct := pgconn.CommandTag("UPDATE 1") - - var update bool - for i := 0; i < b.N; i++ { - update = strings.HasPrefix(ct.String(), "UPDATE") - } - if !update { - b.Error("expected update") - } -} - -func BenchmarkCommandTagInsert(b *testing.B) { - benchmarks := []struct { - commandTag string - is bool - }{ - {"INSERT 1", true}, - {"INSERT 1234567890", true}, - {"UPDATE 1", false}, - {"UPDATE 1234567890", false}, - {"DELETE 1", false}, - {"DELETE 1234567890", false}, - {"SELECT 1", false}, - {"SELECT 1234567890", false}, - {"UNKNOWN 1234567890", false}, - } - - for _, bm := range benchmarks { - ct := pgconn.CommandTag(bm.commandTag) - b.Run(bm.commandTag, func(b *testing.B) { - var is bool - for i := 0; i < b.N; i++ { - is = ct.Insert() - } - if is != bm.is { - b.Errorf("expected %v got %v", bm.is, is) - } - }) - } -} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 843bbef4..16d54f3a 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -685,15 +685,17 @@ func (pgConn *PgConn) ParameterStatus(key string) string { } // CommandTag is the result of an Exec function -type CommandTag []byte +type CommandTag struct { + buf []byte +} // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { // Find last non-digit idx := -1 - for i := len(ct) - 1; i >= 0; i-- { - if ct[i] >= '0' && ct[i] <= '9' { + for i := len(ct.buf) - 1; i >= 0; i-- { + if ct.buf[i] >= '0' && ct.buf[i] <= '9' { idx = i } else { break @@ -705,7 +707,7 @@ func (ct CommandTag) RowsAffected() int64 { } var n int64 - for _, b := range ct[idx:] { + for _, b := range ct.buf[idx:] { n = n*10 + int64(b-'0') } @@ -713,51 +715,51 @@ func (ct CommandTag) RowsAffected() int64 { } func (ct CommandTag) String() string { - return string(ct) + return string(ct.buf) } // Insert is true if the command tag starts with "INSERT". func (ct CommandTag) Insert() bool { - return len(ct) >= 6 && - ct[0] == 'I' && - ct[1] == 'N' && - ct[2] == 'S' && - ct[3] == 'E' && - ct[4] == 'R' && - ct[5] == 'T' + return len(ct.buf) >= 6 && + ct.buf[0] == 'I' && + ct.buf[1] == 'N' && + ct.buf[2] == 'S' && + ct.buf[3] == 'E' && + ct.buf[4] == 'R' && + ct.buf[5] == 'T' } // Update is true if the command tag starts with "UPDATE". func (ct CommandTag) Update() bool { - return len(ct) >= 6 && - ct[0] == 'U' && - ct[1] == 'P' && - ct[2] == 'D' && - ct[3] == 'A' && - ct[4] == 'T' && - ct[5] == 'E' + return len(ct.buf) >= 6 && + ct.buf[0] == 'U' && + ct.buf[1] == 'P' && + ct.buf[2] == 'D' && + ct.buf[3] == 'A' && + ct.buf[4] == 'T' && + ct.buf[5] == 'E' } // Delete is true if the command tag starts with "DELETE". func (ct CommandTag) Delete() bool { - return len(ct) >= 6 && - ct[0] == 'D' && - ct[1] == 'E' && - ct[2] == 'L' && - ct[3] == 'E' && - ct[4] == 'T' && - ct[5] == 'E' + return len(ct.buf) >= 6 && + ct.buf[0] == 'D' && + ct.buf[1] == 'E' && + ct.buf[2] == 'L' && + ct.buf[3] == 'E' && + ct.buf[4] == 'T' && + ct.buf[5] == 'E' } // Select is true if the command tag starts with "SELECT". func (ct CommandTag) Select() bool { - return len(ct) >= 6 && - ct[0] == 'S' && - ct[1] == 'E' && - ct[2] == 'L' && - ct[3] == 'E' && - ct[4] == 'C' && - ct[5] == 'T' + return len(ct.buf) >= 6 && + ct.buf[0] == 'S' && + ct.buf[1] == 'E' && + ct.buf[2] == 'L' && + ct.buf[3] == 'E' && + ct.buf[4] == 'C' && + ct.buf[5] == 'T' } type StatementDescription struct { @@ -1076,13 +1078,13 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if err := pgConn.lock(); err != nil { - result.concludeCommand(nil, err) + result.concludeCommand(CommandTag{}, err) result.closed = true return result } if len(paramValues) > math.MaxUint16 { - result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.concludeCommand(CommandTag{}, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result @@ -1091,7 +1093,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by if ctx != context.Background() { select { case <-ctx.Done(): - result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) + result.concludeCommand(CommandTag{}, newContextAlreadyDoneError(ctx)) result.closed = true pgConn.unlock() return result @@ -1111,7 +1113,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() - result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) + result.concludeCommand(CommandTag{}, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() @@ -1124,14 +1126,14 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return CommandTag{}, err } if ctx != context.Background() { select { case <-ctx.Done(): pgConn.unlock() - return nil, newContextAlreadyDoneError(ctx) + return CommandTag{}, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1146,7 +1148,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm if err != nil { pgConn.asyncClose() pgConn.unlock() - return nil, &writeError{err: err, safeToRetry: n == 0} + return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} } // Read results @@ -1156,7 +1158,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1165,13 +1167,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := w.Write(msg.Data) if err != nil { pgConn.asyncClose() - return nil, err + return CommandTag{}, err } case *pgproto3.ReadyForQuery: pgConn.unlock() return commandTag, pgErr case *pgproto3.CommandComplete: - commandTag = CommandTag(msg.CommandTag) + commandTag = pgConn.makeCommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) } @@ -1184,14 +1186,14 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return CommandTag{}, err } defer pgConn.unlock() if ctx != context.Background() { select { case <-ctx.Done(): - return nil, newContextAlreadyDoneError(ctx) + return CommandTag{}, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1205,7 +1207,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() - return nil, &writeError{err: err, safeToRetry: n == 0} + return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} } // Send copy data @@ -1255,7 +1257,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1279,7 +1281,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() - return nil, err + return CommandTag{}, err } // Read results @@ -1288,14 +1290,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: return commandTag, pgErr case *pgproto3.CommandComplete: - commandTag = CommandTag(msg.CommandTag) + commandTag = pgConn.makeCommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) } @@ -1368,7 +1370,7 @@ func (mrr *MultiResultReader) NextResult() bool { return true case *pgproto3.CommandComplete: mrr.pgConn.resultReader = ResultReader{ - commandTag: CommandTag(msg.CommandTag), + commandTag: mrr.pgConn.makeCommandTag(msg.CommandTag), commandConcluded: true, closed: true, } @@ -1483,7 +1485,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for !rr.commandConcluded { _, err := rr.receiveMessage() if err != nil { - return nil, rr.err + return CommandTag{}, rr.err } } @@ -1491,7 +1493,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for { msg, err := rr.receiveMessage() if err != nil { - return nil, rr.err + return CommandTag{}, rr.err } switch msg := msg.(type) { @@ -1538,7 +1540,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error if err != nil { err = preferContextOverNetTimeoutError(rr.ctx, err) - rr.concludeCommand(nil, err) + rr.concludeCommand(CommandTag{}, err) rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { @@ -1552,11 +1554,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.RowDescription: rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: - rr.concludeCommand(CommandTag(msg.CommandTag), nil) + rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: - rr.concludeCommand(nil, nil) + rr.concludeCommand(CommandTag{}, nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(nil, ErrorResponseToPgError(msg)) + rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg)) } return msg, nil @@ -1659,6 +1661,13 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } +// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. +func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { + ct := make([]byte, len(buf)) + copy(ct, buf) + return CommandTag{buf: ct} +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning diff --git a/pgconn/pgconn_private_test.go b/pgconn/pgconn_private_test.go new file mode 100644 index 00000000..4368f717 --- /dev/null +++ b/pgconn/pgconn_private_test.go @@ -0,0 +1,41 @@ +package pgconn + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCommandTag(t *testing.T) { + t.Parallel() + + var tests = []struct { + commandTag CommandTag + rowsAffected int64 + isInsert bool + isUpdate bool + isDelete bool + isSelect bool + }{ + {commandTag: CommandTag{buf: []byte("INSERT 0 5")}, rowsAffected: 5, isInsert: true}, + {commandTag: CommandTag{buf: []byte("UPDATE 0")}, rowsAffected: 0, isUpdate: true}, + {commandTag: CommandTag{buf: []byte("UPDATE 1")}, rowsAffected: 1, isUpdate: true}, + {commandTag: CommandTag{buf: []byte("DELETE 0")}, rowsAffected: 0, isDelete: true}, + {commandTag: CommandTag{buf: []byte("DELETE 1")}, rowsAffected: 1, isDelete: true}, + {commandTag: CommandTag{buf: []byte("DELETE 1234567890")}, rowsAffected: 1234567890, isDelete: true}, + {commandTag: CommandTag{buf: []byte("SELECT 1")}, rowsAffected: 1, isSelect: true}, + {commandTag: CommandTag{buf: []byte("SELECT 99999999999")}, rowsAffected: 99999999999, isSelect: true}, + {commandTag: CommandTag{buf: []byte("CREATE TABLE")}, rowsAffected: 0}, + {commandTag: CommandTag{buf: []byte("ALTER TABLE")}, rowsAffected: 0}, + {commandTag: CommandTag{buf: []byte("DROP TABLE")}, rowsAffected: 0}, + } + + for i, tt := range tests { + ct := tt.commandTag + assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) + } +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index d1ba29d2..4d975f32 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -538,7 +538,7 @@ func TestConnExec(t *testing.T) { assert.Len(t, results, 1) assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) @@ -579,12 +579,12 @@ func TestConnExecMultipleQueries(t *testing.T) { assert.Len(t, results, 2) assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) assert.Len(t, results[1].Rows, 1) assert.Equal(t, "1", string(results[1].Rows[0][0])) @@ -741,7 +741,7 @@ func TestConnExecParams(t *testing.T) { } assert.Equal(t, 1, rowCount) commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Equal(t, "SELECT 1", commandTag.String()) assert.NoError(t, err) ensureConnValid(t, pgConn) @@ -840,7 +840,7 @@ func TestConnExecParamsCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(nil), commandTag) + assert.Equal(t, pgconn.CommandTag{}, commandTag) assert.True(t, pgconn.Timeout(err)) assert.ErrorIs(t, err, context.DeadlineExceeded) @@ -880,7 +880,7 @@ func TestConnExecParamsEmptySQL(t *testing.T) { defer closeConn(t, pgConn) result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read() - assert.Nil(t, result.CommandTag) + assert.Equal(t, pgconn.CommandTag{}, result.CommandTag) assert.Len(t, result.Rows, 0) assert.NoError(t, result.Err) @@ -907,7 +907,7 @@ func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { } assert.Equal(t, 1, rowCount) commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Equal(t, "SELECT 1", commandTag.String()) assert.NoError(t, err) ensureConnValid(t, pgConn) @@ -937,7 +937,7 @@ func TestConnExecPrepared(t *testing.T) { } assert.Equal(t, 1, rowCount) commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Equal(t, "SELECT 1", commandTag.String()) assert.NoError(t, err) ensureConnValid(t, pgConn) @@ -1025,7 +1025,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(nil), commandTag) + assert.Equal(t, pgconn.CommandTag{}, commandTag) assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { @@ -1069,7 +1069,7 @@ func TestConnExecPreparedEmptySQL(t *testing.T) { require.NoError(t, err) result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() - assert.Nil(t, result.CommandTag) + assert.Equal(t, pgconn.CommandTag{}, result.CommandTag) assert.Len(t, result.Rows, 0) assert.NoError(t, result.Err) @@ -1097,15 +1097,15 @@ func TestConnExecBatch(t *testing.T) { require.Len(t, results[0].Rows, 1) require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) require.Len(t, results[1].Rows, 1) require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) require.Len(t, results[2].Rows, 1) require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) } func TestConnExecBatchDeferredError(t *testing.T) { @@ -1199,7 +1199,7 @@ func TestConnExecBatchHuge(t *testing.T) { for i := range args { require.Len(t, results[i].Rows, 1) require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + assert.Equal(t, "SELECT 1", results[i].CommandTag.String()) } } @@ -1247,47 +1247,13 @@ func TestConnLocking(t *testing.T) { assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) ensureConnValid(t, pgConn) } -func TestCommandTag(t *testing.T) { - t.Parallel() - - var tests = []struct { - commandTag pgconn.CommandTag - rowsAffected int64 - isInsert bool - isUpdate bool - isDelete bool - isSelect bool - }{ - {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true}, - {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true}, - {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true}, - {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true}, - {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true}, - {commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true}, - {commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true}, - {commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true}, - {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0}, - } - - for i, tt := range tests { - ct := tt.commandTag - assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) - assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) - assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) - assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) - assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) - } -} - func TestConnOnNotice(t *testing.T) { t.Parallel() @@ -1546,7 +1512,7 @@ func TestConnCopyToCanceled(t *testing.T) { defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") assert.Error(t, err) - assert.Equal(t, pgconn.CommandTag(nil), res) + assert.Equal(t, pgconn.CommandTag{}, res) assert.True(t, pgConn.IsClosed()) select { @@ -1571,7 +1537,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(err)) - assert.Equal(t, pgconn.CommandTag(nil), res) + assert.Equal(t, pgconn.CommandTag{}, res) ensureConnValid(t, pgConn) } @@ -1692,7 +1658,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(err)) - assert.Equal(t, pgconn.CommandTag(nil), ct) + assert.Equal(t, pgconn.CommandTag{}, ct) ensureConnValid(t, pgConn) } @@ -2014,7 +1980,7 @@ func TestHijackAndConstruct(t *testing.T) { assert.Len(t, results, 1) assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) diff --git a/pgxpool/batch_results.go b/pgxpool/batch_results.go index 8bec35cb..aa1d609d 100644 --- a/pgxpool/batch_results.go +++ b/pgxpool/batch_results.go @@ -10,7 +10,7 @@ type errBatchResults struct { } func (br errBatchResults) Exec() (pgconn.CommandTag, error) { - return nil, br.err + return pgconn.CommandTag{}, br.err } func (br errBatchResults) Query() (pgx.Rows, error) { @@ -18,7 +18,7 @@ func (br errBatchResults) Query() (pgx.Rows, error) { } func (br errBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return nil, br.err + return pgconn.CommandTag{}, br.err } func (br errBatchResults) QueryRow() pgx.Row { diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index c6f3b77b..7b9f9f29 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -27,7 +27,7 @@ type execer interface { func testExec(t *testing.T, db execer) { results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'") require.NoError(t, err) - assert.EqualValues(t, "SET", results) + assert.EqualValues(t, "SET", results.String()) } type queryer interface { diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 41fb4d5b..30d02879 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -470,7 +470,7 @@ func (p *Pool) Stat() *Stat { func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } defer c.Release() @@ -527,7 +527,7 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg func (p *Pool) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } defer c.Release() diff --git a/pgxpool/rows.go b/pgxpool/rows.go index 0c97dc91..f3f24649 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -12,7 +12,7 @@ type errRows struct { func (errRows) Close() {} func (e errRows) Err() error { return e.err } -func (errRows) CommandTag() pgconn.CommandTag { return nil } +func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } func (errRows) Next() bool { return false } func (e errRows) Scan(dest ...interface{}) error { return e.err } diff --git a/query_test.go b/query_test.go index c85802b2..2f8975ac 100644 --- a/query_test.go +++ b/query_test.go @@ -45,7 +45,7 @@ func TestConnQueryScan(t *testing.T) { t.Fatalf("conn.Query failed: %v", err) } - assert.Equal(t, "SELECT 10", string(rows.CommandTag())) + assert.Equal(t, "SELECT 10", rows.CommandTag().String()) if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") @@ -79,7 +79,7 @@ func TestConnQueryWithoutResultSetCommandTag(t *testing.T) { assert.NoError(t, err) rows.Close() assert.NoError(t, rows.Err()) - assert.Equal(t, "CREATE TABLE", string(rows.CommandTag())) + assert.Equal(t, "CREATE TABLE", rows.CommandTag().String()) } func TestConnQueryScanWithManyColumns(t *testing.T) { @@ -1139,7 +1139,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes if err != nil { t.Fatal(err) } - if string(commandTag) != "INSERT 0 1" { + if commandTag.String() != "INSERT 0 1" { t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag) } @@ -1976,7 +1976,7 @@ func TestConnQueryFuncAbort(t *testing.T) { }, ) require.EqualError(t, err, "abort") - require.Nil(t, ct) + require.Equal(t, pgconn.CommandTag{}, ct) }) } diff --git a/tx.go b/tx.go index 3ed0ca67..6b85b303 100644 --- a/tx.go +++ b/tx.go @@ -235,7 +235,7 @@ func (tx *dbTx) Commit(ctx context.Context) error { } return err } - if string(commandTag) == "ROLLBACK" { + if commandTag.String() == "ROLLBACK" { return ErrTxCommitRollback } @@ -296,7 +296,7 @@ func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) R // QueryFunc delegates to the underlying *Conn. func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if tx.closed { - return nil, ErrTxClosed + return pgconn.CommandTag{}, ErrTxClosed } return tx.conn.QueryFunc(ctx, sql, args, scans, f) @@ -380,7 +380,7 @@ func (sp *dbSavepoint) Rollback(ctx context.Context) error { // Exec delegates to the underlying Tx func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { if sp.closed { - return nil, ErrTxClosed + return pgconn.CommandTag{}, ErrTxClosed } return sp.tx.Exec(ctx, sql, arguments...) @@ -415,7 +415,7 @@ func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...interfa // QueryFunc delegates to the underlying Tx. func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if sp.closed { - return nil, ErrTxClosed + return pgconn.CommandTag{}, ErrTxClosed } return sp.tx.QueryFunc(ctx, sql, args, scans, f)