diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 3414d6d1..91d6db58 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/hex" "fmt" + "slices" "strconv" "strings" "sync" @@ -78,7 +79,7 @@ func (q *Query) Sanitize(args ...any) (string, error) { case bool: p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: - p = []byte(QuoteBytes(arg)) + p = quoteBytes(buf.AvailableBuffer(), arg) case string: p = []byte(QuoteString(arg)) case time.Time: @@ -127,7 +128,19 @@ func QuoteString(str string) string { } func QuoteBytes(buf []byte) string { - return `'\x` + hex.EncodeToString(buf) + "'" + return string(quoteBytes(nil, buf)) +} + +func quoteBytes(dst, buf []byte) []byte { + dst = append(dst, `'\x`...) + + n := hex.EncodedLen(len(buf)) + p := slices.Grow(dst[len(dst):], n)[:n] + hex.Encode(p, buf) + dst = append(dst, p...) + + dst = append(dst, `'`...) + return dst } type sqlLexer struct { diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index 1deff3fb..76ae7a47 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -1,6 +1,7 @@ package sanitize_test import ( + "encoding/hex" "testing" "time" @@ -227,3 +228,27 @@ func TestQuerySanitize(t *testing.T) { } } } + +func TestQuoteBytes(t *testing.T) { + tc := func(name string, input []byte) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := sanitize.QuoteBytes(input) + want := oldQuoteBytes(input) + + if got != want { + t.Errorf("got: %s", got) + t.Fatalf("want: %s", want) + } + }) + } + + tc("nil", nil) + tc("empty", []byte{}) + tc("text", []byte("abcd")) +} + +func oldQuoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +}