Merge pull request from VinGarcia/add-attr-middlewares

Draft: Create a mechanism for users to add their own implementation of modifiers
pull/32/head
Vinícius Garcia 2022-10-18 13:01:54 -03:00 committed by GitHub
commit b2e146d5e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1181 additions and 107 deletions

View File

@ -1,6 +1,9 @@
name: CI
on: [push, pull_request]
on:
push: {}
pull_request:
types: [opened, reopened]
jobs:
tests:
@ -20,9 +23,11 @@ jobs:
- name: Run Tests
run: ./scripts/run-all-tests.sh
- name: Run Coverage
run: bash <(curl -s https://codecov.io/bash)
run: |
curl -Os https://uploader.codecov.io/latest/linux/codecov
chmod +x codecov
./codecov -t $CODECOV_TOKEN
env:
CODECOV_TOKEN: 36be8ba6-7ef1-4ec2-b607-67c1055a62ad

View File

@ -54,7 +54,7 @@ func startMySQLDB(dbName string) (databaseURL string, closer func()) {
}
hostAndPort := resource.GetHostPort("3306/tcp")
databaseUrl := fmt.Sprintf("root:mysql@(%s)/%s?timeout=30s", hostAndPort, dbName)
databaseUrl := fmt.Sprintf("root:mysql@(%s)/%s?timeout=30s&parseTime=true", hostAndPort, dbName)
fmt.Println("Connecting to mariadb on url: ", databaseUrl)

View File

@ -8,6 +8,7 @@ import (
// ErrRecordNotFound ...
var ErrRecordNotFound error = fmt.Errorf("ksql: the query returned no results: %w", sql.ErrNoRows)
var ErrNoValuesToUpdate error = fmt.Errorf("ksql: the input struct contains no values to update")
// ErrAbortIteration ...
var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be used inside QueryChunks function")

View File

@ -0,0 +1,46 @@
package modifiers
import (
"context"
"database/sql/driver"
"github.com/vingarcia/ksql/ksqlmodifiers"
)
// AttrScanWrapper is the wrapper that allow us to intercept the Scan process
// so we can run the modifiers instead of allowing the database driver to use
// its default behavior.
//
// For that this struct implements the `sql.Scanner` interface
type AttrScanWrapper struct {
Ctx context.Context
AttrPtr interface{}
ScanFn ksqlmodifiers.AttrScanner
OpInfo ksqlmodifiers.OpInfo
}
// Scan implements the sql.Scanner interface
func (a AttrScanWrapper) Scan(dbValue interface{}) error {
return a.ScanFn(a.Ctx, a.OpInfo, a.AttrPtr, dbValue)
}
// AttrValueWrapper is the wrapper that allow us to intercept the "Valuing" process
// so we can run the modifiers instead of allowing the database driver to use
// its default behavior.
//
// For that this struct implements the `sql.Valuer` interface
type AttrValueWrapper struct {
Ctx context.Context
Attr interface{}
ValueFn ksqlmodifiers.AttrValuer
OpInfo ksqlmodifiers.OpInfo
}
// Value implements the sql.Valuer interface
func (a AttrValueWrapper) Value() (driver.Value, error) {
return a.ValueFn(a.Ctx, a.OpInfo, a.Attr)
}

View File

@ -0,0 +1,75 @@
package modifiers
import (
"context"
"errors"
"testing"
tt "github.com/vingarcia/ksql/internal/testtools"
"github.com/vingarcia/ksql/ksqlmodifiers"
)
func TestAttrScanWrapper(t *testing.T) {
ctx := context.Background()
var scanArgs map[string]interface{}
wrapper := AttrScanWrapper{
Ctx: ctx,
AttrPtr: "fakeAttrPtr",
ScanFn: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, attrPtr interface{}, dbValue interface{}) error {
scanArgs = map[string]interface{}{
"opInfo": opInfo,
"attrPtr": attrPtr,
"dbValue": dbValue,
}
return errors.New("fakeScanErrMsg")
},
OpInfo: ksqlmodifiers.OpInfo{
Method: "fakeMethod",
DriverName: "fakeDriverName",
},
}
err := wrapper.Scan("fakeDbValue")
tt.AssertErrContains(t, err, "fakeScanErrMsg")
tt.AssertEqual(t, scanArgs, map[string]interface{}{
"opInfo": ksqlmodifiers.OpInfo{
Method: "fakeMethod",
DriverName: "fakeDriverName",
},
"attrPtr": "fakeAttrPtr",
"dbValue": "fakeDbValue",
})
}
func TestAttrWrapper(t *testing.T) {
ctx := context.Background()
var valueArgs map[string]interface{}
wrapper := AttrValueWrapper{
Ctx: ctx,
Attr: "fakeAttr",
ValueFn: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
valueArgs = map[string]interface{}{
"opInfo": opInfo,
"inputValue": inputValue,
}
return "fakeOutputValue", errors.New("fakeValueErrMsg")
},
OpInfo: ksqlmodifiers.OpInfo{
Method: "fakeMethod",
DriverName: "fakeDriverName",
},
}
value, err := wrapper.Value()
tt.AssertErrContains(t, err, "fakeValueErrMsg")
tt.AssertEqual(t, valueArgs, map[string]interface{}{
"opInfo": ksqlmodifiers.OpInfo{
Method: "fakeMethod",
DriverName: "fakeDriverName",
},
"inputValue": "fakeAttr",
})
tt.AssertEqual(t, value, "fakeOutputValue")
}

