diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index c7c8acd5..1e0b20ac 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -47,16 +47,14 @@ func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) buf := getBuf() defer putBuf(buf) - var p []byte for _, part := range q.Parts { - p = p[:0] switch part := part.(type) { case string: buf.WriteString(part) case int: argIdx := part - 1 - + var p []byte if argIdx < 0 { return "", fmt.Errorf("first sql argument must be > 0") } @@ -64,22 +62,23 @@ func (q *Query) Sanitize(args ...any) (string, error) { if argIdx >= len(args) { return "", fmt.Errorf("insufficient arguments") } + buf.WriteByte(' ') arg := args[argIdx] switch arg := arg.(type) { case nil: p = null case int64: - p = strconv.AppendInt(p, arg, 10) + p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10) case float64: - p = strconv.AppendFloat(p, arg, 'f', -1, 64) + p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64) case bool: - p = strconv.AppendBool(p, arg) + p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: p = []byte(QuoteBytes(arg)) case string: p = []byte(QuoteString(arg)) case time.Time: - p = arg.Truncate(time.Microsecond).AppendFormat(p, "'2006-01-02 15:04:05.999999999Z07:00:00'") + p = arg.Truncate(time.Microsecond).AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") default: return "", fmt.Errorf("invalid arg type: %T", arg) } @@ -87,7 +86,6 @@ func (q *Query) Sanitize(args ...any) (string, error) { // Prevent SQL injection via Line Comment Creation // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p - buf.WriteByte(' ') buf.Write(p) buf.WriteByte(' ') default: