Refactor modifiers so it is represented by a struct

pull/29/head
Vinícius Garcia 2022-09-26 01:37:23 -03:00
parent 7661ba0314
commit 8cba3efa2d
10 changed files with 157 additions and 151 deletions

View File

@ -5,28 +5,40 @@ import (
"database/sql/driver" "database/sql/driver"
) )
// AttrWrapper is the wrapper that allow us to intercept the Scan and Value processes // AttrScanWrapper is the wrapper that allow us to intercept the Scan process
// so we can run the modifiers instead of allowing the database driver to use // so we can run the modifiers instead of allowing the database driver to use
// its default behavior. // its default behavior.
// //
// For that this struct implements both the `sql.Scanner` and `sql.Valuer` interfaces. // For that this struct implements the `sql.Scanner` interface
type AttrWrapper struct { type AttrScanWrapper struct {
Ctx context.Context Ctx context.Context
// When Scanning this value should be a pointer to the attribute AttrPtr interface{}
// and when "Valuing" it should just be the actual value
Attr interface{}
Modifier AttrModifier ScanFn AttrScanner
OpInfo OpInfo OpInfo OpInfo
} }
// Scan implements the sql.Scanner interface // Scan implements the sql.Scanner interface
func (a AttrWrapper) Scan(dbValue interface{}) error { func (a AttrScanWrapper) Scan(dbValue interface{}) error {
return a.Modifier.AttrScan(a.Ctx, a.OpInfo, a.Attr, dbValue) return a.ScanFn(a.Ctx, a.OpInfo, a.AttrPtr, dbValue)
}
// AttrValueWrapper is the wrapper that allow us to intercept the "Valuing" process
// so we can run the modifiers instead of allowing the database driver to use
// its default behavior.
//
// For that this struct implements the `sql.Valuer` interface
type AttrValueWrapper struct {
Ctx context.Context
Attr interface{}
ValueFn AttrValuer
OpInfo OpInfo
} }
// Value implements the sql.Valuer interface // Value implements the sql.Valuer interface
func (a AttrWrapper) Value() (driver.Value, error) { func (a AttrValueWrapper) Value() (driver.Value, error) {
return a.Modifier.AttrValue(a.Ctx, a.OpInfo, a.Attr) return a.ValueFn(a.Ctx, a.OpInfo, a.Attr)
} }

View File

@ -8,30 +8,20 @@ import (
tt "github.com/vingarcia/ksql/internal/testtools" tt "github.com/vingarcia/ksql/internal/testtools"
) )
func TestAttrWrapper(t *testing.T) { func TestAttrScanWrapper(t *testing.T) {
ctx := context.Background() ctx := context.Background()
var scanArgs map[string]interface{} var scanArgs map[string]interface{}
var valueArgs map[string]interface{} wrapper := AttrScanWrapper{
wrapper := AttrWrapper{ Ctx: ctx,
Ctx: ctx, AttrPtr: "fakeAttrPtr",
Attr: "fakeAttr", ScanFn: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error {
Modifier: AttrModifierMock{ scanArgs = map[string]interface{}{
AttrScanFn: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error { "opInfo": opInfo,
scanArgs = map[string]interface{}{ "attrPtr": attrPtr,
"opInfo": opInfo, "dbValue": dbValue,
"attrPtr": attrPtr, }
"dbValue": dbValue, return errors.New("fakeScanErrMsg")
}
return errors.New("fakeScanErrMsg")
},
AttrValueFn: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
valueArgs = map[string]interface{}{
"opInfo": opInfo,
"inputValue": inputValue,
}
return "fakeOutputValue", errors.New("fakeValueErrMsg")
},
}, },
OpInfo: OpInfo{ OpInfo: OpInfo{
Method: "fakeMethod", Method: "fakeMethod",
@ -46,9 +36,30 @@ func TestAttrWrapper(t *testing.T) {
Method: "fakeMethod", Method: "fakeMethod",
DriverName: "fakeDriverName", DriverName: "fakeDriverName",
}, },
"attrPtr": "fakeAttr", "attrPtr": "fakeAttrPtr",
"dbValue": "fakeDbValue", "dbValue": "fakeDbValue",
}) })
}
func TestAttrWrapper(t *testing.T) {
ctx := context.Background()
var valueArgs map[string]interface{}
wrapper := AttrValueWrapper{
Ctx: ctx,
Attr: "fakeAttr",
ValueFn: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
valueArgs = map[string]interface{}{
"opInfo": opInfo,
"inputValue": inputValue,
}
return "fakeOutputValue", errors.New("fakeValueErrMsg")
},
OpInfo: OpInfo{
Method: "fakeMethod",
DriverName: "fakeDriverName",
},
}
value, err := wrapper.Value() value, err := wrapper.Value()
tt.AssertErrContains(t, err, "fakeValueErrMsg") tt.AssertErrContains(t, err, "fakeValueErrMsg")

