Add restricted by query label count when using query and sanitize empty label (value) text (#2335)

* Add restricted by query label count when using query and sanitize empty label (value) text
pull/3545/head
Darko Draskovic 2024-08-01 15:38:21 +00:00 committed by Harness
parent 9d5071b45c
commit f392fb1015
5 changed files with 52 additions and 22 deletions

View File

@ -214,7 +214,7 @@ func (s *Service) list(
filter *types.LabelFilter, filter *types.LabelFilter,
) ([]*types.Label, int64, error) { ) ([]*types.Label, int64, error) {
if repoID != nil { if repoID != nil {
total, err := s.labelStore.CountInRepo(ctx, *repoID) total, err := s.labelStore.CountInRepo(ctx, *repoID, filter)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -226,7 +226,7 @@ func (s *Service) list(
return labels, total, nil return labels, total, nil
} }
count, err := s.labelStore.CountInSpace(ctx, *spaceID) count, err := s.labelStore.CountInSpace(ctx, *spaceID, filter)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -258,7 +258,7 @@ func (s *Service) listInScopes(
} }
} }
total, err := s.labelStore.CountInScopes(ctx, repoIDVal, spaceIDs) total, err := s.labelStore.CountInScopes(ctx, repoIDVal, spaceIDs, filter)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }

View File

@ -181,7 +181,11 @@ func (s *Service) ListPullReqLabels(
return createScopeLabels(sortedAssignments, scopeLabelsMap), 0, nil return createScopeLabels(sortedAssignments, scopeLabelsMap), 0, nil
} }
total, err := s.labelStore.CountInScopes(ctx, repo.ID, spaceIDs) total, err := s.labelStore.CountInScopes(ctx, repo.ID, spaceIDs, &types.LabelFilter{
ListQueryFilter: types.ListQueryFilter{
Query: filter.Query,
},
})
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to count labels in scopes: %w", err) return nil, 0, fmt.Errorf("failed to count labels in scopes: %w", err)
} }

View File

