mirror of https://github.com/VinGarcia/ksql.git
Refactor modifiers so it is represented by a struct
parent
7661ba0314
commit
8cba3efa2d
|
@ -5,28 +5,40 @@ import (
|
|||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// AttrWrapper is the wrapper that allow us to intercept the Scan and Value processes
|
||||
// AttrScanWrapper is the wrapper that allow us to intercept the Scan process
|
||||
// so we can run the modifiers instead of allowing the database driver to use
|
||||
// its default behavior.
|
||||
//
|
||||
// For that this struct implements both the `sql.Scanner` and `sql.Valuer` interfaces.
|
||||
type AttrWrapper struct {
|
||||
// For that this struct implements the `sql.Scanner` interface
|
||||
type AttrScanWrapper struct {
|
||||
Ctx context.Context
|
||||
|
||||
// When Scanning this value should be a pointer to the attribute
|
||||
// and when "Valuing" it should just be the actual value
|
||||
Attr interface{}
|
||||
AttrPtr interface{}
|
||||
|
||||
Modifier AttrModifier
|
||||
OpInfo OpInfo
|
||||
ScanFn AttrScanner
|
||||
OpInfo OpInfo
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface
|
||||
func (a AttrWrapper) Scan(dbValue interface{}) error {
|
||||
return a.Modifier.AttrScan(a.Ctx, a.OpInfo, a.Attr, dbValue)
|
||||
func (a AttrScanWrapper) Scan(dbValue interface{}) error {
|
||||
return a.ScanFn(a.Ctx, a.OpInfo, a.AttrPtr, dbValue)
|
||||
}
|
||||
|
||||
// AttrValueWrapper is the wrapper that allow us to intercept the "Valuing" process
|
||||
// so we can run the modifiers instead of allowing the database driver to use
|
||||
// its default behavior.
|
||||
//
|
||||
// For that this struct implements the `sql.Valuer` interface
|
||||
type AttrValueWrapper struct {
|
||||
Ctx context.Context
|
||||
|
||||
Attr interface{}
|
||||
|
||||
ValueFn AttrValuer
|
||||
OpInfo OpInfo
|
||||
}
|
||||
|
||||
// Value implements the sql.Valuer interface
|
||||
func (a AttrWrapper) Value() (driver.Value, error) {
|
||||
return a.Modifier.AttrValue(a.Ctx, a.OpInfo, a.Attr)
|
||||
func (a AttrValueWrapper) Value() (driver.Value, error) {
|
||||
return a.ValueFn(a.Ctx, a.OpInfo, a.Attr)
|
||||
}
|
||||
|
|
|
@ -8,30 +8,20 @@ import (
|
|||
tt "github.com/vingarcia/ksql/internal/testtools"
|
||||
)
|
||||
|
||||
func TestAttrWrapper(t *testing.T) {
|
||||
func TestAttrScanWrapper(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
var scanArgs map[string]interface{}
|
||||
var valueArgs map[string]interface{}
|
||||
wrapper := AttrWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: "fakeAttr",
|
||||
Modifier: AttrModifierMock{
|
||||
AttrScanFn: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error {
|
||||
scanArgs = map[string]interface{}{
|
||||
"opInfo": opInfo,
|
||||
"attrPtr": attrPtr,
|
||||
"dbValue": dbValue,
|
||||
}
|
||||
return errors.New("fakeScanErrMsg")
|
||||
},
|
||||
AttrValueFn: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
|
||||
valueArgs = map[string]interface{}{
|
||||
"opInfo": opInfo,
|
||||
"inputValue": inputValue,
|
||||
}
|
||||
return "fakeOutputValue", errors.New("fakeValueErrMsg")
|
||||
},
|
||||
wrapper := AttrScanWrapper{
|
||||
Ctx: ctx,
|
||||
AttrPtr: "fakeAttrPtr",
|
||||
ScanFn: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error {
|
||||
scanArgs = map[string]interface{}{
|
||||
"opInfo": opInfo,
|
||||
"attrPtr": attrPtr,
|
||||
"dbValue": dbValue,
|
||||
}
|
||||
return errors.New("fakeScanErrMsg")
|
||||
},
|
||||
OpInfo: OpInfo{
|
||||
Method: "fakeMethod",
|
||||
|
@ -46,9 +36,30 @@ func TestAttrWrapper(t *testing.T) {
|
|||
Method: "fakeMethod",
|
||||
DriverName: "fakeDriverName",
|
||||
},
|
||||
"attrPtr": "fakeAttr",
|
||||
"attrPtr": "fakeAttrPtr",
|
||||
"dbValue": "fakeDbValue",
|
||||
})
|
||||
}
|
||||
|
||||
func TestAttrWrapper(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
var valueArgs map[string]interface{}
|
||||
wrapper := AttrValueWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: "fakeAttr",
|
||||
ValueFn: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
|
||||
valueArgs = map[string]interface{}{
|
||||
"opInfo": opInfo,
|
||||
"inputValue": inputValue,
|
||||
}
|
||||
return "fakeOutputValue", errors.New("fakeValueErrMsg")
|
||||
},
|
||||
OpInfo: OpInfo{
|
||||
Method: "fakeMethod",
|
||||
DriverName: "fakeDriverName",
|
||||
},
|
||||
}
|
||||
|
||||
value, err := wrapper.Value()
|
||||
tt.AssertErrContains(t, err, "fakeValueErrMsg")
|
||||
|
|
|
@ -1,13 +1,30 @@
|
|||
package modifiers
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// AttrModifier describes the two operations required to serialize and deserialize an object from the database.
|
||||
type AttrModifier interface {
|
||||
AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error
|
||||
AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error)
|
||||
// AttrModifier informs KSQL how to use this modifier
|
||||
type AttrModifier struct {
|
||||
// The following attributes will tell KSQL to
|
||||
// leave this attribute out of insertions, updates,
|
||||
// and queries respectively.
|
||||
SkipOnInsert bool
|
||||
SkipOnUpdate bool
|
||||
SkipOnQuery bool
|
||||
|
||||
// Implement these functions if you want to override the default Scan/Value behavior
|
||||
// for the target attribute.
|
||||
Scan AttrScanner
|
||||
Value AttrValuer
|
||||
}
|
||||
|
||||
// AttrScanner describes the operation of deserializing an object received from the database.
|
||||
type AttrScanner func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error
|
||||
|
||||
// AttrValuer describes the operation of serializing an object when saving it to the database.
|
||||
type AttrValuer func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error)
|
||||
|
||||
// OpInfo contains information that might be used by a modifier to determine how it should behave.
|
||||
type OpInfo struct {
|
||||
// A string version of the name of one of
|
||||
|
|
|
@ -10,7 +10,7 @@ var modifiers sync.Map
|
|||
|
||||
func init() {
|
||||
// These are the builtin modifiers
|
||||
modifiers.Store("json", jsonModifier{})
|
||||
modifiers.Store("json", jsonModifier)
|
||||
}
|
||||
|
||||
// RegisterAttrModifier allow users to add custom modifiers on startup
|
||||
|
@ -28,7 +28,7 @@ func LoadGlobalModifier(key string) (AttrModifier, error) {
|
|||
rawModifier, _ := modifiers.Load(key)
|
||||
modifier, ok := rawModifier.(AttrModifier)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no modifier found with name '%s'", key)
|
||||
return AttrModifier{}, fmt.Errorf("no modifier found with name '%s'", key)
|
||||
}
|
||||
|
||||
return modifier, nil
|
||||
|
|
|
@ -8,28 +8,36 @@ import (
|
|||
|
||||
func TestRegisterAttrModifier(t *testing.T) {
|
||||
t.Run("should register new modifiers correctly", func(t *testing.T) {
|
||||
modifier1 := AttrModifierMock{}
|
||||
modifier2 := AttrModifierMock{}
|
||||
modifier1 := AttrModifier{
|
||||
SkipOnUpdate: true,
|
||||
}
|
||||
modifier2 := AttrModifier{
|
||||
SkipOnInsert: true,
|
||||
}
|
||||
|
||||
RegisterAttrModifier("fakeModifierName1", &modifier1)
|
||||
RegisterAttrModifier("fakeModifierName2", &modifier2)
|
||||
RegisterAttrModifier("fakeModifierName1", modifier1)
|
||||
RegisterAttrModifier("fakeModifierName2", modifier2)
|
||||
|
||||
mod, err := LoadGlobalModifier("fakeModifierName1")
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, mod, &modifier1)
|
||||
tt.AssertEqual(t, mod, modifier1)
|
||||
|
||||
mod, err = LoadGlobalModifier("fakeModifierName2")
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, mod, &modifier2)
|
||||
tt.AssertEqual(t, mod, modifier2)
|
||||
})
|
||||
|
||||
t.Run("should panic registering a modifier and the name already exists", func(t *testing.T) {
|
||||
modifier1 := AttrModifierMock{}
|
||||
modifier2 := AttrModifierMock{}
|
||||
modifier1 := AttrModifier{
|
||||
SkipOnUpdate: true,
|
||||
}
|
||||
modifier2 := AttrModifier{
|
||||
SkipOnInsert: true,
|
||||
}
|
||||
|
||||
RegisterAttrModifier("fakeModifierName", &modifier1)
|
||||
RegisterAttrModifier("fakeModifierName", modifier1)
|
||||
panicPayload := tt.PanicHandler(func() {
|
||||
RegisterAttrModifier("fakeModifierName", &modifier2)
|
||||
RegisterAttrModifier("fakeModifierName", modifier2)
|
||||
})
|
||||
|
||||
err, ok := panicPayload.(error)
|
||||
|
@ -40,6 +48,6 @@ func TestRegisterAttrModifier(t *testing.T) {
|
|||
t.Run("should return an error when loading an inexistent modifier", func(t *testing.T) {
|
||||
mod, err := LoadGlobalModifier("nonExistentModifier")
|
||||
tt.AssertErrContains(t, err, "nonExistentModifier")
|
||||
tt.AssertEqual(t, mod, nil)
|
||||
tt.AssertEqual(t, mod, AttrModifier{})
|
||||
})
|
||||
}
|
||||
|
|
|
@ -7,40 +7,36 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
// This type was created to make it easier to adapt
|
||||
// input attributes to be convertible to and from JSON
|
||||
// before sending or receiving it from the database.
|
||||
type jsonModifier struct{}
|
||||
// This modifier serializes objects as JSON when
|
||||
// sending it to the database and decodes
|
||||
// them when receiving.
|
||||
var jsonModifier = AttrModifier{
|
||||
Scan: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error {
|
||||
if dbValue == nil {
|
||||
v := reflect.ValueOf(attrPtr)
|
||||
// Set the struct to its 0 value just like json.Unmarshal
|
||||
// does for nil attributes:
|
||||
v.Elem().Set(reflect.Zero(reflect.TypeOf(attrPtr).Elem()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan Implements the Scanner interface in order to load
|
||||
// this field from the JSON stored in the database
|
||||
func (j jsonModifier) AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error {
|
||||
if dbValue == nil {
|
||||
v := reflect.ValueOf(attrPtr)
|
||||
// Set the struct to its 0 value just like json.Unmarshal
|
||||
// does for nil attributes:
|
||||
v.Elem().Set(reflect.Zero(reflect.TypeOf(attrPtr).Elem()))
|
||||
return nil
|
||||
}
|
||||
// Required since sqlite3 returns strings not bytes
|
||||
if v, ok := dbValue.(string); ok {
|
||||
dbValue = []byte(v)
|
||||
}
|
||||
|
||||
// Required since sqlite3 returns strings not bytes
|
||||
if v, ok := dbValue.(string); ok {
|
||||
dbValue = []byte(v)
|
||||
}
|
||||
rawJSON, ok := dbValue.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type received to Scan: %T", dbValue)
|
||||
}
|
||||
return json.Unmarshal(rawJSON, attrPtr)
|
||||
},
|
||||
|
||||
rawJSON, ok := dbValue.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type received to Scan: %T", dbValue)
|
||||
}
|
||||
return json.Unmarshal(rawJSON, attrPtr)
|
||||
}
|
||||
|
||||
// Value Implements the Valuer interface in order to save
|
||||
// this field as JSON on the database.
|
||||
func (j jsonModifier) AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
|
||||
b, err := json.Marshal(inputValue)
|
||||
if opInfo.DriverName == "sqlserver" {
|
||||
return string(b), err
|
||||
}
|
||||
return b, err
|
||||
Value: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
|
||||
b, err := json.Marshal(inputValue)
|
||||
if opInfo.DriverName == "sqlserver" {
|
||||
return string(b), err
|
||||
}
|
||||
return b, err
|
||||
},
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ func TestAttrScan(t *testing.T) {
|
|||
fakeAttr := FakeAttr{
|
||||
Foo: "notZeroValue",
|
||||
}
|
||||
err := jsonModifier{}.AttrScan(ctx, OpInfo{}, &fakeAttr, test.dbInput)
|
||||
err := jsonModifier.Scan(ctx, OpInfo{}, &fakeAttr, test.dbInput)
|
||||
if test.expectErrToContain != nil {
|
||||
tt.AssertErrContains(t, err, test.expectErrToContain...)
|
||||
t.Skip()
|
||||
|
@ -109,7 +109,7 @@ func TestAttrValue(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
output, err := jsonModifier{}.AttrValue(ctx, test.opInfoInput, test.attrValue)
|
||||
output, err := jsonModifier.Value(ctx, test.opInfoInput, test.attrValue)
|
||||
if test.expectErrToContain != nil {
|
||||
tt.AssertErrContains(t, err, test.expectErrToContain...)
|
||||
t.Skip()
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
package modifiers
|
||||
|
||||
import "context"
|
||||
|
||||
// AttrModifierMock mocks the modifiers.AttrModifier interface
|
||||
type AttrModifierMock struct {
|
||||
AttrScanFn func(
|
||||
ctx context.Context,
|
||||
opInfo OpInfo,
|
||||
attrPtr interface{},
|
||||
dbValue interface{},
|
||||
) error
|
||||
|
||||
AttrValueFn func(
|
||||
ctx context.Context,
|
||||
opInfo OpInfo,
|
||||
inputValue interface{},
|
||||
) (outputValue interface{}, _ error)
|
||||
}
|
||||
|
||||
// AttrScan mocks the AttrScan method
|
||||
func (a AttrModifierMock) AttrScan(
|
||||
ctx context.Context,
|
||||
opInfo OpInfo,
|
||||
attrPtr interface{},
|
||||
dbValue interface{},
|
||||
) error {
|
||||
return a.AttrScanFn(ctx, opInfo, attrPtr, dbValue)
|
||||
}
|
||||
|
||||
// AttrValue mocks the AttrValue method
|
||||
func (a AttrModifierMock) AttrValue(
|
||||
ctx context.Context,
|
||||
opInfo OpInfo,
|
||||
inputValue interface{},
|
||||
) (outputValue interface{}, _ error) {
|
||||
return a.AttrValueFn(ctx, opInfo, inputValue)
|
||||
}
|
44
ksql.go
44
ksql.go
|
@ -719,12 +719,12 @@ func buildInsertQuery(
|
|||
recordValue := recordMap[col]
|
||||
params[i] = recordValue
|
||||
|
||||
modifier := info.ByName(col).Modifier
|
||||
if modifier != nil {
|
||||
params[i] = modifiers.AttrWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: recordValue,
|
||||
Modifier: modifier,
|
||||
valueFn := info.ByName(col).Modifier.Value
|
||||
if valueFn != nil {
|
||||
params[i] = modifiers.AttrValueWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: recordValue,
|
||||
ValueFn: valueFn,
|
||||
OpInfo: modifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
Method: "Insert",
|
||||
|
@ -828,12 +828,12 @@ func buildUpdateQuery(
|
|||
for i, k := range keys {
|
||||
recordValue := recordMap[k]
|
||||
|
||||
modifier := info.ByName(k).Modifier
|
||||
if modifier != nil {
|
||||
recordValue = modifiers.AttrWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: recordValue,
|
||||
Modifier: modifier,
|
||||
valueFn := info.ByName(k).Modifier.Value
|
||||
if valueFn != nil {
|
||||
recordValue = modifiers.AttrValueWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: recordValue,
|
||||
ValueFn: valueFn,
|
||||
OpInfo: modifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
Method: "Update",
|
||||
|
@ -1033,11 +1033,11 @@ func getScanArgsForNestedStructs(
|
|||
if fieldInfo.Valid {
|
||||
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
|
||||
|
||||
if fieldInfo.Modifier != nil {
|
||||
valueScanner = &modifiers.AttrWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: valueScanner,
|
||||
Modifier: fieldInfo.Modifier,
|
||||
if fieldInfo.Modifier.Scan != nil {
|
||||
valueScanner = &modifiers.AttrScanWrapper{
|
||||
Ctx: ctx,
|
||||
AttrPtr: valueScanner,
|
||||
ScanFn: fieldInfo.Modifier.Scan,
|
||||
OpInfo: modifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
// We will not differentiate between Query, QueryOne and QueryChunks
|
||||
|
@ -1063,11 +1063,11 @@ func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string,
|
|||
valueScanner := nopScannerValue
|
||||
if fieldInfo.Valid {
|
||||
valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
|
||||
if fieldInfo.Modifier != nil {
|
||||
valueScanner = &modifiers.AttrWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: valueScanner,
|
||||
Modifier: fieldInfo.Modifier,
|
||||
if fieldInfo.Modifier.Scan != nil {
|
||||
valueScanner = &modifiers.AttrScanWrapper{
|
||||
Ctx: ctx,
|
||||
AttrPtr: valueScanner,
|
||||
ScanFn: fieldInfo.Modifier.Scan,
|
||||
OpInfo: modifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
// We will not differentiate between Query, QueryOne and QueryChunks
|
||||
|
|
|
@ -2802,10 +2802,10 @@ func getUserByID(db DBAdapter, dialect Dialect, result *user, id uint) error {
|
|||
|
||||
modifier, _ := modifiers.LoadGlobalModifier("json")
|
||||
|
||||
value := modifiers.AttrWrapper{
|
||||
Ctx: context.TODO(),
|
||||
Attr: &result.Address,
|
||||
Modifier: modifier,
|
||||
value := modifiers.AttrScanWrapper{
|
||||
Ctx: context.TODO(),
|
||||
AttrPtr: &result.Address,
|
||||
ScanFn: modifier.Scan,
|
||||
OpInfo: modifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
// We will not differentiate between Query, QueryOne and QueryChunks
|
||||
|
|
Loading…
Reference in New Issue