diff --git a/path.go b/path.go index b188a41c..67b4457f 100644 --- a/path.go +++ b/path.go @@ -672,12 +672,18 @@ func getParamConstraintType(constraintPart string) TypeConstraint { } } -//nolint:errcheck // TODO: Properly check _all_ errors in here, log them & immediately return +// CheckConstraint validates if a param matches the given constraint +// Returns true if the param passes the constraint check, false otherwise func (c *Constraint) CheckConstraint(param string) bool { - var err error - var num int + // First check if there's a custom constraint with the same name + // This allows custom constraints to override built-in constraints + for _, cc := range c.customConstraints { + if cc.Name() == c.Name { + return cc.Execute(param, c.Data...) + } + } - // check data exists + // Validate constraint has required data needOneData := []TypeConstraint{minLenConstraint, maxLenConstraint, lenConstraint, minConstraint, maxConstraint, datetimeConstraint, regexConstraint} needTwoData := []TypeConstraint{betweenLenConstraint, rangeConstraint} @@ -693,20 +699,23 @@ func (c *Constraint) CheckConstraint(param string) bool { } } - // check constraints + // Check constraints switch c.ID { case noConstraint: - for _, cc := range c.customConstraints { - if cc.Name() == c.Name { - return cc.Execute(param, c.Data...) - } - } + // If we reach here with noConstraint, it means we didn't find a matching custom constraint above + return false case intConstraint: - _, err = strconv.Atoi(param) + if _, err := strconv.Atoi(param); err != nil { + return false + } case boolConstraint: - _, err = strconv.ParseBool(param) + if _, err := strconv.ParseBool(param); err != nil { + return false + } case floatConstraint: - _, err = strconv.ParseFloat(param, 32) + if _, err := strconv.ParseFloat(param, 32); err != nil { + return false + } case alphaConstraint: for _, r := range param { if !unicode.IsLetter(r) { @@ -714,61 +723,98 @@ func (c *Constraint) CheckConstraint(param string) bool { } } case guidConstraint: - _, err = uuid.Parse(param) + if _, err := uuid.Parse(param); err != nil { + return false + } case minLenConstraint: - data, _ := strconv.Atoi(c.Data[0]) - + data, err := strconv.Atoi(c.Data[0]) + if err != nil { + return false + } if len(param) < data { return false } case maxLenConstraint: - data, _ := strconv.Atoi(c.Data[0]) - + data, err := strconv.Atoi(c.Data[0]) + if err != nil { + return false + } if len(param) > data { return false } case lenConstraint: - data, _ := strconv.Atoi(c.Data[0]) - + data, err := strconv.Atoi(c.Data[0]) + if err != nil { + return false + } if len(param) != data { return false } case betweenLenConstraint: - data, _ := strconv.Atoi(c.Data[0]) - data2, _ := strconv.Atoi(c.Data[1]) + data, err := strconv.Atoi(c.Data[0]) + if err != nil { + return false + } + data2, err := strconv.Atoi(c.Data[1]) + if err != nil { + return false + } length := len(param) if length < data || length > data2 { return false } case minConstraint: - data, _ := strconv.Atoi(c.Data[0]) - num, err = strconv.Atoi(param) - + data, err := strconv.Atoi(c.Data[0]) + if err != nil { + return false + } + num, err := strconv.Atoi(param) + if err != nil { + return false + } if num < data { return false } case maxConstraint: - data, _ := strconv.Atoi(c.Data[0]) - num, err = strconv.Atoi(param) - + data, err := strconv.Atoi(c.Data[0]) + if err != nil { + return false + } + num, err := strconv.Atoi(param) + if err != nil { + return false + } if num > data { return false } case rangeConstraint: - data, _ := strconv.Atoi(c.Data[0]) - data2, _ := strconv.Atoi(c.Data[1]) - num, err = strconv.Atoi(param) - + data, err := strconv.Atoi(c.Data[0]) + if err != nil { + return false + } + data2, err := strconv.Atoi(c.Data[1]) + if err != nil { + return false + } + num, err := strconv.Atoi(param) + if err != nil { + return false + } if num < data || num > data2 { return false } case datetimeConstraint: - _, err = time.Parse(c.Data[0], param) + if _, err := time.Parse(c.Data[0], param); err != nil { + return false + } case regexConstraint: + if c.RegexCompiler == nil { + return false + } if match := c.RegexCompiler.MatchString(param); !match { return false } } - return err == nil + return true }