mirror of https://github.com/jackc/pgx.git
443 lines
8.8 KiB
Go
443 lines
8.8 KiB
Go
package sanitize
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
// Part is either a string or an int. A string is raw SQL. An int is a
|
|
// argument placeholder.
|
|
type Part any
|
|
|
|
type Query struct {
|
|
Parts []Part
|
|
}
|
|
|
|
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
|
|
// character. utf8.RuneError is not an error if it is also width 3.
|
|
//
|
|
// https://github.com/jackc/pgx/issues/1380
|
|
const replacementcharacterwidth = 3
|
|
|
|
const maxBufSize = 16384 // 16 Ki
|
|
|
|
var bufPool = &pool[*bytes.Buffer]{
|
|
new: func() *bytes.Buffer {
|
|
return &bytes.Buffer{}
|
|
},
|
|
reset: func(b *bytes.Buffer) bool {
|
|
n := b.Len()
|
|
b.Reset()
|
|
return n < maxBufSize
|
|
},
|
|
}
|
|
|
|
var null = []byte("null")
|
|
|
|
func (q *Query) Sanitize(args ...any) (string, error) {
|
|
argUse := make([]bool, len(args))
|
|
buf := bufPool.get()
|
|
defer bufPool.put(buf)
|
|
|
|
for _, part := range q.Parts {
|
|
switch part := part.(type) {
|
|
case string:
|
|
buf.WriteString(part)
|
|
case int:
|
|
argIdx := part - 1
|
|
var p []byte
|
|
if argIdx < 0 {
|
|
return "", fmt.Errorf("first sql argument must be > 0")
|
|
}
|
|
|
|
if argIdx >= len(args) {
|
|
return "", fmt.Errorf("insufficient arguments")
|
|
}
|
|
|
|
// Prevent SQL injection via Line Comment Creation
|
|
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
|
buf.WriteByte(' ')
|
|
|
|
arg := args[argIdx]
|
|
switch arg := arg.(type) {
|
|
case nil:
|
|
p = null
|
|
case int64:
|
|
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
|
|
case float64:
|
|
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
|
|
case bool:
|
|
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
|
|
case []byte:
|
|
p = QuoteBytes(buf.AvailableBuffer(), arg)
|
|
case string:
|
|
p = QuoteString(buf.AvailableBuffer(), arg)
|
|
case time.Time:
|
|
p = arg.Truncate(time.Microsecond).
|
|
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
|
|
default:
|
|
return "", fmt.Errorf("invalid arg type: %T", arg)
|
|
}
|
|
argUse[argIdx] = true
|
|
|
|
buf.Write(p)
|
|
|
|
// Prevent SQL injection via Line Comment Creation
|
|
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
|
buf.WriteByte(' ')
|
|
default:
|
|
return "", fmt.Errorf("invalid Part type: %T", part)
|
|
}
|
|
}
|
|
|
|
for i, used := range argUse {
|
|
if !used {
|
|
return "", fmt.Errorf("unused argument: %d", i)
|
|
}
|
|
}
|
|
return buf.String(), nil
|
|
}
|
|
|
|
func NewQuery(sql string) (*Query, error) {
|
|
query := &Query{}
|
|
query.init(sql)
|
|
|
|
return query, nil
|
|
}
|
|
|
|
var sqlLexerPool = &pool[*sqlLexer]{
|
|
new: func() *sqlLexer {
|
|
return &sqlLexer{}
|
|
},
|
|
reset: func(sl *sqlLexer) bool {
|
|
*sl = sqlLexer{}
|
|
return true
|
|
},
|
|
}
|
|
|
|
func (q *Query) init(sql string) {
|
|
parts := q.Parts[:0]
|
|
if parts == nil {
|
|
// dirty, but fast heuristic to preallocate for ~90% usecases
|
|
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
|
|
parts = make([]Part, 0, n)
|
|
}
|
|
|
|
l := sqlLexerPool.get()
|
|
defer sqlLexerPool.put(l)
|
|
|
|
l.src = sql
|
|
l.stateFn = rawState
|
|
l.parts = parts
|
|
|
|
for l.stateFn != nil {
|
|
l.stateFn = l.stateFn(l)
|
|
}
|
|
|
|
q.Parts = l.parts
|
|
}
|
|
|
|
func QuoteString(dst []byte, str string) []byte {
|
|
const quote = "'"
|
|
|
|
n := strings.Count(str, quote)
|
|
|
|
dst = append(dst, quote...)
|
|
|
|
p := slices.Grow(dst[len(dst):], 2*len(quote)+len(str)+2*n)
|
|
|
|
for len(str) > 0 {
|
|
i := strings.Index(str, quote)
|
|
if i < 0 {
|
|
p = append(p, str...)
|
|
break
|
|
}
|
|
p = append(p, str[:i]...)
|
|
p = append(p, "''"...)
|
|
str = str[i+1:]
|
|
}
|
|
|
|
dst = append(dst, p...)
|
|
|
|
dst = append(dst, quote...)
|
|
|
|
return dst
|
|
}
|
|
|
|
func QuoteBytes(dst, buf []byte) []byte {
|
|
dst = append(dst, `'\x`...)
|
|
|
|
n := hex.EncodedLen(len(buf))
|
|
p := slices.Grow(dst[len(dst):], n)[:n]
|
|
hex.Encode(p, buf)
|
|
dst = append(dst, p...)
|
|
|
|
dst = append(dst, `'`...)
|
|
return dst
|
|
}
|
|
|
|
type sqlLexer struct {
|
|
src string
|
|
start int
|
|
pos int
|
|
nested int // multiline comment nesting level.
|
|
stateFn stateFn
|
|
parts []Part
|
|
}
|
|
|
|
type stateFn func(*sqlLexer) stateFn
|
|
|
|
func rawState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case 'e', 'E':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune == '\'' {
|
|
l.pos += width
|
|
return escapeStringState
|
|
}
|
|
case '\'':
|
|
return singleQuoteState
|
|
case '"':
|
|
return doubleQuoteState
|
|
case '$':
|
|
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if '0' <= nextRune && nextRune <= '9' {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
|
}
|
|
l.start = l.pos
|
|
return placeholderState
|
|
}
|
|
case '-':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune == '-' {
|
|
l.pos += width
|
|
return oneLineCommentState
|
|
}
|
|
case '/':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune == '*' {
|
|
l.pos += width
|
|
return multilineCommentState
|
|
}
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func singleQuoteState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '\'':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune != '\'' {
|
|
return rawState
|
|
}
|
|
l.pos += width
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func doubleQuoteState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '"':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune != '"' {
|
|
return rawState
|
|
}
|
|
l.pos += width
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// placeholderState consumes a placeholder value. The $ must have already has
|
|
// already been consumed. The first rune must be a digit.
|
|
func placeholderState(l *sqlLexer) stateFn {
|
|
num := 0
|
|
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
if '0' <= r && r <= '9' {
|
|
num *= 10
|
|
num += int(r - '0')
|
|
} else {
|
|
l.parts = append(l.parts, num)
|
|
l.pos -= width
|
|
l.start = l.pos
|
|
return rawState
|
|
}
|
|
}
|
|
}
|
|
|
|
func escapeStringState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '\\':
|
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
case '\'':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune != '\'' {
|
|
return rawState
|
|
}
|
|
l.pos += width
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func oneLineCommentState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '\\':
|
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
case '\n', '\r':
|
|
return rawState
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func multilineCommentState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '/':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune == '*' {
|
|
l.pos += width
|
|
l.nested++
|
|
}
|
|
case '*':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune != '/' {
|
|
continue
|
|
}
|
|
|
|
l.pos += width
|
|
if l.nested == 0 {
|
|
return rawState
|
|
}
|
|
l.nested--
|
|
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
var queryPool = &pool[*Query]{
|
|
new: func() *Query {
|
|
return &Query{}
|
|
},
|
|
reset: func(q *Query) bool {
|
|
n := len(q.Parts)
|
|
q.Parts = q.Parts[:0]
|
|
return n < 64 // drop too large queries
|
|
},
|
|
}
|
|
|
|
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
|
// as necessary. This function is only safe when standard_conforming_strings is
|
|
// on.
|
|
func SanitizeSQL(sql string, args ...any) (string, error) {
|
|
query := queryPool.get()
|
|
query.init(sql)
|
|
defer queryPool.put(query)
|
|
|
|
return query.Sanitize(args...)
|
|
}
|
|
|
|
type pool[E any] struct {
|
|
p sync.Pool
|
|
new func() E
|
|
reset func(E) bool
|
|
}
|
|
|
|
func (pool *pool[E]) get() E {
|
|
v, ok := pool.p.Get().(E)
|
|
if !ok {
|
|
v = pool.new()
|
|
}
|
|
|
|
return v
|
|
}
|
|
|
|
func (p *pool[E]) put(v E) {
|
|
if p.reset(v) {
|
|
p.p.Put(v)
|
|
}
|
|
}
|