mirror of https://github.com/jackc/pgx.git
Merge pull request #2136 from ninedraft/optimize-sanitize
Reduce SQL sanitizer allocationspull/2216/head
commit
ca04098fab
|
@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
|
|||
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
|
||||
|
||||
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
|
||||
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
||||
if [ "$current_branch" == "HEAD" ]; then
|
||||
current_branch=$(git rev-parse HEAD)
|
||||
fi
|
||||
|
||||
restore_branch() {
|
||||
echo "Restoring original branch/commit: $current_branch"
|
||||
git checkout "$current_branch"
|
||||
}
|
||||
trap restore_branch EXIT
|
||||
|
||||
# Check if there are uncommitted changes
|
||||
if ! git diff --quiet || ! git diff --cached --quiet; then
|
||||
echo "There are uncommitted changes. Please commit or stash them before running this script."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure that at least one commit argument is passed
|
||||
if [ "$#" -lt 1 ]; then
|
||||
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
commits=("$@")
|
||||
benchmarks_dir=benchmarks
|
||||
|
||||
if ! mkdir -p "${benchmarks_dir}"; then
|
||||
echo "Unable to create dir for benchmarks data"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Benchmark results
|
||||
bench_files=()
|
||||
|
||||
# Run benchmark for each listed commit
|
||||
for i in "${!commits[@]}"; do
|
||||
commit="${commits[i]}"
|
||||
git checkout "$commit" || {
|
||||
echo "Failed to checkout $commit"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Sanitized commmit message
|
||||
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
|
||||
|
||||
# Benchmark data will go there
|
||||
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
|
||||
|
||||
if ! go test -bench=. -count=10 >"$bench_file"; then
|
||||
echo "Benchmarking failed for commit $commit"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
bench_files+=("$bench_file")
|
||||
done
|
||||
|
||||
# go install golang.org/x/perf/cmd/benchstat[@latest]
|
||||
benchstat "${bench_files[@]}"
|
|
@ -4,8 +4,10 @@ import (
|
|||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
@ -24,18 +26,33 @@ type Query struct {
|
|||
// https://github.com/jackc/pgx/issues/1380
|
||||
const replacementcharacterwidth = 3
|
||||
|
||||
const maxBufSize = 16384 // 16 Ki
|
||||
|
||||
var bufPool = &pool[*bytes.Buffer]{
|
||||
new: func() *bytes.Buffer {
|
||||
return &bytes.Buffer{}
|
||||
},
|
||||
reset: func(b *bytes.Buffer) bool {
|
||||
n := b.Len()
|
||||
b.Reset()
|
||||
return n < maxBufSize
|
||||
},
|
||||
}
|
||||
|
||||
var null = []byte("null")
|
||||
|
||||
func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := &bytes.Buffer{}
|
||||
buf := bufPool.get()
|
||||
defer bufPool.put(buf)
|
||||
|
||||
for _, part := range q.Parts {
|
||||
var str string
|
||||
switch part := part.(type) {
|
||||
case string:
|
||||
str = part
|
||||
buf.WriteString(part)
|
||||
case int:
|
||||
argIdx := part - 1
|
||||
|
||||
var p []byte
|
||||
if argIdx < 0 {
|
||||
return "", fmt.Errorf("first sql argument must be > 0")
|
||||
}
|
||||
|
@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
|||
if argIdx >= len(args) {
|
||||
return "", fmt.Errorf("insufficient arguments")
|
||||
}
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
buf.WriteByte(' ')
|
||||
|
||||
arg := args[argIdx]
|
||||
switch arg := arg.(type) {
|
||||
case nil:
|
||||
str = "null"
|
||||
p = null
|
||||
case int64:
|
||||
str = strconv.FormatInt(arg, 10)
|
||||
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
|
||||
case bool:
|
||||
str = strconv.FormatBool(arg)
|
||||
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
|
||||
case []byte:
|
||||
str = QuoteBytes(arg)
|
||||
p = QuoteBytes(buf.AvailableBuffer(), arg)
|
||||
case string:
|
||||
str = QuoteString(arg)
|
||||
p = QuoteString(buf.AvailableBuffer(), arg)
|
||||
case time.Time:
|
||||
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
p = arg.Truncate(time.Microsecond).
|
||||
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
default:
|
||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
|
||||
buf.Write(p)
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
str = " " + str + " "
|
||||
buf.WriteByte(' ')
|
||||
default:
|
||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
for i, used := range argUse {
|
||||
|
@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
|||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
query := &Query{}
|
||||
query.init(sql)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
var sqlLexerPool = &pool[*sqlLexer]{
|
||||
new: func() *sqlLexer {
|
||||
return &sqlLexer{}
|
||||
},
|
||||
reset: func(sl *sqlLexer) bool {
|
||||
*sl = sqlLexer{}
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
func (q *Query) init(sql string) {
|
||||
parts := q.Parts[:0]
|
||||
if parts == nil {
|
||||
// dirty, but fast heuristic to preallocate for ~90% usecases
|
||||
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
|
||||
parts = make([]Part, 0, n)
|
||||
}
|
||||
|
||||
l := sqlLexerPool.get()
|
||||
defer sqlLexerPool.put(l)
|
||||
|
||||
l.src = sql
|
||||
l.stateFn = rawState
|
||||
l.parts = parts
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
q.Parts = l.parts
|
||||
}
|
||||
|
||||
func QuoteString(str string) string {
|
||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||
func QuoteString(dst []byte, str string) []byte {
|
||||
const quote = '\''
|
||||
|
||||
// Preallocate space for the worst case scenario
|
||||
dst = slices.Grow(dst, len(str)*2+2)
|
||||
|
||||
// Add opening quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
// Iterate through the string without allocating
|
||||
for i := 0; i < len(str); i++ {
|
||||
if str[i] == quote {
|
||||
dst = append(dst, quote, quote)
|
||||
} else {
|
||||
dst = append(dst, str[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Add closing quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func QuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
func QuoteBytes(dst, buf []byte) []byte {
|
||||
if len(buf) == 0 {
|
||||
return append(dst, `'\x'`...)
|
||||
}
|
||||
|
||||
// Calculate required length
|
||||
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
|
||||
|
||||
// Ensure dst has enough capacity
|
||||
if cap(dst)-len(dst) < requiredLen {
|
||||
newDst := make([]byte, len(dst), len(dst)+requiredLen)
|
||||
copy(newDst, dst)
|
||||
dst = newDst
|
||||
}
|
||||
|
||||
// Record original length and extend slice
|
||||
origLen := len(dst)
|
||||
dst = dst[:origLen+requiredLen]
|
||||
|
||||
// Add prefix
|
||||
dst[origLen] = '\''
|
||||
dst[origLen+1] = '\\'
|
||||
dst[origLen+2] = 'x'
|
||||
|
||||
// Encode bytes directly into dst
|
||||
hex.Encode(dst[origLen+3:len(dst)-1], buf)
|
||||
|
||||
// Add suffix
|
||||
dst[len(dst)-1] = '\''
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
|
@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
|
|||
}
|
||||
}
|
||||
|
||||
var queryPool = &pool[*Query]{
|
||||
new: func() *Query {
|
||||
return &Query{}
|
||||
},
|
||||
reset: func(q *Query) bool {
|
||||
n := len(q.Parts)
|
||||
q.Parts = q.Parts[:0]
|
||||
return n < 64 // drop too large queries
|
||||
},
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...any) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := queryPool.get()
|
||||
query.init(sql)
|
||||
defer queryPool.put(query)
|
||||
|
||||
return query.Sanitize(args...)
|
||||
}
|
||||
|
||||
type pool[E any] struct {
|
||||
p sync.Pool
|
||||
new func() E
|
||||
reset func(E) bool
|
||||
}
|
||||
|
||||
func (pool *pool[E]) get() E {
|
||||
v, ok := pool.p.Get().(E)
|
||||
if !ok {
|
||||
v = pool.new()
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (p *pool[E]) put(v E) {
|
||||
if p.reset(v) {
|
||||
p.p.Put(v)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
// sanitize_benchmark_test.go
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
var benchmarkSanitizeResult string
|
||||
|
||||
const benchmarkQuery = "" +
|
||||
`SELECT *
|
||||
FROM "water_containers"
|
||||
WHERE NOT "id" = $1 -- int64
|
||||
AND "tags" NOT IN $2 -- nil
|
||||
AND "volume" > $3 -- float64
|
||||
AND "transportable" = $4 -- bool
|
||||
AND position($5 IN "sign") -- bytes
|
||||
AND "label" LIKE $6 -- string
|
||||
AND "created_at" > $7; -- time.Time`
|
||||
|
||||
var benchmarkArgs = []any{
|
||||
int64(12345),
|
||||
nil,
|
||||
float64(500),
|
||||
true,
|
||||
[]byte("8BADF00D"),
|
||||
"kombucha's han'dy awokowa",
|
||||
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
func BenchmarkSanitize(b *testing.B) {
|
||||
query, err := sanitize.NewQuery(benchmarkQuery)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create query: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize query: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var benchmarkNewSQLResult string
|
||||
|
||||
func BenchmarkSanitizeSQL(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
var err error
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize SQL: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package sanitize_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
func FuzzQuoteString(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add("new\nline")
|
||||
f.Add("sample text")
|
||||
f.Add("sample q'u'o't'e's")
|
||||
f.Add("select 'quoted $42', $1")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
got := string(sanitize.QuoteString([]byte(prefix), input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzQuoteBytes(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add([]byte(nil))
|
||||
f.Add([]byte("\n"))
|
||||
f.Add([]byte("sample text"))
|
||||
f.Add([]byte("sample q'u'o't'e's"))
|
||||
f.Add([]byte("select 'quoted $42', $1"))
|
||||
|
||||
f.Fuzz(func(t *testing.T, input []byte) {
|
||||
got := string(sanitize.QuoteBytes([]byte(prefix), input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
package sanitize_test
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -227,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
tc := func(name, input string) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteString(nil, input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("empty", "")
|
||||
tc("text", "abcd")
|
||||
tc("with quotes", `one's hat is always a cat`)
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteString(str string) string {
|
||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||
}
|
||||
|
||||
func TestQuoteBytes(t *testing.T) {
|
||||
tc := func(name string, input []byte) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteBytes(nil, input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("nil", nil)
|
||||
tc("empty", []byte{})
|
||||
tc("text", []byte("abcd"))
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue