diff --git a/internal/structs/structs.go b/internal/structs/structs.go index 2022453..a4faa6b 100644 --- a/internal/structs/structs.go +++ b/internal/structs/structs.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strings" + "sync" ) // StructInfo stores metainformation of the struct @@ -66,7 +67,7 @@ func (s StructInfo) NumFields() int { // because the total number of types on a program // should be finite. So keeping a single cache here // works fine. -var tagInfoCache = map[reflect.Type]StructInfo{} +var tagInfoCache = &sync.Map{} // GetTagInfo efficiently returns the type information // using a global private cache @@ -78,9 +79,13 @@ func GetTagInfo(key reflect.Type) (StructInfo, error) { return getCachedTagInfo(tagInfoCache, key) } -func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) (StructInfo, error) { - if info, found := tagInfoCache[key]; found { - return info, nil +func getCachedTagInfo(tagInfoCache *sync.Map, key reflect.Type) (StructInfo, error) { + if data, found := tagInfoCache.Load(key); found { + if info, ok := data.(StructInfo); !ok { + return StructInfo{}, fmt.Errorf("invalid cache entry, expected type StructInfo, found %T", data) + } else { + return info, nil + } } info, err := getTagNames(key) @@ -88,7 +93,7 @@ func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type return StructInfo{}, err } - tagInfoCache[key] = info + tagInfoCache.Store(key, info) return info, nil } diff --git a/kbuilder/query.go b/kbuilder/query.go index 4aef10f..f6d0a9e 100644 --- a/kbuilder/query.go +++ b/kbuilder/query.go @@ -5,6 +5,7 @@ import ( "reflect" "strconv" "strings" + "sync" "github.com/pkg/errors" "github.com/vingarcia/ksql" @@ -121,7 +122,7 @@ func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interf return strings.Join(conds, " AND "), params } -// Where adds a new bollean condition to an existing +// Where adds a new boolean condition to an existing // WhereQueries helper. func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries { return append(w, WhereQuery{ @@ -130,7 +131,7 @@ func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries { }) } -// WhereIf condionally adds a new boolean expression to the WhereQueries helper. +// WhereIf conditionally adds a new boolean expression to the WhereQueries helper. func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries { if param == nil || reflect.ValueOf(param).IsNil() { return w @@ -142,7 +143,7 @@ func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries { }) } -// Where adds a new bollean condition to an existing +// Where adds a new boolean condition to an existing // WhereQueries helper. func Where(cond string, params ...interface{}) WhereQueries { return WhereQueries{{ @@ -151,7 +152,7 @@ func Where(cond string, params ...interface{}) WhereQueries { }} } -// WhereIf condionally adds a new boolean expression to the WhereQueries helper +// WhereIf conditionally adds a new boolean expression to the WhereQueries helper func WhereIf(cond string, param interface{}) WhereQueries { if param == nil || reflect.ValueOf(param).IsNil() { return WhereQueries{} @@ -187,7 +188,7 @@ func OrderBy(fields string) OrderByQuery { } } -var cachedSelectQueries = map[reflect.Type]string{} +var cachedSelectQueries = &sync.Map{} // Builds the select query using cached info so that its efficient func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { @@ -200,8 +201,12 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { return "", fmt.Errorf("expected to receive a pointer to struct, but got: %T", obj) } - if query, found := cachedSelectQueries[t]; found { - return query, nil + if data, found := cachedSelectQueries.Load(t); found { + if query, ok := data.(string); !ok { + return "", fmt.Errorf("invalid cache entry, expected type string, found %T", data) + } else { + return query, nil + } } info, err := structs.GetTagInfo(t) @@ -216,6 +221,6 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { } query := strings.Join(escapedNames, ", ") - cachedSelectQueries[t] = query + cachedSelectQueries.Store(t, query) return query, nil } diff --git a/ksql.go b/ksql.go index a447eae..e53e54e 100644 --- a/ksql.go +++ b/ksql.go @@ -7,6 +7,7 @@ import ( "io" "reflect" "strings" + "sync" "unicode" "github.com/pkg/errors" @@ -16,10 +17,10 @@ import ( var selectQueryCache = initializeQueryCache() -func initializeQueryCache() map[string]map[reflect.Type]string { - cache := map[string]map[reflect.Type]string{} +func initializeQueryCache() map[string]*sync.Map { + cache := map[string]*sync.Map{} for dname := range supportedDialects { - cache[dname] = map[reflect.Type]string{} + cache[dname] = &sync.Map{} } return cache @@ -1062,10 +1063,14 @@ func buildSelectQuery( dialect Dialect, structType reflect.Type, info structs.StructInfo, - selectQueryCache map[reflect.Type]string, + selectQueryCache *sync.Map, ) (query string, err error) { - if selectQuery, found := selectQueryCache[structType]; found { - return selectQuery, nil + if data, found := selectQueryCache.Load(structType); found { + if selectQuery, ok := data.(string); !ok { + return "", fmt.Errorf("invalid cache entry, expected type string, found %T", data) + } else { + return selectQuery, nil + } } if info.IsNestedStruct { @@ -1077,7 +1082,7 @@ func buildSelectQuery( query = buildSelectQueryForPlainStructs(dialect, structType, info) } - selectQueryCache[structType] = query + selectQueryCache.Store(structType, query) return query, nil }