diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index d83633a7..4aca2fb9 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -26,28 +26,19 @@ type Query struct { // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 -var bufPool = &sync.Pool{} - -func getBuf() *bytes.Buffer { - buf, _ := bufPool.Get().(*bytes.Buffer) - if buf == nil { - buf = &bytes.Buffer{} - } - - return buf -} - -func putBuf(buf *bytes.Buffer) { - buf.Reset() - bufPool.Put(buf) +var bufPool = &pool[*bytes.Buffer]{ + new: func() *bytes.Buffer { + return &bytes.Buffer{} + }, + reset: (*bytes.Buffer).Reset, } var null = []byte("null") func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) - buf := getBuf() - defer putBuf(buf) + buf := bufPool.get() + defer bufPool.put(buf) for _, part := range q.Parts { switch part := part.(type) { @@ -109,18 +100,39 @@ func (q *Query) Sanitize(args ...any) (string, error) { } func NewQuery(sql string) (*Query, error) { - l := &sqlLexer{ - src: sql, - stateFn: rawState, + query := &Query{} + query.init(sql) + + return query, nil +} + +var sqlLexerPool = &pool[*sqlLexer]{ + new: func() *sqlLexer { + return &sqlLexer{} + }, + reset: func(sl *sqlLexer) { + *sl = sqlLexer{} + }, +} + +func (q *Query) init(sql string) { + parts := q.Parts[:0] + if parts == nil { + n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 + parts = make([]Part, 0, n) } + l := sqlLexerPool.get() + defer sqlLexerPool.put(l) + l.src = sql + l.stateFn = rawState + l.parts = parts + for l.stateFn != nil { l.stateFn = l.stateFn(l) } - query := &Query{Parts: l.parts} - - return query, nil + q.Parts = l.parts } func QuoteString(str string) string { @@ -385,13 +397,42 @@ func multilineCommentState(l *sqlLexer) stateFn { } } +var queryPool = &pool[*Query]{ + new: func() *Query { + return &Query{} + }, + reset: func(q *Query) { + q.Parts = q.Parts[:0] + }, +} + // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. func SanitizeSQL(sql string, args ...any) (string, error) { - query, err := NewQuery(sql) - if err != nil { - return "", err - } + query := queryPool.get() + query.init(sql) + defer queryPool.put(query) + return query.Sanitize(args...) } + +type pool[E any] struct { + p sync.Pool + new func() E + reset func(E) +} + +func (pool *pool[E]) get() E { + v, ok := pool.p.Get().(E) + if !ok { + v = pool.new() + } + + return v +} + +func (p *pool[E]) put(v E) { + p.reset(v) + p.p.Put(v) +}