feat: [CODE-2232]: Branch Rules: UserGroup support: Create and List (#2640)

* feat: [CODE-2327]: add usergroup change to rebase api
* Merge branch 'main' into akp/CODE-2327
* feat: [CODE-2327]: wrap error
* feat: [CODE-2327]: move resolver inside bypass.go and update tests
* feat: [CODE-2327]: newline
* feat: [CODE-2327]: update wire
* Merge remote-tracking branch 'origin' into akp/CODE-2327
* feat: [CODE-2327]: update tests
* feat: [CODE-2327]: fix build
* feat: [CODE-2327]: merge main
* feat: [CODE-2327]: add and update unit tests
* feat: [CODE-2327]: fix
* feat: [CODE-2312]: introduce parser and enclosing method
* Apply suggestion from code review
* feat: [CODE-2327]: annotate error
* Merge branch 'main' into akp/CODE-2327
* feat: [CODE-2327]: usergroup resolver via controller
* feat: [CODE-2312]: update interface in protection
* feat: [CODE-2312]: export and update deduplication use
* feat: [CODE-2327]: Branch Rules: Allow Group bypass.go
* feat: [CODE-2232]: Branch Rules: UserGroup support: Create and List
CODE-2402
Akhilesh Pandey 2024-09-11 08:04:31 +00:00 committed by Harness
parent 16cb3d5bc1
commit a2cea52155
21 changed files with 189 additions and 100 deletions

View File

@ -55,6 +55,7 @@ func (c *Controller) ListChecks(
}
reqChecks, err := protectionRules.RequiredChecks(ctx, protection.RequiredChecksInput{
ResolveUserGroupID: c.userGroupService.ListUserIDsByGroupIDs,
Actor: &session.Principal,
IsRepoOwner: isRepoOwner,
Repo: repo,

View File

@ -204,6 +204,7 @@ func (c *Controller) Merge(
}
ruleOut, violations, err := protectionRules.MergeVerify(ctx, protection.MergeVerifyInput{
ResolveUserGroupID: c.userGroupService.ListUserIDsByGroupIDs,
Actor: &session.Principal,
AllowBypass: in.BypassRules,
IsRepoOwner: isRepoOwner,

View File

@ -83,6 +83,7 @@ func (c *Controller) CommitFiles(ctx context.Context,
}
violations, err := rules.RefChangeVerify(ctx, protection.RefChangeVerifyInput{
ResolveUserGroupID: c.userGroupService.ListUserIDsByGroupIDs,
Actor: &session.Principal,
AllowBypass: in.BypassRules,
IsRepoOwner: isRepoOwner,

View File

@ -36,6 +36,7 @@ import (
"github.com/harness/gitness/app/services/protection"
"github.com/harness/gitness/app/services/publicaccess"
"github.com/harness/gitness/app/services/settings"
"github.com/harness/gitness/app/services/usergroup"
"github.com/harness/gitness/app/store"
"github.com/harness/gitness/app/url"
"github.com/harness/gitness/audit"
@ -82,6 +83,7 @@ type Controller struct {
settings *settings.Service
principalInfoCache store.PrincipalInfoCache
userGroupStore store.UserGroupStore
userGroupService usergroup.SearchService
protectionManager *protection.Manager
git git.Interface
importer *importer.Repository
@ -127,6 +129,7 @@ func NewController(
labelSvc *label.Service,
instrumentation instrument.Service,
userGroupStore store.UserGroupStore,
userGroupService usergroup.SearchService,
) *Controller {
return &Controller{
defaultBranch: config.Git.DefaultBranch,
@ -156,6 +159,7 @@ func NewController(
labelSvc: labelSvc,
instrumentation: instrumentation,
userGroupStore: userGroupStore,
userGroupService: userGroupService,
}
}
@ -237,12 +241,31 @@ func (c *Controller) fetchRules(
return protectionRules, isRepoOwner, nil
}
func (c *Controller) getRuleUsers(ctx context.Context, r *types.Rule) (map[int64]*types.PrincipalInfo, error) {
rule, err := c.protectionManager.FromJSON(r.Type, r.Definition, false)
func (c *Controller) getRuleUserAndUserGroups(
ctx context.Context,
r *types.Rule,
) (map[int64]*types.PrincipalInfo, map[int64]*types.UserGroupInfo, error) {
rule, err := c.parseRule(r)
if err != nil {
return nil, fmt.Errorf("failed to parse json rule definition: %w", err)
return nil, nil, fmt.Errorf("failed to parse rule: %w", err)
}
userMap, err := c.getRuleUsers(ctx, rule)
if err != nil {
return nil, nil, fmt.Errorf("failed to get rule users: %w", err)
}
userGroupMap, err := c.getRuleUserGroups(ctx, rule)
if err != nil {
return nil, nil, fmt.Errorf("failed to get rule user groups: %w", err)
}
return userMap, userGroupMap, nil
}
func (c *Controller) getRuleUsers(
ctx context.Context,
rule protection.Protection,
) (map[int64]*types.PrincipalInfo, error) {
userIDs, err := rule.UserIDs()
if err != nil {
return nil, fmt.Errorf("failed to get user ID from rule: %w", err)
@ -256,12 +279,10 @@ func (c *Controller) getRuleUsers(ctx context.Context, r *types.Rule) (map[int64
return userMap, nil
}
func (c *Controller) getRuleUserGroups(ctx context.Context, r *types.Rule) (map[int64]*types.UserGroupInfo, error) {
rule, err := c.protectionManager.FromJSON(r.Type, r.Definition, false)
if err != nil {
return nil, fmt.Errorf("failed to parse json rule definition: %w", err)
}
func (c *Controller) getRuleUserGroups(
ctx context.Context,
rule protection.Protection,
) (map[int64]*types.UserGroupInfo, error) {
groupIDs, err := rule.UserGroupIDs()
if err != nil {
return nil, fmt.Errorf("failed to get group IDs from rule: %w", err)
@ -283,3 +304,12 @@ func (c *Controller) getRuleUserGroups(ctx context.Context, r *types.Rule) (map[
}
return userGroupInfoMap, nil
}
func (c *Controller) parseRule(r *types.Rule) (protection.Protection, error) {
rule, err := c.protectionManager.FromJSON(r.Type, r.Definition, false)
if err != nil {
return nil, fmt.Errorf("failed to parse json rule definition: %w", err)
}
return rule, nil
}

View File

@ -82,6 +82,7 @@ func (c *Controller) Rebase(
}
violations, err := protectionRules.RefChangeVerify(ctx, protection.RefChangeVerifyInput{
ResolveUserGroupID: c.userGroupService.ListUserIDsByGroupIDs,
Actor: &session.Principal,
AllowBypass: in.BypassRules,
IsRepoOwner: isRepoOwner,

View File

@ -142,15 +142,13 @@ func (c *Controller) RuleCreate(ctx context.Context,
log.Ctx(ctx).Warn().Msgf("failed to insert instrumentation record for create branch rule operation: %s", err)
}
r.Users, err = c.getRuleUsers(ctx, r)
userMap, userGroupMap, err := c.getRuleUserAndUserGroups(ctx, r)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get rule users and user groups: %w", err)
}
r.UserGroups, err = c.getRuleUserGroups(ctx, r)
if err != nil {
return nil, fmt.Errorf("failed to get rule user groups: %w", err)
}
r.Users = userMap
r.UserGroups = userGroupMap
return r, nil
}

View File

@ -39,15 +39,13 @@ func (c *Controller) RuleFind(ctx context.Context,
return nil, fmt.Errorf("failed to find repository-level protection rule by identifier: %w", err)
}
r.Users, err = c.getRuleUsers(ctx, r)
userMap, userGroupMap, err := c.getRuleUserAndUserGroups(ctx, r)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get rule users and user groups: %w", err)
}
r.UserGroups, err = c.getRuleUserGroups(ctx, r)
if err != nil {
return nil, fmt.Errorf("failed to get rule user groups: %w", err)
}
r.Users = userMap
r.UserGroups = userGroupMap
return r, nil
}

View File

@ -61,7 +61,7 @@ func (c *Controller) RuleList(ctx context.Context,
}
for i := range list {
list[i].Users, err = c.getRuleUsers(ctx, &list[i])
list[i].Users, list[i].UserGroups, err = c.getRuleUserAndUserGroups(ctx, &list[i])
if err != nil {
return nil, 0, err
}

View File

@ -102,14 +102,14 @@ func (c *Controller) RuleUpdate(ctx context.Context,
}
oldRule := r.Clone()
if in.isEmpty() {
r.Users, err = c.getRuleUsers(ctx, r)
userMap, userGroupMap, err := c.getRuleUserAndUserGroups(ctx, r)
if err != nil {
return nil, err
}
r.UserGroups, err = c.getRuleUserGroups(ctx, r)
if err != nil {
return nil, fmt.Errorf("failed to get rule user groups: %w", err)
return nil, fmt.Errorf("failed to get rule users and user groups: %w", err)
}
r.Users = userMap
r.UserGroups = userGroupMap
return r, nil
}
@ -132,15 +132,13 @@ func (c *Controller) RuleUpdate(ctx context.Context,
}
}
r.Users, err = c.getRuleUsers(ctx, r)
userMap, userGroupMap, err := c.getRuleUserAndUserGroups(ctx, r)
if err != nil {
return nil, fmt.Errorf("failed to get rule users: %w", err)
return nil, fmt.Errorf("failed to get rule users and user groups: %w", err)
}
r.UserGroups, err = c.getRuleUserGroups(ctx, r)
if err != nil {
return nil, fmt.Errorf("failed to get rule user groups: %w", err)
}
r.Users = userMap
r.UserGroups = userGroupMap
err = c.ruleStore.Update(ctx, r)
if err != nil {

View File

@ -27,6 +27,7 @@ import (
"github.com/harness/gitness/app/services/protection"
"github.com/harness/gitness/app/services/publicaccess"
"github.com/harness/gitness/app/services/settings"
"github.com/harness/gitness/app/services/usergroup"
"github.com/harness/gitness/app/store"
"github.com/harness/gitness/app/url"
"github.com/harness/gitness/audit"
@ -72,13 +73,14 @@ func ProvideController(
labelSvc *label.Service,
instrumentation instrument.Service,
userGroupStore store.UserGroupStore,
userGroupService usergroup.SearchService,
) *Controller {
return NewController(config, tx, urlProvider,
authorizer,
repoStore, spaceStore, pipelineStore,
principalStore, ruleStore, settings, principalInfoCache, protectionManager, rpcClient, importer,
codeOwners, reporeporter, indexer, limiter, locker, auditService, mtxManager, identifierCheck,
repoChecks, publicAccess, labelSvc, instrumentation, userGroupStore)
repoChecks, publicAccess, labelSvc, instrumentation, userGroupStore, userGroupService)
}
func ProvideRepoCheck() Check {

View File

@ -15,8 +15,10 @@
package protection
import (
"context"
"fmt"
"github.com/harness/gitness/cache"
"github.com/harness/gitness/types"
"golang.org/x/exp/slices"
@ -28,7 +30,21 @@ type DefBypass struct {
RepoOwners bool `json:"repo_owners,omitempty"`
}
func (v DefBypass) matches(actor *types.Principal, isRepoOwner bool) bool {
func (v DefBypass) matches(
ctx context.Context,
actor *types.Principal,
isRepoOwner bool,
userGroupResolverFn func(context.Context, []int64) ([]int64, error),
) bool {
if userGroupResolverFn != nil {
userIDs, err := userGroupResolverFn(ctx, v.UserGroupIDs)
if err != nil {
return false
}
v.UserIDs = append(v.UserIDs, userIDs...)
v.UserIDs = cache.Deduplicate(v.UserIDs)
}
return actor != nil &&
(v.RepoOwners && isRepoOwner ||
slices.Contains(v.UserIDs, actor.ID))

View File

@ -15,6 +15,7 @@
package protection
import (
"context"
"testing"
"github.com/harness/gitness/types"
@ -78,7 +79,7 @@ func TestBranch_matches(t *testing.T) {
t.Errorf("invalid: %s", err.Error())
}
if want, got := test.exp, test.bypass.matches(test.actor, test.owner); want != got {
if want, got := test.exp, test.bypass.matches(context.TODO(), test.actor, test.owner, nil); want != got {
t.Errorf("want=%t got=%t", want, got)
}
})

View File

@ -41,10 +41,10 @@ func (v *Branch) MergeVerify(
) (out MergeVerifyOutput, violations []types.RuleViolations, err error) {
out, violations, err = v.PullReq.MergeVerify(ctx, in)
if err != nil {
return
return out, violations, fmt.Errorf("merge verify error: %w", err)
}
bypassable := v.Bypass.matches(in.Actor, in.IsRepoOwner)
bypassable := v.Bypass.matches(ctx, in.Actor, in.IsRepoOwner, in.ResolveUserGroupID)
bypassed := in.AllowBypass && bypassable
for i := range violations {
violations[i].Bypassable = bypassable
@ -73,7 +73,7 @@ func (v *Branch) RequiredChecks(
bypassableIDs map[string]struct{}
)
if bypassable := v.Bypass.matches(in.Actor, in.IsRepoOwner); bypassable {
if bypassable := v.Bypass.matches(ctx, in.Actor, in.IsRepoOwner, in.ResolveUserGroupID); bypassable {
bypassableIDs = ids
} else {
requiredIDs = ids
@ -94,8 +94,11 @@ func (v *Branch) RefChangeVerify(
}
violations, err = v.Lifecycle.RefChangeVerify(ctx, in)
if err != nil {
return nil, fmt.Errorf("lifecycle error: %w", err)
}
bypassable := v.Bypass.matches(in.Actor, in.IsRepoOwner)
bypassable := v.Bypass.matches(ctx, in.Actor, in.IsRepoOwner, in.ResolveUserGroupID)
bypassed := in.AllowBypass && bypassable
for i := range violations {
violations[i].Bypassable = bypassable

View File

@ -183,6 +183,7 @@ func TestBranch_MergeVerify(t *testing.T) {
},
in: MergeVerifyInput{
Actor: user,
ResolveUserGroupID: mockUserGroupResolver,
CodeOwners: &codeowners.Evaluation{},
PullReq: &types.PullReq{},
Reviewers: []*types.PullReqReviewer{},
@ -480,6 +481,31 @@ func TestBranch_RefChangeVerify(t *testing.T) {
},
},
},
{
name: "usergroup-bypass",
branch: Branch{
Bypass: DefBypass{RepoOwners: true},
Lifecycle: DefLifecycle{DeleteForbidden: true},
},
in: RefChangeVerifyInput{
Actor: &types.Principal{ID: 43},
ResolveUserGroupID: mockUserGroupResolver,
AllowBypass: true,
IsRepoOwner: false,
RefAction: RefActionDelete,
RefType: RefTypeBranch,
RefNames: []string{"abc"},
},
expVs: []types.RuleViolations{
{
Bypassable: true,
Bypassed: true,
Violations: []types.Violation{
{Code: codeLifecycleDelete},
},
},
},
},
}
ctx := context.Background()
@ -527,3 +553,7 @@ func TestBranch_RefChangeVerify(t *testing.T) {
})
}
}
func mockUserGroupResolver(_ context.Context, _ []int64) ([]int64, error) {
return []int64{43}, nil
}

View File

@ -26,6 +26,7 @@ type (
}
RefChangeVerifyInput struct {
ResolveUserGroupID func(ctx context.Context, userGroupIDs []int64) ([]int64, error)
Actor *types.Principal
AllowBypass bool
IsRepoOwner bool

View File

@ -35,6 +35,7 @@ type (
}
MergeVerifyInput struct {
ResolveUserGroupID func(ctx context.Context, userGroupIDs []int64) ([]int64, error)
Actor *types.Principal
AllowBypass bool
IsRepoOwner bool
@ -59,6 +60,7 @@ type (
}
RequiredChecksInput struct {
ResolveUserGroupID func(ctx context.Context, userGroupIDs []int64) ([]int64, error)
Actor *types.Principal
IsRepoOwner bool
Repo *types.Repository

View File

@ -29,3 +29,7 @@ func (s *searchService) ListUsers(
) ([]string, error) {
return nil, fmt.Errorf("not implemented")
}
func (s *searchService) ListUserIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
return nil, fmt.Errorf("not implemented")
}

View File

@ -32,4 +32,6 @@ type SearchService interface {
session *auth.Session,
userGroup *types.UserGroup,
) ([]string, error)
ListUserIDsByGroupIDs(ctx context.Context, userGroupIDs []int64) ([]int64, error)
}

2
cache/cache_test.go vendored
View File

@ -59,7 +59,7 @@ func TestDeduplicate(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test.input = deduplicate(test.input)
test.input = Deduplicate(test.input)
if want, got := test.expected, test.input; !reflect.DeepEqual(want, got) {
t.Errorf("failed - want=%v, got=%v", want, got)
return

6
cache/ttl_cache.go vendored
View File

@ -140,7 +140,7 @@ func (c *ExtendedTTLCache[K, V]) Map(ctx context.Context, keys []K) (map[K]V, er
m := make(map[K]V)
now := time.Now()
keys = deduplicate(keys)
keys = Deduplicate(keys)
// Check what's already available in the cache.
@ -211,8 +211,8 @@ func (c *TTLCache[K, V]) Get(ctx context.Context, key K) (V, error) {
return item, nil
}
// deduplicate is a utility function that removes duplicates from slice.
func deduplicate[V constraints.Ordered](slice []V) []V {
// Deduplicate is a utility function that removes duplicates from slice.
func Deduplicate[V constraints.Ordered](slice []V) []V {
if len(slice) <= 1 {
return slice
}

View File

@ -245,7 +245,8 @@ func initSystem(ctx context.Context, config *types.Config) (*server.System, erro
labelService := label.ProvideLabel(transactor, spaceStore, labelStore, labelValueStore, pullReqLabelAssignmentStore)
instrumentService := instrument.ProvideService()
userGroupStore := database.ProvideUserGroupStore(db)
repoController := repo.ProvideController(config, transactor, provider, authorizer, repoStore, spaceStore, pipelineStore, principalStore, ruleStore, settingsService, principalInfoCache, protectionManager, gitInterface, repository, codeownersService, reporter, indexer, resourceLimiter, lockerLocker, auditService, mutexManager, repoIdentifier, repoCheck, publicaccessService, labelService, instrumentService, userGroupStore)
searchService := usergroup.ProvideSearchService()
repoController := repo.ProvideController(config, transactor, provider, authorizer, repoStore, spaceStore, pipelineStore, principalStore, ruleStore, settingsService, principalInfoCache, protectionManager, gitInterface, repository, codeownersService, reporter, indexer, resourceLimiter, lockerLocker, auditService, mutexManager, repoIdentifier, repoCheck, publicaccessService, labelService, instrumentService, userGroupStore, searchService)
reposettingsController := reposettings.ProvideController(authorizer, repoStore, settingsService, auditService)
executionStore := database.ProvideExecutionStore(db)
checkStore := database.ProvideCheckStore(db, principalInfoCache)
@ -332,7 +333,6 @@ func initSystem(ctx context.Context, config *types.Config) (*server.System, erro
return nil, err
}
pullReq := migrate.ProvidePullReqImporter(provider, gitInterface, principalStore, repoStore, pullReqStore, pullReqActivityStore, transactor)
searchService := usergroup.ProvideSearchService()
pullreqController := pullreq2.ProvideController(transactor, provider, authorizer, pullReqStore, pullReqActivityStore, codeCommentView, pullReqReviewStore, pullReqReviewerStore, repoStore, principalStore, userGroupStore, userGroupReviewersStore, principalInfoCache, pullReqFileViewStore, membershipStore, checkStore, gitInterface, reporter3, migrator, pullreqService, listService, protectionManager, streamer, codeownersService, lockerLocker, pullReq, labelService, instrumentService, searchService)
webhookConfig := server.ProvideWebhookConfig(config)
webhookStore := database.ProvideWebhookStore(db)