Merge pull request #2136 from ninedraft/optimize-sanitize

Reduce SQL sanitizer allocations
pull/2216/head
Jack Christensen 2024-12-30 20:24:13 -06:00 committed by GitHub
commit ca04098fab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 387 additions and 28 deletions

View File

@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error()) require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows") require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
} }

View File

@ -0,0 +1,60 @@
#!/usr/bin/env bash
current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi
restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT
# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi
# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi
commits=("$@")
benchmarks_dir=benchmarks
if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi
# Benchmark results
bench_files=()
# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}
# Sanitized commmit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi
bench_files+=("$bench_file")
done
# go install golang.org/x/perf/cmd/benchstat[@latest]
benchstat "${bench_files[@]}"

View File

@ -4,8 +4,10 @@ import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"unicode/utf8" "unicode/utf8"
) )
@ -24,18 +26,33 @@ type Query struct {
// https://github.com/jackc/pgx/issues/1380 // https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3 const replacementcharacterwidth = 3
const maxBufSize = 16384 // 16 Ki
var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}
var null = []byte("null")
func (q *Query) Sanitize(args ...any) (string, error) { func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args)) argUse := make([]bool, len(args))
buf := &bytes.Buffer{} buf := bufPool.get()
defer bufPool.put(buf)
for _, part := range q.Parts { for _, part := range q.Parts {
var str string
switch part := part.(type) { switch part := part.(type) {
case string: case string:
str = part buf.WriteString(part)
case int: case int:
argIdx := part - 1 argIdx := part - 1
var p []byte
if argIdx < 0 { if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0") return "", fmt.Errorf("first sql argument must be > 0")
} }
@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) {
if argIdx >= len(args) { if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments") return "", fmt.Errorf("insufficient arguments")
} }
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
arg := args[argIdx] arg := args[argIdx]
switch arg := arg.(type) { switch arg := arg.(type) {
case nil: case nil:
str = "null" p = null
case int64: case int64:
str = strconv.FormatInt(arg, 10) p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
case float64: case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64) p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
case bool: case bool:
str = strconv.FormatBool(arg) p = strconv.AppendBool(buf.AvailableBuffer(), arg)
case []byte: case []byte:
str = QuoteBytes(arg) p = QuoteBytes(buf.AvailableBuffer(), arg)
case string: case string:
str = QuoteString(arg) p = QuoteString(buf.AvailableBuffer(), arg)
case time.Time: case time.Time:
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
default: default:
return "", fmt.Errorf("invalid arg type: %T", arg) return "", fmt.Errorf("invalid arg type: %T", arg)
} }
argUse[argIdx] = true argUse[argIdx] = true
buf.Write(p)
// Prevent SQL injection via Line Comment Creation // Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = " " + str + " " buf.WriteByte(' ')
default: default:
return "", fmt.Errorf("invalid Part type: %T", part) return "", fmt.Errorf("invalid Part type: %T", part)
} }
buf.WriteString(str)
} }
for i, used := range argUse { for i, used := range argUse {
@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
} }
func NewQuery(sql string) (*Query, error) { func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{ query := &Query{}
src: sql, query.init(sql)
stateFn: rawState,
return query, nil
}
var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}
func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
} }
l := sqlLexerPool.get()
defer sqlLexerPool.put(l)
l.src = sql
l.stateFn = rawState
l.parts = parts
for l.stateFn != nil { for l.stateFn != nil {
l.stateFn = l.stateFn(l) l.stateFn = l.stateFn(l)
} }
query := &Query{Parts: l.parts} q.Parts = l.parts
return query, nil
} }
func QuoteString(str string) string { func QuoteString(dst []byte, str string) []byte {
return "'" + strings.ReplaceAll(str, "'", "''") + "'" const quote = '\''
// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)
// Add opening quote
dst = append(dst, quote)
// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}
// Add closing quote
dst = append(dst, quote)
return dst
} }
func QuoteBytes(buf []byte) string { func QuoteBytes(dst, buf []byte) []byte {
return `'\x` + hex.EncodeToString(buf) + "'" if len(buf) == 0 {
return append(dst, `'\x'`...)
}
// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}
// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]
// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'
// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)
// Add suffix
dst[len(dst)-1] = '\''
return dst
} }
type sqlLexer struct { type sqlLexer struct {
@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
} }
} }
var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args // SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is // as necessary. This function is only safe when standard_conforming_strings is
// on. // on.
func SanitizeSQL(sql string, args ...any) (string, error) { func SanitizeSQL(sql string, args ...any) (string, error) {
query, err := NewQuery(sql) query := queryPool.get()
if err != nil { query.init(sql)
return "", err defer queryPool.put(query)
}
return query.Sanitize(args...) return query.Sanitize(args...)
} }
type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}
func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}
return v
}
func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}

View File

@ -0,0 +1,62 @@
// sanitize_benchmark_test.go
package sanitize_test
import (
"testing"
"time"
"github.com/jackc/pgx/v5/internal/sanitize"
)
var benchmarkSanitizeResult string
const benchmarkQuery = "" +
`SELECT *
FROM "water_containers"
WHERE NOT "id" = $1 -- int64
AND "tags" NOT IN $2 -- nil
AND "volume" > $3 -- float64
AND "transportable" = $4 -- bool
AND position($5 IN "sign") -- bytes
AND "label" LIKE $6 -- string
AND "created_at" > $7; -- time.Time`
var benchmarkArgs = []any{
int64(12345),
nil,
float64(500),
true,
[]byte("8BADF00D"),
"kombucha's han'dy awokowa",
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
}
func BenchmarkSanitize(b *testing.B) {
query, err := sanitize.NewQuery(benchmarkQuery)
if err != nil {
b.Fatalf("failed to create query: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize query: %v", err)
}
}
}
var benchmarkNewSQLResult string
func BenchmarkSanitizeSQL(b *testing.B) {
b.ReportAllocs()
var err error
for i := 0; i < b.N; i++ {
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize SQL: %v", err)
}
}
}

View File

@ -0,0 +1,55 @@
package sanitize_test
import (
"strings"
"testing"
"github.com/jackc/pgx/v5/internal/sanitize"
)
func FuzzQuoteString(f *testing.F) {
const prefix = "prefix"
f.Add("new\nline")
f.Add("sample text")
f.Add("sample q'u'o't'e's")
f.Add("select 'quoted $42', $1")
f.Fuzz(func(t *testing.T, input string) {
got := string(sanitize.QuoteString([]byte(prefix), input))
want := oldQuoteString(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}
func FuzzQuoteBytes(f *testing.F) {
const prefix = "prefix"
f.Add([]byte(nil))
f.Add([]byte("\n"))
f.Add([]byte("sample text"))
f.Add([]byte("sample q'u'o't'e's"))
f.Add([]byte("select 'quoted $42', $1"))
f.Fuzz(func(t *testing.T, input []byte) {
got := string(sanitize.QuoteBytes([]byte(prefix), input))
want := oldQuoteBytes(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}

View File

@ -1,6 +1,8 @@
package sanitize_test package sanitize_test
import ( import (
"encoding/hex"
"strings"
"testing" "testing"
"time" "time"
@ -227,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
} }
} }
} }
func TestQuoteString(t *testing.T) {
tc := func(name, input string) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteString(nil, input))
want := oldQuoteString(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("empty", "")
tc("text", "abcd")
tc("with quotes", `one's hat is always a cat`)
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}
func TestQuoteBytes(t *testing.T) {
tc := func(name string, input []byte) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteBytes(nil, input))
want := oldQuoteBytes(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("nil", nil)
tc("empty", []byte{})
tc("text", []byte("abcd"))
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}