mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
wip
This commit is contained in:
parent
7f7cb82e4f
commit
a3e773c5c1
@ -1,17 +1,9 @@
|
||||
package sanitize
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
rawState = iota
|
||||
singleQuoteState = iota
|
||||
doubleQuoteState = iota
|
||||
placeholderState = iota
|
||||
)
|
||||
|
||||
// Part is either a string or an int. A string is raw SQL. An int is a
|
||||
// argument placeholder.
|
||||
type Part interface{}
|
||||
@ -25,62 +17,122 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) {
|
||||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
var start, pos int
|
||||
state := rawState
|
||||
|
||||
var query Query
|
||||
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(sql[pos:])
|
||||
pos += width
|
||||
|
||||
switch state {
|
||||
case rawState:
|
||||
switch r {
|
||||
case '\'':
|
||||
state = singleQuoteState
|
||||
case '"':
|
||||
state = doubleQuoteState
|
||||
case '$':
|
||||
if pos-start > 0 {
|
||||
query.Parts = append(query.Parts, sql[start:pos-1])
|
||||
}
|
||||
start = pos
|
||||
state = placeholderState
|
||||
case utf8.RuneError:
|
||||
if pos-start > 0 {
|
||||
query.Parts = append(query.Parts, sql[start:pos])
|
||||
}
|
||||
return &query, nil
|
||||
}
|
||||
case singleQuoteState:
|
||||
if r == '\'' || r == utf8.RuneError {
|
||||
state = rawState
|
||||
}
|
||||
case doubleQuoteState:
|
||||
if r == '"' || r == utf8.RuneError {
|
||||
state = rawState
|
||||
}
|
||||
case placeholderState:
|
||||
if r < '0' || r > '9' {
|
||||
pos -= width
|
||||
if start < pos {
|
||||
num, err := strconv.ParseInt(sql[start:pos], 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query.Parts = append(query.Parts, int(num))
|
||||
} else {
|
||||
query.Parts = append(query.Parts, "$")
|
||||
}
|
||||
|
||||
start = pos
|
||||
state = rawState
|
||||
}
|
||||
}
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
stateFn stateFn
|
||||
parts []Part
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\'':
|
||||
return singleQuoteState
|
||||
case '"':
|
||||
return doubleQuoteState
|
||||
case '$':
|
||||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if '0' <= nextRune && nextRune <= '9' {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||
}
|
||||
l.start = l.pos
|
||||
return placeholderState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func singleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '"' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// placeholderState consumes a placeholder value. The $ must have already has
|
||||
// already been consumed. The first rune must be a digit.
|
||||
func placeholderState(l *sqlLexer) stateFn {
|
||||
num := 0
|
||||
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
if '0' <= r && r <= '9' {
|
||||
num *= 10
|
||||
num += int(r - '0')
|
||||
} else {
|
||||
l.parts = append(l.parts, num)
|
||||
l.pos -= width
|
||||
l.start = l.pos
|
||||
return rawState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
|
Loading…
x
Reference in New Issue
Block a user