Add more tests for inserting in tables with composite keys

pull/16/head
Vinícius Garcia 2022-02-11 17:23:47 -03:00
parent 249d8db409
commit 49f872fb84
2 changed files with 88 additions and 24 deletions

39
ksql.go
View File

@ -182,14 +182,14 @@ func (c DB) Query(
}
}
if err := rows.Close(); err != nil {
return err
}
if rows.Err() != nil {
return rows.Err()
}
if err := rows.Close(); err != nil {
return err
}
// Update the original slice passed by reference:
slicePtr.Elem().Set(slice)
@ -407,7 +407,7 @@ func (c DB) Insert(
return err
}
query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, t, v, info, record, table.idColumns...)
query, params, scanValues, err := buildInsertQuery(c.dialect, table, t, v, info, record)
if err != nil {
return err
}
@ -667,19 +667,26 @@ func (c DB) Update(
func buildInsertQuery(
dialect Dialect,
tableName string,
table Table,
t reflect.Type,
v reflect.Value,
info structs.StructInfo,
record interface{},
idNames ...string,
) (query string, params []interface{}, scanValues []interface{}, err error) {
recordMap, err := kstructs.StructToMap(record)
if err != nil {
return "", nil, nil, err
}
for _, fieldName := range idNames {
if table.name == "" {
return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: table name cannot be an empty string")
}
for _, fieldName := range table.idColumns {
if fieldName == "" {
return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: ID columns cannot be empty strings")
}
field, found := recordMap[fieldName]
if !found {
continue
@ -721,12 +728,12 @@ func buildInsertQuery(
switch dialect.InsertMethod() {
case insertWithReturning:
escapedIDNames := []string{}
for _, id := range idNames {
for _, id := range table.idColumns {
escapedIDNames = append(escapedIDNames, dialect.Escape(id))
}
returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ")
for _, id := range idNames {
for _, id := range table.idColumns {
scanValues = append(
scanValues,
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
@ -734,12 +741,12 @@ func buildInsertQuery(
}
case insertWithOutput:
escapedIDNames := []string{}
for _, id := range idNames {
for _, id := range table.idColumns {
escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id))
}
outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ")
for _, id := range idNames {
for _, id := range table.idColumns {
scanValues = append(
scanValues,
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
@ -751,7 +758,7 @@ func buildInsertQuery(
// on the selected driver, thus, they might be empty strings.
query = fmt.Sprintf(
"INSERT INTO %s (%s)%s VALUES (%s)%s",
dialect.Escape(tableName),
dialect.Escape(table.name),
strings.Join(escapedColumnNames, ", "),
outputQuery,
strings.Join(valuesQuery, ", "),
@ -850,10 +857,10 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
}
}()
ormCopy := c
ormCopy.db = tx
dbCopy := c
dbCopy.db = tx
err = fn(ormCopy)
err = fn(dbCopy)
if err != nil {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil {

View File

@ -47,9 +47,10 @@ type Post struct {
Title string `ksql:"title"`
}
var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id")
var UserPermissionsTable = NewTable("user_permissions", "id", "user_id", "post_id")
type UserPermissions struct {
ID int `ksql:"id"`
UserID int `ksql:"user_id"`
PostID int `ksql:"post_id"`
}
@ -794,9 +795,43 @@ func TestInsert(t *testing.T) {
userPerms, err := getUserPermissionsByUser(db, config.driver, 1)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, userPerms, []UserPermissions{
{UserID: 1, PostID: 42},
})
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].UserID, 1)
tt.AssertEqual(t, userPerms[0].PostID, 42)
})
t.Run("should accept partially provided values for composite key tables", func(t *testing.T) {
db, closer := connectDB(t, config)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, config.driver)
permission := UserPermissions{
UserID: 2,
PostID: 42,
}
err = c.Insert(ctx, UserPermissionsTable, &permission)
tt.AssertNoErr(t, err)
userPerms, err := getUserPermissionsByUser(db, config.driver, 2)
tt.AssertNoErr(t, err)
// Should retrieve the generated ID from the database,
// only if the database supports returning multiple values:
switch c.dialect.InsertMethod() {
case insertWithNoIDRetrieval, insertWithLastInsertID:
tt.AssertEqual(t, permission.ID, 0)
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].UserID, 2)
tt.AssertEqual(t, userPerms[0].PostID, 42)
case insertWithReturning, insertWithOutput:
tt.AssertNotEqual(t, permission.ID, 0)
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].ID, permission.ID)
tt.AssertEqual(t, userPerms[0].UserID, 2)
tt.AssertEqual(t, userPerms[0].PostID, 42)
}
})
})
})
@ -826,11 +861,11 @@ func TestInsert(t *testing.T) {
})
assert.NotEqual(t, nil, err)
ifUserForgetToExpandList := []interface{}{
cantInsertSlice := []interface{}{
&User{Name: "foo", Age: 22},
&User{Name: "bar", Age: 32},
}
err = c.Insert(ctx, UsersTable, ifUserForgetToExpandList)
err = c.Insert(ctx, UsersTable, cantInsertSlice)
assert.NotEqual(t, nil, err)
// We might want to support this in the future, but not for now:
@ -864,6 +899,28 @@ func TestInsert(t *testing.T) {
assert.NotEqual(t, nil, err)
})
t.Run("should report error if table contains an empty ID name", func(t *testing.T) {
db, closer := connectDB(t, config)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, config.driver)
err := c.Insert(ctx, NewTable("users", ""), &User{Name: "fake-name"})
tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string")
})
t.Run("should report error if ksql.Table.name is empty", func(t *testing.T) {
db, closer := connectDB(t, config)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, config.driver)
err := c.Insert(ctx, NewTable("", "id"), &User{Name: "fake-name"})
tt.AssertErrContains(t, err, "ksql.Table", "table name", "empty string")
})
t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) {
db, closer := connectDB(t, config)
defer closer.Close()
@ -2246,7 +2303,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
dialect := supportedDialects[driver]
rows, err := db.QueryContext(context.TODO(),
`SELECT user_id, post_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0),
`SELECT id, user_id, post_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0),
userID,
)
if err != nil {
@ -2256,7 +2313,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
for rows.Next() {
var userPerm UserPermissions
err := rows.Scan(&userPerm.UserID, &userPerm.PostID)
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID)
if err != nil {
return nil, err
}