Merge pull request #23 from matheusoliveira/feat/map-sync

Use sync.Map on global caches to avoid race-condition
pull/24/head
Vinícius Garcia 2022-07-04 22:16:25 -03:00 committed by GitHub
commit 511aa03982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 20 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
) )
// StructInfo stores metainformation of the struct // StructInfo stores metainformation of the struct
@ -66,7 +67,7 @@ func (s StructInfo) NumFields() int {
// because the total number of types on a program // because the total number of types on a program
// should be finite. So keeping a single cache here // should be finite. So keeping a single cache here
// works fine. // works fine.
var tagInfoCache = map[reflect.Type]StructInfo{} var tagInfoCache = &sync.Map{}
// GetTagInfo efficiently returns the type information // GetTagInfo efficiently returns the type information
// using a global private cache // using a global private cache
@ -78,9 +79,13 @@ func GetTagInfo(key reflect.Type) (StructInfo, error) {
return getCachedTagInfo(tagInfoCache, key) return getCachedTagInfo(tagInfoCache, key)
} }
func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) (StructInfo, error) { func getCachedTagInfo(tagInfoCache *sync.Map, key reflect.Type) (StructInfo, error) {
if info, found := tagInfoCache[key]; found { if data, found := tagInfoCache.Load(key); found {
return info, nil if info, ok := data.(StructInfo); !ok {
return StructInfo{}, fmt.Errorf("invalid cache entry, expected type StructInfo, found %T", data)
} else {
return info, nil
}
} }
info, err := getTagNames(key) info, err := getTagNames(key)
@ -88,7 +93,7 @@ func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type
return StructInfo{}, err return StructInfo{}, err
} }
tagInfoCache[key] = info tagInfoCache.Store(key, info)
return info, nil return info, nil
} }

View File

@ -5,6 +5,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/vingarcia/ksql" "github.com/vingarcia/ksql"
@ -121,7 +122,7 @@ func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interf
return strings.Join(conds, " AND "), params return strings.Join(conds, " AND "), params
} }
// Where adds a new bollean condition to an existing // Where adds a new boolean condition to an existing
// WhereQueries helper. // WhereQueries helper.
func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries { func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries {
return append(w, WhereQuery{ return append(w, WhereQuery{
@ -130,7 +131,7 @@ func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries {
}) })
} }
// WhereIf condionally adds a new boolean expression to the WhereQueries helper. // WhereIf conditionally adds a new boolean expression to the WhereQueries helper.
func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries { func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries {
if param == nil || reflect.ValueOf(param).IsNil() { if param == nil || reflect.ValueOf(param).IsNil() {
return w return w
@ -142,7 +143,7 @@ func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries {
}) })
} }
// Where adds a new bollean condition to an existing // Where adds a new boolean condition to an existing
// WhereQueries helper. // WhereQueries helper.
func Where(cond string, params ...interface{}) WhereQueries { func Where(cond string, params ...interface{}) WhereQueries {
return WhereQueries{{ return WhereQueries{{
@ -151,7 +152,7 @@ func Where(cond string, params ...interface{}) WhereQueries {
}} }}
} }
// WhereIf condionally adds a new boolean expression to the WhereQueries helper // WhereIf conditionally adds a new boolean expression to the WhereQueries helper
func WhereIf(cond string, param interface{}) WhereQueries { func WhereIf(cond string, param interface{}) WhereQueries {
if param == nil || reflect.ValueOf(param).IsNil() { if param == nil || reflect.ValueOf(param).IsNil() {
return WhereQueries{} return WhereQueries{}
@ -187,7 +188,7 @@ func OrderBy(fields string) OrderByQuery {
} }
} }
var cachedSelectQueries = map[reflect.Type]string{} var cachedSelectQueries = &sync.Map{}
// Builds the select query using cached info so that its efficient // Builds the select query using cached info so that its efficient
func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
@ -200,8 +201,12 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
return "", fmt.Errorf("expected to receive a pointer to struct, but got: %T", obj) return "", fmt.Errorf("expected to receive a pointer to struct, but got: %T", obj)
} }
if query, found := cachedSelectQueries[t]; found { if data, found := cachedSelectQueries.Load(t); found {
return query, nil if query, ok := data.(string); !ok {
return "", fmt.Errorf("invalid cache entry, expected type string, found %T", data)
} else {
return query, nil
}
} }
info, err := structs.GetTagInfo(t) info, err := structs.GetTagInfo(t)
@ -216,6 +221,6 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
} }
query := strings.Join(escapedNames, ", ") query := strings.Join(escapedNames, ", ")
cachedSelectQueries[t] = query cachedSelectQueries.Store(t, query)
return query, nil return query, nil
} }

19
ksql.go
View File

@ -7,6 +7,7 @@ import (
"io" "io"
"reflect" "reflect"
"strings" "strings"
"sync"
"unicode" "unicode"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -16,10 +17,10 @@ import (
var selectQueryCache = initializeQueryCache() var selectQueryCache = initializeQueryCache()
func initializeQueryCache() map[string]map[reflect.Type]string { func initializeQueryCache() map[string]*sync.Map {
cache := map[string]map[reflect.Type]string{} cache := map[string]*sync.Map{}
for dname := range supportedDialects { for dname := range supportedDialects {
cache[dname] = map[reflect.Type]string{} cache[dname] = &sync.Map{}
} }
return cache return cache
@ -1062,10 +1063,14 @@ func buildSelectQuery(
dialect Dialect, dialect Dialect,
structType reflect.Type, structType reflect.Type,
info structs.StructInfo, info structs.StructInfo,
selectQueryCache map[reflect.Type]string, selectQueryCache *sync.Map,
) (query string, err error) { ) (query string, err error) {
if selectQuery, found := selectQueryCache[structType]; found { if data, found := selectQueryCache.Load(structType); found {
return selectQuery, nil if selectQuery, ok := data.(string); !ok {
return "", fmt.Errorf("invalid cache entry, expected type string, found %T", data)
} else {
return selectQuery, nil
}
} }
if info.IsNestedStruct { if info.IsNestedStruct {
@ -1077,7 +1082,7 @@ func buildSelectQuery(
query = buildSelectQueryForPlainStructs(dialect, structType, info) query = buildSelectQueryForPlainStructs(dialect, structType, info)
} }
selectQueryCache[structType] = query selectQueryCache.Store(structType, query)
return query, nil return query, nil
} }