mirror of https://github.com/VinGarcia/ksql.git
Add more tests for inserting in tables with composite keys
parent
249d8db409
commit
49f872fb84
39
ksql.go
39
ksql.go
|
@ -182,14 +182,14 @@ func (c DB) Query(
|
|||
}
|
||||
}
|
||||
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the original slice passed by reference:
|
||||
slicePtr.Elem().Set(slice)
|
||||
|
||||
|
@ -407,7 +407,7 @@ func (c DB) Insert(
|
|||
return err
|
||||
}
|
||||
|
||||
query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, t, v, info, record, table.idColumns...)
|
||||
query, params, scanValues, err := buildInsertQuery(c.dialect, table, t, v, info, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -667,19 +667,26 @@ func (c DB) Update(
|
|||
|
||||
func buildInsertQuery(
|
||||
dialect Dialect,
|
||||
tableName string,
|
||||
table Table,
|
||||
t reflect.Type,
|
||||
v reflect.Value,
|
||||
info structs.StructInfo,
|
||||
record interface{},
|
||||
idNames ...string,
|
||||
) (query string, params []interface{}, scanValues []interface{}, err error) {
|
||||
recordMap, err := kstructs.StructToMap(record)
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
for _, fieldName := range idNames {
|
||||
if table.name == "" {
|
||||
return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: table name cannot be an empty string")
|
||||
}
|
||||
|
||||
for _, fieldName := range table.idColumns {
|
||||
if fieldName == "" {
|
||||
return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: ID columns cannot be empty strings")
|
||||
}
|
||||
|
||||
field, found := recordMap[fieldName]
|
||||
if !found {
|
||||
continue
|
||||
|
@ -721,12 +728,12 @@ func buildInsertQuery(
|
|||
switch dialect.InsertMethod() {
|
||||
case insertWithReturning:
|
||||
escapedIDNames := []string{}
|
||||
for _, id := range idNames {
|
||||
for _, id := range table.idColumns {
|
||||
escapedIDNames = append(escapedIDNames, dialect.Escape(id))
|
||||
}
|
||||
returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ")
|
||||
|
||||
for _, id := range idNames {
|
||||
for _, id := range table.idColumns {
|
||||
scanValues = append(
|
||||
scanValues,
|
||||
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
||||
|
@ -734,12 +741,12 @@ func buildInsertQuery(
|
|||
}
|
||||
case insertWithOutput:
|
||||
escapedIDNames := []string{}
|
||||
for _, id := range idNames {
|
||||
for _, id := range table.idColumns {
|
||||
escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id))
|
||||
}
|
||||
outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ")
|
||||
|
||||
for _, id := range idNames {
|
||||
for _, id := range table.idColumns {
|
||||
scanValues = append(
|
||||
scanValues,
|
||||
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
||||
|
@ -751,7 +758,7 @@ func buildInsertQuery(
|
|||
// on the selected driver, thus, they might be empty strings.
|
||||
query = fmt.Sprintf(
|
||||
"INSERT INTO %s (%s)%s VALUES (%s)%s",
|
||||
dialect.Escape(tableName),
|
||||
dialect.Escape(table.name),
|
||||
strings.Join(escapedColumnNames, ", "),
|
||||
outputQuery,
|
||||
strings.Join(valuesQuery, ", "),
|
||||
|
@ -850,10 +857,10 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
|
|||
}
|
||||
}()
|
||||
|
||||
ormCopy := c
|
||||
ormCopy.db = tx
|
||||
dbCopy := c
|
||||
dbCopy.db = tx
|
||||
|
||||
err = fn(ormCopy)
|
||||
err = fn(dbCopy)
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback(ctx)
|
||||
if rollbackErr != nil {
|
||||
|
|
73
ksql_test.go
73
ksql_test.go
|
@ -47,9 +47,10 @@ type Post struct {
|
|||
Title string `ksql:"title"`
|
||||
}
|
||||
|
||||
var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id")
|
||||
var UserPermissionsTable = NewTable("user_permissions", "id", "user_id", "post_id")
|
||||
|
||||
type UserPermissions struct {
|
||||
ID int `ksql:"id"`
|
||||
UserID int `ksql:"user_id"`
|
||||
PostID int `ksql:"post_id"`
|
||||
}
|
||||
|
@ -794,9 +795,43 @@ func TestInsert(t *testing.T) {
|
|||
|
||||
userPerms, err := getUserPermissionsByUser(db, config.driver, 1)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, userPerms, []UserPermissions{
|
||||
{UserID: 1, PostID: 42},
|
||||
})
|
||||
tt.AssertEqual(t, len(userPerms), 1)
|
||||
tt.AssertEqual(t, userPerms[0].UserID, 1)
|
||||
tt.AssertEqual(t, userPerms[0].PostID, 42)
|
||||
})
|
||||
|
||||
t.Run("should accept partially provided values for composite key tables", func(t *testing.T) {
|
||||
db, closer := connectDB(t, config)
|
||||
defer closer.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, config.driver)
|
||||
|
||||
permission := UserPermissions{
|
||||
UserID: 2,
|
||||
PostID: 42,
|
||||
}
|
||||
err = c.Insert(ctx, UserPermissionsTable, &permission)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
userPerms, err := getUserPermissionsByUser(db, config.driver, 2)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
// Should retrieve the generated ID from the database,
|
||||
// only if the database supports returning multiple values:
|
||||
switch c.dialect.InsertMethod() {
|
||||
case insertWithNoIDRetrieval, insertWithLastInsertID:
|
||||
tt.AssertEqual(t, permission.ID, 0)
|
||||
tt.AssertEqual(t, len(userPerms), 1)
|
||||
tt.AssertEqual(t, userPerms[0].UserID, 2)
|
||||
tt.AssertEqual(t, userPerms[0].PostID, 42)
|
||||
case insertWithReturning, insertWithOutput:
|
||||
tt.AssertNotEqual(t, permission.ID, 0)
|
||||
tt.AssertEqual(t, len(userPerms), 1)
|
||||
tt.AssertEqual(t, userPerms[0].ID, permission.ID)
|
||||
tt.AssertEqual(t, userPerms[0].UserID, 2)
|
||||
tt.AssertEqual(t, userPerms[0].PostID, 42)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -826,11 +861,11 @@ func TestInsert(t *testing.T) {
|
|||
})
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
ifUserForgetToExpandList := []interface{}{
|
||||
cantInsertSlice := []interface{}{
|
||||
&User{Name: "foo", Age: 22},
|
||||
&User{Name: "bar", Age: 32},
|
||||
}
|
||||
err = c.Insert(ctx, UsersTable, ifUserForgetToExpandList)
|
||||
err = c.Insert(ctx, UsersTable, cantInsertSlice)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// We might want to support this in the future, but not for now:
|
||||
|
@ -864,6 +899,28 @@ func TestInsert(t *testing.T) {
|
|||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
|
||||
t.Run("should report error if table contains an empty ID name", func(t *testing.T) {
|
||||
db, closer := connectDB(t, config)
|
||||
defer closer.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, config.driver)
|
||||
|
||||
err := c.Insert(ctx, NewTable("users", ""), &User{Name: "fake-name"})
|
||||
tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string")
|
||||
})
|
||||
|
||||
t.Run("should report error if ksql.Table.name is empty", func(t *testing.T) {
|
||||
db, closer := connectDB(t, config)
|
||||
defer closer.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, config.driver)
|
||||
|
||||
err := c.Insert(ctx, NewTable("", "id"), &User{Name: "fake-name"})
|
||||
tt.AssertErrContains(t, err, "ksql.Table", "table name", "empty string")
|
||||
})
|
||||
|
||||
t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) {
|
||||
db, closer := connectDB(t, config)
|
||||
defer closer.Close()
|
||||
|
@ -2246,7 +2303,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
|
|||
dialect := supportedDialects[driver]
|
||||
|
||||
rows, err := db.QueryContext(context.TODO(),
|
||||
`SELECT user_id, post_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0),
|
||||
`SELECT id, user_id, post_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0),
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -2256,7 +2313,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
|
|||
|
||||
for rows.Next() {
|
||||
var userPerm UserPermissions
|
||||
err := rows.Scan(&userPerm.UserID, &userPerm.PostID)
|
||||
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue