mirror of https://github.com/VinGarcia/ksql.git
Add support for tables with composite keys
parent
e1e711dc91
commit
203b141aca
|
@ -224,7 +224,7 @@ ok github.com/vingarcia/kissorm 21.740s
|
|||
|
||||
### TODO List
|
||||
|
||||
- Allow the ID field to have a different name
|
||||
- Implement support for JSON fields on the database (encoding/decoding them automatically into structs)
|
||||
- Implement support for nested objects with prefixed table names
|
||||
- Improve error messages
|
||||
- Add tests for tables using composite keys
|
||||
|
|
197
kiss_orm.go
197
kiss_orm.go
|
@ -10,12 +10,20 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// DB ...
|
||||
// DB represents the kissorm client responsible for
|
||||
// interfacing with the "database/sql" package implementing
|
||||
// the Kissorm interface `ORMProvider`.
|
||||
type DB struct {
|
||||
driver string
|
||||
dialect dialect
|
||||
tableName string
|
||||
db sqlProvider
|
||||
|
||||
// Most dbs have a single primary key,
|
||||
// But in future kissorm should work with compound keys as well
|
||||
idCols []string
|
||||
|
||||
insertMethod insertMethod
|
||||
}
|
||||
|
||||
type sqlProvider interface {
|
||||
|
@ -23,12 +31,21 @@ type sqlProvider interface {
|
|||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// New instantiates a new client
|
||||
type insertMethod int
|
||||
|
||||
const (
|
||||
insertWithReturning insertMethod = iota
|
||||
insertWithLastInsertID
|
||||
insertWithNoIDRetrieval
|
||||
)
|
||||
|
||||
// New instantiates a new Kissorm client
|
||||
func New(
|
||||
dbDriver string,
|
||||
connectionString string,
|
||||
maxOpenConns int,
|
||||
tableName string,
|
||||
idCols ...string,
|
||||
) (DB, error) {
|
||||
db, err := sql.Open(dbDriver, connectionString)
|
||||
if err != nil {
|
||||
|
@ -45,11 +62,31 @@ func New(
|
|||
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
|
||||
}
|
||||
|
||||
if len(idCols) == 0 {
|
||||
idCols = append(idCols, "id")
|
||||
}
|
||||
|
||||
var insertMethod insertMethod
|
||||
switch dbDriver {
|
||||
case "sqlite3":
|
||||
insertMethod = insertWithLastInsertID
|
||||
if len(idCols) > 1 {
|
||||
insertMethod = insertWithNoIDRetrieval
|
||||
}
|
||||
case "postgres":
|
||||
insertMethod = insertWithReturning
|
||||
default:
|
||||
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
|
||||
}
|
||||
|
||||
return DB{
|
||||
dialect: dialect,
|
||||
driver: dbDriver,
|
||||
db: db,
|
||||
tableName: tableName,
|
||||
|
||||
idCols: idCols,
|
||||
insertMethod: insertMethod,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -272,18 +309,22 @@ func (c DB) Insert(
|
|||
records ...interface{},
|
||||
) error {
|
||||
for _, record := range records {
|
||||
query, params, err := buildInsertQuery(c.dialect, c.tableName, record, "id")
|
||||
query, params, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch c.driver {
|
||||
case "postgres":
|
||||
err = c.insertOnPostgres(ctx, record, query, params)
|
||||
case "sqlite3":
|
||||
err = c.insertWithLastInsertID(ctx, record, query, params)
|
||||
switch c.insertMethod {
|
||||
case insertWithReturning:
|
||||
err = c.insertWithReturningID(ctx, record, query, params, c.idCols)
|
||||
case insertWithLastInsertID:
|
||||
err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0])
|
||||
case insertWithNoIDRetrieval:
|
||||
err = c.insertWithNoIDRetrieval(ctx, record, query, params)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported driver `%s`", c.driver)
|
||||
// Unsupported drivers should be detected on the New() function,
|
||||
// So we don't expect the code to ever get into this default case.
|
||||
err = fmt.Errorf("code error: unsupported driver `%s`", c.driver)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -294,13 +335,18 @@ func (c DB) Insert(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c DB) insertOnPostgres(
|
||||
func (c DB) insertWithReturningID(
|
||||
ctx context.Context,
|
||||
record interface{},
|
||||
query string,
|
||||
params []interface{},
|
||||
idNames []string,
|
||||
) error {
|
||||
query = query + " RETURNING id"
|
||||
escapedIDNames := []string{}
|
||||
for _, id := range idNames {
|
||||
escapedIDNames = append(escapedIDNames, c.dialect.Escape(id))
|
||||
}
|
||||
query += " RETURNING " + strings.Join(idNames, ", ")
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
|
@ -309,7 +355,7 @@ func (c DB) insertOnPostgres(
|
|||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
err := fmt.Errorf("unexpected error retrieving the id from the database")
|
||||
err := fmt.Errorf("unexpected error retrieving the id columns from the database")
|
||||
if rows.Err() != nil {
|
||||
err = rows.Err()
|
||||
}
|
||||
|
@ -324,8 +370,14 @@ func (c DB) insertOnPostgres(
|
|||
}
|
||||
info := getCachedTagInfo(tagInfoCache, t.Elem())
|
||||
|
||||
fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
|
||||
err = rows.Scan(fieldAddr.Interface())
|
||||
var scanFields []interface{}
|
||||
for _, id := range idNames {
|
||||
scanFields = append(
|
||||
scanFields,
|
||||
v.Elem().Field(info.Index[id]).Addr().Interface(),
|
||||
)
|
||||
}
|
||||
err = rows.Scan(scanFields...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -338,6 +390,7 @@ func (c DB) insertWithLastInsertID(
|
|||
record interface{},
|
||||
query string,
|
||||
params []interface{},
|
||||
idName string,
|
||||
) error {
|
||||
result, err := c.db.ExecContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
|
@ -347,7 +400,7 @@ func (c DB) insertWithLastInsertID(
|
|||
v := reflect.ValueOf(record)
|
||||
t := v.Type()
|
||||
if err = assertStructPtr(t); err != nil {
|
||||
return errors.Wrap(err, "can't write id field")
|
||||
return errors.Wrap(err, "can't write to `"+idName+"` field")
|
||||
}
|
||||
|
||||
info := getCachedTagInfo(tagInfoCache, t.Elem())
|
||||
|
@ -360,13 +413,13 @@ func (c DB) insertWithLastInsertID(
|
|||
vID := reflect.ValueOf(id)
|
||||
tID := vID.Type()
|
||||
|
||||
fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
|
||||
fieldAddr := v.Elem().Field(info.Index[idName]).Addr()
|
||||
fieldType := fieldAddr.Type().Elem()
|
||||
|
||||
if !tID.ConvertibleTo(fieldType) {
|
||||
return fmt.Errorf(
|
||||
"Can't convert last insert id of type int64 into field `%s` of type %s",
|
||||
"id",
|
||||
idName,
|
||||
fieldType,
|
||||
)
|
||||
}
|
||||
|
@ -375,6 +428,16 @@ func (c DB) insertWithLastInsertID(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c DB) insertWithNoIDRetrieval(
|
||||
ctx context.Context,
|
||||
record interface{},
|
||||
query string,
|
||||
params []interface{},
|
||||
) error {
|
||||
_, err := c.db.ExecContext(ctx, query, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func assertStructPtr(t reflect.Type) error {
|
||||
if t.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("expected a Kind of Ptr but got: %s", t)
|
||||
|
@ -394,13 +457,63 @@ func (c DB) Delete(
|
|||
return nil
|
||||
}
|
||||
|
||||
query := buildDeleteQuery(c.dialect, c.tableName, ids)
|
||||
idMaps, err := normalizeIDsAsMaps(c.idCols, ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := c.db.ExecContext(ctx, query, ids...)
|
||||
var query string
|
||||
var params []interface{}
|
||||
if len(c.idCols) == 1 {
|
||||
query, params = buildSingleKeyDeleteQuery(c.dialect, c.tableName, c.idCols[0], idMaps)
|
||||
} else {
|
||||
query, params = buildCompositeKeyDeleteQuery(c.dialect, c.tableName, c.idCols, idMaps)
|
||||
}
|
||||
|
||||
_, err = c.db.ExecContext(ctx, query, params...)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]interface{}, error) {
|
||||
if len(idNames) == 0 {
|
||||
return nil, fmt.Errorf("internal kissorm error: missing idNames")
|
||||
}
|
||||
|
||||
idMaps := []map[string]interface{}{}
|
||||
for i := range ids {
|
||||
t := reflect.TypeOf(ids[i])
|
||||
switch t.Kind() {
|
||||
case reflect.Struct:
|
||||
m, err := StructToMap(ids[i])
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "could not get ID(s) from record on idx %d", i)
|
||||
}
|
||||
idMaps = append(idMaps, m)
|
||||
case reflect.Map:
|
||||
m, ok := ids[i].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected map[string]interface{} but got %T", ids[i])
|
||||
}
|
||||
idMaps = append(idMaps, m)
|
||||
default:
|
||||
idMaps = append(idMaps, map[string]interface{}{
|
||||
idNames[0]: ids[i],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for i, m := range idMaps {
|
||||
for _, id := range idNames {
|
||||
if _, found := m[id]; !found {
|
||||
return nil, fmt.Errorf("missing required id field `%s` on record with idx %d", id, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return idMaps, nil
|
||||
}
|
||||
|
||||
// Update updates the given instances on the database by id.
|
||||
//
|
||||
// Partial updates are supported, i.e. it will ignore nil pointer attributes
|
||||
|
@ -409,7 +522,7 @@ func (c DB) Update(
|
|||
records ...interface{},
|
||||
) error {
|
||||
for _, record := range records {
|
||||
query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, "id")
|
||||
query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, c.idCols...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -656,15 +769,51 @@ func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type
|
|||
return info
|
||||
}
|
||||
|
||||
func buildDeleteQuery(dialect dialect, table string, ids []interface{}) string {
|
||||
func buildSingleKeyDeleteQuery(
|
||||
dialect dialect,
|
||||
table string,
|
||||
idName string,
|
||||
idMaps []map[string]interface{},
|
||||
) (query string, params []interface{}) {
|
||||
values := []string{}
|
||||
for i := range ids {
|
||||
for i, m := range idMaps {
|
||||
values = append(values, dialect.Placeholder(i))
|
||||
params = append(params, m[idName])
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id IN (%s)",
|
||||
"DELETE FROM %s WHERE %s IN (%s)",
|
||||
dialect.Escape(table),
|
||||
dialect.Escape(idName),
|
||||
strings.Join(values, ","),
|
||||
)
|
||||
), params
|
||||
}
|
||||
|
||||
func buildCompositeKeyDeleteQuery(
|
||||
dialect dialect,
|
||||
table string,
|
||||
idNames []string,
|
||||
idMaps []map[string]interface{},
|
||||
) (query string, params []interface{}) {
|
||||
escapedNames := []string{}
|
||||
for _, name := range idNames {
|
||||
escapedNames = append(escapedNames, dialect.Escape(name))
|
||||
}
|
||||
|
||||
values := []string{}
|
||||
for _, m := range idMaps {
|
||||
tuple := []string{}
|
||||
for _, name := range idNames {
|
||||
params = append(params, m[name])
|
||||
tuple = append(tuple, dialect.Placeholder(len(values)))
|
||||
}
|
||||
values = append(values, "("+strings.Join(tuple, ",")+")")
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE (%s) IN (VALUES %s)",
|
||||
dialect.Escape(table),
|
||||
strings.Join(escapedNames, ","),
|
||||
strings.Join(values, ","),
|
||||
), params
|
||||
}
|
||||
|
|
|
@ -1029,6 +1029,12 @@ func newTestDB(db *sql.DB, driver string, tableName string) DB {
|
|||
dialect: getDriverDialect(driver),
|
||||
db: db,
|
||||
tableName: tableName,
|
||||
|
||||
idCols: []string{"id"},
|
||||
insertMethod: map[string]insertMethod{
|
||||
"sqlite3": insertWithLastInsertID,
|
||||
"postgres": insertWithReturning,
|
||||
}[driver],
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue