This commit is contained in:
Jack Christensen 2017-04-07 15:15:36 -05:00
parent 7f7cb82e4f
commit a3e773c5c1

View File

@ -1,17 +1,9 @@
package sanitize
import (
"strconv"
"unicode/utf8"
)
const (
rawState = iota
singleQuoteState = iota
doubleQuoteState = iota
placeholderState = iota
)
// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part interface{}
@ -25,62 +17,122 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) {
}
func NewQuery(sql string) (*Query, error) {
var start, pos int
state := rawState
var query Query
for {
r, width := utf8.DecodeRuneInString(sql[pos:])
pos += width
switch state {
case rawState:
switch r {
case '\'':
state = singleQuoteState
case '"':
state = doubleQuoteState
case '$':
if pos-start > 0 {
query.Parts = append(query.Parts, sql[start:pos-1])
}
start = pos
state = placeholderState
case utf8.RuneError:
if pos-start > 0 {
query.Parts = append(query.Parts, sql[start:pos])
}
return &query, nil
}
case singleQuoteState:
if r == '\'' || r == utf8.RuneError {
state = rawState
}
case doubleQuoteState:
if r == '"' || r == utf8.RuneError {
state = rawState
}
case placeholderState:
if r < '0' || r > '9' {
pos -= width
if start < pos {
num, err := strconv.ParseInt(sql[start:pos], 10, 32)
if err != nil {
return nil, err
}
query.Parts = append(query.Parts, int(num))
} else {
query.Parts = append(query.Parts, "$")
}
start = pos
state = rawState
}
}
l := &sqlLexer{
src: sql,
stateFn: rawState,
}
return &query, nil
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
}
type sqlLexer struct {
src string
start int
pos int
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '$':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if '0' <= nextRune && nextRune <= '9' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return placeholderState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
num := 0
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
} else {
l.parts = append(l.parts, num)
l.pos -= width
l.start = l.pos
return rawState
}
}
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args