diff --git a/ksql.go b/ksql.go index 12c66cd..f74f699 100644 --- a/ksql.go +++ b/ksql.go @@ -182,14 +182,14 @@ func (c DB) Query( } } - if err := rows.Close(); err != nil { - return err - } - if rows.Err() != nil { return rows.Err() } + if err := rows.Close(); err != nil { + return err + } + // Update the original slice passed by reference: slicePtr.Elem().Set(slice) @@ -407,7 +407,7 @@ func (c DB) Insert( return err } - query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, t, v, info, record, table.idColumns...) + query, params, scanValues, err := buildInsertQuery(c.dialect, table, t, v, info, record) if err != nil { return err } @@ -667,19 +667,26 @@ func (c DB) Update( func buildInsertQuery( dialect Dialect, - tableName string, + table Table, t reflect.Type, v reflect.Value, info structs.StructInfo, record interface{}, - idNames ...string, ) (query string, params []interface{}, scanValues []interface{}, err error) { recordMap, err := kstructs.StructToMap(record) if err != nil { return "", nil, nil, err } - for _, fieldName := range idNames { + if table.name == "" { + return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: table name cannot be an empty string") + } + + for _, fieldName := range table.idColumns { + if fieldName == "" { + return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: ID columns cannot be empty strings") + } + field, found := recordMap[fieldName] if !found { continue @@ -721,12 +728,12 @@ func buildInsertQuery( switch dialect.InsertMethod() { case insertWithReturning: escapedIDNames := []string{} - for _, id := range idNames { + for _, id := range table.idColumns { escapedIDNames = append(escapedIDNames, dialect.Escape(id)) } returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ") - for _, id := range idNames { + for _, id := range table.idColumns { scanValues = append( scanValues, v.Elem().Field(info.ByName(id).Index).Addr().Interface(), @@ -734,12 +741,12 @@ func buildInsertQuery( } case insertWithOutput: escapedIDNames := []string{} - for _, id := range idNames { + for _, id := range table.idColumns { escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id)) } outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ") - for _, id := range idNames { + for _, id := range table.idColumns { scanValues = append( scanValues, v.Elem().Field(info.ByName(id).Index).Addr().Interface(), @@ -751,7 +758,7 @@ func buildInsertQuery( // on the selected driver, thus, they might be empty strings. query = fmt.Sprintf( "INSERT INTO %s (%s)%s VALUES (%s)%s", - dialect.Escape(tableName), + dialect.Escape(table.name), strings.Join(escapedColumnNames, ", "), outputQuery, strings.Join(valuesQuery, ", "), @@ -850,10 +857,10 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { } }() - ormCopy := c - ormCopy.db = tx + dbCopy := c + dbCopy.db = tx - err = fn(ormCopy) + err = fn(dbCopy) if err != nil { rollbackErr := tx.Rollback(ctx) if rollbackErr != nil { diff --git a/ksql_test.go b/ksql_test.go index f69ff6a..ed22392 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -47,9 +47,10 @@ type Post struct { Title string `ksql:"title"` } -var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id") +var UserPermissionsTable = NewTable("user_permissions", "id", "user_id", "post_id") type UserPermissions struct { + ID int `ksql:"id"` UserID int `ksql:"user_id"` PostID int `ksql:"post_id"` } @@ -794,9 +795,43 @@ func TestInsert(t *testing.T) { userPerms, err := getUserPermissionsByUser(db, config.driver, 1) tt.AssertNoErr(t, err) - tt.AssertEqual(t, userPerms, []UserPermissions{ - {UserID: 1, PostID: 42}, - }) + tt.AssertEqual(t, len(userPerms), 1) + tt.AssertEqual(t, userPerms[0].UserID, 1) + tt.AssertEqual(t, userPerms[0].PostID, 42) + }) + + t.Run("should accept partially provided values for composite key tables", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + permission := UserPermissions{ + UserID: 2, + PostID: 42, + } + err = c.Insert(ctx, UserPermissionsTable, &permission) + tt.AssertNoErr(t, err) + + userPerms, err := getUserPermissionsByUser(db, config.driver, 2) + tt.AssertNoErr(t, err) + + // Should retrieve the generated ID from the database, + // only if the database supports returning multiple values: + switch c.dialect.InsertMethod() { + case insertWithNoIDRetrieval, insertWithLastInsertID: + tt.AssertEqual(t, permission.ID, 0) + tt.AssertEqual(t, len(userPerms), 1) + tt.AssertEqual(t, userPerms[0].UserID, 2) + tt.AssertEqual(t, userPerms[0].PostID, 42) + case insertWithReturning, insertWithOutput: + tt.AssertNotEqual(t, permission.ID, 0) + tt.AssertEqual(t, len(userPerms), 1) + tt.AssertEqual(t, userPerms[0].ID, permission.ID) + tt.AssertEqual(t, userPerms[0].UserID, 2) + tt.AssertEqual(t, userPerms[0].PostID, 42) + } }) }) }) @@ -826,11 +861,11 @@ func TestInsert(t *testing.T) { }) assert.NotEqual(t, nil, err) - ifUserForgetToExpandList := []interface{}{ + cantInsertSlice := []interface{}{ &User{Name: "foo", Age: 22}, &User{Name: "bar", Age: 32}, } - err = c.Insert(ctx, UsersTable, ifUserForgetToExpandList) + err = c.Insert(ctx, UsersTable, cantInsertSlice) assert.NotEqual(t, nil, err) // We might want to support this in the future, but not for now: @@ -864,6 +899,28 @@ func TestInsert(t *testing.T) { assert.NotEqual(t, nil, err) }) + t.Run("should report error if table contains an empty ID name", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Insert(ctx, NewTable("users", ""), &User{Name: "fake-name"}) + tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string") + }) + + t.Run("should report error if ksql.Table.name is empty", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Insert(ctx, NewTable("", "id"), &User{Name: "fake-name"}) + tt.AssertErrContains(t, err, "ksql.Table", "table name", "empty string") + }) + t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) { db, closer := connectDB(t, config) defer closer.Close() @@ -2246,7 +2303,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results dialect := supportedDialects[driver] rows, err := db.QueryContext(context.TODO(), - `SELECT user_id, post_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0), + `SELECT id, user_id, post_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0), userID, ) if err != nil { @@ -2256,7 +2313,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results for rows.Next() { var userPerm UserPermissions - err := rows.Scan(&userPerm.UserID, &userPerm.PostID) + err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID) if err != nil { return nil, err }