From 74cb87bea027e7924e614a8f2ff0c9719f8dd4b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Wed, 24 Aug 2022 22:45:24 -0300 Subject: [PATCH] Decouple ksql.DB from TagInfoCache so we can replace it during tests --- internal/structs/structs.go | 12 +++++++++ ksql.go | 52 +++++++++++++++++++++---------------- test_adapters.go | 16 ++++++++---- 3 files changed, 53 insertions(+), 27 deletions(-) diff --git a/internal/structs/structs.go b/internal/structs/structs.go index a4faa6b..8f7c3db 100644 --- a/internal/structs/structs.go +++ b/internal/structs/structs.go @@ -69,12 +69,24 @@ func (s StructInfo) NumFields() int { // works fine. var tagInfoCache = &sync.Map{} +// TagInfoCache implements the ksql.TagInfoCache interface +// this abstraction was created for allowing the use of +// mocks during tests. +type TagInfoCache struct{} + // GetTagInfo efficiently returns the type information // using a global private cache // // In the future we might move this cache inside // a struct, but for now this accessor is the one // we are using +func (t TagInfoCache) GetTagInfo(key reflect.Type) (StructInfo, error) { + return getCachedTagInfo(tagInfoCache, key) +} + +// GetTagInfo is the static version of the method above +// created for convenience when the extra abstraction +// is unnecessary. func GetTagInfo(key reflect.Type) (StructInfo, error) { return getCachedTagInfo(tagInfoCache, key) } diff --git a/ksql.go b/ksql.go index e49e93b..75ce62b 100644 --- a/ksql.go +++ b/ksql.go @@ -33,6 +33,8 @@ type DB struct { driver string dialect Dialect db DBAdapter + + tagInfoCache TagInfoCache } // DBAdapter is minimalistic interface to decouple our implementation @@ -45,6 +47,10 @@ type DBAdapter interface { QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) } +type TagInfoCache interface { + GetTagInfo(key reflect.Type) (structs.StructInfo, error) +} + // TxBeginner needs to be implemented by the DBAdapter in order to make it possible // to use the `ksql.Transaction()` function. type TxBeginner interface { @@ -107,6 +113,8 @@ func NewWithAdapter( dialect: dialect, driver: dialectName, db: db, + + tagInfoCache: structs.TagInfoCache{}, }, nil } @@ -142,7 +150,7 @@ func (c DB) Query( slice = slice.Slice(0, 0) } - info, err := structs.GetTagInfo(structType) + info, err := c.tagInfoCache.GetTagInfo(structType) if err != nil { return err } @@ -154,7 +162,7 @@ func (c DB) Query( } if firstToken == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) + selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -185,7 +193,7 @@ func (c DB) Query( elemPtr = elemPtr.Elem() } - err = scanRows(c.dialect, rows, elemPtr.Interface()) + err = c.scanRows(c.dialect, rows, elemPtr.Interface()) if err != nil { return err } @@ -232,7 +240,7 @@ func (c DB) QueryOne( return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) } - info, err := structs.GetTagInfo(tStruct) + info, err := c.tagInfoCache.GetTagInfo(tStruct) if err != nil { return err } @@ -244,7 +252,7 @@ func (c DB) QueryOne( } if firstToken == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()]) + selectPrefix, err := c.buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -264,7 +272,7 @@ func (c DB) QueryOne( return ErrRecordNotFound } - err = scanRowsFromType(c.dialect, rows, record, t, v) + err = c.scanRowsFromType(c.dialect, rows, record, t, v) if err != nil { return err } @@ -305,7 +313,7 @@ func (c DB) QueryChunks( return err } - info, err := structs.GetTagInfo(structType) + info, err := c.tagInfoCache.GetTagInfo(structType) if err != nil { return err } @@ -317,7 +325,7 @@ func (c DB) QueryChunks( } if firstToken == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) + selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -343,7 +351,7 @@ func (c DB) QueryChunks( chunk = reflect.Append(chunk, elemValue) } - err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface()) + err = c.scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface()) if err != nil { return err } @@ -415,7 +423,7 @@ func (c DB) Insert( return fmt.Errorf("can't insert in ksql.Table: %s", err) } - info, err := structs.GetTagInfo(t.Elem()) + info, err := c.tagInfoCache.GetTagInfo(t.Elem()) if err != nil { return err } @@ -652,7 +660,7 @@ func (c DB) Patch( } tStruct = t.Elem() } - info, err := structs.GetTagInfo(tStruct) + info, err := c.tagInfoCache.GetTagInfo(tStruct) if err != nil { return err } @@ -929,13 +937,13 @@ func (nopScanner) Scan(value interface{}) error { return nil } -func scanRows(dialect Dialect, rows Rows, record interface{}) error { +func (c DB) scanRows(dialect Dialect, rows Rows, record interface{}) error { v := reflect.ValueOf(record) t := v.Type() - return scanRowsFromType(dialect, rows, record, t, v) + return c.scanRowsFromType(dialect, rows, record, t, v) } -func scanRowsFromType( +func (c DB) scanRowsFromType( dialect Dialect, rows Rows, record interface{}, @@ -953,7 +961,7 @@ func scanRowsFromType( return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record) } - info, err := structs.GetTagInfo(t) + info, err := c.tagInfoCache.GetTagInfo(t) if err != nil { return err } @@ -963,7 +971,7 @@ func scanRowsFromType( // This version is positional meaning that it expect the arguments // to follow an specific order. It's ok because we don't allow the // user to type the "SELECT" part of the query for nested structs. - scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info) + scanArgs, err = c.getScanArgsForNestedStructs(dialect, rows, t, v, info) if err != nil { return err } @@ -980,7 +988,7 @@ func scanRowsFromType( return rows.Scan(scanArgs...) } -func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) { +func (c DB) getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) { scanArgs := []interface{}{} for i := 0; i < v.NumField(); i++ { if !info.ByIndex(i).Valid { @@ -988,7 +996,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r } // TODO(vingarcia00): Handle case where type is pointer - nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type) + nestedStructInfo, err := c.tagInfoCache.GetTagInfo(t.Field(i).Type) if err != nil { return nil, err } @@ -1076,7 +1084,7 @@ func getFirstToken(s string) string { return token.String() } -func buildSelectQuery( +func (c DB) buildSelectQuery( dialect Dialect, structType reflect.Type, info structs.StructInfo, @@ -1091,7 +1099,7 @@ func buildSelectQuery( } if info.IsNestedStruct { - query, err = buildSelectQueryForNestedStructs(dialect, structType, info) + query, err = c.buildSelectQueryForNestedStructs(dialect, structType, info) if err != nil { return "", err } @@ -1121,7 +1129,7 @@ func buildSelectQueryForPlainStructs( return "SELECT " + strings.Join(fields, ", ") + " " } -func buildSelectQueryForNestedStructs( +func (c DB) buildSelectQueryForNestedStructs( dialect Dialect, structType reflect.Type, info structs.StructInfo, @@ -1142,7 +1150,7 @@ func buildSelectQueryForNestedStructs( ) } - nestedStructTagInfo, err := structs.GetTagInfo(nestedStructType) + nestedStructTagInfo, err := c.tagInfoCache.GetTagInfo(nestedStructType) if err != nil { return "", err } diff --git a/test_adapters.go b/test_adapters.go index 2362398..8beb629 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/pkg/errors" + "github.com/vingarcia/ksql/internal/structs" tt "github.com/vingarcia/ksql/internal/testtools" "github.com/vingarcia/ksql/nullable" ) @@ -2455,7 +2456,7 @@ func ScanRowsTest( tt.AssertEqual(t, rows.Next(), true) var u user - err = scanRows(dialect, rows, &u) + err = c.scanRows(dialect, rows, &u) tt.AssertNoErr(t, err) tt.AssertEqual(t, u.Name, "User2") @@ -2488,7 +2489,7 @@ func ScanRowsTest( // Omitted for testing purposes: // Name string `ksql:"name"` } - err = scanRows(dialect, rows, &u) + err = c.scanRows(dialect, rows, &u) tt.AssertNoErr(t, err) tt.AssertEqual(t, u.Age, 22) @@ -2504,6 +2505,7 @@ func ScanRowsTest( ctx := context.TODO() db, closer := newDBAdapter(t) defer closer.Close() + c := newTestDB(db, driver) rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'") tt.AssertNoErr(t, err) @@ -2511,7 +2513,7 @@ func ScanRowsTest( var u user err = rows.Close() tt.AssertNoErr(t, err) - err = scanRows(dialect, rows, &u) + err = c.scanRows(dialect, rows, &u) tt.AssertNotEqual(t, err, nil) }) @@ -2525,13 +2527,14 @@ func ScanRowsTest( ctx := context.TODO() db, closer := newDBAdapter(t) defer closer.Close() + c := newTestDB(db, driver) rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'") tt.AssertNoErr(t, err) defer rows.Close() var u user - err = scanRows(dialect, rows, u) + err = c.scanRows(dialect, rows, u) tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user") }) @@ -2545,13 +2548,14 @@ func ScanRowsTest( ctx := context.TODO() db, closer := newDBAdapter(t) defer closer.Close() + c := newTestDB(db, driver) rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'") tt.AssertNoErr(t, err) defer rows.Close() var u map[string]interface{} - err = scanRows(dialect, rows, &u) + err = c.scanRows(dialect, rows, &u) tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "map[string]interface") }) }) @@ -2684,6 +2688,8 @@ func newTestDB(db DBAdapter, driver string) DB { driver: driver, dialect: supportedDialects[driver], db: db, + + tagInfoCache: structs.TagInfoCache{}, } }