diff --git a/ksql.go b/ksql.go index 5d02154..c420c51 100644 --- a/ksql.go +++ b/ksql.go @@ -609,9 +609,14 @@ func normalizeIDsAsMap(idNames []string, idOrMap interface{}) (idMap map[string] } } - for _, id := range idNames { - if _, found := idMap[id]; !found { - return nil, fmt.Errorf("missing required id field `%s` on input record", id) + for _, idName := range idNames { + id, found := idMap[idName] + if !found { + return nil, fmt.Errorf("missing required id field `%s` on input record", idName) + } + + if id == nil || reflect.ValueOf(id).IsZero() { + return nil, fmt.Errorf("invalid value '%v' received for id column: '%s'", id, idName) } } diff --git a/ksql_test.go b/ksql_test.go index f2c8cbb..6b14912 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -47,12 +47,12 @@ type Post struct { Title string `ksql:"title"` } -var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id") +var UserPermissionsTable = NewTable("user_permissions", "user_id", "perm_id") type UserPermission struct { ID int `ksql:"id"` UserID int `ksql:"user_id"` - PostID int `ksql:"post_id"` + PermID int `ksql:"perm_id"` } type testConfig struct { @@ -787,10 +787,10 @@ func TestInsert(t *testing.T) { ctx := context.Background() c := newTestDB(db, config.driver) - table := NewTable("user_permissions", "id", "user_id", "post_id") + table := NewTable("user_permissions", "id", "user_id", "perm_id") err = c.Insert(ctx, table, &UserPermission{ UserID: 1, - PostID: 42, + PermID: 42, }) tt.AssertNoErr(t, err) @@ -798,7 +798,7 @@ func TestInsert(t *testing.T) { tt.AssertNoErr(t, err) tt.AssertEqual(t, len(userPerms), 1) tt.AssertEqual(t, userPerms[0].UserID, 1) - tt.AssertEqual(t, userPerms[0].PostID, 42) + tt.AssertEqual(t, userPerms[0].PermID, 42) }) t.Run("should accept partially provided values for composite key tables", func(t *testing.T) { @@ -810,10 +810,10 @@ func TestInsert(t *testing.T) { // Table defined with 3 values, but we'll provide only 2, // the third will be generated for the purposes of this test: - table := NewTable("user_permissions", "id", "user_id", "post_id") + table := NewTable("user_permissions", "id", "user_id", "perm_id") permission := UserPermission{ UserID: 2, - PostID: 42, + PermID: 42, } err = c.Insert(ctx, table, &permission) tt.AssertNoErr(t, err) @@ -828,13 +828,13 @@ func TestInsert(t *testing.T) { tt.AssertEqual(t, permission.ID, 0) tt.AssertEqual(t, len(userPerms), 1) tt.AssertEqual(t, userPerms[0].UserID, 2) - tt.AssertEqual(t, userPerms[0].PostID, 42) + tt.AssertEqual(t, userPerms[0].PermID, 42) case insertWithReturning, insertWithOutput: tt.AssertNotEqual(t, permission.ID, 0) tt.AssertEqual(t, len(userPerms), 1) tt.AssertEqual(t, userPerms[0].ID, permission.ID) tt.AssertEqual(t, userPerms[0].UserID, 2) - tt.AssertEqual(t, userPerms[0].PostID, 42) + tt.AssertEqual(t, userPerms[0].PermID, 42) } }) }) @@ -1096,7 +1096,7 @@ func TestDelete(t *testing.T) { // This permission should not be deleted, we'll use the id to check it: p0 := UserPermission{ UserID: 1, - PostID: 44, + PermID: 44, } err = c.Insert(ctx, NewTable("user_permissions", "id"), &p0) tt.AssertNoErr(t, err) @@ -1104,7 +1104,7 @@ func TestDelete(t *testing.T) { p1 := UserPermission{ UserID: 1, - PostID: 42, + PermID: 42, } err = c.Insert(ctx, NewTable("user_permissions", "id"), &p1) tt.AssertNoErr(t, err) @@ -1116,7 +1116,7 @@ func TestDelete(t *testing.T) { tt.AssertNoErr(t, err) tt.AssertEqual(t, len(userPerms), 1) tt.AssertEqual(t, userPerms[0].UserID, 1) - tt.AssertEqual(t, userPerms[0].PostID, 44) + tt.AssertEqual(t, userPerms[0].PermID, 44) }) t.Run("using maps", func(t *testing.T) { @@ -1129,7 +1129,7 @@ func TestDelete(t *testing.T) { // This permission should not be deleted, we'll use the id to check it: p0 := UserPermission{ UserID: 2, - PostID: 44, + PermID: 44, } err = c.Insert(ctx, NewTable("user_permissions", "id"), &p0) tt.AssertNoErr(t, err) @@ -1137,14 +1137,14 @@ func TestDelete(t *testing.T) { p1 := UserPermission{ UserID: 2, - PostID: 42, + PermID: 42, } err = c.Insert(ctx, NewTable("user_permissions", "id"), &p1) tt.AssertNoErr(t, err) err = c.Delete(ctx, UserPermissionsTable, map[string]interface{}{ "user_id": 2, - "post_id": 42, + "perm_id": 42, }) tt.AssertNoErr(t, err) @@ -1152,7 +1152,7 @@ func TestDelete(t *testing.T) { tt.AssertNoErr(t, err) tt.AssertEqual(t, len(userPerms), 1) tt.AssertEqual(t, userPerms[0].UserID, 2) - tt.AssertEqual(t, userPerms[0].PostID, 44) + tt.AssertEqual(t, userPerms[0].PermID, 44) }) }) @@ -1179,6 +1179,103 @@ func TestDelete(t *testing.T) { assert.NotEqual(t, nil, err) }) + t.Run("should report error if one of the ids is missing from the input", func(t *testing.T) { + t.Run("single id", func(t *testing.T) { + t.Run("struct with missing attr", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Delete(ctx, NewTable("users", "id"), &struct { + // Missing ID + Name string `ksql:"name"` + }{Name: "fake-name"}) + tt.AssertErrContains(t, err, "missing required", "id") + }) + + t.Run("struct with NULL attr", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Delete(ctx, NewTable("users", "id"), &struct { + // Null ID + ID *int `ksql:"id"` + Name string `ksql:"name"` + }{Name: "fake-name"}) + tt.AssertErrContains(t, err, "missing required", "id") + }) + + t.Run("struct with zero attr", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Delete(ctx, NewTable("users", "id"), &struct { + // Uninitialized ID + ID int `ksql:"id"` + Name string `ksql:"name"` + }{Name: "fake-name"}) + tt.AssertErrContains(t, err, "invalid value", "0", "id") + }) + }) + + t.Run("multiple ids", func(t *testing.T) { + t.Run("struct with missing attr", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Delete(ctx, NewTable("user_permissions", "user_id", "perm_id"), map[string]interface{}{ + // Missing PermID + "user_id": 1, + "name": "fake-name", + }) + tt.AssertErrContains(t, err, "missing required", "perm_id") + }) + + t.Run("struct with NULL attr", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Delete(ctx, NewTable("user_permissions", "user_id", "perm_id"), map[string]interface{}{ + // Null Perm ID + "user_id": 1, + "perm_id": nil, + "name": "fake-name", + }) + tt.AssertErrContains(t, err, "invalid value", "nil", "perm_id") + }) + + t.Run("struct with zero attr", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Delete(ctx, NewTable("user_permissions", "user_id", "perm_id"), map[string]interface{}{ + // Zero Perm ID + "user_id": 1, + "perm_id": 0, + "name": "fake-name", + }) + tt.AssertErrContains(t, err, "invalid value", "0", "perm_id") + }) + }) + }) + t.Run("should report error if table contains an empty ID name", func(t *testing.T) { db, closer := connectDB(t, config) defer closer.Close() @@ -1186,7 +1283,7 @@ func TestDelete(t *testing.T) { ctx := context.Background() c := newTestDB(db, config.driver) - err := c.Delete(ctx, NewTable("users", ""), &User{Name: "fake-name"}) + err := c.Delete(ctx, NewTable("users", ""), &User{ID: 42, Name: "fake-name"}) tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string") }) @@ -2222,31 +2319,31 @@ func createTables(driver string) error { switch driver { case "sqlite3": _, err = db.Exec(`CREATE TABLE user_permissions ( - id INTEGER PRIMARY KEY, - user_id INTEGER, - post_id INTEGER, - UNIQUE (user_id, post_id) + id INTEGER PRIMARY KEY, + user_id INTEGER, + perm_id INTEGER, + UNIQUE (user_id, perm_id) )`) case "postgres": _, err = db.Exec(`CREATE TABLE user_permissions ( - id serial PRIMARY KEY, + id serial PRIMARY KEY, user_id INT, - post_id INT, - UNIQUE (user_id, post_id) + perm_id INT, + UNIQUE (user_id, perm_id) )`) case "mysql": _, err = db.Exec(`CREATE TABLE user_permissions ( id INT AUTO_INCREMENT PRIMARY KEY, user_id INT, - post_id INT, - UNIQUE KEY (user_id, post_id) + perm_id INT, + UNIQUE KEY (user_id, perm_id) )`) case "sqlserver": _, err = db.Exec(`CREATE TABLE user_permissions ( id INT IDENTITY(1,1) PRIMARY KEY, user_id INT, - post_id INT, - CONSTRAINT unique_1 UNIQUE (user_id, post_id) + perm_id INT, + CONSTRAINT unique_1 UNIQUE (user_id, perm_id) )`) } if err != nil { @@ -2400,7 +2497,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results dialect := supportedDialects[driver] rows, err := db.QueryContext(context.TODO(), - `SELECT id, user_id, post_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0), + `SELECT id, user_id, perm_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0), userID, ) if err != nil { @@ -2410,7 +2507,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results for rows.Next() { var userPerm UserPermission - err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID) + err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PermID) if err != nil { return nil, err }