Add comment support when sanitizing SQL queries

pull/962/head
Rusakow Andrew 2021-03-13 17:07:37 +07:00 committed by Jack Christensen
parent 00704ce8b7
commit 292539a590
2 changed files with 93 additions and 1 deletions

View File

@ -83,7 +83,7 @@ func NewQuery(sql string) (*Query, error) {
}
func QuoteString(str string) string {
return "'" + strings.Replace(str, "'", "''", -1) + "'"
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}
func QuoteBytes(buf []byte) string {
@ -94,6 +94,7 @@ type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []Part
}
@ -125,6 +126,18 @@ func rawState(l *sqlLexer) stateFn {
l.start = l.pos
return placeholderState
}
case '-':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '-' {
l.pos += width
return oneLineCommentState
}
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
return multilineCommentState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
@ -225,6 +238,61 @@ func escapeStringState(l *sqlLexer) stateFn {
}
}
func oneLineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\n':
return rawState
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func multilineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
l.nested++
}
case '*':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '/' {
continue
}
l.pos += width
if l.nested == 0 {
return rawState
}
l.nested--
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.

View File

@ -60,6 +60,30 @@ func TestNewQuery(t *testing.T) {
sql: `select e'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
},
{
sql: `select /* a baby's toy */ 'barbie', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}},
},
{
sql: `select /* *_* */ $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}},
},
{
sql: `select 42 /* /* /* 42 */ */ */, $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}},
},
{
sql: "select -- a baby's toy\n'barbie', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}},
},
{
sql: "select 42 -- is a Thinker's favorite number",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Thinker's favorite number"}},
},
{
sql: "select 42, -- \\nis a Thinker's favorite number\n$1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Thinker's favorite number\n", 1}},
},
}
for i, tt := range successTests {