mirror of https://github.com/jackc/pgx.git
Merge pull request #2136 from ninedraft/optimize-sanitize
Reduce SQL sanitizer allocationspull/2216/head
commit
ca04098fab
|
@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
|
||||||
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
|
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
|
||||||
|
|
||||||
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
|
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
|
||||||
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
||||||
|
if [ "$current_branch" == "HEAD" ]; then
|
||||||
|
current_branch=$(git rev-parse HEAD)
|
||||||
|
fi
|
||||||
|
|
||||||
|
restore_branch() {
|
||||||
|
echo "Restoring original branch/commit: $current_branch"
|
||||||
|
git checkout "$current_branch"
|
||||||
|
}
|
||||||
|
trap restore_branch EXIT
|
||||||
|
|
||||||
|
# Check if there are uncommitted changes
|
||||||
|
if ! git diff --quiet || ! git diff --cached --quiet; then
|
||||||
|
echo "There are uncommitted changes. Please commit or stash them before running this script."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Ensure that at least one commit argument is passed
|
||||||
|
if [ "$#" -lt 1 ]; then
|
||||||
|
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
commits=("$@")
|
||||||
|
benchmarks_dir=benchmarks
|
||||||
|
|
||||||
|
if ! mkdir -p "${benchmarks_dir}"; then
|
||||||
|
echo "Unable to create dir for benchmarks data"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Benchmark results
|
||||||
|
bench_files=()
|
||||||
|
|
||||||
|
# Run benchmark for each listed commit
|
||||||
|
for i in "${!commits[@]}"; do
|
||||||
|
commit="${commits[i]}"
|
||||||
|
git checkout "$commit" || {
|
||||||
|
echo "Failed to checkout $commit"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sanitized commmit message
|
||||||
|
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
|
||||||
|
|
||||||
|
# Benchmark data will go there
|
||||||
|
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
|
||||||
|
|
||||||
|
if ! go test -bench=. -count=10 >"$bench_file"; then
|
||||||
|
echo "Benchmarking failed for commit $commit"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
bench_files+=("$bench_file")
|
||||||
|
done
|
||||||
|
|
||||||
|
# go install golang.org/x/perf/cmd/benchstat[@latest]
|
||||||
|
benchstat "${bench_files[@]}"
|
|
@ -4,8 +4,10 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
@ -24,18 +26,33 @@ type Query struct {
|
||||||
// https://github.com/jackc/pgx/issues/1380
|
// https://github.com/jackc/pgx/issues/1380
|
||||||
const replacementcharacterwidth = 3
|
const replacementcharacterwidth = 3
|
||||||
|
|
||||||
|
const maxBufSize = 16384 // 16 Ki
|
||||||
|
|
||||||
|
var bufPool = &pool[*bytes.Buffer]{
|
||||||
|
new: func() *bytes.Buffer {
|
||||||
|
return &bytes.Buffer{}
|
||||||
|
},
|
||||||
|
reset: func(b *bytes.Buffer) bool {
|
||||||
|
n := b.Len()
|
||||||
|
b.Reset()
|
||||||
|
return n < maxBufSize
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var null = []byte("null")
|
||||||
|
|
||||||
func (q *Query) Sanitize(args ...any) (string, error) {
|
func (q *Query) Sanitize(args ...any) (string, error) {
|
||||||
argUse := make([]bool, len(args))
|
argUse := make([]bool, len(args))
|
||||||
buf := &bytes.Buffer{}
|
buf := bufPool.get()
|
||||||
|
defer bufPool.put(buf)
|
||||||
|
|
||||||
for _, part := range q.Parts {
|
for _, part := range q.Parts {
|
||||||
var str string
|
|
||||||
switch part := part.(type) {
|
switch part := part.(type) {
|
||||||
case string:
|
case string:
|
||||||
str = part
|
buf.WriteString(part)
|
||||||
case int:
|
case int:
|
||||||
argIdx := part - 1
|
argIdx := part - 1
|
||||||
|
var p []byte
|
||||||
if argIdx < 0 {
|
if argIdx < 0 {
|
||||||
return "", fmt.Errorf("first sql argument must be > 0")
|
return "", fmt.Errorf("first sql argument must be > 0")
|
||||||
}
|
}
|
||||||
|
@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
||||||
if argIdx >= len(args) {
|
if argIdx >= len(args) {
|
||||||
return "", fmt.Errorf("insufficient arguments")
|
return "", fmt.Errorf("insufficient arguments")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prevent SQL injection via Line Comment Creation
|
||||||
|
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||||
|
buf.WriteByte(' ')
|
||||||
|
|
||||||
arg := args[argIdx]
|
arg := args[argIdx]
|
||||||
switch arg := arg.(type) {
|
switch arg := arg.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
str = "null"
|
p = null
|
||||||
case int64:
|
case int64:
|
||||||
str = strconv.FormatInt(arg, 10)
|
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
|
||||||
case float64:
|
case float64:
|
||||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
|
||||||
case bool:
|
case bool:
|
||||||
str = strconv.FormatBool(arg)
|
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
|
||||||
case []byte:
|
case []byte:
|
||||||
str = QuoteBytes(arg)
|
p = QuoteBytes(buf.AvailableBuffer(), arg)
|
||||||
case string:
|
case string:
|
||||||
str = QuoteString(arg)
|
p = QuoteString(buf.AvailableBuffer(), arg)
|
||||||
case time.Time:
|
case time.Time:
|
||||||
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
p = arg.Truncate(time.Microsecond).
|
||||||
|
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||||
}
|
}
|
||||||
argUse[argIdx] = true
|
argUse[argIdx] = true
|
||||||
|
|
||||||
|
buf.Write(p)
|
||||||
|
|
||||||
// Prevent SQL injection via Line Comment Creation
|
// Prevent SQL injection via Line Comment Creation
|
||||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||||
str = " " + str + " "
|
buf.WriteByte(' ')
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||||
}
|
}
|
||||||
buf.WriteString(str)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, used := range argUse {
|
for i, used := range argUse {
|
||||||
|
@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewQuery(sql string) (*Query, error) {
|
func NewQuery(sql string) (*Query, error) {
|
||||||
l := &sqlLexer{
|
query := &Query{}
|
||||||
src: sql,
|
query.init(sql)
|
||||||
stateFn: rawState,
|
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var sqlLexerPool = &pool[*sqlLexer]{
|
||||||
|
new: func() *sqlLexer {
|
||||||
|
return &sqlLexer{}
|
||||||
|
},
|
||||||
|
reset: func(sl *sqlLexer) bool {
|
||||||
|
*sl = sqlLexer{}
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Query) init(sql string) {
|
||||||
|
parts := q.Parts[:0]
|
||||||
|
if parts == nil {
|
||||||
|
// dirty, but fast heuristic to preallocate for ~90% usecases
|
||||||
|
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
|
||||||
|
parts = make([]Part, 0, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
l := sqlLexerPool.get()
|
||||||
|
defer sqlLexerPool.put(l)
|
||||||
|
|
||||||
|
l.src = sql
|
||||||
|
l.stateFn = rawState
|
||||||
|
l.parts = parts
|
||||||
|
|
||||||
for l.stateFn != nil {
|
for l.stateFn != nil {
|
||||||
l.stateFn = l.stateFn(l)
|
l.stateFn = l.stateFn(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := &Query{Parts: l.parts}
|
q.Parts = l.parts
|
||||||
|
|
||||||
return query, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func QuoteString(str string) string {
|
func QuoteString(dst []byte, str string) []byte {
|
||||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
const quote = '\''
|
||||||
|
|
||||||
|
// Preallocate space for the worst case scenario
|
||||||
|
dst = slices.Grow(dst, len(str)*2+2)
|
||||||
|
|
||||||
|
// Add opening quote
|
||||||
|
dst = append(dst, quote)
|
||||||
|
|
||||||
|
// Iterate through the string without allocating
|
||||||
|
for i := 0; i < len(str); i++ {
|
||||||
|
if str[i] == quote {
|
||||||
|
dst = append(dst, quote, quote)
|
||||||
|
} else {
|
||||||
|
dst = append(dst, str[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add closing quote
|
||||||
|
dst = append(dst, quote)
|
||||||
|
|
||||||
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
func QuoteBytes(buf []byte) string {
|
func QuoteBytes(dst, buf []byte) []byte {
|
||||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
if len(buf) == 0 {
|
||||||
|
return append(dst, `'\x'`...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate required length
|
||||||
|
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
|
||||||
|
|
||||||
|
// Ensure dst has enough capacity
|
||||||
|
if cap(dst)-len(dst) < requiredLen {
|
||||||
|
newDst := make([]byte, len(dst), len(dst)+requiredLen)
|
||||||
|
copy(newDst, dst)
|
||||||
|
dst = newDst
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record original length and extend slice
|
||||||
|
origLen := len(dst)
|
||||||
|
dst = dst[:origLen+requiredLen]
|
||||||
|
|
||||||
|
// Add prefix
|
||||||
|
dst[origLen] = '\''
|
||||||
|
dst[origLen+1] = '\\'
|
||||||
|
dst[origLen+2] = 'x'
|
||||||
|
|
||||||
|
// Encode bytes directly into dst
|
||||||
|
hex.Encode(dst[origLen+3:len(dst)-1], buf)
|
||||||
|
|
||||||
|
// Add suffix
|
||||||
|
dst[len(dst)-1] = '\''
|
||||||
|
|
||||||
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlLexer struct {
|
type sqlLexer struct {
|
||||||
|
@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var queryPool = &pool[*Query]{
|
||||||
|
new: func() *Query {
|
||||||
|
return &Query{}
|
||||||
|
},
|
||||||
|
reset: func(q *Query) bool {
|
||||||
|
n := len(q.Parts)
|
||||||
|
q.Parts = q.Parts[:0]
|
||||||
|
return n < 64 // drop too large queries
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||||
// as necessary. This function is only safe when standard_conforming_strings is
|
// as necessary. This function is only safe when standard_conforming_strings is
|
||||||
// on.
|
// on.
|
||||||
func SanitizeSQL(sql string, args ...any) (string, error) {
|
func SanitizeSQL(sql string, args ...any) (string, error) {
|
||||||
query, err := NewQuery(sql)
|
query := queryPool.get()
|
||||||
if err != nil {
|
query.init(sql)
|
||||||
return "", err
|
defer queryPool.put(query)
|
||||||
}
|
|
||||||
return query.Sanitize(args...)
|
return query.Sanitize(args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type pool[E any] struct {
|
||||||
|
p sync.Pool
|
||||||
|
new func() E
|
||||||
|
reset func(E) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *pool[E]) get() E {
|
||||||
|
v, ok := pool.p.Get().(E)
|
||||||
|
if !ok {
|
||||||
|
v = pool.new()
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pool[E]) put(v E) {
|
||||||
|
if p.reset(v) {
|
||||||
|
p.p.Put(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
// sanitize_benchmark_test.go
|
||||||
|
package sanitize_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||||
|
)
|
||||||
|
|
||||||
|
var benchmarkSanitizeResult string
|
||||||
|
|
||||||
|
const benchmarkQuery = "" +
|
||||||
|
`SELECT *
|
||||||
|
FROM "water_containers"
|
||||||
|
WHERE NOT "id" = $1 -- int64
|
||||||
|
AND "tags" NOT IN $2 -- nil
|
||||||
|
AND "volume" > $3 -- float64
|
||||||
|
AND "transportable" = $4 -- bool
|
||||||
|
AND position($5 IN "sign") -- bytes
|
||||||
|
AND "label" LIKE $6 -- string
|
||||||
|
AND "created_at" > $7; -- time.Time`
|
||||||
|
|
||||||
|
var benchmarkArgs = []any{
|
||||||
|
int64(12345),
|
||||||
|
nil,
|
||||||
|
float64(500),
|
||||||
|
true,
|
||||||
|
[]byte("8BADF00D"),
|
||||||
|
"kombucha's han'dy awokowa",
|
||||||
|
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSanitize(b *testing.B) {
|
||||||
|
query, err := sanitize.NewQuery(benchmarkQuery)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to create query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to sanitize query: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var benchmarkNewSQLResult string
|
||||||
|
|
||||||
|
func BenchmarkSanitizeSQL(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
var err error
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to sanitize SQL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
package sanitize_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||||
|
)
|
||||||
|
|
||||||
|
func FuzzQuoteString(f *testing.F) {
|
||||||
|
const prefix = "prefix"
|
||||||
|
f.Add("new\nline")
|
||||||
|
f.Add("sample text")
|
||||||
|
f.Add("sample q'u'o't'e's")
|
||||||
|
f.Add("select 'quoted $42', $1")
|
||||||
|
|
||||||
|
f.Fuzz(func(t *testing.T, input string) {
|
||||||
|
got := string(sanitize.QuoteString([]byte(prefix), input))
|
||||||
|
want := oldQuoteString(input)
|
||||||
|
|
||||||
|
quoted, ok := strings.CutPrefix(got, prefix)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("result has no prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
if want != quoted {
|
||||||
|
t.Errorf("got %q", got)
|
||||||
|
t.Fatalf("want %q", want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func FuzzQuoteBytes(f *testing.F) {
|
||||||
|
const prefix = "prefix"
|
||||||
|
f.Add([]byte(nil))
|
||||||
|
f.Add([]byte("\n"))
|
||||||
|
f.Add([]byte("sample text"))
|
||||||
|
f.Add([]byte("sample q'u'o't'e's"))
|
||||||
|
f.Add([]byte("select 'quoted $42', $1"))
|
||||||
|
|
||||||
|
f.Fuzz(func(t *testing.T, input []byte) {
|
||||||
|
got := string(sanitize.QuoteBytes([]byte(prefix), input))
|
||||||
|
want := oldQuoteBytes(input)
|
||||||
|
|
||||||
|
quoted, ok := strings.CutPrefix(got, prefix)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("result has no prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
if want != quoted {
|
||||||
|
t.Errorf("got %q", got)
|
||||||
|
t.Fatalf("want %q", want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,6 +1,8 @@
|
||||||
package sanitize_test
|
package sanitize_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -227,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQuoteString(t *testing.T) {
|
||||||
|
tc := func(name, input string) {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := string(sanitize.QuoteString(nil, input))
|
||||||
|
want := oldQuoteString(input)
|
||||||
|
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("got: %s", got)
|
||||||
|
t.Fatalf("want: %s", want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tc("empty", "")
|
||||||
|
tc("text", "abcd")
|
||||||
|
tc("with quotes", `one's hat is always a cat`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function was used before optimizations.
|
||||||
|
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||||
|
func oldQuoteString(str string) string {
|
||||||
|
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuoteBytes(t *testing.T) {
|
||||||
|
tc := func(name string, input []byte) {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := string(sanitize.QuoteBytes(nil, input))
|
||||||
|
want := oldQuoteBytes(input)
|
||||||
|
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("got: %s", got)
|
||||||
|
t.Fatalf("want: %s", want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tc("nil", nil)
|
||||||
|
tc("empty", []byte{})
|
||||||
|
tc("text", []byte("abcd"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function was used before optimizations.
|
||||||
|
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||||
|
func oldQuoteBytes(buf []byte) string {
|
||||||
|
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue