Refactor modifiers into its own package

pull/29/head
Vinícius Garcia 2022-09-21 22:30:20 -03:00
parent cb15295e46
commit 41f4d5487b
8 changed files with 138 additions and 108 deletions

View File

@ -1,66 +0,0 @@
package ksql
import (
"context"
"database/sql/driver"
"fmt"
)
// Here we keep all the registered modifier
var modifiers = map[string]AttrModifier{
"json": jsonModifier{},
}
// RegisterAttrModifier allow users to add custom modifiers on startup
// it is recommended to do this inside an init() function.
func RegisterAttrModifier(key string, modifier AttrModifier) {
_, found := modifiers[key]
if found {
panic(fmt.Errorf("KSQL: cannot register modifier '%s' name is already in use", key))
}
modifiers[key] = modifier
}
// AttrModifier describes the two operations required to serialize and deserialize an object from the database.
type AttrModifier interface {
AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error
AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error)
}
// OpInfo contains information that might be used by a modifier to determine how it should behave.
type OpInfo struct {
// A string version of the name of one of
// the methods of the `ksql.Provider` interface, e.g. `Insert` or `Query`
Method string
// The string representing the current underlying database, e.g.:
// "postgres", "sqlite3", "mysql" or "sqlserver".
DriverName string
}
// attrModifier is the wrapper that allow us to intercept the Scan and Value processes
// so we can run the modifiers instead of allowing the database driver to use
// its default behavior.
//
// For that this struct implements both the `sql.Scanner` and `sql.Valuer` interfaces.
type attrModifier struct {
ctx context.Context
// When Scanning this value should be a pointer to the attribute
// and when "Valuing" it should just be the actual value
attr interface{}
modifierName string
opInfo OpInfo
}
// Scan implements the sql.Scanner interface
func (a attrModifier) Scan(dbValue interface{}) error {
return modifiers[a.modifierName].AttrScan(a.ctx, a.opInfo, a.attr, dbValue)
}
// Value implements the sql.Valuer interface
func (a attrModifier) Value() (driver.Value, error) {
return modifiers[a.modifierName].AttrValue(a.ctx, a.opInfo, a.attr)
}

View File

@ -0,0 +1,32 @@
package modifiers
import (
"context"
"database/sql/driver"
)
// AttrWrapper is the wrapper that allow us to intercept the Scan and Value processes
// so we can run the modifiers instead of allowing the database driver to use
// its default behavior.
//
// For that this struct implements both the `sql.Scanner` and `sql.Valuer` interfaces.
type AttrWrapper struct {
Ctx context.Context
// When Scanning this value should be a pointer to the attribute
// and when "Valuing" it should just be the actual value
Attr interface{}
Modifier AttrModifier
OpInfo OpInfo
}
// Scan implements the sql.Scanner interface
func (a AttrWrapper) Scan(dbValue interface{}) error {
return a.Modifier.AttrScan(a.Ctx, a.OpInfo, a.Attr, dbValue)
}
// Value implements the sql.Valuer interface
func (a AttrWrapper) Value() (driver.Value, error) {
return a.Modifier.AttrValue(a.Ctx, a.OpInfo, a.Attr)
}

View File

@ -0,0 +1,20 @@
package modifiers
import "context"
// AttrModifier describes the two operations required to serialize and deserialize an object from the database.
type AttrModifier interface {
AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error
AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error)
}
// OpInfo contains information that might be used by a modifier to determine how it should behave.
type OpInfo struct {
// A string version of the name of one of
// the methods of the `ksql.Provider` interface, e.g. `Insert` or `Query`
Method string
// The string representing the current underlying database, e.g.:
// "postgres", "sqlite3", "mysql" or "sqlserver".
DriverName string
}

View File

@ -0,0 +1,35 @@
package modifiers
import (
"fmt"
"sync"
)
// Here we keep all the registered modifiers
var modifiers sync.Map
func init() {
// These are the builtin modifiers
modifiers.Store("json", jsonModifier{})
}
// RegisterAttrModifier allow users to add custom modifiers on startup
// it is recommended to do this inside an init() function.
func RegisterAttrModifier(key string, modifier AttrModifier) {
_, found := modifiers.Load(key)
if found {
panic(fmt.Errorf("KSQL: cannot register modifier '%s' name is already in use", key))
}
modifiers.Store(key, modifier)
}
func LoadGlobalModifier(key string) (AttrModifier, error) {
rawModifier, _ := modifiers.Load(key)
modifier, ok := rawModifier.(AttrModifier)
if !ok {
return nil, fmt.Errorf("no modifier found with name '%s'", key)
}
return modifier, nil
}

View File

@ -1,4 +1,4 @@
package ksql
package modifiers
import (
"context"

View File

@ -5,6 +5,8 @@ import (
"reflect"
"strings"
"sync"
"github.com/vingarcia/ksql/internal/modifiers"
)
// StructInfo stores metainformation of the struct
@ -20,10 +22,10 @@ type StructInfo struct {
// information regarding a specific field
// of a struct.
type FieldInfo struct {
Name string
Index int
Valid bool
ModifierName string
Name string
Index int
Valid bool
Modifier modifiers.AttrModifier
}
// ByIndex returns either the *FieldInfo of a valid
@ -232,7 +234,7 @@ func (p PtrConverter) Convert(destType reflect.Type) (reflect.Value, error) {
//
// This should save several calls to `Field(i).Tag.Get("foo")`
// which improves performance by a lot.
func getTagNames(t reflect.Type) (StructInfo, error) {
func getTagNames(t reflect.Type) (_ StructInfo, err error) {
info := StructInfo{
byIndex: map[int]*FieldInfo{},
byName: map[string]*FieldInfo{},
@ -249,10 +251,13 @@ func getTagNames(t reflect.Type) (StructInfo, error) {
}
tags := strings.Split(name, ",")
var modifierName string
var modifier modifiers.AttrModifier
if len(tags) > 1 {
name = tags[0]
modifierName = tags[1]
modifier, err = modifiers.LoadGlobalModifier(tags[1])
if err != nil {
return StructInfo{}, fmt.Errorf("attribute contains invalid modifier name: %w", err)
}
}
if _, found := info.byName[name]; found {
@ -263,9 +268,9 @@ func getTagNames(t reflect.Type) (StructInfo, error) {
}
info.add(FieldInfo{
Name: name,
Index: i,
ModifierName: modifierName,
Name: name,
Index: i,
Modifier: modifier,
})
}

53
ksql.go
View File

@ -11,6 +11,7 @@ import (
"unicode"
"github.com/pkg/errors"
"github.com/vingarcia/ksql/internal/modifiers"
"github.com/vingarcia/ksql/internal/structs"
"github.com/vingarcia/ksql/ksqltest"
)
@ -718,13 +719,13 @@ func buildInsertQuery(
recordValue := recordMap[col]
params[i] = recordValue
modifierName := info.ByName(col).ModifierName
if modifierName != "" {
params[i] = attrModifier{
ctx: ctx,
attr: recordValue,
modifierName: modifierName,
opInfo: OpInfo{
modifier := info.ByName(col).Modifier
if modifier != nil {
params[i] = modifiers.AttrWrapper{
Ctx: ctx,
Attr: recordValue,
Modifier: modifier,
OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(),
Method: "Insert",
},
@ -827,13 +828,13 @@ func buildUpdateQuery(
for i, k := range keys {
recordValue := recordMap[k]
modifierName := info.ByName(k).ModifierName
if modifierName != "" {
recordValue = attrModifier{
ctx: ctx,
attr: recordValue,
modifierName: modifierName,
opInfo: OpInfo{
modifier := info.ByName(k).Modifier
if modifier != nil {
recordValue = modifiers.AttrWrapper{
Ctx: ctx,
Attr: recordValue,
Modifier: modifier,
OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(),
Method: "Update",
},
@ -1032,12 +1033,12 @@ func getScanArgsForNestedStructs(
if fieldInfo.Valid {
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.ModifierName != "" {
valueScanner = &attrModifier{
ctx: ctx,
attr: valueScanner,
modifierName: fieldInfo.ModifierName,
opInfo: OpInfo{
if fieldInfo.Modifier != nil {
valueScanner = &modifiers.AttrWrapper{
Ctx: ctx,
Attr: valueScanner,
Modifier: fieldInfo.Modifier,
OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers
@ -1062,12 +1063,12 @@ func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string,
valueScanner := nopScannerValue
if fieldInfo.Valid {
valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.ModifierName != "" {
valueScanner = &attrModifier{
ctx: ctx,
attr: valueScanner,
modifierName: fieldInfo.ModifierName,
opInfo: OpInfo{
if fieldInfo.Modifier != nil {
valueScanner = &modifiers.AttrWrapper{
Ctx: ctx,
Attr: valueScanner,
Modifier: fieldInfo.Modifier,
OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers

View File

@ -9,6 +9,7 @@ import (
"testing"
"github.com/pkg/errors"
"github.com/vingarcia/ksql/internal/modifiers"
tt "github.com/vingarcia/ksql/internal/testtools"
"github.com/vingarcia/ksql/nullable"
)
@ -2799,11 +2800,13 @@ func getUserByID(db DBAdapter, dialect Dialect, result *user, id uint) error {
return sql.ErrNoRows
}
value := attrModifier{
ctx: context.TODO(),
attr: &result.Address,
modifierName: "json",
opInfo: OpInfo{
modifier, _ := modifiers.LoadGlobalModifier("json")
value := modifiers.AttrWrapper{
Ctx: context.TODO(),
Attr: &result.Address,
Modifier: modifier,
OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers