Merge pull request #23 from matheusoliveira/feat/map-sync

Use sync.Map on global caches to avoid race-condition
pull/24/head
Vinícius Garcia 2022-07-04 22:16:25 -03:00 committed by GitHub
commit 511aa03982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 20 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
)
// StructInfo stores metainformation of the struct
@ -66,7 +67,7 @@ func (s StructInfo) NumFields() int {
// because the total number of types on a program
// should be finite. So keeping a single cache here
// works fine.
var tagInfoCache = map[reflect.Type]StructInfo{}
var tagInfoCache = &sync.Map{}
// GetTagInfo efficiently returns the type information
// using a global private cache
@ -78,9 +79,13 @@ func GetTagInfo(key reflect.Type) (StructInfo, error) {
return getCachedTagInfo(tagInfoCache, key)
}
func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) (StructInfo, error) {
if info, found := tagInfoCache[key]; found {
return info, nil
func getCachedTagInfo(tagInfoCache *sync.Map, key reflect.Type) (StructInfo, error) {
if data, found := tagInfoCache.Load(key); found {
if info, ok := data.(StructInfo); !ok {
return StructInfo{}, fmt.Errorf("invalid cache entry, expected type StructInfo, found %T", data)
} else {
return info, nil
}
}
info, err := getTagNames(key)
@ -88,7 +93,7 @@ func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type
return StructInfo{}, err
}
tagInfoCache[key] = info
tagInfoCache.Store(key, info)
return info, nil
}

View File

@ -5,6 +5,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"github.com/pkg/errors"
"github.com/vingarcia/ksql"
@ -121,7 +122,7 @@ func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interf
return strings.Join(conds, " AND "), params
}
// Where adds a new bollean condition to an existing
// Where adds a new boolean condition to an existing
// WhereQueries helper.
func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries {
return append(w, WhereQuery{
@ -130,7 +131,7 @@ func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries {
})
}
// WhereIf condionally adds a new boolean expression to the WhereQueries helper.
// WhereIf conditionally adds a new boolean expression to the WhereQueries helper.
func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries {
if param == nil || reflect.ValueOf(param).IsNil() {
return w
@ -142,7 +143,7 @@ func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries {
})
}
// Where adds a new bollean condition to an existing
// Where adds a new boolean condition to an existing
// WhereQueries helper.
func Where(cond string, params ...interface{}) WhereQueries {
return WhereQueries{{
@ -151,7 +152,7 @@ func Where(cond string, params ...interface{}) WhereQueries {
}}
}
// WhereIf condionally adds a new boolean expression to the WhereQueries helper
// WhereIf conditionally adds a new boolean expression to the WhereQueries helper
func WhereIf(cond string, param interface{}) WhereQueries {
if param == nil || reflect.ValueOf(param).IsNil() {
return WhereQueries{}
@ -187,7 +188,7 @@ func OrderBy(fields string) OrderByQuery {
}
}
var cachedSelectQueries = map[reflect.Type]string{}
var cachedSelectQueries = &sync.Map{}
// Builds the select query using cached info so that its efficient
func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
@ -200,8 +201,12 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
return "", fmt.Errorf("expected to receive a pointer to struct, but got: %T", obj)
}
if query, found := cachedSelectQueries[t]; found {
return query, nil
if data, found := cachedSelectQueries.Load(t); found {
if query, ok := data.(string); !ok {
return "", fmt.Errorf("invalid cache entry, expected type string, found %T", data)
} else {
return query, nil
}
}
info, err := structs.GetTagInfo(t)
@ -216,6 +221,6 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
}
query := strings.Join(escapedNames, ", ")
cachedSelectQueries[t] = query
cachedSelectQueries.Store(t, query)
return query, nil
}

19
ksql.go
View File

@ -7,6 +7,7 @@ import (
"io"
"reflect"
"strings"
"sync"
"unicode"
"github.com/pkg/errors"
@ -16,10 +17,10 @@ import (
var selectQueryCache = initializeQueryCache()
func initializeQueryCache() map[string]map[reflect.Type]string {
cache := map[string]map[reflect.Type]string{}
func initializeQueryCache() map[string]*sync.Map {
cache := map[string]*sync.Map{}
for dname := range supportedDialects {
cache[dname] = map[reflect.Type]string{}
cache[dname] = &sync.Map{}
}
return cache
@ -1062,10 +1063,14 @@ func buildSelectQuery(
dialect Dialect,
structType reflect.Type,
info structs.StructInfo,
selectQueryCache map[reflect.Type]string,
selectQueryCache *sync.Map,
) (query string, err error) {
if selectQuery, found := selectQueryCache[structType]; found {
return selectQuery, nil
if data, found := selectQueryCache.Load(structType); found {
if selectQuery, ok := data.(string); !ok {
return "", fmt.Errorf("invalid cache entry, expected type string, found %T", data)
} else {
return selectQuery, nil
}
}
if info.IsNestedStruct {
@ -1077,7 +1082,7 @@ func buildSelectQuery(
query = buildSelectQueryForPlainStructs(dialect, structType, info)
}
selectQueryCache[structType] = query
selectQueryCache.Store(structType, query)
return query, nil
}