diff --git a/app/api/controller/repo/controller.go b/app/api/controller/repo/controller.go index 1fb2bf094..61ea3c564 100644 --- a/app/api/controller/repo/controller.go +++ b/app/api/controller/repo/controller.go @@ -81,6 +81,7 @@ type Controller struct { ruleStore store.RuleStore settings *settings.Service principalInfoCache store.PrincipalInfoCache + userGroupStore store.UserGroupStore protectionManager *protection.Manager git git.Interface importer *importer.Repository @@ -125,6 +126,7 @@ func NewController( publicAccess publicaccess.Service, labelSvc *label.Service, instrumentation instrument.Service, + userGroupStore store.UserGroupStore, ) *Controller { return &Controller{ defaultBranch: config.Git.DefaultBranch, @@ -153,6 +155,7 @@ func NewController( publicAccess: publicAccess, labelSvc: labelSvc, instrumentation: instrumentation, + userGroupStore: userGroupStore, } } @@ -252,3 +255,31 @@ func (c *Controller) getRuleUsers(ctx context.Context, r *types.Rule) (map[int64 return userMap, nil } + +func (c *Controller) getRuleUserGroups(ctx context.Context, r *types.Rule) (map[int64]*types.UserGroupInfo, error) { + rule, err := c.protectionManager.FromJSON(r.Type, r.Definition, false) + if err != nil { + return nil, fmt.Errorf("failed to parse json rule definition: %w", err) + } + + groupIDs, err := rule.UserGroupIDs() + if err != nil { + return nil, fmt.Errorf("failed to get group IDs from rule: %w", err) + } + + userGroupInfoMap := make(map[int64]*types.UserGroupInfo) + + if len(groupIDs) == 0 { + return userGroupInfoMap, nil + } + + groupMap, err := c.userGroupStore.Map(ctx, groupIDs) + if err != nil { + return nil, fmt.Errorf("failed to get userGroup infos: %w", err) + } + + for k, v := range groupMap { + userGroupInfoMap[k] = v.ToUserGroupInfo() + } + return userGroupInfoMap, nil +} diff --git a/app/api/controller/repo/rule_create.go b/app/api/controller/repo/rule_create.go index 8d9315287..277d52569 100644 --- a/app/api/controller/repo/rule_create.go +++ b/app/api/controller/repo/rule_create.go @@ -147,5 +147,10 @@ func (c *Controller) RuleCreate(ctx context.Context, return nil, err } + r.UserGroups, err = c.getRuleUserGroups(ctx, r) + if err != nil { + return nil, fmt.Errorf("failed to get rule user groups: %w", err) + } + return r, nil } diff --git a/app/api/controller/repo/rule_find.go b/app/api/controller/repo/rule_find.go index 53d394964..a80b13632 100644 --- a/app/api/controller/repo/rule_find.go +++ b/app/api/controller/repo/rule_find.go @@ -44,5 +44,10 @@ func (c *Controller) RuleFind(ctx context.Context, return nil, err } + r.UserGroups, err = c.getRuleUserGroups(ctx, r) + if err != nil { + return nil, fmt.Errorf("failed to get rule user groups: %w", err) + } + return r, nil } diff --git a/app/api/controller/repo/rule_update.go b/app/api/controller/repo/rule_update.go index b0c23e527..fb666782f 100644 --- a/app/api/controller/repo/rule_update.go +++ b/app/api/controller/repo/rule_update.go @@ -106,6 +106,10 @@ func (c *Controller) RuleUpdate(ctx context.Context, if err != nil { return nil, err } + r.UserGroups, err = c.getRuleUserGroups(ctx, r) + if err != nil { + return nil, fmt.Errorf("failed to get rule user groups: %w", err) + } return r, nil } @@ -130,7 +134,12 @@ func (c *Controller) RuleUpdate(ctx context.Context, r.Users, err = c.getRuleUsers(ctx, r) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get rule users: %w", err) + } + + r.UserGroups, err = c.getRuleUserGroups(ctx, r) + if err != nil { + return nil, fmt.Errorf("failed to get rule user groups: %w", err) } err = c.ruleStore.Update(ctx, r) diff --git a/app/api/controller/repo/wire.go b/app/api/controller/repo/wire.go index 8ef541df3..b4dd34d1e 100644 --- a/app/api/controller/repo/wire.go +++ b/app/api/controller/repo/wire.go @@ -71,13 +71,14 @@ func ProvideController( publicAccess publicaccess.Service, labelSvc *label.Service, instrumentation instrument.Service, + userGroupStore store.UserGroupStore, ) *Controller { return NewController(config, tx, urlProvider, authorizer, repoStore, spaceStore, pipelineStore, principalStore, ruleStore, settings, principalInfoCache, protectionManager, rpcClient, importer, codeOwners, reporeporter, indexer, limiter, locker, auditService, mtxManager, identifierCheck, - repoChecks, publicAccess, labelSvc, instrumentation) + repoChecks, publicAccess, labelSvc, instrumentation, userGroupStore) } func ProvideRepoCheck() Check { diff --git a/app/services/protection/bypass.go b/app/services/protection/bypass.go index cf30ed441..991920431 100644 --- a/app/services/protection/bypass.go +++ b/app/services/protection/bypass.go @@ -23,8 +23,9 @@ import ( ) type DefBypass struct { - UserIDs []int64 `json:"user_ids,omitempty"` - RepoOwners bool `json:"repo_owners,omitempty"` + UserIDs []int64 `json:"user_ids,omitempty"` + UserGroupIDs []int64 `json:"user_group_ids,omitempty"` + RepoOwners bool `json:"repo_owners,omitempty"` } func (v DefBypass) matches(actor *types.Principal, isRepoOwner bool) bool { diff --git a/app/services/protection/rule_branch.go b/app/services/protection/rule_branch.go index 36c65a4a7..6b0357ef4 100644 --- a/app/services/protection/rule_branch.go +++ b/app/services/protection/rule_branch.go @@ -109,6 +109,10 @@ func (v *Branch) UserIDs() ([]int64, error) { return v.Bypass.UserIDs, nil } +func (v *Branch) UserGroupIDs() ([]int64, error) { + return v.Bypass.UserGroupIDs, nil +} + func (v *Branch) Sanitize() error { if err := v.Bypass.Sanitize(); err != nil { return fmt.Errorf("bypass: %w", err) diff --git a/app/services/protection/service.go b/app/services/protection/service.go index 9f7a46f17..ad2151721 100644 --- a/app/services/protection/service.go +++ b/app/services/protection/service.go @@ -35,8 +35,8 @@ type ( Protection interface { MergeVerifier RefChangeVerifier - UserIDs() ([]int64, error) + UserGroupIDs() ([]int64, error) } Definition interface { diff --git a/app/services/protection/set.go b/app/services/protection/set.go index 1e0045a95..a43249870 100644 --- a/app/services/protection/set.go +++ b/app/services/protection/set.go @@ -156,6 +156,32 @@ func (s ruleSet) UserIDs() ([]int64, error) { return result, nil } +func (s ruleSet) UserGroupIDs() ([]int64, error) { + mapIDs := make(map[int64]struct{}) + err := s.forEachRule(func(_ *types.RuleInfoInternal, p Protection) error { + userGroupIDs, err := p.UserGroupIDs() + if err != nil { + return err + } + + for _, userGroupID := range userGroupIDs { + mapIDs[userGroupID] = struct{}{} + } + + return nil + }) + if err != nil { + return nil, err + } + + result := make([]int64, 0, len(mapIDs)) + for userGroupID := range mapIDs { + result = append(result, userGroupID) + } + + return result, nil +} + func (s ruleSet) forEachRule( fn func(r *types.RuleInfoInternal, p Protection) error, ) error { diff --git a/cmd/gitness/wire_gen.go b/cmd/gitness/wire_gen.go index b0b96f15d..508859590 100644 --- a/cmd/gitness/wire_gen.go +++ b/cmd/gitness/wire_gen.go @@ -243,7 +243,8 @@ func initSystem(ctx context.Context, config *types.Config) (*server.System, erro pullReqLabelAssignmentStore := database.ProvidePullReqLabelStore(db) labelService := label.ProvideLabel(transactor, spaceStore, labelStore, labelValueStore, pullReqLabelAssignmentStore) instrumentService := instrument.ProvideService() - repoController := repo.ProvideController(config, transactor, provider, authorizer, repoStore, spaceStore, pipelineStore, principalStore, ruleStore, settingsService, principalInfoCache, protectionManager, gitInterface, repository, codeownersService, reporter, indexer, resourceLimiter, lockerLocker, auditService, mutexManager, repoIdentifier, repoCheck, publicaccessService, labelService, instrumentService) + userGroupStore := database.ProvideUserGroupStore(db) + repoController := repo.ProvideController(config, transactor, provider, authorizer, repoStore, spaceStore, pipelineStore, principalStore, ruleStore, settingsService, principalInfoCache, protectionManager, gitInterface, repository, codeownersService, reporter, indexer, resourceLimiter, lockerLocker, auditService, mutexManager, repoIdentifier, repoCheck, publicaccessService, labelService, instrumentService, userGroupStore) reposettingsController := reposettings.ProvideController(authorizer, repoStore, settingsService, auditService) executionStore := database.ProvideExecutionStore(db) checkStore := database.ProvideCheckStore(db, principalInfoCache) @@ -310,7 +311,6 @@ func initSystem(ctx context.Context, config *types.Config) (*server.System, erro codeCommentView := database.ProvideCodeCommentView(db) pullReqReviewStore := database.ProvidePullReqReviewStore(db) pullReqReviewerStore := database.ProvidePullReqReviewerStore(db, principalInfoCache) - userGroupStore := database.ProvideUserGroupStore(db) userGroupReviewersStore := database.ProvideUserGroupReviewerStore(db, principalInfoCache, userGroupStore) pullReqFileViewStore := database.ProvidePullReqFileViewStore(db) reporter3, err := events5.ProvideReporter(eventsSystem) diff --git a/types/rule.go b/types/rule.go index b3545516b..9bd56e22d 100644 --- a/types/rule.go +++ b/types/rule.go @@ -45,7 +45,8 @@ type Rule struct { CreatedByInfo PrincipalInfo `json:"created_by"` - Users map[int64]*PrincipalInfo `json:"users"` + Users map[int64]*PrincipalInfo `json:"users"` + UserGroups map[int64]*UserGroupInfo `json:"user_groups"` } // TODO [CODE-1363]: remove after identifier migration.