This commit is contained in:
Jack Christensen 2017-04-07 15:15:36 -05:00
parent 7f7cb82e4f
commit a3e773c5c1

View File

@ -1,17 +1,9 @@
package sanitize package sanitize
import ( import (
"strconv"
"unicode/utf8" "unicode/utf8"
) )
const (
rawState = iota
singleQuoteState = iota
doubleQuoteState = iota
placeholderState = iota
)
// Part is either a string or an int. A string is raw SQL. An int is a // Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder. // argument placeholder.
type Part interface{} type Part interface{}
@ -25,64 +17,124 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) {
} }
func NewQuery(sql string) (*Query, error) { func NewQuery(sql string) (*Query, error) {
var start, pos int l := &sqlLexer{
state := rawState src: sql,
stateFn: rawState,
}
var query Query for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
}
type sqlLexer struct {
src string
start int
pos int
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for { for {
r, width := utf8.DecodeRuneInString(sql[pos:]) r, width := utf8.DecodeRuneInString(l.src[l.pos:])
pos += width l.pos += width
switch state {
case rawState:
switch r { switch r {
case '\'': case '\'':
state = singleQuoteState return singleQuoteState
case '"': case '"':
state = doubleQuoteState return doubleQuoteState
case '$': case '$':
if pos-start > 0 { nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
query.Parts = append(query.Parts, sql[start:pos-1]) if '0' <= nextRune && nextRune <= '9' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return placeholderState
} }
start = pos
state = placeholderState
case utf8.RuneError: case utf8.RuneError:
if pos-start > 0 { if l.pos-l.start > 0 {
query.Parts = append(query.Parts, sql[start:pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
} }
return &query, nil return nil
} }
case singleQuoteState:
if r == '\'' || r == utf8.RuneError {
state = rawState
} }
case doubleQuoteState:
if r == '"' || r == utf8.RuneError {
state = rawState
} }
case placeholderState:
if r < '0' || r > '9' { func singleQuoteState(l *sqlLexer) stateFn {
pos -= width for {
if start < pos { r, width := utf8.DecodeRuneInString(l.src[l.pos:])
num, err := strconv.ParseInt(sql[start:pos], 10, 32) l.pos += width
if err != nil {
return nil, err switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
} }
query.Parts = append(query.Parts, int(num)) l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
num := 0
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
} else { } else {
query.Parts = append(query.Parts, "$") l.parts = append(l.parts, num)
} l.pos -= width
l.start = l.pos
start = pos return rawState
state = rawState
} }
} }
} }
return &query, nil
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args // SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is // as necessary. This function is only safe when standard_conforming_strings is
// on. // on.