diff --git a/internal/testtools/assert.go b/internal/testtools/assert.go index 8763f1a..e10c793 100644 --- a/internal/testtools/assert.go +++ b/internal/testtools/assert.go @@ -47,6 +47,19 @@ func AssertErrContains(t *testing.T, err error, substrs ...string) { } } +// AssertContains will check if the input text contains +// all the substrs specified on the substrs argument or +// fail with an appropriate error message. +func AssertContains(t *testing.T, str string, substrs ...string) { + for _, substr := range substrs { + require.True(t, + strings.Contains(str, substr), + "missing substring '%s' in error message: '%s'", + substr, str, + ) + } +} + // AssertApproxDuration checks if the durations v1 and v2 are close up to the tolerance specified. // The format and args slice can be used for generating an appropriate error message if they are not. func AssertApproxDuration(t *testing.T, tolerance time.Duration, v1, v2 time.Duration, format string, args ...interface{}) { diff --git a/internal_mocks.go b/internal_mocks.go index 013e3ac..7ef9b9e 100644 --- a/internal_mocks.go +++ b/internal_mocks.go @@ -62,6 +62,26 @@ func (m mockRows) Columns() ([]string, error) { return m.ColumnsFn() } +// mockResult mocks the ksql.Result interface +type mockResult struct { + LastInsertIdFn func() (int64, error) + RowsAffectedFn func() (int64, error) +} + +func (m mockResult) LastInsertId() (int64, error) { + if m.LastInsertIdFn != nil { + return m.LastInsertIdFn() + } + return 0, nil +} + +func (m mockResult) RowsAffected() (int64, error) { + if m.RowsAffectedFn != nil { + return m.RowsAffectedFn() + } + return 0, nil +} + // mockTx mocks the ksql.Tx interface type mockTx struct { DBAdapter diff --git a/ksql_test.go b/ksql_test.go index ae1a61f..13105ca 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -2,6 +2,7 @@ package ksql import ( "context" + "errors" "fmt" "io" "testing" @@ -103,39 +104,107 @@ func TestClose(t *testing.T) { func TestInjectLogger(t *testing.T) { ctx := context.Background() - t.Run("should work for the Query function", func(t *testing.T) { - var inputQuery string - var inputParams []interface{} - c := DB{ - db: mockDBAdapter{ - QueryContextFn: func(ctx context.Context, query string, params ...interface{}) (Rows, error) { - inputQuery = query - inputParams = params + tests := []struct { + desc string + logLevel string + methodCall func(ctx context.Context, db Provider) error + queryErr error - return mockRows{ - NextFn: func() bool { return false }, - }, nil - }, + expectLoggedQueryToContain []string + expectLoggedParams []interface{} + expectLoggedErrToContain []string + }{ + { + desc: "should work for the Query function", + logLevel: "info", + methodCall: func(ctx context.Context, db Provider) error { + var row []struct { + Count int `ksql:"count"` + } + return db.Query(ctx, &row, `SELECT count(*) AS count FROM users WHERE type = $1 AND age < $2`, "fakeType", 42) }, - } - var loggedQuery string - var loggedParams []interface{} - var loggedErr error - ctx := InjectLogger(ctx, "info", func(ctx context.Context, values LogValues) { - loggedQuery = values.Query - loggedParams = values.Params - loggedErr = values.Err + expectLoggedQueryToContain: []string{"count(*)", "type = $1"}, + expectLoggedParams: []interface{}{"fakeType", 42}, + }, + { + desc: "should work for the Query function when an error is returned", + logLevel: "info", + methodCall: func(ctx context.Context, db Provider) error { + var row []struct { + Count int `ksql:"count"` + } + return db.Query(ctx, &row, `SELECT count(*) AS count FROM users WHERE type = $1 AND age < $2`, "fakeType", 42) + }, + queryErr: errors.New("fakeErrMsg"), + + expectLoggedQueryToContain: []string{"count(*)", "type = $1"}, + expectLoggedParams: []interface{}{"fakeType", 42}, + expectLoggedErrToContain: []string{"fakeErrMsg"}, + }, + { + desc: "should work for the Query function when an error is returned with error level", + logLevel: "error", + methodCall: func(ctx context.Context, db Provider) error { + var row []struct { + Count int `ksql:"count"` + } + return db.Query(ctx, &row, `SELECT count(*) AS count FROM users WHERE type = $1 AND age < $2`, "fakeType", 42) + }, + queryErr: errors.New("fakeErrMsg"), + + expectLoggedQueryToContain: []string{"count(*)", "type = $1"}, + expectLoggedParams: []interface{}{"fakeType", 42}, + expectLoggedErrToContain: []string{"fakeErrMsg"}, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + var inputQuery string + var inputParams []interface{} + c := DB{ + db: mockDBAdapter{ + QueryContextFn: func(ctx context.Context, query string, params ...interface{}) (Rows, error) { + inputQuery = query + inputParams = params + + return mockRows{ + NextFn: func() bool { return false }, + }, test.queryErr + }, + ExecContextFn: func(ctx context.Context, query string, params ...interface{}) (Result, error) { + inputQuery = query + inputParams = params + + return mockResult{}, test.queryErr + }, + }, + } + + var loggedQuery string + var loggedParams []interface{} + var loggedErr error + ctx := InjectLogger(ctx, "info", func(ctx context.Context, values LogValues) { + loggedQuery = values.Query + loggedParams = values.Params + loggedErr = values.Err + }) + + err := test.methodCall(ctx, c) + if test.expectLoggedErrToContain != nil { + tt.AssertErrContains(t, err, test.expectLoggedErrToContain...) + tt.AssertErrContains(t, loggedErr, test.expectLoggedErrToContain...) + } else { + tt.AssertNoErr(t, err) + tt.AssertEqual(t, loggedErr, nil) + } + + tt.AssertEqual(t, loggedQuery, inputQuery) + tt.AssertEqual(t, loggedParams, inputParams) + + tt.AssertContains(t, loggedQuery, test.expectLoggedQueryToContain...) + tt.AssertEqual(t, loggedParams, test.expectLoggedParams) }) - - var row []struct { - Count int `ksql:"count"` - } - err := c.Query(ctx, &row, `SELECT count(*) AS count FROM users WHERE type = $1 AND age < $2`, "fakeType", 42) - tt.AssertNoErr(t, err) - tt.AssertEqual(t, len(row), 0) - tt.AssertEqual(t, loggedQuery, inputQuery) - tt.AssertEqual(t, loggedParams, inputParams) - tt.AssertEqual(t, loggedErr, nil) - }) + } }