View File

@ -0,0 +1,55 @@
package modifiers
import (
"fmt"
"sync"
"github.com/vingarcia/ksql/ksqlmodifiers"
)
// Here we keep all the registered modifiers
var modifiers sync.Map
func init() {
// Here we expose the registration function in a public package,
// so users can use it:
ksqlmodifiers.RegisterAttrModifier = RegisterAttrModifier
// These are the builtin modifiers:
// This one is useful for serializing/desserializing structs:
modifiers.Store("json", jsonModifier)
// This next two are useful for the UpdatedAt and Created fields respectively:
// They only work on time.Time attributes and will set the attribute to time.Now().
modifiers.Store("timeNowUTC", timeNowUTCModifier)
modifiers.Store("timeNowUTC/skipUpdates", timeNowUTCSkipUpdatesModifier)
// These are mostly example modifiers and they are also used
// to test the feature of skipping updates, inserts and queries.
modifiers.Store("skipUpdates", skipUpdatesModifier)
modifiers.Store("skipInserts", skipInsertsModifier)
}
// RegisterAttrModifier allow users to add custom modifiers on startup
// it is recommended to do this inside an init() function.
func RegisterAttrModifier(key string, modifier ksqlmodifiers.AttrModifier) {
_, found := modifiers.Load(key)
if found {
panic(fmt.Errorf("KSQL: cannot register modifier '%s' name is already in use", key))
}
modifiers.Store(key, modifier)
}
// LoadGlobalModifier is used internally by KSQL to load
// modifiers during runtime.
func LoadGlobalModifier(key string) (ksqlmodifiers.AttrModifier, error) {
rawModifier, _ := modifiers.Load(key)
modifier, ok := rawModifier.(ksqlmodifiers.AttrModifier)
if !ok {
return ksqlmodifiers.AttrModifier{}, fmt.Errorf("no modifier found with name '%s'", key)
}
return modifier, nil
}

View File

@ -0,0 +1,54 @@
package modifiers
import (
"testing"
tt "github.com/vingarcia/ksql/internal/testtools"
"github.com/vingarcia/ksql/ksqlmodifiers"
)
func TestRegisterAttrModifier(t *testing.T) {
t.Run("should register new modifiers correctly", func(t *testing.T) {
modifier1 := ksqlmodifiers.AttrModifier{
SkipOnUpdate: true,
}
modifier2 := ksqlmodifiers.AttrModifier{
SkipOnInsert: true,
}
RegisterAttrModifier("fakeModifierName1", modifier1)
RegisterAttrModifier("fakeModifierName2", modifier2)
mod, err := LoadGlobalModifier("fakeModifierName1")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, mod, modifier1)
mod, err = LoadGlobalModifier("fakeModifierName2")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, mod, modifier2)
})
t.Run("should panic registering a modifier and the name already exists", func(t *testing.T) {
modifier1 := ksqlmodifiers.AttrModifier{
SkipOnUpdate: true,
}
modifier2 := ksqlmodifiers.AttrModifier{
SkipOnInsert: true,
}
RegisterAttrModifier("fakeModifierName", modifier1)
panicPayload := tt.PanicHandler(func() {
RegisterAttrModifier("fakeModifierName", modifier2)
})
err, ok := panicPayload.(error)
tt.AssertEqual(t, ok, true)
tt.AssertErrContains(t, err, "KSQL", "fakeModifierName", "name is already in use")
})
t.Run("should return an error when loading an inexistent modifier", func(t *testing.T) {
mod, err := LoadGlobalModifier("nonExistentModifier")
tt.AssertErrContains(t, err, "nonExistentModifier")
tt.AssertEqual(t, mod, ksqlmodifiers.AttrModifier{})
})
}

View File

