From a0028cbd0de2eb3c332aefff381bc7a852abb62c Mon Sep 17 00:00:00 2001 From: Matt Schultz Date: Tue, 16 Mar 2021 14:59:47 -0500 Subject: [PATCH] Handle SendBatch calls on closed transactions with null connections. This was previously panicking due to a null pointer exception as exposed in the provided unit test. --- batch.go | 30 ++++++++++++++---------------- tx_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/batch.go b/batch.go index 9f787f99..f412e6f1 100644 --- a/batch.go +++ b/batch.go @@ -70,10 +70,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { err = errors.New("no result") } if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{} { - "sql": query, + br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ + "sql": query, "args": logQueryArgs(arguments), - "err": err, + "err": err, }) } return nil, err @@ -90,9 +90,9 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { }) } } else if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{} { - "sql": query, - "args": logQueryArgs(arguments), + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{}{ + "sql": query, + "args": logQueryArgs(arguments), "commandTag": commandTag, }) } @@ -107,14 +107,12 @@ func (br *batchResults) Query() (Rows, error) { query = "batch query" } - rows := br.conn.getRows(br.ctx, query, arguments) - if br.err != nil { - rows.err = br.err - rows.closed = true - return rows, br.err + return &connRows{err: br.err, closed: true}, br.err } + rows := br.conn.getRows(br.ctx, query, arguments) + if !br.mrr.NextResult() { rows.err = br.mrr.Close() if rows.err == nil { @@ -123,10 +121,10 @@ func (br *batchResults) Query() (Rows, error) { rows.closed = true if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{} { - "sql": query, + br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{}{ + "sql": query, "args": logQueryArgs(arguments), - "err": rows.err, + "err": rows.err, }) } @@ -159,8 +157,8 @@ func (br *batchResults) Close() error { } if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{} { - "sql": query, + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{}{ + "sql": query, "args": logQueryArgs(args), }) } diff --git a/tx_test.go b/tx_test.go index 7634ea59..e9830d32 100644 --- a/tx_test.go +++ b/tx_test.go @@ -586,3 +586,36 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { require.NoError(t, err) require.EqualValues(t, 2, n) } + +func TestTxSendBatchClosed(t *testing.T) { + t.Parallel() + + db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, db) + + tx, err := db.Begin(context.Background()) + require.NoError(t, err) + defer tx.Rollback(context.Background()) + + err = tx.Commit(context.Background()) + require.NoError(t, err) + + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("select 2") + batch.Queue("select 3") + + br := tx.SendBatch(context.Background(), batch) + defer br.Close() + + var n int + + _, err = br.Exec() + require.Error(t, err) + + err = br.QueryRow().Scan(&n) + require.Error(t, err) + + _, err = br.Query() + require.Error(t, err) +}