View File

@ -1,13 +1,30 @@
package modifiers package modifiers
import "context" import (
"context"
)
// AttrModifier describes the two operations required to serialize and deserialize an object from the database. // AttrModifier informs KSQL how to use this modifier
type AttrModifier interface { type AttrModifier struct {
AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error // The following attributes will tell KSQL to
AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) // leave this attribute out of insertions, updates,
// and queries respectively.
SkipOnInsert bool
SkipOnUpdate bool
SkipOnQuery bool
// Implement these functions if you want to override the default Scan/Value behavior
// for the target attribute.
Scan AttrScanner
Value AttrValuer
} }
// AttrScanner describes the operation of deserializing an object received from the database.
type AttrScanner func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error
// AttrValuer describes the operation of serializing an object when saving it to the database.
type AttrValuer func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error)
// OpInfo contains information that might be used by a modifier to determine how it should behave. // OpInfo contains information that might be used by a modifier to determine how it should behave.
type OpInfo struct { type OpInfo struct {
// A string version of the name of one of // A string version of the name of one of

View File

@ -10,7 +10,7 @@ var modifiers sync.Map
func init() { func init() {
// These are the builtin modifiers // These are the builtin modifiers
modifiers.Store("json", jsonModifier{}) modifiers.Store("json", jsonModifier)
} }
// RegisterAttrModifier allow users to add custom modifiers on startup // RegisterAttrModifier allow users to add custom modifiers on startup
@ -28,7 +28,7 @@ func LoadGlobalModifier(key string) (AttrModifier, error) {
rawModifier, _ := modifiers.Load(key) rawModifier, _ := modifiers.Load(key)
modifier, ok := rawModifier.(AttrModifier) modifier, ok := rawModifier.(AttrModifier)
if !ok { if !ok {
return nil, fmt.Errorf("no modifier found with name '%s'", key) return AttrModifier{}, fmt.Errorf("no modifier found with name '%s'", key)
} }
return modifier, nil return modifier, nil

View File

@ -8,28 +8,36 @@ import (
func TestRegisterAttrModifier(t *testing.T) { func TestRegisterAttrModifier(t *testing.T) {
t.Run("should register new modifiers correctly", func(t *testing.T) { t.Run("should register new modifiers correctly", func(t *testing.T) {
modifier1 := AttrModifierMock{} modifier1 := AttrModifier{
modifier2 := AttrModifierMock{} SkipOnUpdate: true,
}
modifier2 := AttrModifier{
SkipOnInsert: true,
}
RegisterAttrModifier("fakeModifierName1", &modifier1) RegisterAttrModifier("fakeModifierName1", modifier1)
RegisterAttrModifier("fakeModifierName2", &modifier2) RegisterAttrModifier("fakeModifierName2", modifier2)
mod, err := LoadGlobalModifier("fakeModifierName1") mod, err := LoadGlobalModifier("fakeModifierName1")
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
tt.AssertEqual(t, mod, &modifier1) tt.AssertEqual(t, mod, modifier1)
mod, err = LoadGlobalModifier("fakeModifierName2") mod, err = LoadGlobalModifier("fakeModifierName2")
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
tt.AssertEqual(t, mod, &modifier2) tt.AssertEqual(t, mod, modifier2)
}) })
t.Run("should panic registering a modifier and the name already exists", func(t *testing.T) { t.Run("should panic registering a modifier and the name already exists", func(t *testing.T) {
modifier1 := AttrModifierMock{} modifier1 := AttrModifier{
modifier2 := AttrModifierMock{} SkipOnUpdate: true,
}
modifier2 := AttrModifier{
SkipOnInsert: true,
}
RegisterAttrModifier("fakeModifierName", &modifier1) RegisterAttrModifier("fakeModifierName", modifier1)
panicPayload := tt.PanicHandler(func() { panicPayload := tt.PanicHandler(func() {
RegisterAttrModifier("fakeModifierName", &modifier2) RegisterAttrModifier("fakeModifierName", modifier2)
}) })
err, ok := panicPayload.(error) err, ok := panicPayload.(error)
@ -40,6 +48,6 @@ func TestRegisterAttrModifier(t *testing.T) {
t.Run("should return an error when loading an inexistent modifier", func(t *testing.T) { t.Run("should return an error when loading an inexistent modifier", func(t *testing.T) {
mod, err := LoadGlobalModifier("nonExistentModifier") mod, err := LoadGlobalModifier("nonExistentModifier")
tt.AssertErrContains(t, err, "nonExistentModifier") tt.AssertErrContains(t, err, "nonExistentModifier")
tt.AssertEqual(t, mod, nil) tt.AssertEqual(t, mod, AttrModifier{})
}) })
} }

View File

@ -7,40 +7,36 @@ import (
"reflect" "reflect"
) )
// This type was created to make it easier to adapt // This modifier serializes objects as JSON when
// input attributes to be convertible to and from JSON // sending it to the database and decodes
// before sending or receiving it from the database. // them when receiving.
type jsonModifier struct{} var jsonModifier = AttrModifier{
Scan: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error {
if dbValue == nil {
v := reflect.ValueOf(attrPtr)
// Set the struct to its 0 value just like json.Unmarshal
// does for nil attributes:
v.Elem().Set(reflect.Zero(reflect.TypeOf(attrPtr).Elem()))
return nil
}
// Scan Implements the Scanner interface in order to load // Required since sqlite3 returns strings not bytes
// this field from the JSON stored in the database if v, ok := dbValue.(string); ok {
func (j jsonModifier) AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error { dbValue = []byte(v)
if dbValue == nil { }
v := reflect.ValueOf(attrPtr)
// Set the struct to its 0 value just like json.Unmarshal
// does for nil attributes:
v.Elem().Set(reflect.Zero(reflect.TypeOf(attrPtr).Elem()))
return nil
}
// Required since sqlite3 returns strings not bytes rawJSON, ok := dbValue.([]byte)
if v, ok := dbValue.(string); ok { if !ok {
dbValue = []byte(v) return fmt.Errorf("unexpected type received to Scan: %T", dbValue)
} }
return json.Unmarshal(rawJSON, attrPtr)
},
rawJSON, ok := dbValue.([]byte) Value: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
if !ok { b, err := json.Marshal(inputValue)
return fmt.Errorf("unexpected type received to Scan: %T", dbValue) if opInfo.DriverName == "sqlserver" {
} return string(b), err
return json.Unmarshal(rawJSON, attrPtr) }
} return b, err
},
// Value Implements the Valuer interface in order to save
// this field as JSON on the database.
func (j jsonModifier) AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
b, err := json.Marshal(inputValue)
if opInfo.DriverName == "sqlserver" {
return string(b), err
}
return b, err
} }

View File

@ -51,7 +51,7 @@ func TestAttrScan(t *testing.T) {
fakeAttr := FakeAttr{ fakeAttr := FakeAttr{
Foo: "notZeroValue", Foo: "notZeroValue",
} }
err := jsonModifier{}.AttrScan(ctx, OpInfo{}, &fakeAttr, test.dbInput) err := jsonModifier.Scan(ctx, OpInfo{}, &fakeAttr, test.dbInput)
if test.expectErrToContain != nil { if test.expectErrToContain != nil {
tt.AssertErrContains(t, err, test.expectErrToContain...) tt.AssertErrContains(t, err, test.expectErrToContain...)
t.Skip() t.Skip()
@ -109,7 +109,7 @@ func TestAttrValue(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
output, err := jsonModifier{}.AttrValue(ctx, test.opInfoInput, test.attrValue) output, err := jsonModifier.Value(ctx, test.opInfoInput, test.attrValue)
if test.expectErrToContain != nil { if test.expectErrToContain != nil {
tt.AssertErrContains(t, err, test.expectErrToContain...) tt.AssertErrContains(t, err, test.expectErrToContain...)
t.Skip() t.Skip()

View File

@ -1,38 +0,0 @@
package modifiers
import "context"
// AttrModifierMock mocks the modifiers.AttrModifier interface
type AttrModifierMock struct {
AttrScanFn func(
ctx context.Context,
opInfo OpInfo,
attrPtr interface{},
dbValue interface{},
) error
AttrValueFn func(
ctx context.Context,
opInfo OpInfo,
inputValue interface{},
) (outputValue interface{}, _ error)
}
// AttrScan mocks the AttrScan method
func (a AttrModifierMock) AttrScan(
ctx context.Context,
opInfo OpInfo,
attrPtr interface{},
dbValue interface{},
) error {
return a.AttrScanFn(ctx, opInfo, attrPtr, dbValue)
}
// AttrValue mocks the AttrValue method
func (a AttrModifierMock) AttrValue(
ctx context.Context,
opInfo OpInfo,
inputValue interface{},
) (outputValue interface{}, _ error) {
return a.AttrValueFn(ctx, opInfo, inputValue)
}

44
ksql.go
View File

@ -719,12 +719,12 @@ func buildInsertQuery(
recordValue := recordMap[col] recordValue := recordMap[col]
params[i] = recordValue params[i] = recordValue
modifier := info.ByName(col).Modifier valueFn := info.ByName(col).Modifier.Value
if modifier != nil { if valueFn != nil {
params[i] = modifiers.AttrWrapper{ params[i] = modifiers.AttrValueWrapper{
Ctx: ctx, Ctx: ctx,
Attr: recordValue, Attr: recordValue,
Modifier: modifier, ValueFn: valueFn,
OpInfo: modifiers.OpInfo{ OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(), DriverName: dialect.DriverName(),
Method: "Insert", Method: "Insert",
@ -828,12 +828,12 @@ func buildUpdateQuery(
for i, k := range keys { for i, k := range keys {
recordValue := recordMap[k] recordValue := recordMap[k]
modifier := info.ByName(k).Modifier valueFn := info.ByName(k).Modifier.Value
if modifier != nil { if valueFn != nil {
recordValue = modifiers.AttrWrapper{ recordValue = modifiers.AttrValueWrapper{
Ctx: ctx, Ctx: ctx,
Attr: recordValue, Attr: recordValue,
Modifier: modifier, ValueFn: valueFn,
OpInfo: modifiers.OpInfo{ OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(), DriverName: dialect.DriverName(),
Method: "Update", Method: "Update",
@ -1033,11 +1033,11 @@ func getScanArgsForNestedStructs(
if fieldInfo.Valid { if fieldInfo.Valid {
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface() valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.Modifier != nil { if fieldInfo.Modifier.Scan != nil {
valueScanner = &modifiers.AttrWrapper{ valueScanner = &modifiers.AttrScanWrapper{
Ctx: ctx, Ctx: ctx,
Attr: valueScanner, AttrPtr: valueScanner,
Modifier: fieldInfo.Modifier, ScanFn: fieldInfo.Modifier.Scan,
OpInfo: modifiers.OpInfo{ OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(), DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks // We will not differentiate between Query, QueryOne and QueryChunks
@ -1063,11 +1063,11 @@ func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string,
valueScanner := nopScannerValue valueScanner := nopScannerValue
if fieldInfo.Valid { if fieldInfo.Valid {
valueScanner = v.Field(fieldInfo.Index).Addr().Interface() valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.Modifier != nil { if fieldInfo.Modifier.Scan != nil {
valueScanner = &modifiers.AttrWrapper{ valueScanner = &modifiers.AttrScanWrapper{
Ctx: ctx, Ctx: ctx,
Attr: valueScanner, AttrPtr: valueScanner,
Modifier: fieldInfo.Modifier, ScanFn: fieldInfo.Modifier.Scan,
OpInfo: modifiers.OpInfo{ OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(), DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks // We will not differentiate between Query, QueryOne and QueryChunks

View File

@ -2802,10 +2802,10 @@ func getUserByID(db DBAdapter, dialect Dialect, result *user, id uint) error {
modifier, _ := modifiers.LoadGlobalModifier("json") modifier, _ := modifiers.LoadGlobalModifier("json")
value := modifiers.AttrWrapper{ value := modifiers.AttrScanWrapper{
Ctx: context.TODO(), Ctx: context.TODO(),
Attr: &result.Address, AttrPtr: &result.Address,
Modifier: modifier, ScanFn: modifier.Scan,
OpInfo: modifiers.OpInfo{ OpInfo: modifiers.OpInfo{
DriverName: dialect.DriverName(), DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks // We will not differentiate between Query, QueryOne and QueryChunks