diff --git a/internal/testtools/assert.go b/internal/testtools/assert.go index 8763f1a..e10c793 100644 --- a/internal/testtools/assert.go +++ b/internal/testtools/assert.go @@ -47,6 +47,19 @@ func AssertErrContains(t *testing.T, err error, substrs ...string) { } } +// AssertContains will check if the input text contains +// all the substrs specified on the substrs argument or +// fail with an appropriate error message. +func AssertContains(t *testing.T, str string, substrs ...string) { + for _, substr := range substrs { + require.True(t, + strings.Contains(str, substr), + "missing substring '%s' in error message: '%s'", + substr, str, + ) + } +} + // AssertApproxDuration checks if the durations v1 and v2 are close up to the tolerance specified. // The format and args slice can be used for generating an appropriate error message if they are not. func AssertApproxDuration(t *testing.T, tolerance time.Duration, v1, v2 time.Duration, format string, args ...interface{}) { diff --git a/internal_mocks.go b/internal_mocks.go index 013e3ac..7ef9b9e 100644 --- a/internal_mocks.go +++ b/internal_mocks.go @@ -62,6 +62,26 @@ func (m mockRows) Columns() ([]string, error) { return m.ColumnsFn() } +// mockResult mocks the ksql.Result interface +type mockResult struct { + LastInsertIdFn func() (int64, error) + RowsAffectedFn func() (int64, error) +} + +func (m mockResult) LastInsertId() (int64, error) { + if m.LastInsertIdFn != nil { + return m.LastInsertIdFn() + } + return 0, nil +} + +func (m mockResult) RowsAffected() (int64, error) { + if m.RowsAffectedFn != nil { + return m.RowsAffectedFn() + } + return 0, nil +} + // mockTx mocks the ksql.Tx interface type mockTx struct { DBAdapter diff --git a/ksql.go b/ksql.go index 6181157..45b6a10 100644 --- a/ksql.go +++ b/ksql.go @@ -148,7 +148,7 @@ func (c DB) Query( records interface{}, query string, params ...interface{}, -) error { +) (err error) { slicePtr := reflect.ValueOf(records) slicePtrType := slicePtr.Type() if slicePtrType.Kind() != reflect.Ptr { @@ -187,6 +187,8 @@ func (c DB) Query( query = selectPrefix + query } + defer ctxLog(ctx, query, params, &err) + rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return fmt.Errorf("error running query: %w", err) @@ -242,7 +244,7 @@ func (c DB) QueryOne( record interface{}, query string, params ...interface{}, -) error { +) (err error) { v := reflect.ValueOf(record) t := v.Type() if t.Kind() != reflect.Ptr { @@ -277,6 +279,8 @@ func (c DB) QueryOne( query = selectPrefix + query } + defer ctxLog(ctx, query, params, &err) + rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return fmt.Errorf("error running query: %w", err) @@ -423,10 +427,10 @@ func (c DB) Insert( ctx context.Context, table Table, record interface{}, -) error { +) (err error) { v := reflect.ValueOf(record) t := v.Type() - if err := assertStructPtr(t); err != nil { + if err = assertStructPtr(t); err != nil { return fmt.Errorf( "KSQL: expected record to be a pointer to struct, but got: %T", record, @@ -451,6 +455,8 @@ func (c DB) Insert( return err } + defer ctxLog(ctx, query, params, &err) + switch table.insertMethodFor(c.dialect) { case sqldialect.InsertWithReturning, sqldialect.InsertWithOutput: err = c.insertReturningIDs(ctx, query, params, scanValues, table.idColumns) @@ -581,7 +587,7 @@ func (c DB) Delete( ctx context.Context, table Table, idOrRecord interface{}, -) error { +) (err error) { if err := table.validate(); err != nil { return fmt.Errorf("can't delete from ksql.Table: %w", err) } @@ -595,6 +601,8 @@ func (c DB) Delete( var params []interface{} query, params = buildDeleteQuery(c.dialect, table, idMap) + defer ctxLog(ctx, query, params, &err) + result, err := c.db.ExecContext(ctx, query, params...) if err != nil { return err @@ -655,7 +663,7 @@ func (c DB) Patch( ctx context.Context, table Table, record interface{}, -) error { +) (err error) { v := reflect.ValueOf(record) t := v.Type() tStruct := t @@ -680,6 +688,8 @@ func (c DB) Patch( return err } + defer ctxLog(ctx, query, params, &err) + result, err := c.db.ExecContext(ctx, query, params...) if err != nil { return err diff --git a/ksql_test.go b/ksql_test.go index ba997ef..61405a8 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -1,6 +1,8 @@ package ksql import ( + "context" + "errors" "fmt" "io" "testing" @@ -98,3 +100,243 @@ func TestClose(t *testing.T) { tt.AssertErrContains(t, err, "fakeCloseErrMsg") }) } + +func TestInjectLogger(t *testing.T) { + ctx := context.Background() + + tests := []struct { + desc string + methodCall func(ctx context.Context, db Provider) error + queryErr error + + expectLoggedQueryToContain []string + expectLoggedParams map[interface{}]bool + expectLoggedErrToContain []string + }{ + { + desc: "should work for the Query function", + methodCall: func(ctx context.Context, db Provider) error { + var row []struct { + Count int `ksql:"count"` + } + return db.Query(ctx, &row, `FROM users WHERE type = $1 AND age < $2`, "fakeType", 42) + }, + + expectLoggedQueryToContain: []string{"SELECT", "count", "type = $1"}, + expectLoggedParams: map[interface{}]bool{"fakeType": true, 42: true}, + }, + { + desc: "should work for the Query function when an error is returned", + methodCall: func(ctx context.Context, db Provider) error { + var row []struct { + Count int `ksql:"count"` + } + return db.Query(ctx, &row, `FROM users WHERE type = $1 AND age < $2`, "fakeType", 42) + }, + queryErr: errors.New("fakeErrMsg"), + + expectLoggedQueryToContain: []string{"SELECT", "count", "type = $1"}, + expectLoggedParams: map[interface{}]bool{"fakeType": true, 42: true}, + expectLoggedErrToContain: []string{"fakeErrMsg"}, + }, + { + desc: "should work for the QueryOne function", + methodCall: func(ctx context.Context, db Provider) error { + var row struct { + Count int `ksql:"count"` + } + return db.QueryOne(ctx, &row, `FROM users WHERE type = $1 AND age < $2`, "fakeType", 42) + }, + + expectLoggedQueryToContain: []string{"SELECT", "count", "type = $1"}, + expectLoggedParams: map[interface{}]bool{"fakeType": true, 42: true}, + }, + { + desc: "should work for the QueryOne function when an error is returned", + methodCall: func(ctx context.Context, db Provider) error { + var row struct { + Count int `ksql:"count"` + } + return db.QueryOne(ctx, &row, `FROM users WHERE type = $1 AND age < $2`, "fakeType", 42) + }, + queryErr: errors.New("fakeErrMsg"), + + expectLoggedQueryToContain: []string{"SELECT", "count", "type = $1"}, + expectLoggedParams: map[interface{}]bool{"fakeType": true, 42: true}, + expectLoggedErrToContain: []string{"fakeErrMsg"}, + }, + { + desc: "should work for the Insert function", + methodCall: func(ctx context.Context, db Provider) error { + fakeRecord := struct { + ID int `ksql:"id"` + Count int `ksql:"count"` + }{ + ID: 42, + Count: 43, + } + return db.Insert(ctx, NewTable("fakeTable"), &fakeRecord) + }, + + expectLoggedQueryToContain: []string{"INSERT", "fakeTable", `"id"`}, + expectLoggedParams: map[interface{}]bool{42: true, 43: true}, + }, + { + desc: "should work for the Insert function when an error is returned", + methodCall: func(ctx context.Context, db Provider) error { + fakeRecord := struct { + ID int `ksql:"id"` + Count int `ksql:"count"` + }{ + ID: 42, + Count: 43, + } + return db.Insert(ctx, NewTable("fakeTable"), &fakeRecord) + }, + queryErr: errors.New("fakeErrMsg"), + + expectLoggedQueryToContain: []string{"INSERT", "fakeTable", `"id"`}, + expectLoggedParams: map[interface{}]bool{42: true, 43: true}, + expectLoggedErrToContain: []string{"fakeErrMsg"}, + }, + { + desc: "should work for the Patch function", + methodCall: func(ctx context.Context, db Provider) error { + fakeRecord := struct { + ID int `ksql:"id"` + Count int `ksql:"count"` + }{ + ID: 42, + Count: 43, + } + return db.Patch(ctx, NewTable("fakeTable"), &fakeRecord) + }, + + expectLoggedQueryToContain: []string{"UPDATE", "fakeTable", `"id"`}, + expectLoggedParams: map[interface{}]bool{42: true, 43: true}, + }, + { + desc: "should work for the Patch function when an error is returned", + methodCall: func(ctx context.Context, db Provider) error { + fakeRecord := struct { + ID int `ksql:"id"` + Count int `ksql:"count"` + }{ + ID: 42, + Count: 43, + } + return db.Patch(ctx, NewTable("fakeTable"), &fakeRecord) + }, + queryErr: errors.New("fakeErrMsg"), + + expectLoggedQueryToContain: []string{"UPDATE", "fakeTable", `"id"`}, + expectLoggedParams: map[interface{}]bool{42: true, 43: true}, + expectLoggedErrToContain: []string{"fakeErrMsg"}, + }, + { + desc: "should work for the Delete function", + methodCall: func(ctx context.Context, db Provider) error { + fakeRecord := struct { + ID int `ksql:"id"` + Count int `ksql:"count"` + }{ + ID: 42, + Count: 43, + } + return db.Delete(ctx, NewTable("fakeTable"), &fakeRecord) + }, + + expectLoggedQueryToContain: []string{"DELETE", "fakeTable", `"id"`}, + expectLoggedParams: map[interface{}]bool{42: true}, + }, + { + desc: "should work for the Delete function when an error is returned", + methodCall: func(ctx context.Context, db Provider) error { + fakeRecord := struct { + ID int `ksql:"id"` + Count int `ksql:"count"` + }{ + ID: 42, + Count: 43, + } + return db.Delete(ctx, NewTable("fakeTable"), &fakeRecord) + }, + queryErr: errors.New("fakeErrMsg"), + + expectLoggedQueryToContain: []string{"DELETE", "fakeTable", `"id"`}, + expectLoggedParams: map[interface{}]bool{42: true}, + expectLoggedErrToContain: []string{"fakeErrMsg"}, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + var inputQuery string + var inputParams []interface{} + numRows := 1 + c := DB{ + dialect: sqldialect.SupportedDialects["postgres"], + db: mockDBAdapter{ + QueryContextFn: func(ctx context.Context, query string, params ...interface{}) (Rows, error) { + inputQuery = query + inputParams = params + + return mockRows{ + ScanFn: func(args ...interface{}) error { + return nil + }, + // Make sure this mock will return a single row + // for the purposes of this test: + NextFn: func() bool { + numRows-- + return numRows >= 0 + }, + ColumnsFn: func() ([]string, error) { return []string{"count"}, nil }, + }, test.queryErr + }, + ExecContextFn: func(ctx context.Context, query string, params ...interface{}) (Result, error) { + inputQuery = query + inputParams = params + + return mockResult{ + // Make sure this mock will return a single row + // for the purposes of this test: + RowsAffectedFn: func() (int64, error) { + return 1, nil + }, + }, test.queryErr + }, + }, + } + + var loggedQuery string + var loggedParams []interface{} + var loggedErr error + ctx := InjectLogger(ctx, func(ctx context.Context, values LogValues) { + loggedQuery = values.Query + loggedParams = values.Params + loggedErr = values.Err + }) + + err := test.methodCall(ctx, c) + if test.expectLoggedErrToContain != nil { + tt.AssertErrContains(t, err, test.expectLoggedErrToContain...) + tt.AssertErrContains(t, loggedErr, test.expectLoggedErrToContain...) + } else { + tt.AssertNoErr(t, err) + tt.AssertEqual(t, loggedErr, nil) + } + + tt.AssertEqual(t, loggedQuery, inputQuery) + tt.AssertEqual(t, loggedParams, inputParams) + + tt.AssertContains(t, loggedQuery, test.expectLoggedQueryToContain...) + + paramsMap := map[interface{}]bool{} + for _, param := range loggedParams { + paramsMap[param] = true + } + tt.AssertEqual(t, paramsMap, test.expectLoggedParams) + }) + } +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..0df561c --- /dev/null +++ b/logger.go @@ -0,0 +1,108 @@ +package ksql + +import ( + "context" + "encoding/json" + "fmt" +) + +// This variable is only used during tests: +var logPrinter = fmt.Println + +var _ LoggerFn = ErrorLogger + +// ErrorLogger is a builtin logger that can be passed to +// ksql.InjectLogger() to only log when an error occurs. +// +// Note: Only errors that happen after KSQL sends the +// query to the backend adapter will be logged. +// +// Validation errors will just return an error as usual. +func ErrorLogger(ctx context.Context, values LogValues) { + if values.Err == nil { + return + } + + Logger(ctx, values) +} + +var _ LoggerFn = Logger + +// Logger is a builtin logger that can be passed to +// ksql.InjectLogger() to log every query and query errors. +// +// Note: Only errors that happen after KSQL sends the +// query to the backend adapter will be logged. +// +// Validation errors will just return an error as usual. +func Logger(ctx context.Context, values LogValues) { + m := map[string]interface{}{ + "query": values.Query, + "params": values.Params, + } + + if values.Err != nil { + m["error"] = values.Err.Error() + } + + b, _ := json.Marshal(m) + logPrinter(string(b)) +} + +type loggerKey struct{} + +// LogValues is the argument type of ksql.LoggerFn which contains +// the data available for logging whenever a query is executed. +type LogValues struct { + Query string + Params []interface{} + Err error +} + +// LoggerFn is a the type of function received as +// argument of the ksql.InjectLogger function. +type LoggerFn func(ctx context.Context, values LogValues) + +type loggerFn func(ctx context.Context, query string, params []interface{}, err error) + +// InjectLogger is a debugging tool that allows the user to force +// KSQL to log the query, query params and error response whenever +// a query is executed. +// +// Example Usage: +// +// ctx = ksql.InjectLogger(ctx, ksql.Logger) +// +// var user User +// db.Insert(ctx, usersTable, &user) +// +// user.Name = "NewName" +// db.Patch(ctx, usersTable, &user) +// +// var users []User +// db.Query(ctx, &users, someQuery, someParams...) +// db.QueryOne(ctx, &user, someQuery, someParams...) +// +// db.Delete(ctx, usersTable, user.ID) +// +func InjectLogger( + ctx context.Context, + logFn LoggerFn, +) context.Context { + return context.WithValue(ctx, loggerKey{}, loggerFn(func(ctx context.Context, query string, params []interface{}, err error) { + logFn(ctx, LogValues{ + Query: query, + Params: params, + Err: err, + }) + })) +} + +func ctxLog(ctx context.Context, query string, params []interface{}, err *error) { + l := ctx.Value(loggerKey{}) + if l == nil { + return + } + + l.(loggerFn)(ctx, query, params, *err) +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..6ed26e6 --- /dev/null +++ b/logger_test.go @@ -0,0 +1,107 @@ +package ksql + +import ( + "context" + "errors" + "fmt" + "testing" + + tt "github.com/vingarcia/ksql/internal/testtools" +) + +func TestCtxLog(t *testing.T) { + ctx := context.Background() + + defer func() { + logPrinter = fmt.Println + }() + + t.Run("should not log anything nor panic when the logger is not injected", func(t *testing.T) { + var printedArgs []interface{} + logPrinter = func(args ...interface{}) (n int, err error) { + printedArgs = args + return 0, nil + } + + panicPayload := tt.PanicHandler(func() { + ctxLog(ctx, "fakeQuery", []interface{}{}, nil) + }) + tt.AssertEqual(t, panicPayload, nil) + tt.AssertEqual(t, printedArgs, []interface{}(nil)) + }) + +} + +func TestBuiltinLoggers(t *testing.T) { + ctx := context.Background() + + defer func() { + logPrinter = fmt.Println + }() + + t.Run("Logger", func(t *testing.T) { + t.Run("with no errors", func(t *testing.T) { + var printedArgs []interface{} + logPrinter = func(args ...interface{}) (n int, err error) { + printedArgs = args + return 0, nil + } + + Logger(ctx, LogValues{ + Query: "FakeQuery", + Params: []interface{}{"FakeParam"}, + }) + + tt.AssertContains(t, fmt.Sprint(printedArgs...), "FakeQuery", "FakeParam") + }) + + t.Run("with errors", func(t *testing.T) { + var printedArgs []interface{} + logPrinter = func(args ...interface{}) (n int, err error) { + printedArgs = args + return 0, nil + } + + Logger(ctx, LogValues{ + Query: "FakeQuery", + Params: []interface{}{"FakeParam"}, + Err: errors.New("fakeErrMsg"), + }) + + tt.AssertContains(t, fmt.Sprint(printedArgs...), "FakeQuery", "FakeParam", "fakeErrMsg") + }) + }) + + t.Run("ErrorsLogger", func(t *testing.T) { + t.Run("with no errors", func(t *testing.T) { + var printedArgs []interface{} + logPrinter = func(args ...interface{}) (n int, err error) { + printedArgs = args + return 0, nil + } + + ErrorLogger(ctx, LogValues{ + Query: "FakeQuery", + Params: []interface{}{"FakeParam"}, + }) + + tt.AssertEqual(t, printedArgs, []interface{}(nil)) + }) + + t.Run("with errors", func(t *testing.T) { + var printedArgs []interface{} + logPrinter = func(args ...interface{}) (n int, err error) { + printedArgs = args + return 0, nil + } + + ErrorLogger(ctx, LogValues{ + Query: "FakeQuery", + Params: []interface{}{"FakeParam"}, + Err: errors.New("fakeErrMsg"), + }) + + tt.AssertContains(t, fmt.Sprint(printedArgs...), "FakeQuery", "FakeParam", "fakeErrMsg") + }) + }) +}