diff --git a/attr_serializers.go b/attr_serializers.go new file mode 100644 index 0000000..613fa0e --- /dev/null +++ b/attr_serializers.go @@ -0,0 +1,66 @@ +package ksql + +import ( + "context" + "database/sql/driver" + "fmt" +) + +// Here we keep all the registered serializers +var serializers = map[string]AttrSerializer{ + "json": jsonSerializer{}, +} + +// RegisterAttrSerializer allow users to add custom serializers on startup +// it is recommended to do this inside an init() function. +func RegisterAttrSerializer(key string, serializer AttrSerializer) { + _, found := serializers[key] + if found { + panic(fmt.Errorf("KSQL: cannot register serializer '%s' name is already in use", key)) + } + + serializers[key] = serializer +} + +// AttrSerializer describes the two operations required to serialize and deserialize an object from the database. +type AttrSerializer interface { + AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error + AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) +} + +// OpInfo contains information that might be used by a serializer to determine how it should behave. +type OpInfo struct { + // A string version of the name of one of + // the methods of the `ksql.Provider` interface, e.g. `Insert` or `Query` + Method string + + // The string representing the current underlying database, e.g.: + // "postgres", "sqlite3", "mysql" or "sqlserver". + DriverName string +} + +// attrSerializer is the wrapper that allow us to intercept the Scan and Value processes +// so we can run the serializers instead of allowing the database driver to use +// its default behavior. +// +// For that this struct implements both the `sql.Scanner` and `sql.Valuer` interfaces. +type attrSerializer struct { + ctx context.Context + + // When Scanning this value should be a pointer to the attribute + // and when "Valuing" it should just be the actual value + attr interface{} + + serializerName string + opInfo OpInfo +} + +// Scan implements the sql.Scanner interface +func (a attrSerializer) Scan(dbValue interface{}) error { + return serializers[a.serializerName].AttrScan(a.ctx, a.opInfo, a.attr, dbValue) +} + +// Value implements the sql.Valuer interface +func (a attrSerializer) Value() (driver.Value, error) { + return serializers[a.serializerName].AttrValue(a.ctx, a.opInfo, a.attr) +} diff --git a/internal/structs/structs.go b/internal/structs/structs.go index 9419ea1..9c70d17 100644 --- a/internal/structs/structs.go +++ b/internal/structs/structs.go @@ -20,10 +20,10 @@ type StructInfo struct { // information regarding a specific field // of a struct. type FieldInfo struct { - Name string - Index int - Valid bool - SerializeAsJSON bool + Name string + Index int + Valid bool + SerializerName string } // ByIndex returns either the *FieldInfo of a valid @@ -249,10 +249,10 @@ func getTagNames(t reflect.Type) (StructInfo, error) { } tags := strings.Split(name, ",") - serializeAsJSON := false + var serializerName string if len(tags) > 1 { name = tags[0] - serializeAsJSON = tags[1] == "json" + serializerName = tags[1] } if _, found := info.byName[name]; found { @@ -263,9 +263,9 @@ func getTagNames(t reflect.Type) (StructInfo, error) { } info.add(FieldInfo{ - Name: name, - Index: i, - SerializeAsJSON: serializeAsJSON, + Name: name, + Index: i, + SerializerName: serializerName, }) } diff --git a/json.go b/json.go index 03fdda1..d418db4 100644 --- a/json.go +++ b/json.go @@ -1,7 +1,7 @@ package ksql import ( - "database/sql/driver" + "context" "encoding/json" "fmt" "reflect" @@ -10,39 +10,36 @@ import ( // This type was created to make it easier to adapt // input attributes to be convertible to and from JSON // before sending or receiving it from the database. -type jsonSerializable struct { - DriverName string - Attr interface{} -} +type jsonSerializer struct{} // Scan Implements the Scanner interface in order to load // this field from the JSON stored in the database -func (j *jsonSerializable) Scan(value interface{}) error { - if value == nil { - v := reflect.ValueOf(j.Attr) +func (j jsonSerializer) AttrScan(ctx context.Context, opInfo OpInfo, attrPtr interface{}, dbValue interface{}) error { + if dbValue == nil { + v := reflect.ValueOf(attrPtr) // Set the struct to its 0 value just like json.Unmarshal // does for nil attributes: - v.Elem().Set(reflect.Zero(reflect.TypeOf(j.Attr).Elem())) + v.Elem().Set(reflect.Zero(reflect.TypeOf(attrPtr).Elem())) return nil } // Required since sqlite3 returns strings not bytes - if v, ok := value.(string); ok { - value = []byte(v) + if v, ok := dbValue.(string); ok { + dbValue = []byte(v) } - rawJSON, ok := value.([]byte) + rawJSON, ok := dbValue.([]byte) if !ok { - return fmt.Errorf("unexpected type received to Scan: %T", value) + return fmt.Errorf("unexpected type received to Scan: %T", dbValue) } - return json.Unmarshal(rawJSON, j.Attr) + return json.Unmarshal(rawJSON, attrPtr) } // Value Implements the Valuer interface in order to save // this field as JSON on the database. -func (j jsonSerializable) Value() (driver.Value, error) { - b, err := json.Marshal(j.Attr) - if j.DriverName == "sqlserver" { +func (j jsonSerializer) AttrValue(ctx context.Context, opInfo OpInfo, inputValue interface{}) (outputValue interface{}, _ error) { + b, err := json.Marshal(inputValue) + if opInfo.DriverName == "sqlserver" { return string(b), err } return b, err diff --git a/ksql.go b/ksql.go index 41d4869..dbb6903 100644 --- a/ksql.go +++ b/ksql.go @@ -185,7 +185,7 @@ func (c DB) Query( elemPtr = elemPtr.Elem() } - err = scanRows(c.dialect, rows, elemPtr.Interface()) + err = scanRows(ctx, c.dialect, rows, elemPtr.Interface()) if err != nil { return err } @@ -264,7 +264,7 @@ func (c DB) QueryOne( return ErrRecordNotFound } - err = scanRowsFromType(c.dialect, rows, record, t, v) + err = scanRowsFromType(ctx, c.dialect, rows, record, t, v) if err != nil { return err } @@ -343,7 +343,7 @@ func (c DB) QueryChunks( chunk = reflect.Append(chunk, elemValue) } - err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface()) + err = scanRows(ctx, c.dialect, rows, chunk.Index(idx).Addr().Interface()) if err != nil { return err } @@ -420,7 +420,7 @@ func (c DB) Insert( return err } - query, params, scanValues, err := buildInsertQuery(c.dialect, table, t, v, info, record) + query, params, scanValues, err := buildInsertQuery(ctx, c.dialect, table, t, v, info, record) if err != nil { return err } @@ -657,7 +657,7 @@ func (c DB) Patch( return err } - query, params, err := buildUpdateQuery(c.dialect, table.name, info, record, table.idColumns...) + query, params, err := buildUpdateQuery(ctx, c.dialect, table.name, info, record, table.idColumns...) if err != nil { return err } @@ -682,6 +682,7 @@ func (c DB) Patch( } func buildInsertQuery( + ctx context.Context, dialect Dialect, table Table, t reflect.Type, @@ -716,10 +717,17 @@ func buildInsertQuery( for i, col := range columnNames { recordValue := recordMap[col] params[i] = recordValue - if info.ByName(col).SerializeAsJSON { - params[i] = jsonSerializable{ - DriverName: dialect.DriverName(), - Attr: recordValue, + + serializerName := info.ByName(col).SerializerName + if serializerName != "" { + params[i] = attrSerializer{ + ctx: ctx, + attr: recordValue, + serializerName: serializerName, + opInfo: OpInfo{ + DriverName: dialect.DriverName(), + Method: "Insert", + }, } } @@ -777,13 +785,14 @@ func buildInsertQuery( } func buildUpdateQuery( + ctx context.Context, dialect Dialect, tableName string, info structs.StructInfo, record interface{}, idFieldNames ...string, ) (query string, args []interface{}, err error) { - recordMap, err := ksqltest.StructToMap(record) + recordMap, err := structs.StructToMap(record) if err != nil { return "", nil, err } @@ -817,10 +826,17 @@ func buildUpdateQuery( var setQuery []string for i, k := range keys { recordValue := recordMap[k] - if info.ByName(k).SerializeAsJSON { - recordValue = jsonSerializable{ - DriverName: dialect.DriverName(), - Attr: recordValue, + + serializerName := info.ByName(k).SerializerName + if serializerName != "" { + recordValue = attrSerializer{ + ctx: ctx, + attr: recordValue, + serializerName: serializerName, + opInfo: OpInfo{ + DriverName: dialect.DriverName(), + Method: "Update", + }, } } args[i] = recordValue @@ -929,13 +945,14 @@ func (nopScanner) Scan(value interface{}) error { return nil } -func scanRows(dialect Dialect, rows Rows, record interface{}) error { +func scanRows(ctx context.Context, dialect Dialect, rows Rows, record interface{}) error { v := reflect.ValueOf(record) t := v.Type() - return scanRowsFromType(dialect, rows, record, t, v) + return scanRowsFromType(ctx, dialect, rows, record, t, v) } func scanRowsFromType( + ctx context.Context, dialect Dialect, rows Rows, record interface{}, @@ -963,7 +980,7 @@ func scanRowsFromType( // This version is positional meaning that it expect the arguments // to follow an specific order. It's ok because we don't allow the // user to type the "SELECT" part of the query for nested structs. - scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info) + scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info) if err != nil { return err } @@ -974,7 +991,7 @@ func scanRowsFromType( } // Since this version uses the names of the columns it works // with any order of attributes/columns. - scanArgs = getScanArgsFromNames(dialect, names, v, info) + scanArgs = getScanArgsFromNames(ctx, dialect, names, v, info) } err = rows.Scan(scanArgs...) @@ -984,7 +1001,14 @@ func scanRowsFromType( return nil } -func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) { +func getScanArgsForNestedStructs( + ctx context.Context, + dialect Dialect, + rows Rows, + t reflect.Type, + v reflect.Value, + info structs.StructInfo, +) ([]interface{}, error) { scanArgs := []interface{}{} for i := 0; i < v.NumField(); i++ { if !info.ByIndex(i).Valid { @@ -1007,10 +1031,18 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r valueScanner := nopScannerValue if fieldInfo.Valid { valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface() - if fieldInfo.SerializeAsJSON { - valueScanner = &jsonSerializable{ - DriverName: dialect.DriverName(), - Attr: valueScanner, + + if fieldInfo.SerializerName != "" { + valueScanner = &attrSerializer{ + ctx: ctx, + attr: valueScanner, + serializerName: fieldInfo.SerializerName, + opInfo: OpInfo{ + DriverName: dialect.DriverName(), + // We will not differentiate between Query, QueryOne and QueryChunks + // if we did this could lead users to make very strange serializers + Method: "Query", + }, } } } @@ -1022,7 +1054,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r return scanArgs, nil } -func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} { +func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} { scanArgs := []interface{}{} for _, name := range names { fieldInfo := info.ByName(name) @@ -1030,10 +1062,17 @@ func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info valueScanner := nopScannerValue if fieldInfo.Valid { valueScanner = v.Field(fieldInfo.Index).Addr().Interface() - if fieldInfo.SerializeAsJSON { - valueScanner = &jsonSerializable{ - DriverName: dialect.DriverName(), - Attr: valueScanner, + if fieldInfo.SerializerName != "" { + valueScanner = &attrSerializer{ + ctx: ctx, + attr: valueScanner, + serializerName: fieldInfo.SerializerName, + opInfo: OpInfo{ + DriverName: dialect.DriverName(), + // We will not differentiate between Query, QueryOne and QueryChunks + // if we did this could lead users to make very strange serializers + Method: "Query", + }, } } } diff --git a/test_adapters.go b/test_adapters.go index 5fdfd5e..e641b61 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -2547,7 +2547,7 @@ func ScanRowsTest( tt.AssertEqual(t, rows.Next(), true) var u user - err = scanRows(dialect, rows, &u) + err = scanRows(ctx, dialect, rows, &u) tt.AssertNoErr(t, err) tt.AssertEqual(t, u.Name, "User2") @@ -2580,7 +2580,7 @@ func ScanRowsTest( // Omitted for testing purposes: // Name string `ksql:"name"` } - err = scanRows(dialect, rows, &u) + err = scanRows(ctx, dialect, rows, &u) tt.AssertNoErr(t, err) tt.AssertEqual(t, u.Age, 22) @@ -2603,7 +2603,7 @@ func ScanRowsTest( var u user err = rows.Close() tt.AssertNoErr(t, err) - err = scanRows(dialect, rows, &u) + err = scanRows(ctx, dialect, rows, &u) tt.AssertNotEqual(t, err, nil) }) @@ -2623,7 +2623,7 @@ func ScanRowsTest( defer rows.Close() var u user - err = scanRows(dialect, rows, u) + err = scanRows(ctx, dialect, rows, u) tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user") }) @@ -2643,7 +2643,7 @@ func ScanRowsTest( defer rows.Close() var u map[string]interface{} - err = scanRows(dialect, rows, &u) + err = scanRows(ctx, dialect, rows, &u) tt.AssertErrContains(t, err, "KSQL", "expected", "pointer to struct", "map[string]interface") }) }) @@ -2799,9 +2799,16 @@ func getUserByID(db DBAdapter, dialect Dialect, result *user, id uint) error { return sql.ErrNoRows } - value := jsonSerializable{ - DriverName: dialect.DriverName(), - Attr: &result.Address, + value := attrSerializer{ + ctx: context.TODO(), + attr: &result.Address, + serializerName: "json", + opInfo: OpInfo{ + DriverName: dialect.DriverName(), + // We will not differentiate between Query, QueryOne and QueryChunks + // if we did this could lead users to make very strange serializers + Method: "Query", + }, } err = rows.Scan(&result.ID, &result.Name, &result.Age, &value)