This commit is contained in:
Jack Christensen 2017-04-07 16:50:08 -05:00
parent a3e773c5c1
commit c39c29d895
4 changed files with 264 additions and 69 deletions

View File

@ -0,0 +1,48 @@
mode: set
github.com/jackc/pgx/internal/sanitize/sanitize.go:19.63,22.31 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:40.2,40.26 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:22.31,24.30 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:38.3,38.23 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:25.15,26.14 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:27.12,29.29 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:35.11,36.56 1 0
github.com/jackc/pgx/internal/sanitize/sanitize.go:30.13,31.44 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:32.16,33.27 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:43.43,49.23 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:53.2,55.19 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:49.23,51.3 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:58.37,60.2 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:72.36,73.6 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:73.6,77.12 3 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:78.13,79.27 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:80.12,81.27 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:82.12,84.42 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:91.23,92.25 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:96.4,96.14 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:84.42,85.26 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:88.5,89.28 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:85.26,87.6 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:92.25,95.5 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:101.44,102.6 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:102.6,106.12 3 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:107.13,109.24 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:112.4,112.18 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:113.23,114.25 1 0
github.com/jackc/pgx/internal/sanitize/sanitize.go:118.4,118.14 1 0
github.com/jackc/pgx/internal/sanitize/sanitize.go:109.24,111.5 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:114.25,117.5 2 0
github.com/jackc/pgx/internal/sanitize/sanitize.go:123.44,124.6 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:124.6,128.12 3 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:129.12,131.23 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:134.4,134.18 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:135.23,136.25 1 0
github.com/jackc/pgx/internal/sanitize/sanitize.go:140.4,140.14 1 0
github.com/jackc/pgx/internal/sanitize/sanitize.go:131.23,133.5 1 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:136.25,139.5 2 0
github.com/jackc/pgx/internal/sanitize/sanitize.go:147.44,150.6 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:150.6,154.27 3 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:154.27,157.4 2 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:157.4,162.4 4 1
github.com/jackc/pgx/internal/sanitize/sanitize.go:169.67,171.16 2 0
github.com/jackc/pgx/internal/sanitize/sanitize.go:174.2,174.32 1 0
github.com/jackc/pgx/internal/sanitize/sanitize.go:171.16,173.3 1 0

View File

@ -1,6 +1,10 @@
package sanitize
import (
"bytes"
"fmt"
"strconv"
"strings"
"unicode/utf8"
)
@ -13,7 +17,33 @@ type Query struct {
}
func (q *Query) Sanitize(args ...interface{}) (string, error) {
return "", nil
buf := &bytes.Buffer{}
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
case int:
argIdx := part - 1
if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
case int:
str = strconv.FormatInt(int64(arg), 10)
case string:
str = QuoteString(arg)
}
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
return buf.String(), nil
}
func NewQuery(sql string) (*Query, error) {
@ -31,6 +61,10 @@ func NewQuery(sql string) (*Query, error) {
return query, nil
}
func QuoteString(str string) string {
return "'" + strings.Replace(str, "'", "''", -1) + "'"
}
type sqlLexer struct {
src string
start int
@ -47,6 +81,12 @@ func rawState(l *sqlLexer) stateFn {
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
@ -135,6 +175,31 @@ func placeholderState(l *sqlLexer) stateFn {
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.

View File

@ -0,0 +1,150 @@
package sanitize_test
import (
"testing"
"github.com/jackc/pgx/internal/sanitize"
)
func TestNewQuery(t *testing.T) {
successTests := []struct {
sql string
expected sanitize.Query
}{
{
sql: "select 42",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
},
{
sql: "select $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
},
{
sql: "select 'quoted $42', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
},
{
sql: `select "doubled quoted $42", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
},
{
sql: "select 'foo''bar', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
},
{
sql: `select "foo""bar", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
},
{
sql: "select '''', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
},
{
sql: `select """", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
},
{
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
},
{
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
},
{
sql: `select E'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
},
{
sql: `select e'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
},
}
for i, tt := range successTests {
query, err := sanitize.NewQuery(tt.sql)
if err != nil {
t.Errorf("%d. %v", i, err)
}
if len(query.Parts) == len(tt.expected.Parts) {
for j := range query.Parts {
if query.Parts[j] != tt.expected.Parts[j] {
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
}
}
} else {
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
}
}
}
func TestQuerySanitize(t *testing.T) {
successfulTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
args: []interface{}{},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{42},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{nil},
expected: `select null`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foobar"},
expected: `select 'foobar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foo'bar"},
expected: `select 'foo''bar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{`foo\'bar`},
expected: `select 'foo\''bar'`,
},
}
for i, tt := range successfulTests {
actual, err := tt.query.Sanitize(tt.args...)
if err != nil {
t.Errorf("%d. %v", i, err)
continue
}
if tt.expected != actual {
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
}
}
errorTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
args: []interface{}{42},
expected: `insufficient arguments`,
},
}
for i, tt := range errorTests {
_, err := tt.query.Sanitize(tt.args...)
if err.Error() != tt.expected {
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
}
}
}

View File

@ -1,68 +0,0 @@
package sanitize_test
import (
"testing"
"github.com/jackc/pgx/sanitize"
)
func TestNewQuery(t *testing.T) {
successTests := []struct {
sql string
expected sanitize.Query
}{
{
sql: "select 42",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
},
{
sql: "select $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
},
{
sql: "select 'quoted $42', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
},
{
sql: `select "doubled quoted $42", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
},
{
sql: "select 'foo''bar', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
},
{
sql: `select "foo""bar", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
},
{
sql: "select '''', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
},
{
sql: `select """", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
},
{
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
},
}
for i, tt := range successTests {
query, err := sanitize.NewQuery(tt.sql)
if err != nil {
t.Errorf("%d. %v", i, err)
}
if len(query.Parts) == len(tt.expected.Parts) {
for j := range query.Parts {
if query.Parts[j] != tt.expected.Parts[j] {
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
}
}
} else {
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
}
}
}