From 8cba3efa2d6226fb45a3915669c63c3fd849e670 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Mon, 26 Sep 2022 01:37:23 -0300 Subject: [PATCH] Refactor modifiers so it is represented by a struct --- internal/modifiers/attr_wrapper.go | 36 ++++++++---- internal/modifiers/attr_wrapper_test.go | 55 ++++++++++-------- internal/modifiers/contract.go | 27 +++++++-- internal/modifiers/global_modifiers.go | 4 +- internal/modifiers/global_modifiers_test.go | 30 ++++++---- internal/modifiers/json_modifier.go | 62 ++++++++++----------- internal/modifiers/json_modifier_test.go | 4 +- internal/modifiers/mocks.go | 38 ------------- ksql.go | 44 +++++++-------- test_adapters.go | 8 +-- 10 files changed, 157 insertions(+), 151 deletions(-) delete mode 100644 internal/modifiers/mocks.go diff --git a/internal/modifiers/attr_wrapper.go b/internal/modifiers/attr_wrapper.go index 8800f58..b17f4b9 100644 --- a/internal/modifiers/attr_wrapper.go +++ b/internal/modifiers/attr_wrapper.go @@ -5,28 +5,40 @@ import ( "database/sql/driver" ) -// AttrWrapper is the wrapper that allow us to intercept the Scan and Value processes +// AttrScanWrapper is the wrapper that allow us to intercept the Scan process // so we can run the modifiers instead of allowing the database driver to use // its default behavior. // -// For that this struct implements both the `sql.Scanner` and `sql.Valuer` interfaces. -type AttrWrapper struct { +// For that this struct implements the `sql.Scanner` interface +type AttrScanWrapper struct { Ctx context.Context - // When Scanning this value should be a pointer to the attribute - // and when "Valuing" it should just be the actual value - Attr interface{} + AttrPtr interface{} - Modifier AttrModifier - OpInfo OpInfo + ScanFn AttrScanner + OpInfo OpInfo } // Scan implements the sql.Scanner interface -func (a AttrWrapper) Scan(dbValue interface{}) error { - return a.Modifier.AttrScan(a.Ctx, a.OpInfo, a.Attr, dbValue) +func (a AttrScanWrapper) Scan(dbValue interface{}) error { + return a.ScanFn(a.Ctx, a.OpInfo, a.AttrPtr, dbValue) +} + +// AttrValueWrapper is the wrapper that allow us to intercept the "Valuing" process +// so we can run the modifiers instead of allowing the database driver to use +// its default behavior. +// +// For that this struct implements the `sql.Valuer` interface +type AttrValueWrapper struct { + Ctx context.Context + + Attr interface{} + + ValueFn AttrValuer + OpInfo OpInfo } // Value implements the sql.Valuer interface -func (a AttrWrapper) Value() (driver.Value, error) { - return a.Modifier.AttrValue(a.Ctx, a.OpInfo, a.Attr) +func (a AttrValueWrapper) Value() (driver.Value, error) { + return a.ValueFn(a.Ctx, a.OpInfo, a.Attr) } diff --git a/internal/modifiers/attr_wrapper_test.go b/internal/modifiers/attr_wrapper_test.go index bb9dfd1..a10d695 100644 --- a/internal/modifiers/attr_wrapper_test.go +++ b/internal/modifiers/attr_wrapper_test.go @@ -8,30 +8,20 @@ import ( tt "github.com/vingarcia/ksql/internal/testtools" ) -func TestAttrWrapper(t *testing.T) { +func TestAttrScanWrapper(t *testing.T) { ctx := context.Background() var scanArgs map[string]interface{} - var valueArgs map[string]interface{} - wrapper := AttrWrapper{ - Ctx: ctx, - Attr: "fakeAttr", - Modifier: AttrModifierMock{ - AttrScanFn: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error { - scanArgs = map[string]interface{}{ - "opInfo": opInfo, - "attrPtr": attrPtr, - "dbValue": dbValue, - } - return errors.New("fakeScanErrMsg") - }, - AttrValueFn: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { - valueArgs = map[string]interface{}{ - "opInfo": opInfo, - "inputValue": inputValue, - } - return "fakeOutputValue", errors.New("fakeValueErrMsg") - }, + wrapper := AttrScanWrapper{ + Ctx: ctx, + AttrPtr: "fakeAttrPtr", + ScanFn: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error { + scanArgs = map[string]interface{}{ + "opInfo": opInfo, + "attrPtr": attrPtr, + "dbValue": dbValue, + } + return errors.New("fakeScanErrMsg") }, OpInfo: OpInfo{ Method: "fakeMethod", @@ -46,9 +36,30 @@ func TestAttrWrapper(t *testing.T) { Method: "fakeMethod", DriverName: "fakeDriverName", }, - "attrPtr": "fakeAttr", + "attrPtr": "fakeAttrPtr", "dbValue": "fakeDbValue", }) +} + +func TestAttrWrapper(t *testing.T) { + ctx := context.Background() + + var valueArgs map[string]interface{} + wrapper := AttrValueWrapper{ + Ctx: ctx, + Attr: "fakeAttr", + ValueFn: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { + valueArgs = map[string]interface{}{ + "opInfo": opInfo, + "inputValue": inputValue, + } + return "fakeOutputValue", errors.New("fakeValueErrMsg") + }, + OpInfo: OpInfo{ + Method: "fakeMethod", + DriverName: "fakeDriverName", + }, + } value, err := wrapper.Value() tt.AssertErrContains(t, err, "fakeValueErrMsg") diff --git a/internal/modifiers/contract.go b/internal/modifiers/contract.go index 60243c4..69356e6 100644 --- a/internal/modifiers/contract.go +++ b/internal/modifiers/contract.go @@ -1,13 +1,30 @@ package modifiers -import "context" +import ( + "context" +) -// AttrModifier describes the two operations required to serialize and deserialize an object from the database. -type AttrModifier interface { - AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error - AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) +// AttrModifier informs KSQL how to use this modifier +type AttrModifier struct { + // The following attributes will tell KSQL to + // leave this attribute out of insertions, updates, + // and queries respectively. + SkipOnInsert bool + SkipOnUpdate bool + SkipOnQuery bool + + // Implement these functions if you want to override the default Scan/Value behavior + // for the target attribute. + Scan AttrScanner + Value AttrValuer } +// AttrScanner describes the operation of deserializing an object received from the database. +type AttrScanner func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error + +// AttrValuer describes the operation of serializing an object when saving it to the database. +type AttrValuer func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) + // OpInfo contains information that might be used by a modifier to determine how it should behave. type OpInfo struct { // A string version of the name of one of diff --git a/internal/modifiers/global_modifiers.go b/internal/modifiers/global_modifiers.go index 73e6257..f904a1b 100644 --- a/internal/modifiers/global_modifiers.go +++ b/internal/modifiers/global_modifiers.go @@ -10,7 +10,7 @@ var modifiers sync.Map func init() { // These are the builtin modifiers - modifiers.Store("json", jsonModifier{}) + modifiers.Store("json", jsonModifier) } // RegisterAttrModifier allow users to add custom modifiers on startup @@ -28,7 +28,7 @@ func LoadGlobalModifier(key string) (AttrModifier, error) { rawModifier, _ := modifiers.Load(key) modifier, ok := rawModifier.(AttrModifier) if !ok { - return nil, fmt.Errorf("no modifier found with name '%s'", key) + return AttrModifier{}, fmt.Errorf("no modifier found with name '%s'", key) } return modifier, nil diff --git a/internal/modifiers/global_modifiers_test.go b/internal/modifiers/global_modifiers_test.go index 4607cbc..b1d3b1d 100644 --- a/internal/modifiers/global_modifiers_test.go +++ b/internal/modifiers/global_modifiers_test.go @@ -8,28 +8,36 @@ import ( func TestRegisterAttrModifier(t *testing.T) { t.Run("should register new modifiers correctly", func(t *testing.T) { - modifier1 := AttrModifierMock{} - modifier2 := AttrModifierMock{} + modifier1 := AttrModifier{ + SkipOnUpdate: true, + } + modifier2 := AttrModifier{ + SkipOnInsert: true, + } - RegisterAttrModifier("fakeModifierName1", &modifier1) - RegisterAttrModifier("fakeModifierName2", &modifier2) + RegisterAttrModifier("fakeModifierName1", modifier1) + RegisterAttrModifier("fakeModifierName2", modifier2) mod, err := LoadGlobalModifier("fakeModifierName1") tt.AssertNoErr(t, err) - tt.AssertEqual(t, mod, &modifier1) + tt.AssertEqual(t, mod, modifier1) mod, err = LoadGlobalModifier("fakeModifierName2") tt.AssertNoErr(t, err) - tt.AssertEqual(t, mod, &modifier2) + tt.AssertEqual(t, mod, modifier2) }) t.Run("should panic registering a modifier and the name already exists", func(t *testing.T) { - modifier1 := AttrModifierMock{} - modifier2 := AttrModifierMock{} + modifier1 := AttrModifier{ + SkipOnUpdate: true, + } + modifier2 := AttrModifier{ + SkipOnInsert: true, + } - RegisterAttrModifier("fakeModifierName", &modifier1) + RegisterAttrModifier("fakeModifierName", modifier1) panicPayload := tt.PanicHandler(func() { - RegisterAttrModifier("fakeModifierName", &modifier2) + RegisterAttrModifier("fakeModifierName", modifier2) }) err, ok := panicPayload.(error) @@ -40,6 +48,6 @@ func TestRegisterAttrModifier(t *testing.T) { t.Run("should return an error when loading an inexistent modifier", func(t *testing.T) { mod, err := LoadGlobalModifier("nonExistentModifier") tt.AssertErrContains(t, err, "nonExistentModifier") - tt.AssertEqual(t, mod, nil) + tt.AssertEqual(t, mod, AttrModifier{}) }) } diff --git a/internal/modifiers/json_modifier.go b/internal/modifiers/json_modifier.go index 2731d3a..7e5cb61 100644 --- a/internal/modifiers/json_modifier.go +++ b/internal/modifiers/json_modifier.go @@ -7,40 +7,36 @@ import ( "reflect" ) -// This type was created to make it easier to adapt -// input attributes to be convertible to and from JSON -// before sending or receiving it from the database. -type jsonModifier struct{} +// This modifier serializes objects as JSON when +// sending it to the database and decodes +// them when receiving. +var jsonModifier = AttrModifier{ + Scan: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error { + if dbValue == nil { + v := reflect.ValueOf(attrPtr) + // Set the struct to its 0 value just like json.Unmarshal + // does for nil attributes: + v.Elem().Set(reflect.Zero(reflect.TypeOf(attrPtr).Elem())) + return nil + } -// Scan Implements the Scanner interface in order to load -// this field from the JSON stored in the database -func (j jsonModifier) AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error { - if dbValue == nil { - v := reflect.ValueOf(attrPtr) - // Set the struct to its 0 value just like json.Unmarshal - // does for nil attributes: - v.Elem().Set(reflect.Zero(reflect.TypeOf(attrPtr).Elem())) - return nil - } + // Required since sqlite3 returns strings not bytes + if v, ok := dbValue.(string); ok { + dbValue = []byte(v) + } - // Required since sqlite3 returns strings not bytes - if v, ok := dbValue.(string); ok { - dbValue = []byte(v) - } + rawJSON, ok := dbValue.([]byte) + if !ok { + return fmt.Errorf("unexpected type received to Scan: %T", dbValue) + } + return json.Unmarshal(rawJSON, attrPtr) + }, - rawJSON, ok := dbValue.([]byte) - if !ok { - return fmt.Errorf("unexpected type received to Scan: %T", dbValue) - } - return json.Unmarshal(rawJSON, attrPtr) -} - -// Value Implements the Valuer interface in order to save -// this field as JSON on the database. -func (j jsonModifier) AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { - b, err := json.Marshal(inputValue) - if opInfo.DriverName == "sqlserver" { - return string(b), err - } - return b, err + Value: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { + b, err := json.Marshal(inputValue) + if opInfo.DriverName == "sqlserver" { + return string(b), err + } + return b, err + }, } diff --git a/internal/modifiers/json_modifier_test.go b/internal/modifiers/json_modifier_test.go index 968e468..335356e 100644 --- a/internal/modifiers/json_modifier_test.go +++ b/internal/modifiers/json_modifier_test.go @@ -51,7 +51,7 @@ func TestAttrScan(t *testing.T) { fakeAttr := FakeAttr{ Foo: "notZeroValue", } - err := jsonModifier{}.AttrScan(ctx, OpInfo{}, &fakeAttr, test.dbInput) + err := jsonModifier.Scan(ctx, OpInfo{}, &fakeAttr, test.dbInput) if test.expectErrToContain != nil { tt.AssertErrContains(t, err, test.expectErrToContain...) t.Skip() @@ -109,7 +109,7 @@ func TestAttrValue(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - output, err := jsonModifier{}.AttrValue(ctx, test.opInfoInput, test.attrValue) + output, err := jsonModifier.Value(ctx, test.opInfoInput, test.attrValue) if test.expectErrToContain != nil { tt.AssertErrContains(t, err, test.expectErrToContain...) t.Skip() diff --git a/internal/modifiers/mocks.go b/internal/modifiers/mocks.go deleted file mode 100644 index d023dc9..0000000 --- a/internal/modifiers/mocks.go +++ /dev/null @@ -1,38 +0,0 @@ -package modifiers - -import "context" - -// AttrModifierMock mocks the modifiers.AttrModifier interface -type AttrModifierMock struct { - AttrScanFn func( - ctx context.Context, - opInfo OpInfo, - attrPtr interface{}, - dbValue interface{}, - ) error - - AttrValueFn func( - ctx context.Context, - opInfo OpInfo, - inputValue interface{}, - ) (outputValue interface{}, _ error) -} - -// AttrScan mocks the AttrScan method -func (a AttrModifierMock) AttrScan( - ctx context.Context, - opInfo OpInfo, - attrPtr interface{}, - dbValue interface{}, -) error { - return a.AttrScanFn(ctx, opInfo, attrPtr, dbValue) -} - -// AttrValue mocks the AttrValue method -func (a AttrModifierMock) AttrValue( - ctx context.Context, - opInfo OpInfo, - inputValue interface{}, -) (outputValue interface{}, _ error) { - return a.AttrValueFn(ctx, opInfo, inputValue) -} diff --git a/ksql.go b/ksql.go index f75c4c8..22ba52f 100644 --- a/ksql.go +++ b/ksql.go @@ -719,12 +719,12 @@ func buildInsertQuery( recordValue := recordMap[col] params[i] = recordValue - modifier := info.ByName(col).Modifier - if modifier != nil { - params[i] = modifiers.AttrWrapper{ - Ctx: ctx, - Attr: recordValue, - Modifier: modifier, + valueFn := info.ByName(col).Modifier.Value + if valueFn != nil { + params[i] = modifiers.AttrValueWrapper{ + Ctx: ctx, + Attr: recordValue, + ValueFn: valueFn, OpInfo: modifiers.OpInfo{ DriverName: dialect.DriverName(), Method: "Insert", @@ -828,12 +828,12 @@ func buildUpdateQuery( for i, k := range keys { recordValue := recordMap[k] - modifier := info.ByName(k).Modifier - if modifier != nil { - recordValue = modifiers.AttrWrapper{ - Ctx: ctx, - Attr: recordValue, - Modifier: modifier, + valueFn := info.ByName(k).Modifier.Value + if valueFn != nil { + recordValue = modifiers.AttrValueWrapper{ + Ctx: ctx, + Attr: recordValue, + ValueFn: valueFn, OpInfo: modifiers.OpInfo{ DriverName: dialect.DriverName(), Method: "Update", @@ -1033,11 +1033,11 @@ func getScanArgsForNestedStructs( if fieldInfo.Valid { valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface() - if fieldInfo.Modifier != nil { - valueScanner = &modifiers.AttrWrapper{ - Ctx: ctx, - Attr: valueScanner, - Modifier: fieldInfo.Modifier, + if fieldInfo.Modifier.Scan != nil { + valueScanner = &modifiers.AttrScanWrapper{ + Ctx: ctx, + AttrPtr: valueScanner, + ScanFn: fieldInfo.Modifier.Scan, OpInfo: modifiers.OpInfo{ DriverName: dialect.DriverName(), // We will not differentiate between Query, QueryOne and QueryChunks @@ -1063,11 +1063,11 @@ func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, valueScanner := nopScannerValue if fieldInfo.Valid { valueScanner = v.Field(fieldInfo.Index).Addr().Interface() - if fieldInfo.Modifier != nil { - valueScanner = &modifiers.AttrWrapper{ - Ctx: ctx, - Attr: valueScanner, - Modifier: fieldInfo.Modifier, + if fieldInfo.Modifier.Scan != nil { + valueScanner = &modifiers.AttrScanWrapper{ + Ctx: ctx, + AttrPtr: valueScanner, + ScanFn: fieldInfo.Modifier.Scan, OpInfo: modifiers.OpInfo{ DriverName: dialect.DriverName(), // We will not differentiate between Query, QueryOne and QueryChunks diff --git a/test_adapters.go b/test_adapters.go index b3f4520..30474ad 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -2802,10 +2802,10 @@ func getUserByID(db DBAdapter, dialect Dialect, result *user, id uint) error { modifier, _ := modifiers.LoadGlobalModifier("json") - value := modifiers.AttrWrapper{ - Ctx: context.TODO(), - Attr: &result.Address, - Modifier: modifier, + value := modifiers.AttrScanWrapper{ + Ctx: context.TODO(), + AttrPtr: &result.Address, + ScanFn: modifier.Scan, OpInfo: modifiers.OpInfo{ DriverName: dialect.DriverName(), // We will not differentiate between Query, QueryOne and QueryChunks