From ea1e13a660aa60e9e34f67277ccdd77c9a74e535 Mon Sep 17 00:00:00 2001 From: merlin Date: Tue, 1 Oct 2024 14:50:59 +0300 Subject: [PATCH] quoteString --- internal/sanitize/sanitize.go | 31 ++++++++++++++++++++++++++++-- internal/sanitize/sanitize_test.go | 25 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 91d6db58..d83633a7 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -81,7 +81,7 @@ func (q *Query) Sanitize(args ...any) (string, error) { case []byte: p = quoteBytes(buf.AvailableBuffer(), arg) case string: - p = []byte(QuoteString(arg)) + p = quoteString(buf.AvailableBuffer(), arg) case time.Time: p = arg.Truncate(time.Microsecond). AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") @@ -124,7 +124,34 @@ func NewQuery(sql string) (*Query, error) { } func QuoteString(str string) string { - return "'" + strings.ReplaceAll(str, "'", "''") + "'" + return string(quoteString(nil, str)) +} + +func quoteString(dst []byte, str string) []byte { + const quote = "'" + + n := strings.Count(str, quote) + + dst = append(dst, quote...) + + p := slices.Grow(dst[len(dst):], len(str)+2*n) + + for len(str) > 0 { + i := strings.Index(str, quote) + if i < 0 { + p = append(p, str...) + break + } + p = append(p, str[:i]...) + p = append(p, "''"...) + str = str[i+1:] + } + + dst = append(dst, p...) + + dst = append(dst, quote...) + + return dst } func QuoteBytes(buf []byte) string { diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index 76ae7a47..aafcd682 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -2,6 +2,7 @@ package sanitize_test import ( "encoding/hex" + "strings" "testing" "time" @@ -229,6 +230,30 @@ func TestQuerySanitize(t *testing.T) { } } +func TestQuoteString(t *testing.T) { + tc := func(name, input string) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := sanitize.QuoteString(input) + want := oldQuoteString(input) + + if got != want { + t.Errorf("got: %s", got) + t.Fatalf("want: %s", want) + } + }) + } + + tc("empty", "") + tc("text", "abcd") + tc("with quotes", `one's hat is always a cat`) +} + +func oldQuoteString(str string) string { + return "'" + strings.ReplaceAll(str, "'", "''") + "'" +} + func TestQuoteBytes(t *testing.T) { tc := func(name string, input []byte) { t.Run(name, func(t *testing.T) {