diff --git a/internal/modifiers/attr_wrapper.go b/internal/modifiers/attr_wrapper.go index b17f4b9..a6d277b 100644 --- a/internal/modifiers/attr_wrapper.go +++ b/internal/modifiers/attr_wrapper.go @@ -3,6 +3,8 @@ package modifiers import ( "context" "database/sql/driver" + + "github.com/vingarcia/ksql/kmodifiers" ) // AttrScanWrapper is the wrapper that allow us to intercept the Scan process @@ -15,8 +17,8 @@ type AttrScanWrapper struct { AttrPtr interface{} - ScanFn AttrScanner - OpInfo OpInfo + ScanFn kmodifiers.AttrScanner + OpInfo kmodifiers.OpInfo } // Scan implements the sql.Scanner interface @@ -34,8 +36,8 @@ type AttrValueWrapper struct { Attr interface{} - ValueFn AttrValuer - OpInfo OpInfo + ValueFn kmodifiers.AttrValuer + OpInfo kmodifiers.OpInfo } // Value implements the sql.Valuer interface diff --git a/internal/modifiers/attr_wrapper_test.go b/internal/modifiers/attr_wrapper_test.go index a10d695..d712288 100644 --- a/internal/modifiers/attr_wrapper_test.go +++ b/internal/modifiers/attr_wrapper_test.go @@ -6,6 +6,7 @@ import ( "testing" tt "github.com/vingarcia/ksql/internal/testtools" + "github.com/vingarcia/ksql/kmodifiers" ) func TestAttrScanWrapper(t *testing.T) { @@ -15,7 +16,7 @@ func TestAttrScanWrapper(t *testing.T) { wrapper := AttrScanWrapper{ Ctx: ctx, AttrPtr: "fakeAttrPtr", - ScanFn: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error { + ScanFn: func(ctx context.Context, opInfo kmodifiers.OpInfo, attrPtr interface{}, dbValue interface{}) error { scanArgs = map[string]interface{}{ "opInfo": opInfo, "attrPtr": attrPtr, @@ -23,7 +24,7 @@ func TestAttrScanWrapper(t *testing.T) { } return errors.New("fakeScanErrMsg") }, - OpInfo: OpInfo{ + OpInfo: kmodifiers.OpInfo{ Method: "fakeMethod", DriverName: "fakeDriverName", }, @@ -32,7 +33,7 @@ func TestAttrScanWrapper(t *testing.T) { err := wrapper.Scan("fakeDbValue") tt.AssertErrContains(t, err, "fakeScanErrMsg") tt.AssertEqual(t, scanArgs, map[string]interface{}{ - "opInfo": OpInfo{ + "opInfo": kmodifiers.OpInfo{ Method: "fakeMethod", DriverName: "fakeDriverName", }, @@ -48,14 +49,14 @@ func TestAttrWrapper(t *testing.T) { wrapper := AttrValueWrapper{ Ctx: ctx, Attr: "fakeAttr", - ValueFn: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { + ValueFn: func(ctx context.Context, opInfo kmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { valueArgs = map[string]interface{}{ "opInfo": opInfo, "inputValue": inputValue, } return "fakeOutputValue", errors.New("fakeValueErrMsg") }, - OpInfo: OpInfo{ + OpInfo: kmodifiers.OpInfo{ Method: "fakeMethod", DriverName: "fakeDriverName", }, @@ -64,7 +65,7 @@ func TestAttrWrapper(t *testing.T) { value, err := wrapper.Value() tt.AssertErrContains(t, err, "fakeValueErrMsg") tt.AssertEqual(t, valueArgs, map[string]interface{}{ - "opInfo": OpInfo{ + "opInfo": kmodifiers.OpInfo{ Method: "fakeMethod", DriverName: "fakeDriverName", }, diff --git a/internal/modifiers/global_modifiers.go b/internal/modifiers/global_modifiers.go index 414f506..b77d986 100644 --- a/internal/modifiers/global_modifiers.go +++ b/internal/modifiers/global_modifiers.go @@ -3,12 +3,18 @@ package modifiers import ( "fmt" "sync" + + "github.com/vingarcia/ksql/kmodifiers" ) // Here we keep all the registered modifiers var modifiers sync.Map func init() { + // Here we expose the registration function in a public package, + // so users can use it: + kmodifiers.RegisterAttrModifier = RegisterAttrModifier + // These are the builtin modifiers: // This one is useful for serializing/desserializing structs: @@ -27,7 +33,7 @@ func init() { // RegisterAttrModifier allow users to add custom modifiers on startup // it is recommended to do this inside an init() function. -func RegisterAttrModifier(key string, modifier AttrModifier) { +func RegisterAttrModifier(key string, modifier kmodifiers.AttrModifier) { _, found := modifiers.Load(key) if found { panic(fmt.Errorf("KSQL: cannot register modifier '%s' name is already in use", key)) @@ -38,11 +44,11 @@ func RegisterAttrModifier(key string, modifier AttrModifier) { // LoadGlobalModifier is used internally by KSQL to load // modifiers during runtime. -func LoadGlobalModifier(key string) (AttrModifier, error) { +func LoadGlobalModifier(key string) (kmodifiers.AttrModifier, error) { rawModifier, _ := modifiers.Load(key) - modifier, ok := rawModifier.(AttrModifier) + modifier, ok := rawModifier.(kmodifiers.AttrModifier) if !ok { - return AttrModifier{}, fmt.Errorf("no modifier found with name '%s'", key) + return kmodifiers.AttrModifier{}, fmt.Errorf("no modifier found with name '%s'", key) } return modifier, nil diff --git a/internal/modifiers/global_modifiers_test.go b/internal/modifiers/global_modifiers_test.go index b1d3b1d..9f68368 100644 --- a/internal/modifiers/global_modifiers_test.go +++ b/internal/modifiers/global_modifiers_test.go @@ -4,14 +4,15 @@ import ( "testing" tt "github.com/vingarcia/ksql/internal/testtools" + "github.com/vingarcia/ksql/kmodifiers" ) func TestRegisterAttrModifier(t *testing.T) { t.Run("should register new modifiers correctly", func(t *testing.T) { - modifier1 := AttrModifier{ + modifier1 := kmodifiers.AttrModifier{ SkipOnUpdate: true, } - modifier2 := AttrModifier{ + modifier2 := kmodifiers.AttrModifier{ SkipOnInsert: true, } @@ -28,10 +29,10 @@ func TestRegisterAttrModifier(t *testing.T) { }) t.Run("should panic registering a modifier and the name already exists", func(t *testing.T) { - modifier1 := AttrModifier{ + modifier1 := kmodifiers.AttrModifier{ SkipOnUpdate: true, } - modifier2 := AttrModifier{ + modifier2 := kmodifiers.AttrModifier{ SkipOnInsert: true, } @@ -48,6 +49,6 @@ func TestRegisterAttrModifier(t *testing.T) { t.Run("should return an error when loading an inexistent modifier", func(t *testing.T) { mod, err := LoadGlobalModifier("nonExistentModifier") tt.AssertErrContains(t, err, "nonExistentModifier") - tt.AssertEqual(t, mod, AttrModifier{}) + tt.AssertEqual(t, mod, kmodifiers.AttrModifier{}) }) } diff --git a/internal/modifiers/json_modifier.go b/internal/modifiers/json_modifier.go index 8cb6b8c..0574bb7 100644 --- a/internal/modifiers/json_modifier.go +++ b/internal/modifiers/json_modifier.go @@ -4,13 +4,15 @@ import ( "context" "encoding/json" "fmt" + + "github.com/vingarcia/ksql/kmodifiers" ) // This modifier serializes objects as JSON when // sending it to the database and decodes // them when receiving. -var jsonModifier = AttrModifier{ - Scan: func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error { +var jsonModifier = kmodifiers.AttrModifier{ + Scan: func(ctx context.Context, opInfo kmodifiers.OpInfo, attrPtr interface{}, dbValue interface{}) error { if dbValue == nil { return nil } @@ -27,7 +29,7 @@ var jsonModifier = AttrModifier{ return json.Unmarshal(rawJSON, attrPtr) }, - Value: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { + Value: func(ctx context.Context, opInfo kmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { b, err := json.Marshal(inputValue) // SQL server uses the NVARCHAR type to store JSON and // it expects to receive strings not []byte, thus: diff --git a/internal/modifiers/json_modifier_test.go b/internal/modifiers/json_modifier_test.go index a53f56c..110f440 100644 --- a/internal/modifiers/json_modifier_test.go +++ b/internal/modifiers/json_modifier_test.go @@ -5,6 +5,7 @@ import ( "testing" tt "github.com/vingarcia/ksql/internal/testtools" + "github.com/vingarcia/ksql/kmodifiers" ) func TestAttrScan(t *testing.T) { @@ -53,7 +54,7 @@ func TestAttrScan(t *testing.T) { fakeAttr := FakeAttr{ Foo: "notZeroValue", } - err := jsonModifier.Scan(ctx, OpInfo{}, &fakeAttr, test.dbInput) + err := jsonModifier.Scan(ctx, kmodifiers.OpInfo{}, &fakeAttr, test.dbInput) if test.expectErrToContain != nil { tt.AssertErrContains(t, err, test.expectErrToContain...) t.Skip() @@ -75,7 +76,7 @@ func TestAttrValue(t *testing.T) { tests := []struct { desc string dbInput interface{} - opInfoInput OpInfo + opInfoInput kmodifiers.OpInfo attrValue interface{} expectedOutput interface{} @@ -84,7 +85,7 @@ func TestAttrValue(t *testing.T) { { desc: "should return a byte array when the driver is not sqlserver", dbInput: []byte(`{"foo":"bar"}`), - opInfoInput: OpInfo{ + opInfoInput: kmodifiers.OpInfo{ DriverName: "notSQLServer", }, attrValue: FakeAttr{ @@ -97,7 +98,7 @@ func TestAttrValue(t *testing.T) { { desc: "should return a string when the driver is sqlserver", dbInput: []byte(`{"foo":"bar"}`), - opInfoInput: OpInfo{ + opInfoInput: kmodifiers.OpInfo{ DriverName: "sqlserver", }, attrValue: FakeAttr{ diff --git a/internal/modifiers/skip_modifiers.go b/internal/modifiers/skip_modifiers.go index e7e1135..89d8ce7 100644 --- a/internal/modifiers/skip_modifiers.go +++ b/internal/modifiers/skip_modifiers.go @@ -1,9 +1,11 @@ package modifiers -var skipInsertsModifier = AttrModifier{ +import "github.com/vingarcia/ksql/kmodifiers" + +var skipInsertsModifier = kmodifiers.AttrModifier{ SkipOnInsert: true, } -var skipUpdatesModifier = AttrModifier{ +var skipUpdatesModifier = kmodifiers.AttrModifier{ SkipOnUpdate: true, } diff --git a/internal/modifiers/time_modifiers.go b/internal/modifiers/time_modifiers.go index 171060c..abbbc62 100644 --- a/internal/modifiers/time_modifiers.go +++ b/internal/modifiers/time_modifiers.go @@ -3,20 +3,22 @@ package modifiers import ( "context" "time" + + "github.com/vingarcia/ksql/kmodifiers" ) // This one is useful for updatedAt timestamps -var timeNowUTCModifier = AttrModifier{ - Value: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { +var timeNowUTCModifier = kmodifiers.AttrModifier{ + Value: func(ctx context.Context, opInfo kmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { return time.Now().UTC(), nil }, } // This one is useful for createdAt timestamps -var timeNowUTCSkipUpdatesModifier = AttrModifier{ +var timeNowUTCSkipUpdatesModifier = kmodifiers.AttrModifier{ SkipOnUpdate: true, - Value: func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { + Value: func(ctx context.Context, opInfo kmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { return time.Now().UTC(), nil }, } diff --git a/internal/structs/structs.go b/internal/structs/structs.go index 356c528..b7ecdc0 100644 --- a/internal/structs/structs.go +++ b/internal/structs/structs.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/vingarcia/ksql/internal/modifiers" + "github.com/vingarcia/ksql/kmodifiers" ) // StructInfo stores metainformation of the struct @@ -25,7 +26,7 @@ type FieldInfo struct { Name string Index int Valid bool - Modifier modifiers.AttrModifier + Modifier kmodifiers.AttrModifier } // ByIndex returns either the *FieldInfo of a valid @@ -251,7 +252,7 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) { } tags := strings.Split(name, ",") - var modifier modifiers.AttrModifier + var modifier kmodifiers.AttrModifier if len(tags) > 1 { name = tags[0] modifier, err = modifiers.LoadGlobalModifier(tags[1]) diff --git a/internal/modifiers/contract.go b/kmodifiers/attr_modifier.go similarity index 96% rename from internal/modifiers/contract.go rename to kmodifiers/attr_modifier.go index dda135e..0f07838 100644 --- a/internal/modifiers/contract.go +++ b/kmodifiers/attr_modifier.go @@ -1,8 +1,6 @@ -package modifiers +package kmodifiers -import ( - "context" -) +import "context" // AttrModifier informs KSQL how to use this modifier type AttrModifier struct { diff --git a/kmodifiers/register_new_modifier.go b/kmodifiers/register_new_modifier.go new file mode 100644 index 0000000..b909903 --- /dev/null +++ b/kmodifiers/register_new_modifier.go @@ -0,0 +1,9 @@ +package kmodifiers + +// RegisterAttrModifier allow users to add custom modifiers on startup +// it is recommended to do this inside an init() function. +var RegisterAttrModifier func(key string, modifier AttrModifier) + +// This method is set at startup by the `internal/modifiers` package. +// It was done that way in order to keep most of the implementation private +// while also avoiding cyclic dependencies. diff --git a/ksql.go b/ksql.go index a346b25..e1f7555 100644 --- a/ksql.go +++ b/ksql.go @@ -12,6 +12,7 @@ import ( "github.com/vingarcia/ksql/internal/modifiers" "github.com/vingarcia/ksql/internal/structs" + "github.com/vingarcia/ksql/kmodifiers" "github.com/vingarcia/ksql/ksqltest" ) @@ -728,7 +729,7 @@ func buildInsertQuery( Ctx: ctx, Attr: recordValue, ValueFn: valueFn, - OpInfo: modifiers.OpInfo{ + OpInfo: kmodifiers.OpInfo{ DriverName: dialect.DriverName(), Method: "Insert", }, @@ -857,7 +858,7 @@ func buildUpdateQuery( Ctx: ctx, Attr: recordValue, ValueFn: valueFn, - OpInfo: modifiers.OpInfo{ + OpInfo: kmodifiers.OpInfo{ DriverName: dialect.DriverName(), Method: "Update", }, @@ -1063,7 +1064,7 @@ func getScanArgsForNestedStructs( Ctx: ctx, AttrPtr: valueScanner, ScanFn: fieldInfo.Modifier.Scan, - OpInfo: modifiers.OpInfo{ + OpInfo: kmodifiers.OpInfo{ DriverName: dialect.DriverName(), // We will not differentiate between Query, QueryOne and QueryChunks // if we did this could lead users to make very strange modifiers @@ -1093,7 +1094,7 @@ func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, Ctx: ctx, AttrPtr: valueScanner, ScanFn: fieldInfo.Modifier.Scan, - OpInfo: modifiers.OpInfo{ + OpInfo: kmodifiers.OpInfo{ DriverName: dialect.DriverName(), // We will not differentiate between Query, QueryOne and QueryChunks // if we did this could lead users to make very strange modifiers diff --git a/test_adapters.go b/test_adapters.go index c0b9d11..7e485d3 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -11,6 +11,7 @@ import ( "github.com/vingarcia/ksql/internal/modifiers" tt "github.com/vingarcia/ksql/internal/testtools" + "github.com/vingarcia/ksql/kmodifiers" "github.com/vingarcia/ksql/nullable" ) @@ -3335,7 +3336,7 @@ func getUserByID(db DBAdapter, dialect Dialect, result *user, id uint) error { Ctx: context.TODO(), AttrPtr: &result.Address, ScanFn: modifier.Scan, - OpInfo: modifiers.OpInfo{ + OpInfo: kmodifiers.OpInfo{ DriverName: dialect.DriverName(), // We will not differentiate between Query, QueryOne and QueryChunks // if we did this could lead users to make very strange modifiers