diff --git a/CHANGELOG.md b/CHANGELOG.md index bd4001dd..b627ddda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ * Types now have Valid boolean field instead of Status byte. This matches database/sql pattern. * Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. +# 4.15.0 (February 7, 2022) + +* Upgrade to pgconn v1.11.0 +* Upgrade to pgtype v1.10.0 +* Upgrade puddle to v1.2.1 +* Make BatchResults.Close safe to be called multiple times # 4.14.1 (November 28, 2021) diff --git a/batch.go b/batch.go index 18ee8339..caa5a02f 100644 --- a/batch.go +++ b/batch.go @@ -3,6 +3,7 @@ package pgx import ( "context" "errors" + "fmt" "github.com/jackc/pgx/v5/pgconn" ) @@ -46,17 +47,18 @@ type BatchResults interface { // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. - // In this case the underlying connection will have been closed. + // In this case the underlying connection will have been closed. Close is safe to call multiple times. Close() error } type batchResults struct { - ctx context.Context - conn *Conn - mrr *pgconn.MultiResultReader - err error - b *Batch - ix int + ctx context.Context + conn *Conn + mrr *pgconn.MultiResultReader + err error + b *Batch + ix int + closed bool } // Exec reads the results from the next query in the batch as if the query has been sent with Exec. @@ -64,6 +66,9 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { if br.err != nil { return nil, br.err } + if br.closed { + return nil, fmt.Errorf("batch already closed") + } query, arguments, _ := br.nextQueryAndArgs() @@ -114,6 +119,11 @@ func (br *batchResults) Query() (Rows, error) { return &connRows{err: br.err, closed: true}, br.err } + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &connRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + } + rows := br.conn.getRows(br.ctx, query, arguments) if !br.mrr.NextResult() { @@ -140,6 +150,10 @@ func (br *batchResults) Query() (Rows, error) { // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { + if br.closed { + return nil, fmt.Errorf("batch already closed") + } + rows, err := br.Query() if err != nil { return nil, err @@ -179,6 +193,11 @@ func (br *batchResults) Close() error { return br.err } + if br.closed { + return nil + } + br.closed = true + // log any queries that haven't yet been logged by Exec or Query for { query, args, ok := br.nextQueryAndArgs() diff --git a/batch_test.go b/batch_test.go index dc57b379..32901830 100644 --- a/batch_test.go +++ b/batch_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "context" + "errors" "os" "testing" @@ -33,6 +34,7 @@ func TestConnSendBatch(t *testing.T) { batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) batch.Queue("select id, description, amount from ledger order by id") batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select * from ledger where false") batch.Queue("select sum(amount) from ledger") br := conn.SendBatch(context.Background(), batch) @@ -127,6 +129,11 @@ func TestConnSendBatch(t *testing.T) { t.Error(err) } + err = br.QueryRow().Scan(&id, &description, &amount) + if !errors.Is(err, pgx.ErrNoRows) { + t.Errorf("expected pgx.ErrNoRows but got: %v", err) + } + err = br.QueryRow().Scan(&amount) if err != nil { t.Error(err) diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 20586b81..427e0ea9 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -979,3 +979,55 @@ func TestCreateMinPoolReturnsFirstError(t *testing.T) { require.True(t, connectAttempts >= 5, "Expected %d got %d", 5, connectAttempts) require.ErrorIs(t, err, mockErr) } + +func TestPoolSendBatchBatchCloseTwice(t *testing.T) { + t.Parallel() + + pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + errChan := make(chan error) + testCount := 5000 + + for i := 0; i < testCount; i++ { + go func() { + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("select 2") + + br := pool.SendBatch(context.Background(), batch) + defer br.Close() + + var err error + var n int32 + err = br.QueryRow().Scan(&n) + if err != nil { + errChan <- err + return + } + if n != 1 { + errChan <- fmt.Errorf("expected 1 got %v", n) + return + } + + err = br.QueryRow().Scan(&n) + if err != nil { + errChan <- err + return + } + if n != 2 { + errChan <- fmt.Errorf("expected 2 got %v", n) + return + } + + err = br.Close() + errChan <- err + }() + } + + for i := 0; i < testCount; i++ { + err := <-errChan + assert.NoError(t, err) + } +} diff --git a/rows.go b/rows.go index 06e0a933..aa5310fe 100644 --- a/rows.go +++ b/rows.go @@ -42,10 +42,13 @@ type Rows interface { // Scan reads the values from the current row into dest values positionally. // dest can include pointers to core types, values implementing the Scanner - // interface, and nil. nil will skip the value entirely. + // interface, and nil. nil will skip the value entirely. It is an error to + // call Scan without first calling Next() and checking that it returned true. Scan(dest ...interface{}) error - // Values returns the decoded row values. + // Values returns the decoded row values. As with Scan(), it is an error to + // call Values without first calling Next() and checking that it returned + // true. Values() ([]interface{}, error) // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid until the next Next