mirror of https://github.com/jackc/pgx.git
Add comment support when sanitizing SQL queries
parent
00704ce8b7
commit
292539a590
|
@ -83,7 +83,7 @@ func NewQuery(sql string) (*Query, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func QuoteString(str string) string {
|
func QuoteString(str string) string {
|
||||||
return "'" + strings.Replace(str, "'", "''", -1) + "'"
|
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||||
}
|
}
|
||||||
|
|
||||||
func QuoteBytes(buf []byte) string {
|
func QuoteBytes(buf []byte) string {
|
||||||
|
@ -94,6 +94,7 @@ type sqlLexer struct {
|
||||||
src string
|
src string
|
||||||
start int
|
start int
|
||||||
pos int
|
pos int
|
||||||
|
nested int // multiline comment nesting level.
|
||||||
stateFn stateFn
|
stateFn stateFn
|
||||||
parts []Part
|
parts []Part
|
||||||
}
|
}
|
||||||
|
@ -125,6 +126,18 @@ func rawState(l *sqlLexer) stateFn {
|
||||||
l.start = l.pos
|
l.start = l.pos
|
||||||
return placeholderState
|
return placeholderState
|
||||||
}
|
}
|
||||||
|
case '-':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '-' {
|
||||||
|
l.pos += width
|
||||||
|
return oneLineCommentState
|
||||||
|
}
|
||||||
|
case '/':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '*' {
|
||||||
|
l.pos += width
|
||||||
|
return multilineCommentState
|
||||||
|
}
|
||||||
case utf8.RuneError:
|
case utf8.RuneError:
|
||||||
if l.pos-l.start > 0 {
|
if l.pos-l.start > 0 {
|
||||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
@ -225,6 +238,61 @@ func escapeStringState(l *sqlLexer) stateFn {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func oneLineCommentState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\\':
|
||||||
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
case '\n':
|
||||||
|
return rawState
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func multilineCommentState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '/':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '*' {
|
||||||
|
l.pos += width
|
||||||
|
l.nested++
|
||||||
|
}
|
||||||
|
case '*':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '/' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
l.pos += width
|
||||||
|
if l.nested == 0 {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.nested--
|
||||||
|
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||||
// as necessary. This function is only safe when standard_conforming_strings is
|
// as necessary. This function is only safe when standard_conforming_strings is
|
||||||
// on.
|
// on.
|
||||||
|
|
|
@ -60,6 +60,30 @@ func TestNewQuery(t *testing.T) {
|
||||||
sql: `select e'escape string\' $42', $1`,
|
sql: `select e'escape string\' $42', $1`,
|
||||||
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
|
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
sql: `select /* a baby's toy */ 'barbie', $1`,
|
||||||
|
expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: `select /* *_* */ $1`,
|
||||||
|
expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: `select 42 /* /* /* 42 */ */ */, $1`,
|
||||||
|
expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "select -- a baby's toy\n'barbie', $1",
|
||||||
|
expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "select 42 -- is a Thinker's favorite number",
|
||||||
|
expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Thinker's favorite number"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "select 42, -- \\nis a Thinker's favorite number\n$1",
|
||||||
|
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Thinker's favorite number\n", 1}},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range successTests {
|
for i, tt := range successTests {
|
||||||
|
|
Loading…
Reference in New Issue