@ -0,0 +1,41 @@
package modifiers
import (
"context"
"encoding/json"
"fmt"
"github.com/vingarcia/ksql/ksqlmodifiers"
)
// This modifier serializes objects as JSON when
// sending it to the database and decodes
// them when receiving.
var jsonModifier = ksqlmodifiers.AttrModifier{
Scan: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, attrPtr interface{}, dbValue interface{}) error {
if dbValue == nil {
return nil
}
// Required since sqlite3 returns strings not bytes
if v, ok := dbValue.(string); ok {
dbValue = []byte(v)
}
rawJSON, ok := dbValue.([]byte)
if !ok {
return fmt.Errorf("unexpected type received to Scan: %T", dbValue)
}
return json.Unmarshal(rawJSON, attrPtr)
},
Value: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
b, err := json.Marshal(inputValue)
// SQL server uses the NVARCHAR type to store JSON and
// it expects to receive strings not []byte, thus:
if opInfo.DriverName == "sqlserver" {
return string(b), err
}
return b, err
},
}

View File

@ -0,0 +1,125 @@
package modifiers
import (
"context"
"testing"
tt "github.com/vingarcia/ksql/internal/testtools"
"github.com/vingarcia/ksql/ksqlmodifiers"
)
func TestAttrScan(t *testing.T) {
ctx := context.Background()
type FakeAttr struct {
Foo string `json:"foo"`
}
tests := []struct {
desc string
dbInput interface{}
expectedValue interface{}
expectErrToContain []string
}{
{
desc: "should not set struct to zero value if input is nil",
dbInput: nil,
expectedValue: FakeAttr{
Foo: "notZeroValue",
},
},
{
desc: "should work when input is a byte slice",
dbInput: []byte(`{"foo":"bar"}`),
expectedValue: FakeAttr{
Foo: "bar",
},
},
{
desc: "should work when input is a string",
dbInput: `{"foo":"bar"}`,
expectedValue: FakeAttr{
Foo: "bar",
},
},
{
desc: "should report error if input type is unsupported",
dbInput: 10,
expectErrToContain: []string{"unexpected type", "int"},
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
fakeAttr := FakeAttr{
Foo: "notZeroValue",
}
err := jsonModifier.Scan(ctx, ksqlmodifiers.OpInfo{}, &fakeAttr, test.dbInput)
if test.expectErrToContain != nil {
tt.AssertErrContains(t, err, test.expectErrToContain...)
t.Skip()
}
tt.AssertNoErr(t, err)
tt.AssertEqual(t, fakeAttr, test.expectedValue)
})
}
}
func TestAttrValue(t *testing.T) {
ctx := context.Background()
type FakeAttr struct {
Foo string `json:"foo"`
}
tests := []struct {
desc string
dbInput interface{}
opInfoInput ksqlmodifiers.OpInfo
attrValue interface{}
expectedOutput interface{}
expectErrToContain []string
}{
{
desc: "should return a byte array when the driver is not sqlserver",
dbInput: []byte(`{"foo":"bar"}`),
opInfoInput: ksqlmodifiers.OpInfo{
DriverName: "notSQLServer",
},
attrValue: FakeAttr{
Foo: "bar",
},
expectedOutput: tt.ToJSON(t, map[string]interface{}{
"foo": "bar",
}),
},
{
desc: "should return a string when the driver is sqlserver",
dbInput: []byte(`{"foo":"bar"}`),
opInfoInput: ksqlmodifiers.OpInfo{
DriverName: "sqlserver",
},
attrValue: FakeAttr{
Foo: "bar",
},
expectedOutput: string(tt.ToJSON(t, map[string]interface{}{
"foo": "bar",
})),
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
output, err := jsonModifier.Value(ctx, test.opInfoInput, test.attrValue)
if test.expectErrToContain != nil {
tt.AssertErrContains(t, err, test.expectErrToContain...)
t.Skip()
}
tt.AssertNoErr(t, err)
tt.AssertEqual(t, output, test.expectedOutput)
})
}
}

View File

@ -0,0 +1,11 @@
package modifiers
import "github.com/vingarcia/ksql/ksqlmodifiers"
var skipInsertsModifier = ksqlmodifiers.AttrModifier{
SkipOnInsert: true,
}
var skipUpdatesModifier = ksqlmodifiers.AttrModifier{
SkipOnUpdate: true,
}

View File

@ -0,0 +1,24 @@
package modifiers
import (
"context"
"time"
"github.com/vingarcia/ksql/ksqlmodifiers"
)
// This one is useful for updatedAt timestamps
var timeNowUTCModifier = ksqlmodifiers.AttrModifier{
Value: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
return time.Now().UTC(), nil
},
}
// This one is useful for createdAt timestamps
var timeNowUTCSkipUpdatesModifier = ksqlmodifiers.AttrModifier{
SkipOnUpdate: true,
Value: func(ctx context.Context, opInfo ksqlmodifiers.OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
return time.Now().UTC(), nil
},
}

View File

