Refactor code so that users can add their own serializers

pull/29/head
Vinícius Garcia 2022-09-14 23:03:33 -03:00
parent 67ad75242a
commit f95cd2b7b2
5 changed files with 171 additions and 62 deletions

66
attr_serializers.go Normal file
View File

@ -0,0 +1,66 @@
package ksql
import (
"context"
"database/sql/driver"
"fmt"
)
// Here we keep all the registered serializers
var serializers = map[string]AttrSerializer{
"json": jsonSerializer{},
}
// RegisterAttrSerializer allow users to add custom serializers on startup
// it is recommended to do this inside an init() function.
func RegisterAttrSerializer(key string, serializer AttrSerializer) {
_, found := serializers[key]
if found {
panic(fmt.Errorf("KSQL: cannot register serializer '%s' name is already in use", key))
}
serializers[key] = serializer
}
// AttrSerializer describes the two operations required to serialize and deserialize an object from the database.
type AttrSerializer interface {
AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error
AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error)
}
// OpInfo contains information that might be used by a serializer to determine how it should behave.
type OpInfo struct {
// A string version of the name of one of
// the methods of the `ksql.Provider` interface, e.g. `Insert` or `Query`
Method string
// The string representing the current underlying database, e.g.:
// "postgres", "sqlite3", "mysql" or "sqlserver".
DriverName string
}
// attrSerializer is the wrapper that allow us to intercept the Scan and Value processes
// so we can run the serializers instead of allowing the database driver to use
// its default behavior.
//
// For that this struct implements both the `sql.Scanner` and `sql.Valuer` interfaces.
type attrSerializer struct {
ctx context.Context
// When Scanning this value should be a pointer to the attribute
// and when "Valuing" it should just be the actual value
attr interface{}
serializerName string
opInfo OpInfo
}
// Scan implements the sql.Scanner interface
func (a attrSerializer) Scan(dbValue interface{}) error {
return serializers[a.serializerName].AttrScan(a.ctx, a.opInfo, a.attr, dbValue)
}
// Value implements the sql.Valuer interface
func (a attrSerializer) Value() (driver.Value, error) {
return serializers[a.serializerName].AttrValue(a.ctx, a.opInfo, a.attr)
}

View File

@ -20,10 +20,10 @@ type StructInfo struct {
// information regarding a specific field // information regarding a specific field
// of a struct. // of a struct.
type FieldInfo struct { type FieldInfo struct {
Name string Name string
Index int Index int
Valid bool Valid bool
SerializeAsJSON bool SerializerName string
} }
// ByIndex returns either the *FieldInfo of a valid // ByIndex returns either the *FieldInfo of a valid
@ -249,10 +249,10 @@ func getTagNames(t reflect.Type) (StructInfo, error) {
} }
tags := strings.Split(name, ",") tags := strings.Split(name, ",")
serializeAsJSON := false var serializerName string
if len(tags) > 1 { if len(tags) > 1 {
name = tags[0] name = tags[0]
serializeAsJSON = tags[1] == "json" serializerName = tags[1]
} }
if _, found := info.byName[name]; found { if _, found := info.byName[name]; found {
@ -263,9 +263,9 @@ func getTagNames(t reflect.Type) (StructInfo, error) {
} }
info.add(FieldInfo{ info.add(FieldInfo{
Name: name, Name: name,
Index: i, Index: i,
SerializeAsJSON: serializeAsJSON, SerializerName: serializerName,
}) })
} }

31
json.go
View File

