mirror of https://github.com/VinGarcia/ksql.git
Add more tests for inserting in tables with composite keys
parent
249d8db409
commit
49f872fb84
39
ksql.go
39
ksql.go
|
@ -182,14 +182,14 @@ func (c DB) Query(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.Err() != nil {
|
if rows.Err() != nil {
|
||||||
return rows.Err()
|
return rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Update the original slice passed by reference:
|
// Update the original slice passed by reference:
|
||||||
slicePtr.Elem().Set(slice)
|
slicePtr.Elem().Set(slice)
|
||||||
|
|
||||||
|
@ -407,7 +407,7 @@ func (c DB) Insert(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, t, v, info, record, table.idColumns...)
|
query, params, scanValues, err := buildInsertQuery(c.dialect, table, t, v, info, record)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -667,19 +667,26 @@ func (c DB) Update(
|
||||||
|
|
||||||
func buildInsertQuery(
|
func buildInsertQuery(
|
||||||
dialect Dialect,
|
dialect Dialect,
|
||||||
tableName string,
|
table Table,
|
||||||
t reflect.Type,
|
t reflect.Type,
|
||||||
v reflect.Value,
|
v reflect.Value,
|
||||||
info structs.StructInfo,
|
info structs.StructInfo,
|
||||||
record interface{},
|
record interface{},
|
||||||
idNames ...string,
|
|
||||||
) (query string, params []interface{}, scanValues []interface{}, err error) {
|
) (query string, params []interface{}, scanValues []interface{}, err error) {
|
||||||
recordMap, err := kstructs.StructToMap(record)
|
recordMap, err := kstructs.StructToMap(record)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, nil, err
|
return "", nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, fieldName := range idNames {
|
if table.name == "" {
|
||||||
|
return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: table name cannot be an empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, fieldName := range table.idColumns {
|
||||||
|
if fieldName == "" {
|
||||||
|
return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: ID columns cannot be empty strings")
|
||||||
|
}
|
||||||
|
|
||||||
field, found := recordMap[fieldName]
|
field, found := recordMap[fieldName]
|
||||||
if !found {
|
if !found {
|
||||||
continue
|
continue
|
||||||
|
@ -721,12 +728,12 @@ func buildInsertQuery(
|
||||||
switch dialect.InsertMethod() {
|
switch dialect.InsertMethod() {
|
||||||
case insertWithReturning:
|
case insertWithReturning:
|
||||||
escapedIDNames := []string{}
|
escapedIDNames := []string{}
|
||||||
for _, id := range idNames {
|
for _, id := range table.idColumns {
|
||||||
escapedIDNames = append(escapedIDNames, dialect.Escape(id))
|
escapedIDNames = append(escapedIDNames, dialect.Escape(id))
|
||||||
}
|
}
|
||||||
returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ")
|
returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ")
|
||||||
|
|
||||||
for _, id := range idNames {
|
for _, id := range table.idColumns {
|
||||||
scanValues = append(
|
scanValues = append(
|
||||||
scanValues,
|
scanValues,
|
||||||
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
||||||
|
@ -734,12 +741,12 @@ func buildInsertQuery(
|
||||||
}
|
}
|
||||||
case insertWithOutput:
|
case insertWithOutput:
|
||||||
escapedIDNames := []string{}
|
escapedIDNames := []string{}
|
||||||
for _, id := range idNames {
|
for _, id := range table.idColumns {
|
||||||
escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id))
|
escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id))
|
||||||
}
|
}
|
||||||
outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ")
|
outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ")
|
||||||
|
|
||||||
for _, id := range idNames {
|
for _, id := range table.idColumns {
|
||||||
scanValues = append(
|
scanValues = append(
|
||||||
scanValues,
|
scanValues,
|
||||||
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
||||||
|
@ -751,7 +758,7 @@ func buildInsertQuery(
|
||||||
// on the selected driver, thus, they might be empty strings.
|
// on the selected driver, thus, they might be empty strings.
|
||||||
query = fmt.Sprintf(
|
query = fmt.Sprintf(
|
||||||
"INSERT INTO %s (%s)%s VALUES (%s)%s",
|
"INSERT INTO %s (%s)%s VALUES (%s)%s",
|
||||||
dialect.Escape(tableName),
|
dialect.Escape(table.name),
|
||||||
strings.Join(escapedColumnNames, ", "),
|
strings.Join(escapedColumnNames, ", "),
|
||||||
outputQuery,
|
outputQuery,
|
||||||
strings.Join(valuesQuery, ", "),
|
strings.Join(valuesQuery, ", "),
|
||||||
|
@ -850,10 +857,10 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ormCopy := c
|
dbCopy := c
|
||||||
ormCopy.db = tx
|
dbCopy.db = tx
|
||||||
|
|
||||||
err = fn(ormCopy)
|
err = fn(dbCopy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rollbackErr := tx.Rollback(ctx)
|
rollbackErr := tx.Rollback(ctx)
|
||||||
if rollbackErr != nil {
|
if rollbackErr != nil {
|
||||||
|
|
73
ksql_test.go
73
ksql_test.go
|
@ -47,9 +47,10 @@ type Post struct {
|
||||||
Title string `ksql:"title"`
|
Title string `ksql:"title"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id")
|
var UserPermissionsTable = NewTable("user_permissions", "id", "user_id", "post_id")
|
||||||
|
|
||||||
type UserPermissions struct {
|
type UserPermissions struct {
|
||||||
|
ID int `ksql:"id"`
|
||||||
UserID int `ksql:"user_id"`
|
UserID int `ksql:"user_id"`
|
||||||
PostID int `ksql:"post_id"`
|
PostID int `ksql:"post_id"`
|
||||||
}
|
}
|
||||||
|
@ -794,9 +795,43 @@ func TestInsert(t *testing.T) {
|
||||||
|
|
||||||
userPerms, err := getUserPermissionsByUser(db, config.driver, 1)
|
userPerms, err := getUserPermissionsByUser(db, config.driver, 1)
|
||||||
tt.AssertNoErr(t, err)
|
tt.AssertNoErr(t, err)
|
||||||
tt.AssertEqual(t, userPerms, []UserPermissions{
|
tt.AssertEqual(t, len(userPerms), 1)
|
||||||
{UserID: 1, PostID: 42},
|
tt.AssertEqual(t, userPerms[0].UserID, 1)
|
||||||
})
|
tt.AssertEqual(t, userPerms[0].PostID, 42)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should accept partially provided values for composite key tables", func(t *testing.T) {
|
||||||
|
db, closer := connectDB(t, config)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, config.driver)
|
||||||
|
|
||||||
|
permission := UserPermissions{
|
||||||
|
UserID: 2,
|
||||||
|
PostID: 42,
|
||||||
|
}
|
||||||
|
err = c.Insert(ctx, UserPermissionsTable, &permission)
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
|
||||||
|
userPerms, err := getUserPermissionsByUser(db, config.driver, 2)
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
|
||||||
|
// Should retrieve the generated ID from the database,
|
||||||
|
// only if the database supports returning multiple values:
|
||||||
|
switch c.dialect.InsertMethod() {
|
||||||
|
case insertWithNoIDRetrieval, insertWithLastInsertID:
|
||||||
|
tt.AssertEqual(t, permission.ID, 0)
|
||||||
|
tt.AssertEqual(t, len(userPerms), 1)
|
||||||
|
tt.AssertEqual(t, userPerms[0].UserID, 2)
|
||||||
|
tt.AssertEqual(t, userPerms[0].PostID, 42)
|
||||||
|
case insertWithReturning, insertWithOutput:
|
||||||
|
tt.AssertNotEqual(t, permission.ID, 0)
|
||||||
|
tt.AssertEqual(t, len(userPerms), 1)
|
||||||
|
tt.AssertEqual(t, userPerms[0].ID, permission.ID)
|
||||||
|
tt.AssertEqual(t, userPerms[0].UserID, 2)
|
||||||
|
tt.AssertEqual(t, userPerms[0].PostID, 42)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -826,11 +861,11 @@ func TestInsert(t *testing.T) {
|
||||||
})
|
})
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
|
|
||||||
ifUserForgetToExpandList := []interface{}{
|
cantInsertSlice := []interface{}{
|
||||||
&User{Name: "foo", Age: 22},
|
&User{Name: "foo", Age: 22},
|
||||||
&User{Name: "bar", Age: 32},
|
&User{Name: "bar", Age: 32},
|
||||||
}
|
}
|
||||||
err = c.Insert(ctx, UsersTable, ifUserForgetToExpandList)
|
err = c.Insert(ctx, UsersTable, cantInsertSlice)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
|
|
||||||
// We might want to support this in the future, but not for now:
|
// We might want to support this in the future, but not for now:
|
||||||
|
@ -864,6 +899,28 @@ func TestInsert(t *testing.T) {
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("should report error if table contains an empty ID name", func(t *testing.T) {
|
||||||
|
db, closer := connectDB(t, config)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, config.driver)
|
||||||
|
|
||||||
|
err := c.Insert(ctx, NewTable("users", ""), &User{Name: "fake-name"})
|
||||||
|
tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should report error if ksql.Table.name is empty", func(t *testing.T) {
|
||||||
|
db, closer := connectDB(t, config)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, config.driver)
|
||||||
|
|
||||||
|
err := c.Insert(ctx, NewTable("", "id"), &User{Name: "fake-name"})
|
||||||
|
tt.AssertErrContains(t, err, "ksql.Table", "table name", "empty string")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) {
|
t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) {
|
||||||
db, closer := connectDB(t, config)
|
db, closer := connectDB(t, config)
|
||||||
defer closer.Close()
|
defer closer.Close()
|
||||||
|
@ -2246,7 +2303,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
|
||||||
dialect := supportedDialects[driver]
|
dialect := supportedDialects[driver]
|
||||||
|
|
||||||
rows, err := db.QueryContext(context.TODO(),
|
rows, err := db.QueryContext(context.TODO(),
|
||||||
`SELECT user_id, post_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0),
|
`SELECT id, user_id, post_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0),
|
||||||
userID,
|
userID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2256,7 +2313,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var userPerm UserPermissions
|
var userPerm UserPermissions
|
||||||
err := rows.Scan(&userPerm.UserID, &userPerm.PostID)
|
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue