diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 0e7ddf13..30a4a757 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -83,7 +83,7 @@ func NewQuery(sql string) (*Query, error) { } func QuoteString(str string) string { - return "'" + strings.Replace(str, "'", "''", -1) + "'" + return "'" + strings.ReplaceAll(str, "'", "''") + "'" } func QuoteBytes(buf []byte) string { @@ -94,6 +94,7 @@ type sqlLexer struct { src string start int pos int + nested int // multiline comment nesting level. stateFn stateFn parts []Part } @@ -125,6 +126,18 @@ func rawState(l *sqlLexer) stateFn { l.start = l.pos return placeholderState } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } case utf8.RuneError: if l.pos-l.start > 0 { l.parts = append(l.parts, l.src[l.start:l.pos]) @@ -225,6 +238,61 @@ func escapeStringState(l *sqlLexer) stateFn { } } +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n': + return rawState + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index fff14896..344c46b0 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -60,6 +60,30 @@ func TestNewQuery(t *testing.T) { sql: `select e'escape string\' $42', $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}}, }, + { + sql: `select /* a baby's toy */ 'barbie', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}}, + }, + { + sql: `select /* *_* */ $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}}, + }, + { + sql: `select 42 /* /* /* 42 */ */ */, $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}}, + }, + { + sql: "select -- a baby's toy\n'barbie', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}}, + }, + { + sql: "select 42 -- is a Thinker's favorite number", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Thinker's favorite number"}}, + }, + { + sql: "select 42, -- \\nis a Thinker's favorite number\n$1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Thinker's favorite number\n", 1}}, + }, } for i, tt := range successTests {