@ -986,13 +986,18 @@ type (
IncrementValueCount(ctx context.Context, labelID int64, increment int) (int64, error) IncrementValueCount(ctx context.Context, labelID int64, increment int) (int64, error)
// CountInSpace counts the number of labels defined in a specified space. // CountInSpace counts the number of labels defined in a specified space.
CountInSpace(ctx context.Context, spaceID int64) (int64, error) CountInSpace(ctx context.Context, spaceID int64, filter *types.LabelFilter) (int64, error)
// CountInRepo counts the number of labels defined in a specified repository. // CountInRepo counts the number of labels defined in a specified repository.
CountInRepo(ctx context.Context, repoID int64) (int64, error) CountInRepo(ctx context.Context, repoID int64, filter *types.LabelFilter) (int64, error)
// CountInScopes counts the number of labels defined in specified repo/spaces. // CountInScopes counts the number of labels defined in specified repo/spaces.
CountInScopes(ctx context.Context, repoID int64, spaceIDs []int64) (int64, error) CountInScopes(
ctx context.Context,
repoID int64,
spaceIDs []int64,
filter *types.LabelFilter,
) (int64, error)
} }
LabelValueStore interface { LabelValueStore interface {

View File

@ -327,22 +327,38 @@ func (s *labelStore) ListInfosInScopes(
return mapLabelInfos(dst), nil return mapLabelInfos(dst), nil
} }
func (s *labelStore) CountInSpace(ctx context.Context, spaceID int64) (int64, error) { func (s *labelStore) CountInSpace(
ctx context.Context,
spaceID int64,
filter *types.LabelFilter,
) (int64, error) {
const sqlQuery = `SELECT COUNT(*) FROM labels WHERE label_space_id = $1` const sqlQuery = `SELECT COUNT(*) FROM labels WHERE label_space_id = $1`
return s.count(ctx, sqlQuery, spaceID) return s.count(ctx, sqlQuery, spaceID, filter)
} }
func (s *labelStore) CountInRepo(ctx context.Context, repoID int64) (int64, error) { func (s *labelStore) CountInRepo(
ctx context.Context,
repoID int64,
filter *types.LabelFilter,
) (int64, error) {
const sqlQuery = `SELECT COUNT(*) FROM labels WHERE label_repo_id = $1` const sqlQuery = `SELECT COUNT(*) FROM labels WHERE label_repo_id = $1`
return s.count(ctx, sqlQuery, repoID) return s.count(ctx, sqlQuery, repoID, filter)
} }
func (s labelStore) count(ctx context.Context, sqlQuery string, scopeID int64) (int64, error) { func (s labelStore) count(
ctx context.Context,
sqlQuery string,
scopeID int64,
filter *types.LabelFilter,
) (int64, error) {
sqlQuery += `
AND LOWER(label_key) LIKE '%' || LOWER($2) || '%'`
db := dbtx.GetAccessor(ctx, s.db) db := dbtx.GetAccessor(ctx, s.db)
var count int64 var count int64
if err := db.QueryRowContext(ctx, sqlQuery, scopeID).Scan(&count); err != nil { if err := db.QueryRowContext(ctx, sqlQuery, scopeID, filter.Query).Scan(&count); err != nil {
return 0, database.ProcessSQLErrorf(ctx, err, "Failed to count labels") return 0, database.ProcessSQLErrorf(ctx, err, "Failed to count labels")
} }
@ -353,13 +369,15 @@ func (s *labelStore) CountInScopes(
ctx context.Context, ctx context.Context,
repoID int64, repoID int64,
spaceIDs []int64, spaceIDs []int64,
filter *types.LabelFilter,
) (int64, error) { ) (int64, error) {
stmt := database.Builder.Select("COUNT(*)"). stmt := database.Builder.Select("COUNT(*)").
From("labels"). From("labels").
Where(squirrel.Or{ Where(squirrel.Or{
squirrel.Eq{"label_space_id": spaceIDs}, squirrel.Eq{"label_space_id": spaceIDs},
squirrel.Eq{"label_repo_id": repoID}, squirrel.Eq{"label_repo_id": repoID},
}) }).
Where("LOWER(label_key) LIKE '%' || LOWER(?) || '%'", filter.Query)
sql, args, err := stmt.ToSql() sql, args, err := stmt.ToSql()
if err != nil { if err != nil {

View File

@ -15,6 +15,7 @@
package types package types
import ( import (
"strings"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
@ -124,7 +125,7 @@ type DefineLabelInput struct {
} }
func (in DefineLabelInput) Validate() error { func (in DefineLabelInput) Validate() error {
if err := validateLabelText(in.Key, "key"); err != nil { if err := validateLabelText(&in.Key, "key"); err != nil {
return err return err
} }
@ -149,7 +150,7 @@ type UpdateLabelInput struct {
func (in UpdateLabelInput) Validate() error { func (in UpdateLabelInput) Validate() error {
if in.Key != nil { if in.Key != nil {
if err := validateLabelText(*in.Key, "key"); err != nil { if err := validateLabelText(in.Key, "key"); err != nil {
return err return err
} }
} }
@ -176,7 +177,7 @@ type DefineValueInput struct {
} }
func (in DefineValueInput) Validate() error { func (in DefineValueInput) Validate() error {
if err := validateLabelText(in.Value, "value"); err != nil { if err := validateLabelText(&in.Value, "value"); err != nil {
return err return err
} }
@ -194,7 +195,7 @@ type UpdateValueInput struct {
func (in UpdateValueInput) Validate() error { func (in UpdateValueInput) Validate() error {
if in.Value != nil { if in.Value != nil {
if err := validateLabelText(*in.Value, "value"); err != nil { if err := validateLabelText(in.Value, "value"); err != nil {
return err return err
} }
} }
@ -255,16 +256,18 @@ func (in *SaveInput) Validate() error {
var labelTypes, _ = enum.GetAllLabelTypes() var labelTypes, _ = enum.GetAllLabelTypes()
func validateLabelText(text string, typ string) error { func validateLabelText(text *string, typ string) error {
if len(text) == 0 { *text = strings.TrimSpace(*text)
if len(*text) == 0 {
return errors.InvalidArgument("%s must be a non-empty string", typ) return errors.InvalidArgument("%s must be a non-empty string", typ)
} }
if utf8.RuneCountInString(text) > maxLabelLength { if utf8.RuneCountInString(*text) > maxLabelLength {
return errors.InvalidArgument("%s can have at most %d characters", typ, maxLabelLength) return errors.InvalidArgument("%s can have at most %d characters", typ, maxLabelLength)
} }
for _, ch := range text { for _, ch := range *text {
if unicode.IsControl(ch) { if unicode.IsControl(ch) {
return errors.InvalidArgument("%s cannot contain control characters", typ) return errors.InvalidArgument("%s cannot contain control characters", typ)
} }