diff --git a/logger.go b/logger.go index 805ddc9..2e54e3d 100644 --- a/logger.go +++ b/logger.go @@ -2,8 +2,39 @@ package ksql import ( "context" + "encoding/json" + "fmt" ) +// This variable is only used during tests: +var logPrinter = fmt.Println + +var _ LoggerFn = ErrorsLogger + +func ErrorsLogger(ctx context.Context, values LogValues) { + if values.Err == nil { + return + } + + Logger(ctx, values) +} + +var _ LoggerFn = Logger + +func Logger(ctx context.Context, values LogValues) { + m := map[string]interface{}{ + "query": values.Query, + "params": values.Params, + } + + if values.Err != nil { + m["error"] = values.Err.Error() + } + + b, _ := json.Marshal(m) + logPrinter(string(b)) +} + type loggerKey struct{} type LogValues struct { @@ -12,11 +43,12 @@ type LogValues struct { Err error } +type LoggerFn func(ctx context.Context, values LogValues) type loggerFn func(ctx context.Context, query string, params []interface{}, err error) func InjectLogger( ctx context.Context, - logFn func(ctx context.Context, values LogValues), + logFn LoggerFn, ) context.Context { return context.WithValue(ctx, loggerKey{}, loggerFn(func(ctx context.Context, query string, params []interface{}, err error) { logFn(ctx, LogValues{ diff --git a/logger_test.go b/logger_test.go index 3048ca3..49d8704 100644 --- a/logger_test.go +++ b/logger_test.go @@ -2,6 +2,8 @@ package ksql import ( "context" + "errors" + "fmt" "testing" tt "github.com/vingarcia/ksql/internal/testtools" @@ -17,3 +19,77 @@ func TestCtxLog(t *testing.T) { tt.AssertEqual(t, panicPayload, nil) }) } + +func TestBuiltinLoggers(t *testing.T) { + ctx := context.Background() + + defer func() { + logPrinter = fmt.Println + }() + + t.Run("Logger", func(t *testing.T) { + t.Run("with no errors", func(t *testing.T) { + var printedArgs []interface{} + logPrinter = func(args ...interface{}) (n int, err error) { + printedArgs = args + return 0, nil + } + + Logger(ctx, LogValues{ + Query: "FakeQuery", + Params: []interface{}{"FakeParam"}, + }) + + tt.AssertContains(t, fmt.Sprint(printedArgs...), "FakeQuery", "FakeParam") + }) + + t.Run("with errors", func(t *testing.T) { + var printedArgs []interface{} + logPrinter = func(args ...interface{}) (n int, err error) { + printedArgs = args + return 0, nil + } + + Logger(ctx, LogValues{ + Query: "FakeQuery", + Params: []interface{}{"FakeParam"}, + Err: errors.New("fakeErrMsg"), + }) + + tt.AssertContains(t, fmt.Sprint(printedArgs...), "FakeQuery", "FakeParam", "fakeErrMsg") + }) + }) + + t.Run("ErrorsLogger", func(t *testing.T) { + t.Run("with no errors", func(t *testing.T) { + var printedArgs []interface{} + logPrinter = func(args ...interface{}) (n int, err error) { + printedArgs = args + return 0, nil + } + + ErrorsLogger(ctx, LogValues{ + Query: "FakeQuery", + Params: []interface{}{"FakeParam"}, + }) + + tt.AssertEqual(t, printedArgs, []interface{}(nil)) + }) + + t.Run("with errors", func(t *testing.T) { + var printedArgs []interface{} + logPrinter = func(args ...interface{}) (n int, err error) { + printedArgs = args + return 0, nil + } + + ErrorsLogger(ctx, LogValues{ + Query: "FakeQuery", + Params: []interface{}{"FakeParam"}, + Err: errors.New("fakeErrMsg"), + }) + + tt.AssertContains(t, fmt.Sprint(printedArgs...), "FakeQuery", "FakeParam", "fakeErrMsg") + }) + }) +}