diff --git a/batch.go b/batch.go index 2f85b852..c3c9a36b 100644 --- a/batch.go +++ b/batch.go @@ -47,6 +47,8 @@ type batchResults struct { conn *Conn mrr *pgconn.MultiResultReader err error + b *Batch + ix int } // Exec reads the results from the next query in the batch as if the query has been sent with Exec. @@ -55,20 +57,52 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { return nil, br.err } + query, arguments, _ := br.nextQueryAndArgs() + if !br.mrr.NextResult() { err := br.mrr.Close() if err == nil { err = errors.New("no result") } + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{} { + "sql": query, + "args": logQueryArgs(arguments), + "err": err, + }) + } return nil, err } - return br.mrr.ResultReader().Close() + commandTag, err := br.mrr.ResultReader().Close() + + if err != nil { + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ + "sql": query, + "args": logQueryArgs(arguments), + "err": err, + }) + } + } else if br.conn.shouldLog(LogLevelInfo) { + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{} { + "sql": query, + "args": logQueryArgs(arguments), + "commandTag": commandTag, + }) + } + + return commandTag, err } // Query reads the results from the next query in the batch as if the query has been sent with Query. func (br *batchResults) Query() (Rows, error) { - rows := br.conn.getRows(br.ctx, "batch query", nil) + query, arguments, ok := br.nextQueryAndArgs() + if !ok { + query = "batch query" + } + + rows := br.conn.getRows(br.ctx, query, arguments) if br.err != nil { rows.err = br.err @@ -82,6 +116,15 @@ func (br *batchResults) Query() (Rows, error) { rows.err = errors.New("no result") } rows.closed = true + + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{} { + "sql": query, + "args": logQueryArgs(arguments), + "err": rows.err, + }) + } + return rows, rows.err } @@ -103,5 +146,31 @@ func (br *batchResults) Close() error { return br.err } + // log any queries that haven't yet been logged by Exec or Query + for { + query, args, ok := br.nextQueryAndArgs() + if !ok { + break + } + + if br.conn.shouldLog(LogLevelInfo) { + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{} { + "sql": query, + "args": logQueryArgs(args), + }) + } + } + return br.mrr.Close() } + +func (br *batchResults) nextQueryAndArgs() (query string, args []interface{}, ok bool) { + if br.b != nil && br.ix < len(br.b.items) { + bi := br.b.items[br.ix] + query = bi.query + args = bi.arguments + ok = true + br.ix++ + } + return +} diff --git a/batch_test.go b/batch_test.go index a41d2a16..f487c52a 100644 --- a/batch_test.go +++ b/batch_test.go @@ -637,3 +637,91 @@ func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) { err := br.Close() require.NoError(t, err) } + +func TestLogBatchStatementsOnExec(t *testing.T) { + l1 := &testLogger{} + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.Logger = l1 + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + l1.logs = l1.logs[0:0] // Clear logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("create table foo (id bigint)") + batch.Queue("drop table foo") + + br := conn.SendBatch(context.Background(), batch) + + _, err := br.Exec() + if err != nil { + t.Fatalf("Unexpected error creating table: %v", err) + } + + _, err = br.Exec() + if err != nil { + t.Fatalf("Unexpected error dropping table: %v", err) + } + + if len(l1.logs) != 2 { + t.Fatalf("Expected two log entries but got %d", len(l1.logs)) + } + + if l1.logs[0].msg != "BatchResult.Exec" { + t.Errorf("Expected first log message to be 'BatchResult.Exec' but was '%s", l1.logs[0].msg) + } + + if l1.logs[0].data["sql"] != "create table foo (id bigint)" { + t.Errorf("Expected the first query to be 'create table foo (id bigint)' but was '%s'", l1.logs[0].data["sql"]) + } + + if l1.logs[1].msg != "BatchResult.Exec" { + t.Errorf("Expected second log message to be 'BatchResult.Exec' but was '%s", l1.logs[1].msg) + } + + if l1.logs[1].data["sql"] != "drop table foo" { + t.Errorf("Expected the second query to be 'drop table foo' but was '%s'", l1.logs[1].data["sql"]) + } +} + +func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { + l1 := &testLogger{} + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.Logger = l1 + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + l1.logs = l1.logs[0:0] // Clear logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("select generate_series(1,$1)", 100) + batch.Queue("select 1 = 1;") + + br := conn.SendBatch(context.Background(), batch) + + if err := br.Close(); err != nil { + t.Fatalf("Unexpected batch error: %v", err) + } + + if len(l1.logs) != 2 { + t.Fatalf("Expected 2 log statements but found %d", len(l1.logs)) + } + + if l1.logs[0].msg != "BatchResult.Close" { + t.Errorf("Expected first log statement to be 'BatchResult.Close' but was %s", l1.logs[0].msg) + } + + if l1.logs[0].data["sql"] != "select generate_series(1,$1)" { + t.Errorf("Expected first query to be 'select generate_series(1,$1)' but was '%s'", l1.logs[0].data["sql"]) + } + + if l1.logs[1].msg != "BatchResult.Close" { + t.Errorf("Expected second log statement to be 'BatchResult.Close' but was %s", l1.logs[1].msg) + } + + if l1.logs[1].data["sql"] != "select 1 = 1;" { + t.Errorf("Expected second query to be 'select 1 = 1;' but was '%s'", l1.logs[1].data["sql"]) + } +} diff --git a/conn.go b/conn.go index 014e1e2d..0acd5c96 100644 --- a/conn.go +++ b/conn.go @@ -770,6 +770,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { ctx: ctx, conn: c, mrr: mrr, + b: b, + ix: 0, } }