Added SQL parameter sanitization

fixes #3
This commit is contained in:
Jack Christensen 2013-04-15 20:22:28 -05:00
parent 2ed4f46454
commit 8392883350
4 changed files with 126 additions and 45 deletions

65
conn.go
View File

@ -90,8 +90,9 @@ func (c *conn) Close() (err error) {
return
}
func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
if err = c.sendSimpleQuery(sql); err != nil {
func (c *conn) query(sql string, params []interface{}, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
sanitized_sql := c.SanitizeSql(sql, params...)
if err = c.sendSimpleQuery(sanitized_sql); err != nil {
return
}
@ -128,28 +129,28 @@ func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescripti
panic("Unreachable")
}
func (c *conn) Query(sql string) (rows []map[string]string, err error) {
func (c *conn) Query(sql string, params ...interface{}) (rows []map[string]string, err error) {
rows = make([]map[string]string, 0, 8)
onDataRow := func(r *messageReader, fields []fieldDescription) error {
rows = append(rows, c.rxDataRow(r, fields))
return nil
}
err = c.query(sql, onDataRow)
err = c.query(sql, params, onDataRow)
return
}
func (c *conn) SelectString(sql string) (s string, err error) {
func (c *conn) SelectString(sql string, params ...interface{}) (s string, err error) {
onDataRow := func(r *messageReader, _ []fieldDescription) error {
s = c.rxDataRowFirstValue(r)
return nil
}
err = c.query(sql, onDataRow)
err = c.query(sql, params, onDataRow)
return
}
func (c *conn) selectInt(sql string, size int) (i int64, err error) {
func (c *conn) selectInt(sql string, size int, params []interface{}) (i int64, err error) {
var s string
s, err = c.SelectString(sql)
s, err = c.SelectString(sql, params...)
if err != nil {
return
}
@ -158,27 +159,27 @@ func (c *conn) selectInt(sql string, size int) (i int64, err error) {
return
}
func (c *conn) SelectInt64(sql string) (i int64, err error) {
return c.selectInt(sql, 64)
func (c *conn) SelectInt64(sql string, params ...interface{}) (i int64, err error) {
return c.selectInt(sql, 64, params)
}
func (c *conn) SelectInt32(sql string) (i int32, err error) {
func (c *conn) SelectInt32(sql string, params ...interface{}) (i int32, err error) {
var i64 int64
i64, err = c.selectInt(sql, 32)
i64, err = c.selectInt(sql, 32, params)
i = int32(i64)
return
}
func (c *conn) SelectInt16(sql string) (i int16, err error) {
func (c *conn) SelectInt16(sql string, params ...interface{}) (i int16, err error) {
var i64 int64
i64, err = c.selectInt(sql, 16)
i64, err = c.selectInt(sql, 16, params)
i = int16(i64)
return
}
func (c *conn) selectFloat(sql string, size int) (f float64, err error) {
func (c *conn) selectFloat(sql string, size int, params []interface{}) (f float64, err error) {
var s string
s, err = c.SelectString(sql)
s, err = c.SelectString(sql, params...)
if err != nil {
return
}
@ -187,28 +188,28 @@ func (c *conn) selectFloat(sql string, size int) (f float64, err error) {
return
}
func (c *conn) SelectFloat64(sql string) (f float64, err error) {
return c.selectFloat(sql, 64)
func (c *conn) SelectFloat64(sql string, params ...interface{}) (f float64, err error) {
return c.selectFloat(sql, 64, params)
}
func (c *conn) SelectFloat32(sql string) (f float32, err error) {
func (c *conn) SelectFloat32(sql string, params ...interface{}) (f float32, err error) {
var f64 float64
f64, err = c.selectFloat(sql, 32)
f64, err = c.selectFloat(sql, 32, params)
f = float32(f64)
return
}
func (c *conn) SelectAllString(sql string) (strings []string, err error) {
func (c *conn) SelectAllString(sql string, params ...interface{}) (strings []string, err error) {
strings = make([]string, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) error {
strings = append(strings, c.rxDataRowFirstValue(r))
return nil
}
err = c.query(sql, onDataRow)
err = c.query(sql, params, onDataRow)
return
}
func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) {
func (c *conn) SelectAllInt64(sql string, params ...interface{}) (ints []int64, err error) {
ints = make([]int64, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
var i int64
@ -216,11 +217,11 @@ func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) {
ints = append(ints, i)
return
}
err = c.query(sql, onDataRow)
err = c.query(sql, params, onDataRow)
return
}
func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) {
func (c *conn) SelectAllInt32(sql string, params ...interface{}) (ints []int32, err error) {
ints = make([]int32, 0, 8)
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) {
var i int64
@ -228,11 +229,11 @@ func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) {
ints = append(ints, int32(i))
return
}
err = c.query(sql, onDataRow)
err = c.query(sql, params, onDataRow)
return
}
func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) {
func (c *conn) SelectAllInt16(sql string, params ...interface{}) (ints []int16, err error) {
ints = make([]int16, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
var i int64
@ -240,11 +241,11 @@ func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) {
ints = append(ints, int16(i))
return
}
err = c.query(sql, onDataRow)
err = c.query(sql, params, onDataRow)
return
}
func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) {
func (c *conn) SelectAllFloat64(sql string, params ...interface{}) (floats []float64, err error) {
floats = make([]float64, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
var f float64
@ -252,11 +253,11 @@ func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) {
floats = append(floats, f)
return
}
err = c.query(sql, onDataRow)
err = c.query(sql, params, onDataRow)
return
}
func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) {
func (c *conn) SelectAllFloat32(sql string, params ...interface{}) (floats []float32, err error) {
floats = make([]float32, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
var f float64
@ -264,7 +265,7 @@ func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) {
floats = append(floats, float32(f))
return
}
err = c.query(sql, onDataRow)
err = c.query(sql, params, onDataRow)
return
}

View File

@ -89,7 +89,7 @@ func TestConnectWithMD5Password(t *testing.T) {
func TestQuery(t *testing.T) {
conn := getSharedConn()
rows, err := conn.Query("select 'Jack' as name")
rows, err := conn.Query("select $1 as name", "Jack")
if err != nil {
t.Fatal("Query failed")
}
@ -106,7 +106,7 @@ func TestQuery(t *testing.T) {
func TestSelectString(t *testing.T) {
conn := getSharedConn()
s, err := conn.SelectString("select 'foo'")
s, err := conn.SelectString("select $1", "foo")
if err != nil {
t.Fatal("Unable to select string: " + err.Error())
}
@ -119,7 +119,7 @@ func TestSelectString(t *testing.T) {
func TestSelectInt64(t *testing.T) {
conn := getSharedConn()
i, err := conn.SelectInt64("select 1")
i, err := conn.SelectInt64("select $1", 1)
if err != nil {
t.Fatal("Unable to select int64: " + err.Error())
}
@ -142,7 +142,7 @@ func TestSelectInt64(t *testing.T) {
func TestSelectInt32(t *testing.T) {
conn := getSharedConn()
i, err := conn.SelectInt32("select 1")
i, err := conn.SelectInt32("select $1", 1)
if err != nil {
t.Fatal("Unable to select int32: " + err.Error())
}
@ -165,7 +165,7 @@ func TestSelectInt32(t *testing.T) {
func TestSelectInt16(t *testing.T) {
conn := getSharedConn()
i, err := conn.SelectInt16("select 1")
i, err := conn.SelectInt16("select $1", 1)
if err != nil {
t.Fatal("Unable to select int16: " + err.Error())
}
@ -188,7 +188,7 @@ func TestSelectInt16(t *testing.T) {
func TestSelectFloat64(t *testing.T) {
conn := getSharedConn()
f, err := conn.SelectFloat64("select 1.23")
f, err := conn.SelectFloat64("select $1", 1.23)
if err != nil {
t.Fatal("Unable to select float64: " + err.Error())
}
@ -201,7 +201,7 @@ func TestSelectFloat64(t *testing.T) {
func TestSelectFloat32(t *testing.T) {
conn := getSharedConn()
f, err := conn.SelectFloat32("select 1.23")
f, err := conn.SelectFloat32("select $1", 1.23)
if err != nil {
t.Fatal("Unable to select float32: " + err.Error())
}
@ -214,7 +214,7 @@ func TestSelectFloat32(t *testing.T) {
func TestSelectAllString(t *testing.T) {
conn := getSharedConn()
s, err := conn.SelectAllString("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t")
s, err := conn.SelectAllString("select * from (values ($1), ($2), ($3), ($4)) t", "Matthew", "Mark", "Luke", "John")
if err != nil {
t.Fatal("Unable to select all strings: " + err.Error())
}
@ -227,7 +227,7 @@ func TestSelectAllString(t *testing.T) {
func TestSelectAllInt64(t *testing.T) {
conn := getSharedConn()
i, err := conn.SelectAllInt64("select * from (values (1), (2)) t")
i, err := conn.SelectAllInt64("select * from (values ($1), ($2)) t", 1, 2)
if err != nil {
t.Fatal("Unable to select all int64: " + err.Error())
}
@ -250,7 +250,7 @@ func TestSelectAllInt64(t *testing.T) {
func TestSelectAllInt32(t *testing.T) {
conn := getSharedConn()
i, err := conn.SelectAllInt32("select * from (values (1), (2)) t")
i, err := conn.SelectAllInt32("select * from (values ($1), ($2)) t", 1, 2)
if err != nil {
t.Fatal("Unable to select all int32: " + err.Error())
}
@ -273,7 +273,7 @@ func TestSelectAllInt32(t *testing.T) {
func TestSelectAllInt16(t *testing.T) {
conn := getSharedConn()
i, err := conn.SelectAllInt16("select * from (values (1), (2)) t")
i, err := conn.SelectAllInt16("select * from (values ($1), ($2)) t", 1, 2)
if err != nil {
t.Fatal("Unable to select all int16: " + err.Error())
}
@ -296,7 +296,7 @@ func TestSelectAllInt16(t *testing.T) {
func TestSelectAllFloat64(t *testing.T) {
conn := getSharedConn()
f, err := conn.SelectAllFloat64("select * from (values (1.23), (4.56)) t")
f, err := conn.SelectAllFloat64("select * from (values ($1), ($2)) t", 1.23, 4.56)
if err != nil {
t.Fatal("Unable to select all float64: " + err.Error())
}
@ -309,7 +309,7 @@ func TestSelectAllFloat64(t *testing.T) {
func TestSelectAllFloat32(t *testing.T) {
conn := getSharedConn()
f, err := conn.SelectAllFloat32("select * from (values (1.23), (4.56)) t")
f, err := conn.SelectAllFloat32("select * from (values ($1), ($2)) t", 1.23, 4.56)
if err != nil {
t.Fatal("Unable to select all float32: " + err.Error())
}

55
sanitize.go Normal file
View File

@ -0,0 +1,55 @@
package pgx
import (
"reflect"
"regexp"
"strconv"
"strings"
)
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
func (c *conn) QuoteString(input string) (output string) {
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
return
}
func (c *conn) SanitizeSql(sql string, args ...interface{}) (output string) {
replacer := func(match string) (replacement string) {
n, _ := strconv.ParseInt(match[1:], 10, 0)
switch arg := args[n-1].(type) {
case string:
return c.QuoteString(arg)
case int:
return strconv.FormatInt(int64(arg), 10)
case int8:
return strconv.FormatInt(int64(arg), 10)
case int16:
return strconv.FormatInt(int64(arg), 10)
case int32:
return strconv.FormatInt(int64(arg), 10)
case int64:
return strconv.FormatInt(int64(arg), 10)
case uint:
return strconv.FormatUint(uint64(arg), 10)
case uint8:
return strconv.FormatUint(uint64(arg), 10)
case uint16:
return strconv.FormatUint(uint64(arg), 10)
case uint32:
return strconv.FormatUint(uint64(arg), 10)
case uint64:
return strconv.FormatUint(uint64(arg), 10)
case float32:
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
case float64:
return strconv.FormatFloat(arg, 'f', -1, 64)
default:
panic("Unable to sanitize type: " + reflect.TypeOf(arg).String())
}
panic("Unreachable")
}
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
return
}

25
sanitize_test.go Normal file
View File

@ -0,0 +1,25 @@
package pgx
import (
"testing"
)
func TestQuoteString(t *testing.T) {
conn := getSharedConn()
if conn.QuoteString("test") != "'test'" {
t.Error("Failed to quote string")
}
if conn.QuoteString("Jack's") != "'Jack''s'" {
t.Error("Failed to quote and escape string with embedded quote")
}
}
func TestSanitizeSql(t *testing.T) {
conn := getSharedConn()
if conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23) != "select 'Jack''s', 42, 1.23" {
t.Error("Failed to sanitize sql")
}
}