mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
151 lines
4.0 KiB
Go
151 lines
4.0 KiB
Go
package sanitize_test
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/jackc/pgx/internal/sanitize"
|
|
)
|
|
|
|
func TestNewQuery(t *testing.T) {
|
|
successTests := []struct {
|
|
sql string
|
|
expected sanitize.Query
|
|
}{
|
|
{
|
|
sql: "select 42",
|
|
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
|
},
|
|
{
|
|
sql: "select $1",
|
|
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
|
},
|
|
{
|
|
sql: "select 'quoted $42', $1",
|
|
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
|
|
},
|
|
{
|
|
sql: `select "doubled quoted $42", $1`,
|
|
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
|
|
},
|
|
{
|
|
sql: "select 'foo''bar', $1",
|
|
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
|
|
},
|
|
{
|
|
sql: `select "foo""bar", $1`,
|
|
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
|
|
},
|
|
{
|
|
sql: "select '''', $1",
|
|
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
|
|
},
|
|
{
|
|
sql: `select """", $1`,
|
|
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
|
|
},
|
|
{
|
|
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
|
|
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
|
|
},
|
|
{
|
|
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
|
|
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
|
|
},
|
|
{
|
|
sql: `select E'escape string\' $42', $1`,
|
|
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
|
|
},
|
|
{
|
|
sql: `select e'escape string\' $42', $1`,
|
|
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
|
|
},
|
|
}
|
|
|
|
for i, tt := range successTests {
|
|
query, err := sanitize.NewQuery(tt.sql)
|
|
if err != nil {
|
|
t.Errorf("%d. %v", i, err)
|
|
}
|
|
|
|
if len(query.Parts) == len(tt.expected.Parts) {
|
|
for j := range query.Parts {
|
|
if query.Parts[j] != tt.expected.Parts[j] {
|
|
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
|
|
}
|
|
}
|
|
} else {
|
|
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestQuerySanitize(t *testing.T) {
|
|
successfulTests := []struct {
|
|
query sanitize.Query
|
|
args []interface{}
|
|
expected string
|
|
}{
|
|
{
|
|
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
|
args: []interface{}{},
|
|
expected: `select 42`,
|
|
},
|
|
{
|
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
|
args: []interface{}{42},
|
|
expected: `select 42`,
|
|
},
|
|
{
|
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
|
args: []interface{}{nil},
|
|
expected: `select null`,
|
|
},
|
|
{
|
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
|
args: []interface{}{"foobar"},
|
|
expected: `select 'foobar'`,
|
|
},
|
|
{
|
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
|
args: []interface{}{"foo'bar"},
|
|
expected: `select 'foo''bar'`,
|
|
},
|
|
{
|
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
|
args: []interface{}{`foo\'bar`},
|
|
expected: `select 'foo\''bar'`,
|
|
},
|
|
}
|
|
|
|
for i, tt := range successfulTests {
|
|
actual, err := tt.query.Sanitize(tt.args...)
|
|
if err != nil {
|
|
t.Errorf("%d. %v", i, err)
|
|
continue
|
|
}
|
|
|
|
if tt.expected != actual {
|
|
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
|
|
}
|
|
}
|
|
|
|
errorTests := []struct {
|
|
query sanitize.Query
|
|
args []interface{}
|
|
expected string
|
|
}{
|
|
{
|
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
|
|
args: []interface{}{42},
|
|
expected: `insufficient arguments`,
|
|
},
|
|
}
|
|
|
|
for i, tt := range errorTests {
|
|
_, err := tt.query.Sanitize(tt.args...)
|
|
if err.Error() != tt.expected {
|
|
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
|
|
}
|
|
}
|
|
}
|