@ -5,6 +5,9 @@ import (
"reflect"
"strings"
"sync"
"github.com/vingarcia/ksql/internal/modifiers"
"github.com/vingarcia/ksql/ksqlmodifiers"
)
// StructInfo stores metainformation of the struct
@ -20,10 +23,10 @@ type StructInfo struct {
// information regarding a specific field
// of a struct.
type FieldInfo struct {
Name string
Index int
Valid bool
SerializeAsJSON bool
Name string
Index int
Valid bool
Modifier ksqlmodifiers.AttrModifier
}
// ByIndex returns either the *FieldInfo of a valid
@ -232,7 +235,7 @@ func (p PtrConverter) Convert(destType reflect.Type) (reflect.Value, error) {
//
// This should save several calls to `Field(i).Tag.Get("foo")`
// which improves performance by a lot.
func getTagNames(t reflect.Type) (StructInfo, error) {
func getTagNames(t reflect.Type) (_ StructInfo, err error) {
info := StructInfo{
byIndex: map[int]*FieldInfo{},
byName: map[string]*FieldInfo{},
@ -249,10 +252,13 @@ func getTagNames(t reflect.Type) (StructInfo, error) {
}
tags := strings.Split(name, ",")
serializeAsJSON := false
var modifier ksqlmodifiers.AttrModifier
if len(tags) > 1 {
name = tags[0]
serializeAsJSON = tags[1] == "json"
modifier, err = modifiers.LoadGlobalModifier(tags[1])
if err != nil {
return StructInfo{}, fmt.Errorf("attribute contains invalid modifier name: %w", err)
}
}
if _, found := info.byName[name]; found {
@ -263,9 +269,9 @@ func getTagNames(t reflect.Type) (StructInfo, error) {
}
info.add(FieldInfo{
Name: name,
Index: i,
SerializeAsJSON: serializeAsJSON,
Name: name,
Index: i,
Modifier: modifier,
})
}

View File

@ -0,0 +1,13 @@
package tt
import (
"encoding/json"
"testing"
)
func ToJSON(t *testing.T, obj interface{}) []byte {
rawJSON, err := json.Marshal(obj)
AssertNoErr(t, err)
return rawJSON
}

View File

@ -0,0 +1,12 @@
package tt
import (
"testing"
"time"
)
func ParseTime(t *testing.T, timestr string) time.Time {
parsedTime, err := time.Parse(time.RFC3339, timestr)
AssertNoErr(t, err)
return parsedTime
}

49
json.go
View File

@ -1,49 +0,0 @@
package ksql
import (
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
)
// This type was created to make it easier to adapt
// input attributes to be convertible to and from JSON
// before sending or receiving it from the database.
type jsonSerializable struct {
DriverName string
Attr interface{}
}
// Scan Implements the Scanner interface in order to load
// this field from the JSON stored in the database
func (j *jsonSerializable) Scan(value interface{}) error {
if value == nil {
v := reflect.ValueOf(j.Attr)
// Set the struct to its 0 value just like json.Unmarshal
// does for nil attributes:
v.Elem().Set(reflect.Zero(reflect.TypeOf(j.Attr).Elem()))
return nil
}
// Required since sqlite3 returns strings not bytes
if v, ok := value.(string); ok {
value = []byte(v)
}
rawJSON, ok := value.([]byte)
if !ok {
return fmt.Errorf("unexpected type received to Scan: %T", value)
}
return json.Unmarshal(rawJSON, j.Attr)
}
// Value Implements the Valuer interface in order to save
// this field as JSON on the database.
func (j jsonSerializable) Value() (driver.Value, error) {
b, err := json.Marshal(j.Attr)
if j.DriverName == "sqlserver" {
return string(b), err
}
return b, err
}

129
ksql.go
View File

@ -10,7 +10,9 @@ import (
"sync"
"unicode"
"github.com/vingarcia/ksql/internal/modifiers"
"github.com/vingarcia/ksql/internal/structs"
"github.com/vingarcia/ksql/ksqlmodifiers"
"github.com/vingarcia/ksql/ksqltest"
)
@ -184,7 +186,7 @@ func (c DB) Query(
elemPtr = elemPtr.Elem()
}
err = scanRows(c.dialect, rows, elemPtr.Interface())
err = scanRows(ctx, c.dialect, rows, elemPtr.Interface())
if err != nil {
return err
}
@ -257,13 +259,13 @@ func (c DB) QueryOne(
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return rows.Err()
if err := rows.Err(); err != nil {
return err
}
return ErrRecordNotFound
}
err = scanRowsFromType(c.dialect, rows, record, t, v)
err = scanRowsFromType(ctx, c.dialect, rows, record, t, v)
if err != nil {
return err
}
@ -342,7 +344,7 @@ func (c DB) QueryChunks(
chunk = reflect.Append(chunk, elemValue)
}
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
err = scanRows(ctx, c.dialect, rows, chunk.Index(idx).Addr().Interface())
if err != nil {
return err
}
@ -419,7 +421,7 @@ func (c DB) Insert(
return err
}
query, params, scanValues, err := buildInsertQuery(c.dialect, table, t, v, info, record)
query, params, scanValues, err := buildInsertQuery(ctx, c.dialect, table, t, v, info, record)
if err != nil {
return err
}
@ -656,7 +658,7 @@ func (c DB) Patch(
return err
}
query, params, err := buildUpdateQuery(c.dialect, table.name, info, record, table.idColumns...)
query, params, err := buildUpdateQuery(ctx, c.dialect, table.name, info, record, table.idColumns...)
if err != nil {
return err
}
@ -681,6 +683,7 @@ func (c DB) Patch(
}
func buildInsertQuery(
ctx context.Context,
dialect Dialect,
table Table,
t reflect.Type,
@ -707,18 +710,29 @@ func buildInsertQuery(
columnNames := []string{}
for col := range recordMap {
if info.ByName(col).Modifier.SkipOnInsert {
continue
}
columnNames = append(columnNames, col)
}
params = make([]interface{}, len(recordMap))
valuesQuery := make([]string, len(recordMap))
params = make([]interface{}, len(columnNames))
valuesQuery := make([]string, len(columnNames))
for i, col := range columnNames {
recordValue := recordMap[col]
params[i] = recordValue
if info.ByName(col).SerializeAsJSON {
params[i] = jsonSerializable{
DriverName: dialect.DriverName(),
Attr: recordValue,
valueFn := info.ByName(col).Modifier.Value
if valueFn != nil {
params[i] = modifiers.AttrValueWrapper{
Ctx: ctx,
Attr: recordValue,
ValueFn: valueFn,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
Method: "Insert",
},
}
}
@ -761,6 +775,16 @@ func buildInsertQuery(
}
}
if len(columnNames) == 0 && dialect.DriverName() != "mysql" {
query = fmt.Sprintf(
"INSERT INTO %s%s DEFAULT VALUES%s",
dialect.Escape(table.name),
outputQuery,
returningQuery,
)
return query, params, scanValues, nil
}
// Note that the outputQuery and the returningQuery depend
// on the selected driver, thus, they might be empty strings.
query = fmt.Sprintf(
@ -776,21 +800,32 @@ func buildInsertQuery(
}
func buildUpdateQuery(
ctx context.Context,
dialect Dialect,
tableName string,
info structs.StructInfo,
record interface{},
idFieldNames ...string,
) (query string, args []interface{}, err error) {
recordMap, err := ksqltest.StructToMap(record)
recordMap, err := structs.StructToMap(record)
if err != nil {
return "", nil, err
}
for key := range recordMap {
if info.ByName(key).Modifier.SkipOnUpdate {
delete(recordMap, key)
}
}
numAttrs := len(recordMap)
args = make([]interface{}, numAttrs)
numNonIDArgs := numAttrs - len(idFieldNames)
whereArgs := args[numNonIDArgs:]
if numNonIDArgs == 0 {
return "", nil, ErrNoValuesToUpdate
}
err = validateIfAllIdsArePresent(idFieldNames, recordMap)
if err != nil {
return "", nil, err
@ -816,10 +851,17 @@ func buildUpdateQuery(
var setQuery []string
for i, k := range keys {
recordValue := recordMap[k]
if info.ByName(k).SerializeAsJSON {
recordValue = jsonSerializable{
DriverName: dialect.DriverName(),
Attr: recordValue,
valueFn := info.ByName(k).Modifier.Value
if valueFn != nil {
recordValue = modifiers.AttrValueWrapper{
Ctx: ctx,
Attr: recordValue,
ValueFn: valueFn,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
Method: "Update",
},
}
}
args[i] = recordValue
@ -930,13 +972,14 @@ func (nopScanner) Scan(value interface{}) error {
return nil
}
func scanRows(dialect Dialect, rows Rows, record interface{}) error {
func scanRows(ctx context.Context, dialect Dialect, rows Rows, record interface{}) error {
v := reflect.ValueOf(record)
t := v.Type()
return scanRowsFromType(dialect, rows, record, t, v)
return scanRowsFromType(ctx, dialect, rows, record, t, v)
}
func scanRowsFromType(
ctx context.Context,
dialect Dialect,
rows Rows,
record interface{},
@ -964,7 +1007,7 @@ func scanRowsFromType(
// This version is positional meaning that it expect the arguments
// to follow an specific order. It's ok because we don't allow the
// user to type the "SELECT" part of the query for nested structs.
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info)
scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info)
if err != nil {
return err
}
@ -975,7 +1018,7 @@ func scanRowsFromType(
}
// Since this version uses the names of the columns it works
// with any order of attributes/columns.
scanArgs = getScanArgsFromNames(dialect, names, v, info)
scanArgs = getScanArgsFromNames(ctx, dialect, names, v, info)
}
err = rows.Scan(scanArgs...)
@ -985,7 +1028,14 @@ func scanRowsFromType(
return nil
}
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
func getScanArgsForNestedStructs(
ctx context.Context,
dialect Dialect,
rows Rows,
t reflect.Type,
v reflect.Value,
info structs.StructInfo,
) ([]interface{}, error) {
scanArgs := []interface{}{}
for i := 0; i < v.NumField(); i++ {
if !info.ByIndex(i).Valid {
@ -1008,10 +1058,18 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
valueScanner := nopScannerValue
if fieldInfo.Valid {
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.SerializeAsJSON {
valueScanner = &jsonSerializable{
DriverName: dialect.DriverName(),
Attr: valueScanner,
if fieldInfo.Modifier.Scan != nil {
valueScanner = &modifiers.AttrScanWrapper{
Ctx: ctx,
AttrPtr: valueScanner,
ScanFn: fieldInfo.Modifier.Scan,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers
Method: "Query",
},
}
}
}
@ -1023,7 +1081,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
return scanArgs, nil
}
func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
scanArgs := []interface{}{}
for _, name := range names {
fieldInfo := info.ByName(name)
@ -1031,10 +1089,17 @@ func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info
valueScanner := nopScannerValue
if fieldInfo.Valid {
valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.SerializeAsJSON {
valueScanner = &jsonSerializable{
DriverName: dialect.DriverName(),
Attr: valueScanner,
if fieldInfo.Modifier.Scan != nil {
valueScanner = &modifiers.AttrScanWrapper{
Ctx: ctx,
AttrPtr: valueScanner,
ScanFn: fieldInfo.Modifier.Scan,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers
Method: "Query",
},
}
}
}

View File

@ -0,0 +1,34 @@
package ksqlmodifiers
import "context"
// AttrModifier informs KSQL how to use this modifier
type AttrModifier struct {
// The following attributes will tell KSQL to
// leave this attribute out of insertions, updates,
// and queries respectively.
SkipOnInsert bool
SkipOnUpdate bool
// Implement these functions if you want to override the default Scan/Value behavior
// for the target attribute.
Scan AttrScanner
Value AttrValuer
}
// AttrScanner describes the operation of deserializing an object received from the database.
type AttrScanner func(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error
// AttrValuer describes the operation of serializing an object when saving it to the database.
type AttrValuer func(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error)
// OpInfo contains information that might be used by a modifier to determine how it should behave.
type OpInfo struct {
// A string version of the name of one of
// the methods of the `ksql.Provider` interface, e.g. `Insert` or `Query`
Method string
// The string representing the current underlying database, e.g.:
// "postgres", "sqlite3", "mysql" or "sqlserver".
DriverName string
}

6
ksqlmodifiers/godoc.go Normal file
View File

@ -0,0 +1,6 @@
// This package exposes only the public types and functions for creating new
// modifiers for KSQL.
//
// For understanding internal details of the code
// please read the `internal/modifiers` package.
package ksqlmodifiers

View File

@ -0,0 +1,9 @@
package ksqlmodifiers
// RegisterAttrModifier allow users to add custom modifiers on startup
// it is recommended to do this inside an init() function.
var RegisterAttrModifier func(key string, modifier AttrModifier)
// This method is set at startup by the `internal/modifiers` package.
// It was done that way in order to keep most of the implementation private
// while also avoiding cyclic dependencies.

View File

@ -7,8 +7,11 @@ import (
"fmt"
"io"
"testing"
"time"
"github.com/vingarcia/ksql/internal/modifiers"
tt "github.com/vingarcia/ksql/internal/testtools"
"github.com/vingarcia/ksql/ksqlmodifiers"
"github.com/vingarcia/ksql/nullable"
)
@ -70,6 +73,7 @@ func RunTestsForAdapter(
PatchTest(t, driver, connStr, newDBAdapter)
QueryChunksTest(t, driver, connStr, newDBAdapter)
TransactionTest(t, driver, connStr, newDBAdapter)
ModifiersTest(t, driver, connStr, newDBAdapter)
ScanRowsTest(t, driver, connStr, newDBAdapter)
})
}
@ -893,6 +897,33 @@ func InsertTest(
tt.AssertNoErr(t, err)
tt.AssertEqual(t, inserted.Age, 5455)
})
t.Run("should work and retrieve the ID for structs with no attributes", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, driver)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipInserts"`
}
u := tsUser{
Name: "Letícia",
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name *string `ksql:"name"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
})
})
t.Run("composite key tables", func(t *testing.T) {
@ -958,6 +989,55 @@ func InsertTest(
tt.AssertEqual(t, userPerms[0].PermID, 42)
}
})
t.Run("when inserting a struct with no values but composite keys should still retrieve the IDs", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, driver)
// Table defined with 3 values, but we'll provide only 2,
// the third will be generated for the purposes of this test:
table := NewTable("user_permissions", "id", "user_id", "perm_id")
type taggedPerm struct {
ID uint `ksql:"id"`
UserID int `ksql:"user_id"`
PermID int `ksql:"perm_id"`
Type string `ksql:"type,skipInserts"`
}
permission := taggedPerm{
UserID: 3,
PermID: 43,
}
err := c.Insert(ctx, table, &permission)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, permission.ID, 0)
fmt.Println("permID:", permission.ID)
var untaggedPerm struct {
ID uint `ksql:"id"`
UserID int `ksql:"user_id"`
PermID int `ksql:"perm_id"`
Type *string `ksql:"type"`
}
err = c.QueryOne(ctx, &untaggedPerm, `FROM user_permissions WHERE user_id = 3 AND perm_id = 43`)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedPerm.Type, (*string)(nil))
// Should retrieve the generated ID from the database,
// only if the database supports returning multiple values:
switch c.dialect.InsertMethod() {
case insertWithNoIDRetrieval, insertWithLastInsertID:
tt.AssertEqual(t, permission.ID, uint(0))
tt.AssertEqual(t, untaggedPerm.UserID, 3)
tt.AssertEqual(t, untaggedPerm.PermID, 43)
case insertWithReturning, insertWithOutput:
tt.AssertEqual(t, untaggedPerm.ID, permission.ID)
tt.AssertEqual(t, untaggedPerm.UserID, 3)
tt.AssertEqual(t, untaggedPerm.PermID, 43)
}
})
})
})
@ -1711,6 +1791,24 @@ func PatchTest(
tt.AssertNotEqual(t, err, nil)
})
t.Run("should report error if the struct has no fields to update", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, driver)
err = c.Update(ctx, usersTable, struct {
ID uint `ksql:"id"` // ID fields are not updated
Name string `ksql:"name,skipUpdates"` // the skipUpdate modifier should rule this one out
Age *int `ksql:"age"` // Age is a nil pointer so it would not be updated
}{
ID: 1,
Name: "some name",
})
tt.AssertEqual(t, err, ErrNoValuesToUpdate)
})
t.Run("should report error if the id is missing", func(t *testing.T) {
t.Run("with a single primary key", func(t *testing.T) {
db, closer := newDBAdapter(t)
@ -2515,6 +2613,432 @@ func TransactionTest(
})
}
func ModifiersTest(
t *testing.T,
driver string,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("Modifiers", func(t *testing.T) {
err := createTables(driver, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("timeNowUTC modifier", func(t *testing.T) {
t.Run("should be set to time.Now().UTC() on insertion", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
}
u := tsUser{
Name: "Letícia",
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
now := time.Now()
tt.AssertApproxTime(t,
2*time.Second, untaggedUser.UpdatedAt, now,
"updatedAt should be set to %v, but got: %v", now, untaggedUser.UpdatedAt,
)
})
t.Run("should be set to time.Now().UTC() on updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at"`
}
untaggedUser := userWithNoTags{
Name: "Laura Ribeiro",
// Any time different from now:
UpdatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laurinha Ribeiro",
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser2.ID, 0)
now := time.Now()
tt.AssertApproxTime(t,
2*time.Second, untaggedUser2.UpdatedAt, now,
"updatedAt should be set to %v, but got: %v", now, untaggedUser2.UpdatedAt,
)
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at"`
}
untaggedUser := userWithNoTags{
Name: "Marta Ribeiro",
// Any time different from now:
UpdatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
tt.AssertEqual(t, taggedUser.UpdatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
})
})
t.Run("timeNowUTC/skipUpdates modifier", func(t *testing.T) {
t.Run("should be set to time.Now().UTC() on insertion", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
}
u := tsUser{
Name: "Letícia",
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
now := time.Now()
tt.AssertApproxTime(t,
2*time.Second, untaggedUser.CreatedAt, now,
"updatedAt should be set to %v, but got: %v", now, untaggedUser.CreatedAt,
)
})
t.Run("should be ignored on updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at"`
}
untaggedUser := userWithNoTags{
Name: "Laura Ribeiro",
// Any time different from now:
CreatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laurinha Ribeiro",
// Some random time that should be ignored:
CreatedAt: tt.ParseTime(t, "1999-08-05T14:00:00Z"),
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser2.CreatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at"`
}
untaggedUser := userWithNoTags{
Name: "Marta Ribeiro",
// Any time different from now:
CreatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
tt.AssertEqual(t, taggedUser.CreatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
})
})
t.Run("skipInserts modifier", func(t *testing.T) {
t.Run("should ignore the field during insertions", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipInserts"`
Age int `ksql:"age"`
}
u := tsUser{
Name: "Letícia",
Age: 22,
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name *string `ksql:"name"`
Age int `ksql:"age"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
tt.AssertEqual(t, untaggedUser.Age, 22)
})
t.Run("should have no effect on updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
Age int `ksql:"age"`
}
untaggedUser := userWithNoTags{
Name: "Laurinha Ribeiro",
Age: 11,
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipInserts"`
Age int `ksql:"age"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laura Ribeiro",
Age: 12,
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser2.Name, "Laura Ribeiro")
tt.AssertEqual(t, untaggedUser2.Age, 12)
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
}
untaggedUser := userWithNoTags{
Name: "Marta Ribeiro",
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipInserts"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
})
})
t.Run("skipUpdates modifier", func(t *testing.T) {
t.Run("should set the field on insertion", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipUpdates"`
}
u := tsUser{
Name: "Letícia",
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser.Name, "Letícia")
})
t.Run("should be ignored on updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
Age int `ksql:"age"`
}
untaggedUser := userWithNoTags{
Name: "Laurinha Ribeiro",
Age: 11,
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipUpdates"`
Age int `ksql:"age"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laura Ribeiro",
Age: 12,
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser2.Name, "Laurinha Ribeiro")
tt.AssertEqual(t, untaggedUser2.Age, 12)
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
}
untaggedUser := userWithNoTags{
Name: "Marta Ribeiro",
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipUpdates"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
})
})
})
}
// ScanRowsTest runs all tests for making sure the ScanRows feature is
// working for a given adapter and driver.
func ScanRowsTest(
@ -2546,7 +3070,7 @@ func ScanRowsTest(
tt.AssertEqual(t, rows.Next(), true)
var u user
err = scanRows(dialect, rows, &u)
err = scanRows(ctx, dialect, rows, &u)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Name, "User2")
@ -2579,7 +3103,7 @@ func ScanRowsTest(
// Omitted for testing purposes:
// Name string `ksql:"name"`
}
err = scanRows(dialect, rows, &u)
err = scanRows(ctx, dialect, rows, &u)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Age, 22)
@ -2602,7 +3126,7 @@ func ScanRowsTest(
var u user
err = rows.Close()
tt.AssertNoErr(t, err)
err = scanRows(dialect, rows, &u)
err = scanRows(ctx, dialect, rows, &u)
tt.AssertNotEqual(t, err, nil)
})
@ -2622,7 +3146,7 @@ func ScanRowsTest(
defer rows.Close()
var u user
err = scanRows(dialect, rows, u)
err = scanRows(ctx, dialect, rows, u)
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user")
})
@ -2642,7 +3166,7 @@ func ScanRowsTest(
defer rows.Close()
var u map[string]interface{}
err = scanRows(dialect, rows, &u)
err = scanRows(ctx, dialect, rows, &u)
tt.AssertErrContains(t, err, "KSQL", "expected", "pointer to struct", "map[string]interface")
})
})
@ -2667,28 +3191,36 @@ func createTables(driver string, connStr string) error {
id INTEGER PRIMARY KEY,
age INTEGER,
name TEXT,
address BLOB
address BLOB,
created_at DATETIME,
updated_at DATETIME
)`)
case "postgres":
_, err = db.Exec(`CREATE TABLE users (
id serial PRIMARY KEY,
age INT,
name VARCHAR(50),
address jsonb
address jsonb,
created_at TIMESTAMP,
updated_at TIMESTAMP
)`)
case "mysql":
_, err = db.Exec(`CREATE TABLE users (
id INT AUTO_INCREMENT PRIMARY KEY,
age INT,
name VARCHAR(50),
address JSON
address JSON,
created_at DATETIME,
updated_at DATETIME
)`)
case "sqlserver":
_, err = db.Exec(`CREATE TABLE users (
id INT IDENTITY(1,1) PRIMARY KEY,
age INT,
name VARCHAR(50),
address NVARCHAR(4000)
address NVARCHAR(4000),
created_at DATETIME,
updated_at DATETIME
)`)
}
if err != nil {
@ -2798,9 +3330,18 @@ func getUserByID(db DBAdapter, dialect Dialect, result *user, id uint) error {
return sql.ErrNoRows
}
value := jsonSerializable{
DriverName: dialect.DriverName(),
Attr: &result.Address,
modifier, _ := modifiers.LoadGlobalModifier("json")
value := modifiers.AttrScanWrapper{
Ctx: context.TODO(),
AttrPtr: &result.Address,
ScanFn: modifier.Scan,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers
Method: "Query",
},
}
err = rows.Scan(&result.ID, &result.Name, &result.Age, &value)