mirror of https://github.com/VinGarcia/ksql.git
Merge pull request #29 from VinGarcia/add-attr-middlewares
Draft: Create a mechanism for users to add their own implementation of modifierspull/32/head
commit
b2e146d5e8
.github/workflows
adapters/kmysql
ksqlmodifiers
|
@ -1,6 +1,9 @@
|
|||
name: CI
|
||||
|
||||
on: [push, pull_request]
|
||||
on:
|
||||
push: {}
|
||||
pull_request:
|
||||
types: [opened, reopened]
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
|
@ -20,9 +23,11 @@ jobs:
|
|||
- name: Run Tests
|
||||
run: ./scripts/run-all-tests.sh
|
||||
- name: Run Coverage
|
||||
run: bash <(curl -s https://codecov.io/bash)
|
||||
run: |
|
||||
curl -Os https://uploader.codecov.io/latest/linux/codecov
|
||||
chmod +x codecov
|
||||
./codecov -t $CODECOV_TOKEN
|
||||
env:
|
||||
CODECOV_TOKEN: 36be8ba6-7ef1-4ec2-b607-67c1055a62ad
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ func startMySQLDB(dbName string) (databaseURL string, closer func()) {
|
|||
}
|
||||
|
||||
hostAndPort := resource.GetHostPort("3306/tcp")
|
||||
databaseUrl := fmt.Sprintf("root:mysql@(%s)/%s?timeout=30s", hostAndPort, dbName)
|
||||
databaseUrl := fmt.Sprintf("root:mysql@(%s)/%s?timeout=30s&parseTime=true", hostAndPort, dbName)
|
||||
|
||||
fmt.Println("Connecting to mariadb on url: ", databaseUrl)
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
// ErrRecordNotFound ...
|
||||
var ErrRecordNotFound error = fmt.Errorf("ksql: the query returned no results: %w", sql.ErrNoRows)
|
||||
var ErrNoValuesToUpdate error = fmt.Errorf("ksql: the input struct contains no values to update")
|
||||
|
||||
// ErrAbortIteration ...
|
||||
var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be used inside QueryChunks function")
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
package modifiers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
)
|
||||
|
||||
// AttrScanWrapper is the wrapper that allow us to intercept the Scan process
|
||||
// so we can run the modifiers instead of allowing the database driver to use
|
||||
// its default behavior.
|
||||
//
|
||||
// For that this struct implements the `sql.Scanner` interface
|
||||
type AttrScanWrapper struct {
|
||||
Ctx context.Context
|
||||
|
||||
AttrPtr interface{}
|
||||
|
||||
ScanFn ksqlmodifiers.AttrScanner
|
||||
OpInfo ksqlmodifiers.OpInfo
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface
|
||||
func (a AttrScanWrapper) Scan(dbValue interface{}) error {
|
||||
return a.ScanFn(a.Ctx, a.OpInfo, a.AttrPtr, dbValue)
|
||||
}
|
||||
|
||||
// AttrValueWrapper is the wrapper that allow us to intercept the "Valuing" process
|
||||
// so we can run the modifiers instead of allowing the database driver to use
|
||||
// its default behavior.
|
||||
//
|
||||
// For that this struct implements the `sql.Valuer` interface
|
||||
type AttrValueWrapper struct {
|
||||
Ctx context.Context
|
||||
|
||||
Attr interface{}
|
||||
|
||||
ValueFn ksqlmodifiers.AttrValuer
|
||||
OpInfo ksqlmodifiers.OpInfo
|
||||
}
|
||||
|
||||
// Value implements the sql.Valuer interface
|
||||
func (a AttrValueWrapper) Value() (driver.Value, error) {
|
||||
return a.ValueFn(a.Ctx, a.OpInfo, a.Attr)
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package modifiers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
tt "github.com/vingarcia/ksql/internal/testtools"
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
)
|
||||
|
||||
func TestAttrScanWrapper(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
var scanArgs map[string]interface{}
|
||||
wrapper := AttrScanWrapper{
|
||||
Ctx: ctx,
|
||||
AttrPtr: "fakeAttrPtr",
|
||||
ScanFn: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, attrPtr interface{}, dbValue interface{}) error {
|
||||
scanArgs = map[string]interface{}{
|
||||
"opInfo": opInfo,
|
||||
"attrPtr": attrPtr,
|
||||
"dbValue": dbValue,
|
||||
}
|
||||
return errors.New("fakeScanErrMsg")
|
||||
},
|
||||
OpInfo: ksqlmodifiers.OpInfo{
|
||||
Method: "fakeMethod",
|
||||
DriverName: "fakeDriverName",
|
||||
},
|
||||
}
|
||||
|
||||
err := wrapper.Scan("fakeDbValue")
|
||||
tt.AssertErrContains(t, err, "fakeScanErrMsg")
|
||||
tt.AssertEqual(t, scanArgs, map[string]interface{}{
|
||||
"opInfo": ksqlmodifiers.OpInfo{
|
||||
Method: "fakeMethod",
|
||||
DriverName: "fakeDriverName",
|
||||
},
|
||||
"attrPtr": "fakeAttrPtr",
|
||||
"dbValue": "fakeDbValue",
|
||||
})
|
||||
}
|
||||
|
||||
func TestAttrWrapper(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
var valueArgs map[string]interface{}
|
||||
wrapper := AttrValueWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: "fakeAttr",
|
||||
ValueFn: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
|
||||
valueArgs = map[string]interface{}{
|
||||
"opInfo": opInfo,
|
||||
"inputValue": inputValue,
|
||||
}
|
||||
return "fakeOutputValue", errors.New("fakeValueErrMsg")
|
||||
},
|
||||
OpInfo: ksqlmodifiers.OpInfo{
|
||||
Method: "fakeMethod",
|
||||
DriverName: "fakeDriverName",
|
||||
},
|
||||
}
|
||||
|
||||
value, err := wrapper.Value()
|
||||
tt.AssertErrContains(t, err, "fakeValueErrMsg")
|
||||
tt.AssertEqual(t, valueArgs, map[string]interface{}{
|
||||
"opInfo": ksqlmodifiers.OpInfo{
|
||||
Method: "fakeMethod",
|
||||
DriverName: "fakeDriverName",
|
||||
},
|
||||
"inputValue": "fakeAttr",
|
||||
})
|
||||
tt.AssertEqual(t, value, "fakeOutputValue")
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package modifiers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
)
|
||||
|
||||
// Here we keep all the registered modifiers
|
||||
var modifiers sync.Map
|
||||
|
||||
func init() {
|
||||
// Here we expose the registration function in a public package,
|
||||
// so users can use it:
|
||||
ksqlmodifiers.RegisterAttrModifier = RegisterAttrModifier
|
||||
|
||||
// These are the builtin modifiers:
|
||||
|
||||
// This one is useful for serializing/desserializing structs:
|
||||
modifiers.Store("json", jsonModifier)
|
||||
|
||||
// This next two are useful for the UpdatedAt and Created fields respectively:
|
||||
// They only work on time.Time attributes and will set the attribute to time.Now().
|
||||
modifiers.Store("timeNowUTC", timeNowUTCModifier)
|
||||
modifiers.Store("timeNowUTC/skipUpdates", timeNowUTCSkipUpdatesModifier)
|
||||
|
||||
// These are mostly example modifiers and they are also used
|
||||
// to test the feature of skipping updates, inserts and queries.
|
||||
modifiers.Store("skipUpdates", skipUpdatesModifier)
|
||||
modifiers.Store("skipInserts", skipInsertsModifier)
|
||||
}
|
||||
|
||||
// RegisterAttrModifier allow users to add custom modifiers on startup
|
||||
// it is recommended to do this inside an init() function.
|
||||
func RegisterAttrModifier(key string, modifier ksqlmodifiers.AttrModifier) {
|
||||
_, found := modifiers.Load(key)
|
||||
if found {
|
||||
panic(fmt.Errorf("KSQL: cannot register modifier '%s' name is already in use", key))
|
||||
}
|
||||
|
||||
modifiers.Store(key, modifier)
|
||||
}
|
||||
|
||||
// LoadGlobalModifier is used internally by KSQL to load
|
||||
// modifiers during runtime.
|
||||
func LoadGlobalModifier(key string) (ksqlmodifiers.AttrModifier, error) {
|
||||
rawModifier, _ := modifiers.Load(key)
|
||||
modifier, ok := rawModifier.(ksqlmodifiers.AttrModifier)
|
||||
if !ok {
|
||||
return ksqlmodifiers.AttrModifier{}, fmt.Errorf("no modifier found with name '%s'", key)
|
||||
}
|
||||
|
||||
return modifier, nil
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package modifiers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
tt "github.com/vingarcia/ksql/internal/testtools"
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
)
|
||||
|
||||
func TestRegisterAttrModifier(t *testing.T) {
|
||||
t.Run("should register new modifiers correctly", func(t *testing.T) {
|
||||
modifier1 := ksqlmodifiers.AttrModifier{
|
||||
SkipOnUpdate: true,
|
||||
}
|
||||
modifier2 := ksqlmodifiers.AttrModifier{
|
||||
SkipOnInsert: true,
|
||||
}
|
||||
|
||||
RegisterAttrModifier("fakeModifierName1", modifier1)
|
||||
RegisterAttrModifier("fakeModifierName2", modifier2)
|
||||
|
||||
mod, err := LoadGlobalModifier("fakeModifierName1")
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, mod, modifier1)
|
||||
|
||||
mod, err = LoadGlobalModifier("fakeModifierName2")
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, mod, modifier2)
|
||||
})
|
||||
|
||||
t.Run("should panic registering a modifier and the name already exists", func(t *testing.T) {
|
||||
modifier1 := ksqlmodifiers.AttrModifier{
|
||||
SkipOnUpdate: true,
|
||||
}
|
||||
modifier2 := ksqlmodifiers.AttrModifier{
|
||||
SkipOnInsert: true,
|
||||
}
|
||||
|
||||
RegisterAttrModifier("fakeModifierName", modifier1)
|
||||
panicPayload := tt.PanicHandler(func() {
|
||||
RegisterAttrModifier("fakeModifierName", modifier2)
|
||||
})
|
||||
|
||||
err, ok := panicPayload.(error)
|
||||
tt.AssertEqual(t, ok, true)
|
||||
tt.AssertErrContains(t, err, "KSQL", "fakeModifierName", "name is already in use")
|
||||
})
|
||||
|
||||
t.Run("should return an error when loading an inexistent modifier", func(t *testing.T) {
|
||||
mod, err := LoadGlobalModifier("nonExistentModifier")
|
||||
tt.AssertErrContains(t, err, "nonExistentModifier")
|
||||
tt.AssertEqual(t, mod, ksqlmodifiers.AttrModifier{})
|
||||
})
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package modifiers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
)
|
||||
|
||||
// This modifier serializes objects as JSON when
|
||||
// sending it to the database and decodes
|
||||
// them when receiving.
|
||||
var jsonModifier = ksqlmodifiers.AttrModifier{
|
||||
Scan: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, attrPtr interface{}, dbValue interface{}) error {
|
||||
if dbValue == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Required since sqlite3 returns strings not bytes
|
||||
if v, ok := dbValue.(string); ok {
|
||||
dbValue = []byte(v)
|
||||
}
|
||||
|
||||
rawJSON, ok := dbValue.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type received to Scan: %T", dbValue)
|
||||
}
|
||||
return json.Unmarshal(rawJSON, attrPtr)
|
||||
},
|
||||
|
||||
Value: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
|
||||
b, err := json.Marshal(inputValue)
|
||||
// SQL server uses the NVARCHAR type to store JSON and
|
||||
// it expects to receive strings not []byte, thus:
|
||||
if opInfo.DriverName == "sqlserver" {
|
||||
return string(b), err
|
||||
}
|
||||
return b, err
|
||||
},
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
package modifiers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
tt "github.com/vingarcia/ksql/internal/testtools"
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
)
|
||||
|
||||
func TestAttrScan(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
type FakeAttr struct {
|
||||
Foo string `json:"foo"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
dbInput interface{}
|
||||
expectedValue interface{}
|
||||
expectErrToContain []string
|
||||
}{
|
||||
{
|
||||
desc: "should not set struct to zero value if input is nil",
|
||||
dbInput: nil,
|
||||
expectedValue: FakeAttr{
|
||||
Foo: "notZeroValue",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should work when input is a byte slice",
|
||||
dbInput: []byte(`{"foo":"bar"}`),
|
||||
expectedValue: FakeAttr{
|
||||
Foo: "bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should work when input is a string",
|
||||
dbInput: `{"foo":"bar"}`,
|
||||
expectedValue: FakeAttr{
|
||||
Foo: "bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should report error if input type is unsupported",
|
||||
dbInput: 10,
|
||||
expectErrToContain: []string{"unexpected type", "int"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
fakeAttr := FakeAttr{
|
||||
Foo: "notZeroValue",
|
||||
}
|
||||
err := jsonModifier.Scan(ctx, ksqlmodifiers.OpInfo{}, &fakeAttr, test.dbInput)
|
||||
if test.expectErrToContain != nil {
|
||||
tt.AssertErrContains(t, err, test.expectErrToContain...)
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, fakeAttr, test.expectedValue)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttrValue(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
type FakeAttr struct {
|
||||
Foo string `json:"foo"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
dbInput interface{}
|
||||
opInfoInput ksqlmodifiers.OpInfo
|
||||
attrValue interface{}
|
||||
|
||||
expectedOutput interface{}
|
||||
expectErrToContain []string
|
||||
}{
|
||||
{
|
||||
desc: "should return a byte array when the driver is not sqlserver",
|
||||
dbInput: []byte(`{"foo":"bar"}`),
|
||||
opInfoInput: ksqlmodifiers.OpInfo{
|
||||
DriverName: "notSQLServer",
|
||||
},
|
||||
attrValue: FakeAttr{
|
||||
Foo: "bar",
|
||||
},
|
||||
expectedOutput: tt.ToJSON(t, map[string]interface{}{
|
||||
"foo": "bar",
|
||||
}),
|
||||
},
|
||||
{
|
||||
desc: "should return a string when the driver is sqlserver",
|
||||
dbInput: []byte(`{"foo":"bar"}`),
|
||||
opInfoInput: ksqlmodifiers.OpInfo{
|
||||
DriverName: "sqlserver",
|
||||
},
|
||||
attrValue: FakeAttr{
|
||||
Foo: "bar",
|
||||
},
|
||||
expectedOutput: string(tt.ToJSON(t, map[string]interface{}{
|
||||
"foo": "bar",
|
||||
})),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
output, err := jsonModifier.Value(ctx, test.opInfoInput, test.attrValue)
|
||||
if test.expectErrToContain != nil {
|
||||
tt.AssertErrContains(t, err, test.expectErrToContain...)
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, output, test.expectedOutput)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package modifiers
|
||||
|
||||
import "github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
|
||||
var skipInsertsModifier = ksqlmodifiers.AttrModifier{
|
||||
SkipOnInsert: true,
|
||||
}
|
||||
|
||||
var skipUpdatesModifier = ksqlmodifiers.AttrModifier{
|
||||
SkipOnUpdate: true,
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package modifiers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
)
|
||||
|
||||
// This one is useful for updatedAt timestamps
|
||||
var timeNowUTCModifier = ksqlmodifiers.AttrModifier{
|
||||
Value: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
|
||||
return time.Now().UTC(), nil
|
||||
},
|
||||
}
|
||||
|
||||
// This one is useful for createdAt timestamps
|
||||
var timeNowUTCSkipUpdatesModifier = ksqlmodifiers.AttrModifier{
|
||||
SkipOnUpdate: true,
|
||||
|
||||
Value: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
|
||||
return time.Now().UTC(), nil
|
||||
},
|
||||
}
|
|
@ -5,6 +5,9 @@ import (
|
|||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/vingarcia/ksql/internal/modifiers"
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
)
|
||||
|
||||
// StructInfo stores metainformation of the struct
|
||||
|
@ -20,10 +23,10 @@ type StructInfo struct {
|
|||
// information regarding a specific field
|
||||
// of a struct.
|
||||
type FieldInfo struct {
|
||||
Name string
|
||||
Index int
|
||||
Valid bool
|
||||
SerializeAsJSON bool
|
||||
Name string
|
||||
Index int
|
||||
Valid bool
|
||||
Modifier ksqlmodifiers.AttrModifier
|
||||
}
|
||||
|
||||
// ByIndex returns either the *FieldInfo of a valid
|
||||
|
@ -232,7 +235,7 @@ func (p PtrConverter) Convert(destType reflect.Type) (reflect.Value, error) {
|
|||
//
|
||||
// This should save several calls to `Field(i).Tag.Get("foo")`
|
||||
// which improves performance by a lot.
|
||||
func getTagNames(t reflect.Type) (StructInfo, error) {
|
||||
func getTagNames(t reflect.Type) (_ StructInfo, err error) {
|
||||
info := StructInfo{
|
||||
byIndex: map[int]*FieldInfo{},
|
||||
byName: map[string]*FieldInfo{},
|
||||
|
@ -249,10 +252,13 @@ func getTagNames(t reflect.Type) (StructInfo, error) {
|
|||
}
|
||||
|
||||
tags := strings.Split(name, ",")
|
||||
serializeAsJSON := false
|
||||
var modifier ksqlmodifiers.AttrModifier
|
||||
if len(tags) > 1 {
|
||||
name = tags[0]
|
||||
serializeAsJSON = tags[1] == "json"
|
||||
modifier, err = modifiers.LoadGlobalModifier(tags[1])
|
||||
if err != nil {
|
||||
return StructInfo{}, fmt.Errorf("attribute contains invalid modifier name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, found := info.byName[name]; found {
|
||||
|
@ -263,9 +269,9 @@ func getTagNames(t reflect.Type) (StructInfo, error) {
|
|||
}
|
||||
|
||||
info.add(FieldInfo{
|
||||
Name: name,
|
||||
Index: i,
|
||||
SerializeAsJSON: serializeAsJSON,
|
||||
Name: name,
|
||||
Index: i,
|
||||
Modifier: modifier,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
package tt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ToJSON(t *testing.T, obj interface{}) []byte {
|
||||
rawJSON, err := json.Marshal(obj)
|
||||
AssertNoErr(t, err)
|
||||
|
||||
return rawJSON
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
package tt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ParseTime(t *testing.T, timestr string) time.Time {
|
||||
parsedTime, err := time.Parse(time.RFC3339, timestr)
|
||||
AssertNoErr(t, err)
|
||||
return parsedTime
|
||||
}
|
49
json.go
49
json.go
|
@ -1,49 +0,0 @@
|
|||
package ksql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// This type was created to make it easier to adapt
|
||||
// input attributes to be convertible to and from JSON
|
||||
// before sending or receiving it from the database.
|
||||
type jsonSerializable struct {
|
||||
DriverName string
|
||||
Attr interface{}
|
||||
}
|
||||
|
||||
// Scan Implements the Scanner interface in order to load
|
||||
// this field from the JSON stored in the database
|
||||
func (j *jsonSerializable) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
v := reflect.ValueOf(j.Attr)
|
||||
// Set the struct to its 0 value just like json.Unmarshal
|
||||
// does for nil attributes:
|
||||
v.Elem().Set(reflect.Zero(reflect.TypeOf(j.Attr).Elem()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Required since sqlite3 returns strings not bytes
|
||||
if v, ok := value.(string); ok {
|
||||
value = []byte(v)
|
||||
}
|
||||
|
||||
rawJSON, ok := value.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type received to Scan: %T", value)
|
||||
}
|
||||
return json.Unmarshal(rawJSON, j.Attr)
|
||||
}
|
||||
|
||||
// Value Implements the Valuer interface in order to save
|
||||
// this field as JSON on the database.
|
||||
func (j jsonSerializable) Value() (driver.Value, error) {
|
||||
b, err := json.Marshal(j.Attr)
|
||||
if j.DriverName == "sqlserver" {
|
||||
return string(b), err
|
||||
}
|
||||
return b, err
|
||||
}
|
129
ksql.go
129
ksql.go
|
@ -10,7 +10,9 @@ import (
|
|||
"sync"
|
||||
"unicode"
|
||||
|
||||
"github.com/vingarcia/ksql/internal/modifiers"
|
||||
"github.com/vingarcia/ksql/internal/structs"
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
"github.com/vingarcia/ksql/ksqltest"
|
||||
)
|
||||
|
||||
|
@ -184,7 +186,7 @@ func (c DB) Query(
|
|||
elemPtr = elemPtr.Elem()
|
||||
}
|
||||
|
||||
err = scanRows(c.dialect, rows, elemPtr.Interface())
|
||||
err = scanRows(ctx, c.dialect, rows, elemPtr.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -257,13 +259,13 @@ func (c DB) QueryOne(
|
|||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrRecordNotFound
|
||||
}
|
||||
|
||||
err = scanRowsFromType(c.dialect, rows, record, t, v)
|
||||
err = scanRowsFromType(ctx, c.dialect, rows, record, t, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -342,7 +344,7 @@ func (c DB) QueryChunks(
|
|||
chunk = reflect.Append(chunk, elemValue)
|
||||
}
|
||||
|
||||
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
|
||||
err = scanRows(ctx, c.dialect, rows, chunk.Index(idx).Addr().Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -419,7 +421,7 @@ func (c DB) Insert(
|
|||
return err
|
||||
}
|
||||
|
||||
query, params, scanValues, err := buildInsertQuery(c.dialect, table, t, v, info, record)
|
||||
query, params, scanValues, err := buildInsertQuery(ctx, c.dialect, table, t, v, info, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -656,7 +658,7 @@ func (c DB) Patch(
|
|||
return err
|
||||
}
|
||||
|
||||
query, params, err := buildUpdateQuery(c.dialect, table.name, info, record, table.idColumns...)
|
||||
query, params, err := buildUpdateQuery(ctx, c.dialect, table.name, info, record, table.idColumns...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -681,6 +683,7 @@ func (c DB) Patch(
|
|||
}
|
||||
|
||||
func buildInsertQuery(
|
||||
ctx context.Context,
|
||||
dialect Dialect,
|
||||
table Table,
|
||||
t reflect.Type,
|
||||
|
@ -707,18 +710,29 @@ func buildInsertQuery(
|
|||
|
||||
columnNames := []string{}
|
||||
for col := range recordMap {
|
||||
if info.ByName(col).Modifier.SkipOnInsert {
|
||||
continue
|
||||
}
|
||||
|
||||
columnNames = append(columnNames, col)
|
||||
}
|
||||
|
||||
params = make([]interface{}, len(recordMap))
|
||||
valuesQuery := make([]string, len(recordMap))
|
||||
params = make([]interface{}, len(columnNames))
|
||||
valuesQuery := make([]string, len(columnNames))
|
||||
for i, col := range columnNames {
|
||||
recordValue := recordMap[col]
|
||||
params[i] = recordValue
|
||||
if info.ByName(col).SerializeAsJSON {
|
||||
params[i] = jsonSerializable{
|
||||
DriverName: dialect.DriverName(),
|
||||
Attr: recordValue,
|
||||
|
||||
valueFn := info.ByName(col).Modifier.Value
|
||||
if valueFn != nil {
|
||||
params[i] = modifiers.AttrValueWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: recordValue,
|
||||
ValueFn: valueFn,
|
||||
OpInfo: ksqlmodifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
Method: "Insert",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -761,6 +775,16 @@ func buildInsertQuery(
|
|||
}
|
||||
}
|
||||
|
||||
if len(columnNames) == 0 && dialect.DriverName() != "mysql" {
|
||||
query = fmt.Sprintf(
|
||||
"INSERT INTO %s%s DEFAULT VALUES%s",
|
||||
dialect.Escape(table.name),
|
||||
outputQuery,
|
||||
returningQuery,
|
||||
)
|
||||
return query, params, scanValues, nil
|
||||
}
|
||||
|
||||
// Note that the outputQuery and the returningQuery depend
|
||||
// on the selected driver, thus, they might be empty strings.
|
||||
query = fmt.Sprintf(
|
||||
|
@ -776,21 +800,32 @@ func buildInsertQuery(
|
|||
}
|
||||
|
||||
func buildUpdateQuery(
|
||||
ctx context.Context,
|
||||
dialect Dialect,
|
||||
tableName string,
|
||||
info structs.StructInfo,
|
||||
record interface{},
|
||||
idFieldNames ...string,
|
||||
) (query string, args []interface{}, err error) {
|
||||
recordMap, err := ksqltest.StructToMap(record)
|
||||
recordMap, err := structs.StructToMap(record)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
for key := range recordMap {
|
||||
if info.ByName(key).Modifier.SkipOnUpdate {
|
||||
delete(recordMap, key)
|
||||
}
|
||||
}
|
||||
|
||||
numAttrs := len(recordMap)
|
||||
args = make([]interface{}, numAttrs)
|
||||
numNonIDArgs := numAttrs - len(idFieldNames)
|
||||
whereArgs := args[numNonIDArgs:]
|
||||
|
||||
if numNonIDArgs == 0 {
|
||||
return "", nil, ErrNoValuesToUpdate
|
||||
}
|
||||
|
||||
err = validateIfAllIdsArePresent(idFieldNames, recordMap)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
|
@ -816,10 +851,17 @@ func buildUpdateQuery(
|
|||
var setQuery []string
|
||||
for i, k := range keys {
|
||||
recordValue := recordMap[k]
|
||||
if info.ByName(k).SerializeAsJSON {
|
||||
recordValue = jsonSerializable{
|
||||
DriverName: dialect.DriverName(),
|
||||
Attr: recordValue,
|
||||
|
||||
valueFn := info.ByName(k).Modifier.Value
|
||||
if valueFn != nil {
|
||||
recordValue = modifiers.AttrValueWrapper{
|
||||
Ctx: ctx,
|
||||
Attr: recordValue,
|
||||
ValueFn: valueFn,
|
||||
OpInfo: ksqlmodifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
Method: "Update",
|
||||
},
|
||||
}
|
||||
}
|
||||
args[i] = recordValue
|
||||
|
@ -930,13 +972,14 @@ func (nopScanner) Scan(value interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func scanRows(dialect Dialect, rows Rows, record interface{}) error {
|
||||
func scanRows(ctx context.Context, dialect Dialect, rows Rows, record interface{}) error {
|
||||
v := reflect.ValueOf(record)
|
||||
t := v.Type()
|
||||
return scanRowsFromType(dialect, rows, record, t, v)
|
||||
return scanRowsFromType(ctx, dialect, rows, record, t, v)
|
||||
}
|
||||
|
||||
func scanRowsFromType(
|
||||
ctx context.Context,
|
||||
dialect Dialect,
|
||||
rows Rows,
|
||||
record interface{},
|
||||
|
@ -964,7 +1007,7 @@ func scanRowsFromType(
|
|||
// This version is positional meaning that it expect the arguments
|
||||
// to follow an specific order. It's ok because we don't allow the
|
||||
// user to type the "SELECT" part of the query for nested structs.
|
||||
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
||||
scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -975,7 +1018,7 @@ func scanRowsFromType(
|
|||
}
|
||||
// Since this version uses the names of the columns it works
|
||||
// with any order of attributes/columns.
|
||||
scanArgs = getScanArgsFromNames(dialect, names, v, info)
|
||||
scanArgs = getScanArgsFromNames(ctx, dialect, names, v, info)
|
||||
}
|
||||
|
||||
err = rows.Scan(scanArgs...)
|
||||
|
@ -985,7 +1028,14 @@ func scanRowsFromType(
|
|||
return nil
|
||||
}
|
||||
|
||||
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
|
||||
func getScanArgsForNestedStructs(
|
||||
ctx context.Context,
|
||||
dialect Dialect,
|
||||
rows Rows,
|
||||
t reflect.Type,
|
||||
v reflect.Value,
|
||||
info structs.StructInfo,
|
||||
) ([]interface{}, error) {
|
||||
scanArgs := []interface{}{}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !info.ByIndex(i).Valid {
|
||||
|
@ -1008,10 +1058,18 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
|
|||
valueScanner := nopScannerValue
|
||||
if fieldInfo.Valid {
|
||||
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
|
||||
if fieldInfo.SerializeAsJSON {
|
||||
valueScanner = &jsonSerializable{
|
||||
DriverName: dialect.DriverName(),
|
||||
Attr: valueScanner,
|
||||
|
||||
if fieldInfo.Modifier.Scan != nil {
|
||||
valueScanner = &modifiers.AttrScanWrapper{
|
||||
Ctx: ctx,
|
||||
AttrPtr: valueScanner,
|
||||
ScanFn: fieldInfo.Modifier.Scan,
|
||||
OpInfo: ksqlmodifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
// We will not differentiate between Query, QueryOne and QueryChunks
|
||||
// if we did this could lead users to make very strange modifiers
|
||||
Method: "Query",
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1023,7 +1081,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
|
|||
return scanArgs, nil
|
||||
}
|
||||
|
||||
func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
|
||||
func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
|
||||
scanArgs := []interface{}{}
|
||||
for _, name := range names {
|
||||
fieldInfo := info.ByName(name)
|
||||
|
@ -1031,10 +1089,17 @@ func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info
|
|||
valueScanner := nopScannerValue
|
||||
if fieldInfo.Valid {
|
||||
valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
|
||||
if fieldInfo.SerializeAsJSON {
|
||||
valueScanner = &jsonSerializable{
|
||||
DriverName: dialect.DriverName(),
|
||||
Attr: valueScanner,
|
||||
if fieldInfo.Modifier.Scan != nil {
|
||||
valueScanner = &modifiers.AttrScanWrapper{
|
||||
Ctx: ctx,
|
||||
AttrPtr: valueScanner,
|
||||
ScanFn: fieldInfo.Modifier.Scan,
|
||||
OpInfo: ksqlmodifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
// We will not differentiate between Query, QueryOne and QueryChunks
|
||||
// if we did this could lead users to make very strange modifiers
|
||||
Method: "Query",
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
package ksqlmodifiers
|
||||
|
||||
import "context"
|
||||
|
||||
// AttrModifier informs KSQL how to use this modifier
|
||||
type AttrModifier struct {
|
||||
// The following attributes will tell KSQL to
|
||||
// leave this attribute out of insertions, updates,
|
||||
// and queries respectively.
|
||||
SkipOnInsert bool
|
||||
SkipOnUpdate bool
|
||||
|
||||
// Implement these functions if you want to override the default Scan/Value behavior
|
||||
// for the target attribute.
|
||||
Scan AttrScanner
|
||||
Value AttrValuer
|
||||
}
|
||||
|
||||
// AttrScanner describes the operation of deserializing an object received from the database.
|
||||
type AttrScanner func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error
|
||||
|
||||
// AttrValuer describes the operation of serializing an object when saving it to the database.
|
||||
type AttrValuer func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error)
|
||||
|
||||
// OpInfo contains information that might be used by a modifier to determine how it should behave.
|
||||
type OpInfo struct {
|
||||
// A string version of the name of one of
|
||||
// the methods of the `ksql.Provider` interface, e.g. `Insert` or `Query`
|
||||
Method string
|
||||
|
||||
// The string representing the current underlying database, e.g.:
|
||||
// "postgres", "sqlite3", "mysql" or "sqlserver".
|
||||
DriverName string
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
// This package exposes only the public types and functions for creating new
|
||||
// modifiers for KSQL.
|
||||
//
|
||||
// For understanding internal details of the code
|
||||
// please read the `internal/modifiers` package.
|
||||
package ksqlmodifiers
|
|
@ -0,0 +1,9 @@
|
|||
package ksqlmodifiers
|
||||
|
||||
// RegisterAttrModifier allow users to add custom modifiers on startup
|
||||
// it is recommended to do this inside an init() function.
|
||||
var RegisterAttrModifier func(key string, modifier AttrModifier)
|
||||
|
||||
// This method is set at startup by the `internal/modifiers` package.
|
||||
// It was done that way in order to keep most of the implementation private
|
||||
// while also avoiding cyclic dependencies.
|
565
test_adapters.go
565
test_adapters.go
|
@ -7,8 +7,11 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/vingarcia/ksql/internal/modifiers"
|
||||
tt "github.com/vingarcia/ksql/internal/testtools"
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
"github.com/vingarcia/ksql/nullable"
|
||||
)
|
||||
|
||||
|
@ -70,6 +73,7 @@ func RunTestsForAdapter(
|
|||
PatchTest(t, driver, connStr, newDBAdapter)
|
||||
QueryChunksTest(t, driver, connStr, newDBAdapter)
|
||||
TransactionTest(t, driver, connStr, newDBAdapter)
|
||||
ModifiersTest(t, driver, connStr, newDBAdapter)
|
||||
ScanRowsTest(t, driver, connStr, newDBAdapter)
|
||||
})
|
||||
}
|
||||
|
@ -893,6 +897,33 @@ func InsertTest(
|
|||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, inserted.Age, 5455)
|
||||
})
|
||||
|
||||
t.Run("should work and retrieve the ID for structs with no attributes", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type tsUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name,skipInserts"`
|
||||
}
|
||||
u := tsUser{
|
||||
Name: "Letícia",
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, u.ID, 0)
|
||||
|
||||
var untaggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name *string `ksql:"name"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("composite key tables", func(t *testing.T) {
|
||||
|
@ -958,6 +989,55 @@ func InsertTest(
|
|||
tt.AssertEqual(t, userPerms[0].PermID, 42)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when inserting a struct with no values but composite keys should still retrieve the IDs", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
// Table defined with 3 values, but we'll provide only 2,
|
||||
// the third will be generated for the purposes of this test:
|
||||
table := NewTable("user_permissions", "id", "user_id", "perm_id")
|
||||
type taggedPerm struct {
|
||||
ID uint `ksql:"id"`
|
||||
UserID int `ksql:"user_id"`
|
||||
PermID int `ksql:"perm_id"`
|
||||
Type string `ksql:"type,skipInserts"`
|
||||
}
|
||||
permission := taggedPerm{
|
||||
UserID: 3,
|
||||
PermID: 43,
|
||||
}
|
||||
err := c.Insert(ctx, table, &permission)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, permission.ID, 0)
|
||||
|
||||
fmt.Println("permID:", permission.ID)
|
||||
var untaggedPerm struct {
|
||||
ID uint `ksql:"id"`
|
||||
UserID int `ksql:"user_id"`
|
||||
PermID int `ksql:"perm_id"`
|
||||
Type *string `ksql:"type"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &untaggedPerm, `FROM user_permissions WHERE user_id = 3 AND perm_id = 43`)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, untaggedPerm.Type, (*string)(nil))
|
||||
|
||||
// Should retrieve the generated ID from the database,
|
||||
// only if the database supports returning multiple values:
|
||||
switch c.dialect.InsertMethod() {
|
||||
case insertWithNoIDRetrieval, insertWithLastInsertID:
|
||||
tt.AssertEqual(t, permission.ID, uint(0))
|
||||
tt.AssertEqual(t, untaggedPerm.UserID, 3)
|
||||
tt.AssertEqual(t, untaggedPerm.PermID, 43)
|
||||
case insertWithReturning, insertWithOutput:
|
||||
tt.AssertEqual(t, untaggedPerm.ID, permission.ID)
|
||||
tt.AssertEqual(t, untaggedPerm.UserID, 3)
|
||||
tt.AssertEqual(t, untaggedPerm.PermID, 43)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -1711,6 +1791,24 @@ func PatchTest(
|
|||
tt.AssertNotEqual(t, err, nil)
|
||||
})
|
||||
|
||||
t.Run("should report error if the struct has no fields to update", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
err = c.Update(ctx, usersTable, struct {
|
||||
ID uint `ksql:"id"` // ID fields are not updated
|
||||
Name string `ksql:"name,skipUpdates"` // the skipUpdate modifier should rule this one out
|
||||
Age *int `ksql:"age"` // Age is a nil pointer so it would not be updated
|
||||
}{
|
||||
ID: 1,
|
||||
Name: "some name",
|
||||
})
|
||||
tt.AssertEqual(t, err, ErrNoValuesToUpdate)
|
||||
})
|
||||
|
||||
t.Run("should report error if the id is missing", func(t *testing.T) {
|
||||
t.Run("with a single primary key", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
|
@ -2515,6 +2613,432 @@ func TransactionTest(
|
|||
})
|
||||
}
|
||||
|
||||
func ModifiersTest(
|
||||
t *testing.T,
|
||||
driver string,
|
||||
connStr string,
|
||||
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
||||
) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Modifiers", func(t *testing.T) {
|
||||
err := createTables(driver, connStr)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
||||
t.Run("timeNowUTC modifier", func(t *testing.T) {
|
||||
t.Run("should be set to time.Now().UTC() on insertion", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type tsUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
|
||||
}
|
||||
u := tsUser{
|
||||
Name: "Letícia",
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, u.ID, 0)
|
||||
|
||||
var untaggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
UpdatedAt time.Time `ksql:"updated_at"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
now := time.Now()
|
||||
tt.AssertApproxTime(t,
|
||||
2*time.Second, untaggedUser.UpdatedAt, now,
|
||||
"updatedAt should be set to %v, but got: %v", now, untaggedUser.UpdatedAt,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("should be set to time.Now().UTC() on updates", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type userWithNoTags struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
UpdatedAt time.Time `ksql:"updated_at"`
|
||||
}
|
||||
untaggedUser := userWithNoTags{
|
||||
Name: "Laura Ribeiro",
|
||||
// Any time different from now:
|
||||
UpdatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &untaggedUser)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
||||
|
||||
type taggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
|
||||
}
|
||||
u := taggedUser{
|
||||
ID: untaggedUser.ID,
|
||||
Name: "Laurinha Ribeiro",
|
||||
}
|
||||
err = c.Patch(ctx, usersTable, u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
var untaggedUser2 userWithNoTags
|
||||
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, untaggedUser2.ID, 0)
|
||||
|
||||
now := time.Now()
|
||||
tt.AssertApproxTime(t,
|
||||
2*time.Second, untaggedUser2.UpdatedAt, now,
|
||||
"updatedAt should be set to %v, but got: %v", now, untaggedUser2.UpdatedAt,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("should not alter the value on queries", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type userWithNoTags struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
UpdatedAt time.Time `ksql:"updated_at"`
|
||||
}
|
||||
untaggedUser := userWithNoTags{
|
||||
Name: "Marta Ribeiro",
|
||||
// Any time different from now:
|
||||
UpdatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &untaggedUser)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
||||
|
||||
var taggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
|
||||
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
|
||||
tt.AssertEqual(t, taggedUser.UpdatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("timeNowUTC/skipUpdates modifier", func(t *testing.T) {
|
||||
t.Run("should be set to time.Now().UTC() on insertion", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type tsUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
|
||||
}
|
||||
u := tsUser{
|
||||
Name: "Letícia",
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, u.ID, 0)
|
||||
|
||||
var untaggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
CreatedAt time.Time `ksql:"created_at"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
now := time.Now()
|
||||
tt.AssertApproxTime(t,
|
||||
2*time.Second, untaggedUser.CreatedAt, now,
|
||||
"updatedAt should be set to %v, but got: %v", now, untaggedUser.CreatedAt,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("should be ignored on updates", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type userWithNoTags struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
CreatedAt time.Time `ksql:"created_at"`
|
||||
}
|
||||
untaggedUser := userWithNoTags{
|
||||
Name: "Laura Ribeiro",
|
||||
// Any time different from now:
|
||||
CreatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &untaggedUser)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
||||
|
||||
type taggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
|
||||
}
|
||||
u := taggedUser{
|
||||
ID: untaggedUser.ID,
|
||||
Name: "Laurinha Ribeiro",
|
||||
// Some random time that should be ignored:
|
||||
CreatedAt: tt.ParseTime(t, "1999-08-05T14:00:00Z"),
|
||||
}
|
||||
err = c.Patch(ctx, usersTable, u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
var untaggedUser2 userWithNoTags
|
||||
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, untaggedUser2.CreatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
|
||||
})
|
||||
|
||||
t.Run("should not alter the value on queries", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type userWithNoTags struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
CreatedAt time.Time `ksql:"created_at"`
|
||||
}
|
||||
untaggedUser := userWithNoTags{
|
||||
Name: "Marta Ribeiro",
|
||||
// Any time different from now:
|
||||
CreatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &untaggedUser)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
||||
|
||||
var taggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
|
||||
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
|
||||
tt.AssertEqual(t, taggedUser.CreatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("skipInserts modifier", func(t *testing.T) {
|
||||
t.Run("should ignore the field during insertions", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type tsUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name,skipInserts"`
|
||||
Age int `ksql:"age"`
|
||||
}
|
||||
u := tsUser{
|
||||
Name: "Letícia",
|
||||
Age: 22,
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, u.ID, 0)
|
||||
|
||||
var untaggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name *string `ksql:"name"`
|
||||
Age int `ksql:"age"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
|
||||
tt.AssertEqual(t, untaggedUser.Age, 22)
|
||||
})
|
||||
|
||||
t.Run("should have no effect on updates", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type userWithNoTags struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
Age int `ksql:"age"`
|
||||
}
|
||||
untaggedUser := userWithNoTags{
|
||||
Name: "Laurinha Ribeiro",
|
||||
Age: 11,
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &untaggedUser)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
||||
|
||||
type taggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name,skipInserts"`
|
||||
Age int `ksql:"age"`
|
||||
}
|
||||
u := taggedUser{
|
||||
ID: untaggedUser.ID,
|
||||
Name: "Laura Ribeiro",
|
||||
Age: 12,
|
||||
}
|
||||
err = c.Patch(ctx, usersTable, u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
var untaggedUser2 userWithNoTags
|
||||
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, untaggedUser2.Name, "Laura Ribeiro")
|
||||
tt.AssertEqual(t, untaggedUser2.Age, 12)
|
||||
})
|
||||
|
||||
t.Run("should not alter the value on queries", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type userWithNoTags struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
}
|
||||
untaggedUser := userWithNoTags{
|
||||
Name: "Marta Ribeiro",
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &untaggedUser)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
||||
|
||||
var taggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name,skipInserts"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
|
||||
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("skipUpdates modifier", func(t *testing.T) {
|
||||
t.Run("should set the field on insertion", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type tsUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name,skipUpdates"`
|
||||
}
|
||||
u := tsUser{
|
||||
Name: "Letícia",
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, u.ID, 0)
|
||||
|
||||
var untaggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, untaggedUser.Name, "Letícia")
|
||||
})
|
||||
|
||||
t.Run("should be ignored on updates", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type userWithNoTags struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
Age int `ksql:"age"`
|
||||
}
|
||||
untaggedUser := userWithNoTags{
|
||||
Name: "Laurinha Ribeiro",
|
||||
Age: 11,
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &untaggedUser)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
||||
|
||||
type taggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name,skipUpdates"`
|
||||
Age int `ksql:"age"`
|
||||
}
|
||||
u := taggedUser{
|
||||
ID: untaggedUser.ID,
|
||||
Name: "Laura Ribeiro",
|
||||
Age: 12,
|
||||
}
|
||||
err = c.Patch(ctx, usersTable, u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
var untaggedUser2 userWithNoTags
|
||||
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, untaggedUser2.Name, "Laurinha Ribeiro")
|
||||
tt.AssertEqual(t, untaggedUser2.Age, 12)
|
||||
})
|
||||
|
||||
t.Run("should not alter the value on queries", func(t *testing.T) {
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
type userWithNoTags struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
}
|
||||
untaggedUser := userWithNoTags{
|
||||
Name: "Marta Ribeiro",
|
||||
}
|
||||
err := c.Insert(ctx, usersTable, &untaggedUser)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
||||
|
||||
var taggedUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
Name string `ksql:"name,skipUpdates"`
|
||||
}
|
||||
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
|
||||
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ScanRowsTest runs all tests for making sure the ScanRows feature is
|
||||
// working for a given adapter and driver.
|
||||
func ScanRowsTest(
|
||||
|
@ -2546,7 +3070,7 @@ func ScanRowsTest(
|
|||
tt.AssertEqual(t, rows.Next(), true)
|
||||
|
||||
var u user
|
||||
err = scanRows(dialect, rows, &u)
|
||||
err = scanRows(ctx, dialect, rows, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
tt.AssertEqual(t, u.Name, "User2")
|
||||
|
@ -2579,7 +3103,7 @@ func ScanRowsTest(
|
|||
// Omitted for testing purposes:
|
||||
// Name string `ksql:"name"`
|
||||
}
|
||||
err = scanRows(dialect, rows, &u)
|
||||
err = scanRows(ctx, dialect, rows, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
tt.AssertEqual(t, u.Age, 22)
|
||||
|
@ -2602,7 +3126,7 @@ func ScanRowsTest(
|
|||
var u user
|
||||
err = rows.Close()
|
||||
tt.AssertNoErr(t, err)
|
||||
err = scanRows(dialect, rows, &u)
|
||||
err = scanRows(ctx, dialect, rows, &u)
|
||||
tt.AssertNotEqual(t, err, nil)
|
||||
})
|
||||
|
||||
|
@ -2622,7 +3146,7 @@ func ScanRowsTest(
|
|||
defer rows.Close()
|
||||
|
||||
var u user
|
||||
err = scanRows(dialect, rows, u)
|
||||
err = scanRows(ctx, dialect, rows, u)
|
||||
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user")
|
||||
})
|
||||
|
||||
|
@ -2642,7 +3166,7 @@ func ScanRowsTest(
|
|||
defer rows.Close()
|
||||
|
||||
var u map[string]interface{}
|
||||
err = scanRows(dialect, rows, &u)
|
||||
err = scanRows(ctx, dialect, rows, &u)
|
||||
tt.AssertErrContains(t, err, "KSQL", "expected", "pointer to struct", "map[string]interface")
|
||||
})
|
||||
})
|
||||
|
@ -2667,28 +3191,36 @@ func createTables(driver string, connStr string) error {
|
|||
id INTEGER PRIMARY KEY,
|
||||
age INTEGER,
|
||||
name TEXT,
|
||||
address BLOB
|
||||
address BLOB,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)`)
|
||||
case "postgres":
|
||||
_, err = db.Exec(`CREATE TABLE users (
|
||||
id serial PRIMARY KEY,
|
||||
age INT,
|
||||
name VARCHAR(50),
|
||||
address jsonb
|
||||
address jsonb,
|
||||
created_at TIMESTAMP,
|
||||
updated_at TIMESTAMP
|
||||
)`)
|
||||
case "mysql":
|
||||
_, err = db.Exec(`CREATE TABLE users (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
age INT,
|
||||
name VARCHAR(50),
|
||||
address JSON
|
||||
address JSON,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)`)
|
||||
case "sqlserver":
|
||||
_, err = db.Exec(`CREATE TABLE users (
|
||||
id INT IDENTITY(1,1) PRIMARY KEY,
|
||||
age INT,
|
||||
name VARCHAR(50),
|
||||
address NVARCHAR(4000)
|
||||
address NVARCHAR(4000),
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)`)
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -2798,9 +3330,18 @@ func getUserByID(db DBAdapter, dialect Dialect, result *user, id uint) error {
|
|||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
value := jsonSerializable{
|
||||
DriverName: dialect.DriverName(),
|
||||
Attr: &result.Address,
|
||||
modifier, _ := modifiers.LoadGlobalModifier("json")
|
||||
|
||||
value := modifiers.AttrScanWrapper{
|
||||
Ctx: context.TODO(),
|
||||
AttrPtr: &result.Address,
|
||||
ScanFn: modifier.Scan,
|
||||
OpInfo: ksqlmodifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
// We will not differentiate between Query, QueryOne and QueryChunks
|
||||
// if we did this could lead users to make very strange modifiers
|
||||
Method: "Query",
|
||||
},
|
||||
}
|
||||
|
||||
err = rows.Scan(&result.ID, &result.Name, &result.Age, &value)
|
||||
|
|
Loading…
Reference in New Issue