Add support for tables with composite keys

pull/2/head
Vinícius Garcia 2021-02-15 20:57:18 -03:00
parent e1e711dc91
commit 203b141aca
3 changed files with 180 additions and 25 deletions

View File

@ -224,7 +224,7 @@ ok github.com/vingarcia/kissorm 21.740s
### TODO List
- Allow the ID field to have a different name
- Implement support for JSON fields on the database (encoding/decoding them automatically into structs)
- Implement support for nested objects with prefixed table names
- Improve error messages
- Add tests for tables using composite keys

View File

@ -10,12 +10,20 @@ import (
"github.com/pkg/errors"
)
// DB ...
// DB represents the kissorm client responsible for
// interfacing with the "database/sql" package implementing
// the Kissorm interface `ORMProvider`.
type DB struct {
driver string
dialect dialect
tableName string
db sqlProvider
// Most dbs have a single primary key,
// But in future kissorm should work with compound keys as well
idCols []string
insertMethod insertMethod
}
type sqlProvider interface {
@ -23,12 +31,21 @@ type sqlProvider interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// New instantiates a new client
type insertMethod int
const (
insertWithReturning insertMethod = iota
insertWithLastInsertID
insertWithNoIDRetrieval
)
// New instantiates a new Kissorm client
func New(
dbDriver string,
connectionString string,
maxOpenConns int,
tableName string,
idCols ...string,
) (DB, error) {
db, err := sql.Open(dbDriver, connectionString)
if err != nil {
@ -45,11 +62,31 @@ func New(
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
}
if len(idCols) == 0 {
idCols = append(idCols, "id")
}
var insertMethod insertMethod
switch dbDriver {
case "sqlite3":
insertMethod = insertWithLastInsertID
if len(idCols) > 1 {
insertMethod = insertWithNoIDRetrieval
}
case "postgres":
insertMethod = insertWithReturning
default:
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
}
return DB{
dialect: dialect,
driver: dbDriver,
db: db,
tableName: tableName,
idCols: idCols,
insertMethod: insertMethod,
}, nil
}
@ -272,18 +309,22 @@ func (c DB) Insert(
records ...interface{},
) error {
for _, record := range records {
query, params, err := buildInsertQuery(c.dialect, c.tableName, record, "id")
query, params, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...)
if err != nil {
return err
}
switch c.driver {
case "postgres":
err = c.insertOnPostgres(ctx, record, query, params)
case "sqlite3":
err = c.insertWithLastInsertID(ctx, record, query, params)
switch c.insertMethod {
case insertWithReturning:
err = c.insertWithReturningID(ctx, record, query, params, c.idCols)
case insertWithLastInsertID:
err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0])
case insertWithNoIDRetrieval:
err = c.insertWithNoIDRetrieval(ctx, record, query, params)
default:
err = fmt.Errorf("unsupported driver `%s`", c.driver)
// Unsupported drivers should be detected on the New() function,
// So we don't expect the code to ever get into this default case.
err = fmt.Errorf("code error: unsupported driver `%s`", c.driver)
}
if err != nil {
@ -294,13 +335,18 @@ func (c DB) Insert(
return nil
}
func (c DB) insertOnPostgres(
func (c DB) insertWithReturningID(
ctx context.Context,
record interface{},
query string,
params []interface{},
idNames []string,
) error {
query = query + " RETURNING id"
escapedIDNames := []string{}
for _, id := range idNames {
escapedIDNames = append(escapedIDNames, c.dialect.Escape(id))
}
query += " RETURNING " + strings.Join(idNames, ", ")
rows, err := c.db.QueryContext(ctx, query, params...)
if err != nil {
@ -309,7 +355,7 @@ func (c DB) insertOnPostgres(
defer rows.Close()
if !rows.Next() {
err := fmt.Errorf("unexpected error retrieving the id from the database")
err := fmt.Errorf("unexpected error retrieving the id columns from the database")
if rows.Err() != nil {
err = rows.Err()
}
@ -324,8 +370,14 @@ func (c DB) insertOnPostgres(
}
info := getCachedTagInfo(tagInfoCache, t.Elem())
fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
err = rows.Scan(fieldAddr.Interface())
var scanFields []interface{}
for _, id := range idNames {
scanFields = append(
scanFields,
v.Elem().Field(info.Index[id]).Addr().Interface(),
)
}
err = rows.Scan(scanFields...)
if err != nil {
return err
}
@ -338,6 +390,7 @@ func (c DB) insertWithLastInsertID(
record interface{},
query string,
params []interface{},
idName string,
) error {
result, err := c.db.ExecContext(ctx, query, params...)
if err != nil {
@ -347,7 +400,7 @@ func (c DB) insertWithLastInsertID(
v := reflect.ValueOf(record)
t := v.Type()
if err = assertStructPtr(t); err != nil {
return errors.Wrap(err, "can't write id field")
return errors.Wrap(err, "can't write to `"+idName+"` field")
}
info := getCachedTagInfo(tagInfoCache, t.Elem())
@ -360,13 +413,13 @@ func (c DB) insertWithLastInsertID(
vID := reflect.ValueOf(id)
tID := vID.Type()
fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
fieldAddr := v.Elem().Field(info.Index[idName]).Addr()
fieldType := fieldAddr.Type().Elem()
if !tID.ConvertibleTo(fieldType) {
return fmt.Errorf(
"Can't convert last insert id of type int64 into field `%s` of type %s",
"id",
idName,
fieldType,
)
}
@ -375,6 +428,16 @@ func (c DB) insertWithLastInsertID(
return nil
}
func (c DB) insertWithNoIDRetrieval(
ctx context.Context,
record interface{},
query string,
params []interface{},
) error {
_, err := c.db.ExecContext(ctx, query, params...)
return err
}
func assertStructPtr(t reflect.Type) error {
if t.Kind() != reflect.Ptr {
return fmt.Errorf("expected a Kind of Ptr but got: %s", t)
@ -394,13 +457,63 @@ func (c DB) Delete(
return nil
}
query := buildDeleteQuery(c.dialect, c.tableName, ids)
idMaps, err := normalizeIDsAsMaps(c.idCols, ids)
if err != nil {
return err
}
_, err := c.db.ExecContext(ctx, query, ids...)
var query string
var params []interface{}
if len(c.idCols) == 1 {
query, params = buildSingleKeyDeleteQuery(c.dialect, c.tableName, c.idCols[0], idMaps)
} else {
query, params = buildCompositeKeyDeleteQuery(c.dialect, c.tableName, c.idCols, idMaps)
}
_, err = c.db.ExecContext(ctx, query, params...)
return err
}
func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]interface{}, error) {
if len(idNames) == 0 {
return nil, fmt.Errorf("internal kissorm error: missing idNames")
}
idMaps := []map[string]interface{}{}
for i := range ids {
t := reflect.TypeOf(ids[i])
switch t.Kind() {
case reflect.Struct:
m, err := StructToMap(ids[i])
if err != nil {
return nil, errors.Wrapf(err, "could not get ID(s) from record on idx %d", i)
}
idMaps = append(idMaps, m)
case reflect.Map:
m, ok := ids[i].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("expected map[string]interface{} but got %T", ids[i])
}
idMaps = append(idMaps, m)
default:
idMaps = append(idMaps, map[string]interface{}{
idNames[0]: ids[i],
})
}
}
for i, m := range idMaps {
for _, id := range idNames {
if _, found := m[id]; !found {
return nil, fmt.Errorf("missing required id field `%s` on record with idx %d", id, i)
}
}
}
return idMaps, nil
}
// Update updates the given instances on the database by id.
//
// Partial updates are supported, i.e. it will ignore nil pointer attributes
@ -409,7 +522,7 @@ func (c DB) Update(
records ...interface{},
) error {
for _, record := range records {
query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, "id")
query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, c.idCols...)
if err != nil {
return err
}
@ -656,15 +769,51 @@ func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type
return info
}
func buildDeleteQuery(dialect dialect, table string, ids []interface{}) string {
func buildSingleKeyDeleteQuery(
dialect dialect,
table string,
idName string,
idMaps []map[string]interface{},
) (query string, params []interface{}) {
values := []string{}
for i := range ids {
for i, m := range idMaps {
values = append(values, dialect.Placeholder(i))
params = append(params, m[idName])
}
return fmt.Sprintf(
"DELETE FROM %s WHERE id IN (%s)",
"DELETE FROM %s WHERE %s IN (%s)",
dialect.Escape(table),
dialect.Escape(idName),
strings.Join(values, ","),
)
), params
}
func buildCompositeKeyDeleteQuery(
dialect dialect,
table string,
idNames []string,
idMaps []map[string]interface{},
) (query string, params []interface{}) {
escapedNames := []string{}
for _, name := range idNames {
escapedNames = append(escapedNames, dialect.Escape(name))
}
values := []string{}
for _, m := range idMaps {
tuple := []string{}
for _, name := range idNames {
params = append(params, m[name])
tuple = append(tuple, dialect.Placeholder(len(values)))
}
values = append(values, "("+strings.Join(tuple, ",")+")")
}
return fmt.Sprintf(
"DELETE FROM %s WHERE (%s) IN (VALUES %s)",
dialect.Escape(table),
strings.Join(escapedNames, ","),
strings.Join(values, ","),
), params
}

View File

@ -1029,6 +1029,12 @@ func newTestDB(db *sql.DB, driver string, tableName string) DB {
dialect: getDriverDialect(driver),
db: db,
tableName: tableName,
idCols: []string{"id"},
insertMethod: map[string]insertMethod{
"sqlite3": insertWithLastInsertID,
"postgres": insertWithReturning,
}[driver],
}
}