diff --git a/ksql.go b/ksql.go index 6181157..74494ce 100644 --- a/ksql.go +++ b/ksql.go @@ -148,7 +148,7 @@ func (c DB) Query( records interface{}, query string, params ...interface{}, -) error { +) (err error) { slicePtr := reflect.ValueOf(records) slicePtrType := slicePtr.Type() if slicePtrType.Kind() != reflect.Ptr { @@ -187,6 +187,8 @@ func (c DB) Query( query = selectPrefix + query } + defer ctxLog(ctx, query, params, &err) + rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return fmt.Errorf("error running query: %w", err) @@ -242,7 +244,7 @@ func (c DB) QueryOne( record interface{}, query string, params ...interface{}, -) error { +) (err error) { v := reflect.ValueOf(record) t := v.Type() if t.Kind() != reflect.Ptr { @@ -277,6 +279,8 @@ func (c DB) QueryOne( query = selectPrefix + query } + defer ctxLog(ctx, query, params, &err) + rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return fmt.Errorf("error running query: %w", err) @@ -423,10 +427,10 @@ func (c DB) Insert( ctx context.Context, table Table, record interface{}, -) error { +) (err error) { v := reflect.ValueOf(record) t := v.Type() - if err := assertStructPtr(t); err != nil { + if err = assertStructPtr(t); err != nil { return fmt.Errorf( "KSQL: expected record to be a pointer to struct, but got: %T", record, @@ -447,6 +451,7 @@ func (c DB) Insert( } query, params, scanValues, err := buildInsertQuery(ctx, c.dialect, table, t, v, info, record) + defer ctxLog(ctx, query, params, &err) if err != nil { return err } @@ -581,7 +586,7 @@ func (c DB) Delete( ctx context.Context, table Table, idOrRecord interface{}, -) error { +) (err error) { if err := table.validate(); err != nil { return fmt.Errorf("can't delete from ksql.Table: %w", err) } @@ -595,6 +600,8 @@ func (c DB) Delete( var params []interface{} query, params = buildDeleteQuery(c.dialect, table, idMap) + defer ctxLog(ctx, query, params, &err) + result, err := c.db.ExecContext(ctx, query, params...) if err != nil { return err @@ -655,7 +662,7 @@ func (c DB) Patch( ctx context.Context, table Table, record interface{}, -) error { +) (err error) { v := reflect.ValueOf(record) t := v.Type() tStruct := t @@ -676,6 +683,7 @@ func (c DB) Patch( } query, params, err := buildUpdateQuery(ctx, c.dialect, table.name, info, recordMap, table.idColumns...) + defer ctxLog(ctx, query, params, &err) if err != nil { return err } diff --git a/ksql_test.go b/ksql_test.go index ba997ef..ae1a61f 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -1,6 +1,7 @@ package ksql import ( + "context" "fmt" "io" "testing" @@ -98,3 +99,43 @@ func TestClose(t *testing.T) { tt.AssertErrContains(t, err, "fakeCloseErrMsg") }) } + +func TestInjectLogger(t *testing.T) { + ctx := context.Background() + + t.Run("should work for the Query function", func(t *testing.T) { + var inputQuery string + var inputParams []interface{} + c := DB{ + db: mockDBAdapter{ + QueryContextFn: func(ctx context.Context, query string, params ...interface{}) (Rows, error) { + inputQuery = query + inputParams = params + + return mockRows{ + NextFn: func() bool { return false }, + }, nil + }, + }, + } + + var loggedQuery string + var loggedParams []interface{} + var loggedErr error + ctx := InjectLogger(ctx, "info", func(ctx context.Context, values LogValues) { + loggedQuery = values.Query + loggedParams = values.Params + loggedErr = values.Err + }) + + var row []struct { + Count int `ksql:"count"` + } + err := c.Query(ctx, &row, `SELECT count(*) AS count FROM users WHERE type = $1 AND age < $2`, "fakeType", 42) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, len(row), 0) + tt.AssertEqual(t, loggedQuery, inputQuery) + tt.AssertEqual(t, loggedParams, inputParams) + tt.AssertEqual(t, loggedErr, nil) + }) +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..de3ef85 --- /dev/null +++ b/logger.go @@ -0,0 +1,63 @@ +package ksql + +import "context" + +type loggerKey struct{} + +type LogValues struct { + Query string + Params []interface{} + Err error +} + +func InjectLogger( + ctx context.Context, + level string, + logFn func(ctx context.Context, values LogValues), +) context.Context { + if level != "info" { + level = "error" + } + + return context.WithValue(ctx, loggerKey{}, logger{ + level: level, + logFn: func(ctx context.Context, query string, params []interface{}, err error) { + logFn(ctx, LogValues{ + Query: query, + Params: params, + Err: err, + }) + }, + }) +} + +func ctxLog(ctx context.Context, query string, params []interface{}, err *error) { + l := ctx.Value(loggerKey{}) + if l == nil { + return + } + + if *err != nil { + l.(logger)._error(ctx, query, params, *err) + return + } + + l.(logger)._info(ctx, query, params, nil) +} + +type logger struct { + level string + logFn func(ctx context.Context, query string, params []interface{}, err error) +} + +func (l logger) _info(ctx context.Context, query string, params []interface{}, err error) { + if l.level == "error" { + return + } + + l.logFn(ctx, query, params, err) +} + +func (l logger) _error(ctx context.Context, query string, params []interface{}, err error) { + l.logFn(ctx, query, params, err) +}