Handle SendBatch calls on closed transactions with null connections. This was previously panicking due to a null pointer exception as exposed in the provided unit test.

pull/975/head
Matt Schultz 2021-03-16 14:59:47 -05:00 committed by Jack Christensen
parent 495d482f20
commit a0028cbd0d
2 changed files with 47 additions and 16 deletions

View File

@ -70,10 +70,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
err = errors.New("no result")
}
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{} {
"sql": query,
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{
"sql": query,
"args": logQueryArgs(arguments),
"err": err,
"err": err,
})
}
return nil, err
@ -90,9 +90,9 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
})
}
} else if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{} {
"sql": query,
"args": logQueryArgs(arguments),
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{}{
"sql": query,
"args": logQueryArgs(arguments),
"commandTag": commandTag,
})
}
@ -107,14 +107,12 @@ func (br *batchResults) Query() (Rows, error) {
query = "batch query"
}
rows := br.conn.getRows(br.ctx, query, arguments)
if br.err != nil {
rows.err = br.err
rows.closed = true
return rows, br.err
return &connRows{err: br.err, closed: true}, br.err
}
rows := br.conn.getRows(br.ctx, query, arguments)
if !br.mrr.NextResult() {
rows.err = br.mrr.Close()
if rows.err == nil {
@ -123,10 +121,10 @@ func (br *batchResults) Query() (Rows, error) {
rows.closed = true
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{} {
"sql": query,
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{}{
"sql": query,
"args": logQueryArgs(arguments),
"err": rows.err,
"err": rows.err,
})
}
@ -159,8 +157,8 @@ func (br *batchResults) Close() error {
}
if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{} {
"sql": query,
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{}{
"sql": query,
"args": logQueryArgs(args),
})
}

View File

@ -586,3 +586,36 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, 2, n)
}
func TestTxSendBatchClosed(t *testing.T) {
t.Parallel()
db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, db)
tx, err := db.Begin(context.Background())
require.NoError(t, err)
defer tx.Rollback(context.Background())
err = tx.Commit(context.Background())
require.NoError(t, err)
batch := &pgx.Batch{}
batch.Queue("select 1")
batch.Queue("select 2")
batch.Queue("select 3")
br := tx.SendBatch(context.Background(), batch)
defer br.Close()
var n int
_, err = br.Exec()
require.Error(t, err)
err = br.QueryRow().Scan(&n)
require.Error(t, err)
_, err = br.Query()
require.Error(t, err)
}