mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
wip
This commit is contained in:
parent
7f7cb82e4f
commit
a3e773c5c1
@ -1,17 +1,9 @@
|
|||||||
package sanitize
|
package sanitize
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
rawState = iota
|
|
||||||
singleQuoteState = iota
|
|
||||||
doubleQuoteState = iota
|
|
||||||
placeholderState = iota
|
|
||||||
)
|
|
||||||
|
|
||||||
// Part is either a string or an int. A string is raw SQL. An int is a
|
// Part is either a string or an int. A string is raw SQL. An int is a
|
||||||
// argument placeholder.
|
// argument placeholder.
|
||||||
type Part interface{}
|
type Part interface{}
|
||||||
@ -25,62 +17,122 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewQuery(sql string) (*Query, error) {
|
func NewQuery(sql string) (*Query, error) {
|
||||||
var start, pos int
|
l := &sqlLexer{
|
||||||
state := rawState
|
src: sql,
|
||||||
|
stateFn: rawState,
|
||||||
var query Query
|
|
||||||
|
|
||||||
for {
|
|
||||||
r, width := utf8.DecodeRuneInString(sql[pos:])
|
|
||||||
pos += width
|
|
||||||
|
|
||||||
switch state {
|
|
||||||
case rawState:
|
|
||||||
switch r {
|
|
||||||
case '\'':
|
|
||||||
state = singleQuoteState
|
|
||||||
case '"':
|
|
||||||
state = doubleQuoteState
|
|
||||||
case '$':
|
|
||||||
if pos-start > 0 {
|
|
||||||
query.Parts = append(query.Parts, sql[start:pos-1])
|
|
||||||
}
|
|
||||||
start = pos
|
|
||||||
state = placeholderState
|
|
||||||
case utf8.RuneError:
|
|
||||||
if pos-start > 0 {
|
|
||||||
query.Parts = append(query.Parts, sql[start:pos])
|
|
||||||
}
|
|
||||||
return &query, nil
|
|
||||||
}
|
|
||||||
case singleQuoteState:
|
|
||||||
if r == '\'' || r == utf8.RuneError {
|
|
||||||
state = rawState
|
|
||||||
}
|
|
||||||
case doubleQuoteState:
|
|
||||||
if r == '"' || r == utf8.RuneError {
|
|
||||||
state = rawState
|
|
||||||
}
|
|
||||||
case placeholderState:
|
|
||||||
if r < '0' || r > '9' {
|
|
||||||
pos -= width
|
|
||||||
if start < pos {
|
|
||||||
num, err := strconv.ParseInt(sql[start:pos], 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
query.Parts = append(query.Parts, int(num))
|
|
||||||
} else {
|
|
||||||
query.Parts = append(query.Parts, "$")
|
|
||||||
}
|
|
||||||
|
|
||||||
start = pos
|
|
||||||
state = rawState
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &query, nil
|
for l.stateFn != nil {
|
||||||
|
l.stateFn = l.stateFn(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := &Query{Parts: l.parts}
|
||||||
|
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlLexer struct {
|
||||||
|
src string
|
||||||
|
start int
|
||||||
|
pos int
|
||||||
|
stateFn stateFn
|
||||||
|
parts []Part
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateFn func(*sqlLexer) stateFn
|
||||||
|
|
||||||
|
func rawState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\'':
|
||||||
|
return singleQuoteState
|
||||||
|
case '"':
|
||||||
|
return doubleQuoteState
|
||||||
|
case '$':
|
||||||
|
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if '0' <= nextRune && nextRune <= '9' {
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||||
|
}
|
||||||
|
l.start = l.pos
|
||||||
|
return placeholderState
|
||||||
|
}
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleQuoteState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\'':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '\'' {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.pos += width
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '"':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '"' {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.pos += width
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// placeholderState consumes a placeholder value. The $ must have already has
|
||||||
|
// already been consumed. The first rune must be a digit.
|
||||||
|
func placeholderState(l *sqlLexer) stateFn {
|
||||||
|
num := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
if '0' <= r && r <= '9' {
|
||||||
|
num *= 10
|
||||||
|
num += int(r - '0')
|
||||||
|
} else {
|
||||||
|
l.parts = append(l.parts, num)
|
||||||
|
l.pos -= width
|
||||||
|
l.start = l.pos
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||||
|
Loading…
x
Reference in New Issue
Block a user