@ -1,7 +1,7 @@
package ksql package ksql
import ( import (
"database/sql/driver" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
@ -10,39 +10,36 @@ import (
// This type was created to make it easier to adapt // This type was created to make it easier to adapt
// input attributes to be convertible to and from JSON // input attributes to be convertible to and from JSON
// before sending or receiving it from the database. // before sending or receiving it from the database.
type jsonSerializable struct { type jsonSerializer struct{}
DriverName string
Attr interface{}
}
// Scan Implements the Scanner interface in order to load // Scan Implements the Scanner interface in order to load
// this field from the JSON stored in the database // this field from the JSON stored in the database
func (j *jsonSerializable) Scan(value interface{}) error { func (j jsonSerializer) AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error {
if value == nil { if dbValue == nil {
v := reflect.ValueOf(j.Attr) v := reflect.ValueOf(attrPtr)
// Set the struct to its 0 value just like json.Unmarshal // Set the struct to its 0 value just like json.Unmarshal
// does for nil attributes: // does for nil attributes:
v.Elem().Set(reflect.Zero(reflect.TypeOf(j.Attr).Elem())) v.Elem().Set(reflect.Zero(reflect.TypeOf(attrPtr).Elem()))
return nil return nil
} }
// Required since sqlite3 returns strings not bytes // Required since sqlite3 returns strings not bytes
if v, ok := value.(string); ok { if v, ok := dbValue.(string); ok {
value = []byte(v) dbValue = []byte(v)
} }
rawJSON, ok := value.([]byte) rawJSON, ok := dbValue.([]byte)
if !ok { if !ok {
return fmt.Errorf("unexpected type received to Scan: %T", value) return fmt.Errorf("unexpected type received to Scan: %T", dbValue)
} }
return json.Unmarshal(rawJSON, j.Attr) return json.Unmarshal(rawJSON, attrPtr)
} }
// Value Implements the Valuer interface in order to save // Value Implements the Valuer interface in order to save
// this field as JSON on the database. // this field as JSON on the database.
func (j jsonSerializable) Value() (driver.Value, error) { func (j jsonSerializer) AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) {
b, err := json.Marshal(j.Attr) b, err := json.Marshal(inputValue)
if j.DriverName == "sqlserver" { if opInfo.DriverName == "sqlserver" {
return string(b), err return string(b), err
} }
return b, err return b, err

95
ksql.go
View File

@ -185,7 +185,7 @@ func (c DB) Query(
elemPtr = elemPtr.Elem() elemPtr = elemPtr.Elem()
} }
err = scanRows(c.dialect, rows, elemPtr.Interface()) err = scanRows(ctx, c.dialect, rows, elemPtr.Interface())
if err != nil { if err != nil {
return err return err
} }
@ -264,7 +264,7 @@ func (c DB) QueryOne(
return ErrRecordNotFound return ErrRecordNotFound
} }
err = scanRowsFromType(c.dialect, rows, record, t, v) err = scanRowsFromType(ctx, c.dialect, rows, record, t, v)
if err != nil { if err != nil {
return err return err
} }
@ -343,7 +343,7 @@ func (c DB) QueryChunks(
chunk = reflect.Append(chunk, elemValue) chunk = reflect.Append(chunk, elemValue)
} }
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface()) err = scanRows(ctx, c.dialect, rows, chunk.Index(idx).Addr().Interface())
if err != nil { if err != nil {
return err return err
} }
@ -420,7 +420,7 @@ func (c DB) Insert(
return err return err
} }
query, params, scanValues, err := buildInsertQuery(c.dialect, table, t, v, info, record) query, params, scanValues, err := buildInsertQuery(ctx, c.dialect, table, t, v, info, record)
if err != nil { if err != nil {
return err return err
} }
@ -657,7 +657,7 @@ func (c DB) Patch(
return err return err
} }
query, params, err := buildUpdateQuery(c.dialect, table.name, info, record, table.idColumns...) query, params, err := buildUpdateQuery(ctx, c.dialect, table.name, info, record, table.idColumns...)
if err != nil { if err != nil {
return err return err
} }
@ -682,6 +682,7 @@ func (c DB) Patch(
} }
func buildInsertQuery( func buildInsertQuery(
ctx context.Context,
dialect Dialect, dialect Dialect,
table Table, table Table,
t reflect.Type, t reflect.Type,
@ -716,10 +717,17 @@ func buildInsertQuery(
for i, col := range columnNames { for i, col := range columnNames {
recordValue := recordMap[col] recordValue := recordMap[col]
params[i] = recordValue params[i] = recordValue
if info.ByName(col).SerializeAsJSON {
params[i] = jsonSerializable{ serializerName := info.ByName(col).SerializerName
DriverName: dialect.DriverName(), if serializerName != "" {
Attr: recordValue, params[i] = attrSerializer{
ctx: ctx,
attr: recordValue,
serializerName: serializerName,
opInfo: OpInfo{
DriverName: dialect.DriverName(),
Method: "Insert",
},
} }
} }
@ -777,13 +785,14 @@ func buildInsertQuery(
} }
func buildUpdateQuery( func buildUpdateQuery(
ctx context.Context,
dialect Dialect, dialect Dialect,
tableName string, tableName string,
info structs.StructInfo, info structs.StructInfo,
record interface{}, record interface{},
idFieldNames ...string, idFieldNames ...string,
) (query string, args []interface{}, err error) { ) (query string, args []interface{}, err error) {
recordMap, err := ksqltest.StructToMap(record) recordMap, err := structs.StructToMap(record)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -817,10 +826,17 @@ func buildUpdateQuery(
var setQuery []string var setQuery []string
for i, k := range keys { for i, k := range keys {
recordValue := recordMap[k] recordValue := recordMap[k]
if info.ByName(k).SerializeAsJSON {
recordValue = jsonSerializable{ serializerName := info.ByName(k).SerializerName
DriverName: dialect.DriverName(), if serializerName != "" {
Attr: recordValue, recordValue = attrSerializer{
ctx: ctx,
attr: recordValue,
serializerName: serializerName,
opInfo: OpInfo{
DriverName: dialect.DriverName(),
Method: "Update",
},
} }
} }
args[i] = recordValue args[i] = recordValue
@ -929,13 +945,14 @@ func (nopScanner) Scan(value interface{}) error {
return nil return nil
} }
func scanRows(dialect Dialect, rows Rows, record interface{}) error { func scanRows(ctx context.Context, dialect Dialect, rows Rows, record interface{}) error {
v := reflect.ValueOf(record) v := reflect.ValueOf(record)
t := v.Type() t := v.Type()
return scanRowsFromType(dialect, rows, record, t, v) return scanRowsFromType(ctx, dialect, rows, record, t, v)
} }
func scanRowsFromType( func scanRowsFromType(
ctx context.Context,
dialect Dialect, dialect Dialect,
rows Rows, rows Rows,
record interface{}, record interface{},
@ -963,7 +980,7 @@ func scanRowsFromType(
// This version is positional meaning that it expect the arguments // This version is positional meaning that it expect the arguments
// to follow an specific order. It's ok because we don't allow the // to follow an specific order. It's ok because we don't allow the
// user to type the "SELECT" part of the query for nested structs. // user to type the "SELECT" part of the query for nested structs.
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info) scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info)
if err != nil { if err != nil {
return err return err
} }
@ -974,7 +991,7 @@ func scanRowsFromType(
} }
// Since this version uses the names of the columns it works // Since this version uses the names of the columns it works
// with any order of attributes/columns. // with any order of attributes/columns.
scanArgs = getScanArgsFromNames(dialect, names, v, info) scanArgs = getScanArgsFromNames(ctx, dialect, names, v, info)
} }
err = rows.Scan(scanArgs...) err = rows.Scan(scanArgs...)
@ -984,7 +1001,14 @@ func scanRowsFromType(
return nil return nil
} }
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) { func getScanArgsForNestedStructs(
ctx context.Context,
dialect Dialect,
rows Rows,
t reflect.Type,
v reflect.Value,
info structs.StructInfo,
) ([]interface{}, error) {
scanArgs := []interface{}{} scanArgs := []interface{}{}
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
if !info.ByIndex(i).Valid { if !info.ByIndex(i).Valid {
@ -1007,10 +1031,18 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
valueScanner := nopScannerValue valueScanner := nopScannerValue
if fieldInfo.Valid { if fieldInfo.Valid {
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface() valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.SerializeAsJSON {
valueScanner = &jsonSerializable{ if fieldInfo.SerializerName != "" {
DriverName: dialect.DriverName(), valueScanner = &attrSerializer{
Attr: valueScanner, ctx: ctx,
attr: valueScanner,
serializerName: fieldInfo.SerializerName,
opInfo: OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange serializers
Method: "Query",
},
} }
} }
} }
@ -1022,7 +1054,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
return scanArgs, nil return scanArgs, nil
} }
func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} { func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
scanArgs := []interface{}{} scanArgs := []interface{}{}
for _, name := range names { for _, name := range names {
fieldInfo := info.ByName(name) fieldInfo := info.ByName(name)
@ -1030,10 +1062,17 @@ func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info
valueScanner := nopScannerValue valueScanner := nopScannerValue
if fieldInfo.Valid { if fieldInfo.Valid {
valueScanner = v.Field(fieldInfo.Index).Addr().Interface() valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.SerializeAsJSON { if fieldInfo.SerializerName != "" {
valueScanner = &jsonSerializable{ valueScanner = &attrSerializer{
DriverName: dialect.DriverName(), ctx: ctx,
Attr: valueScanner, attr: valueScanner,
serializerName: fieldInfo.SerializerName,
opInfo: OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange serializers
Method: "Query",
},
} }
} }
} }

View File

@ -2547,7 +2547,7 @@ func ScanRowsTest(
tt.AssertEqual(t, rows.Next(), true) tt.AssertEqual(t, rows.Next(), true)
var u user var u user
err = scanRows(dialect, rows, &u) err = scanRows(ctx, dialect, rows, &u)
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Name, "User2") tt.AssertEqual(t, u.Name, "User2")
@ -2580,7 +2580,7 @@ func ScanRowsTest(
// Omitted for testing purposes: // Omitted for testing purposes:
// Name string `ksql:"name"` // Name string `ksql:"name"`
} }
err = scanRows(dialect, rows, &u) err = scanRows(ctx, dialect, rows, &u)
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Age, 22) tt.AssertEqual(t, u.Age, 22)
@ -2603,7 +2603,7 @@ func ScanRowsTest(
var u user var u user
err = rows.Close() err = rows.Close()
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
err = scanRows(dialect, rows, &u) err = scanRows(ctx, dialect, rows, &u)
tt.AssertNotEqual(t, err, nil) tt.AssertNotEqual(t, err, nil)
}) })
@ -2623,7 +2623,7 @@ func ScanRowsTest(
defer rows.Close() defer rows.Close()
var u user var u user
err = scanRows(dialect, rows, u) err = scanRows(ctx, dialect, rows, u)
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user") tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user")
}) })
@ -2643,7 +2643,7 @@ func ScanRowsTest(
defer rows.Close() defer rows.Close()
var u map[string]interface{} var u map[string]interface{}
err = scanRows(dialect, rows, &u) err = scanRows(ctx, dialect, rows, &u)
tt.AssertErrContains(t, err, "KSQL", "expected", "pointer to struct", "map[string]interface") tt.AssertErrContains(t, err, "KSQL", "expected", "pointer to struct", "map[string]interface")
}) })
}) })
@ -2799,9 +2799,16 @@ func getUserByID(db DBAdapter, dialect Dialect, result *user, id uint) error {
return sql.ErrNoRows return sql.ErrNoRows
} }
value := jsonSerializable{ value := attrSerializer{
DriverName: dialect.DriverName(), ctx: context.TODO(),
Attr: &result.Address, attr: &result.Address,
serializerName: "json",
opInfo: OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange serializers
Method: "Query",
},
} }
err = rows.Scan(&result.ID, &result.Name, &result.Age, &value) err = rows.Scan(&result.ID, &result.Name, &result.Age, &value)