Revert making query methods automatically escape arguments.

Must now call SanitizeSql explicitly.

This was necessary because go supports variadic arguments but not
totally optional arguments. So it would require something to
always be passed in.
This commit is contained in:
Jack Christensen 2013-04-16 19:55:01 -05:00
parent 8392883350
commit cbf03821e1
2 changed files with 45 additions and 46 deletions

65
conn.go
View File

@ -90,9 +90,8 @@ func (c *conn) Close() (err error) {
return return
} }
func (c *conn) query(sql string, params []interface{}, onDataRow func(*messageReader, []fieldDescription) error) (err error) { func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
sanitized_sql := c.SanitizeSql(sql, params...) if err = c.sendSimpleQuery(sql); err != nil {
if err = c.sendSimpleQuery(sanitized_sql); err != nil {
return return
} }
@ -129,28 +128,28 @@ func (c *conn) query(sql string, params []interface{}, onDataRow func(*messageRe
panic("Unreachable") panic("Unreachable")
} }
func (c *conn) Query(sql string, params ...interface{}) (rows []map[string]string, err error) { func (c *conn) Query(sql string) (rows []map[string]string, err error) {
rows = make([]map[string]string, 0, 8) rows = make([]map[string]string, 0, 8)
onDataRow := func(r *messageReader, fields []fieldDescription) error { onDataRow := func(r *messageReader, fields []fieldDescription) error {
rows = append(rows, c.rxDataRow(r, fields)) rows = append(rows, c.rxDataRow(r, fields))
return nil return nil
} }
err = c.query(sql, params, onDataRow) err = c.query(sql, onDataRow)
return return
} }
func (c *conn) SelectString(sql string, params ...interface{}) (s string, err error) { func (c *conn) SelectString(sql string) (s string, err error) {
onDataRow := func(r *messageReader, _ []fieldDescription) error { onDataRow := func(r *messageReader, _ []fieldDescription) error {
s = c.rxDataRowFirstValue(r) s = c.rxDataRowFirstValue(r)
return nil return nil
} }
err = c.query(sql, params, onDataRow) err = c.query(sql, onDataRow)
return return
} }
func (c *conn) selectInt(sql string, size int, params []interface{}) (i int64, err error) { func (c *conn) selectInt(sql string, size int) (i int64, err error) {
var s string var s string
s, err = c.SelectString(sql, params...) s, err = c.SelectString(sql)
if err != nil { if err != nil {
return return
} }
@ -159,27 +158,27 @@ func (c *conn) selectInt(sql string, size int, params []interface{}) (i int64, e
return return
} }
func (c *conn) SelectInt64(sql string, params ...interface{}) (i int64, err error) { func (c *conn) SelectInt64(sql string) (i int64, err error) {
return c.selectInt(sql, 64, params) return c.selectInt(sql, 64)
} }
func (c *conn) SelectInt32(sql string, params ...interface{}) (i int32, err error) { func (c *conn) SelectInt32(sql string) (i int32, err error) {
var i64 int64 var i64 int64
i64, err = c.selectInt(sql, 32, params) i64, err = c.selectInt(sql, 32)
i = int32(i64) i = int32(i64)
return return
} }
func (c *conn) SelectInt16(sql string, params ...interface{}) (i int16, err error) { func (c *conn) SelectInt16(sql string) (i int16, err error) {
var i64 int64 var i64 int64
i64, err = c.selectInt(sql, 16, params) i64, err = c.selectInt(sql, 16)
i = int16(i64) i = int16(i64)
return return
} }
func (c *conn) selectFloat(sql string, size int, params []interface{}) (f float64, err error) { func (c *conn) selectFloat(sql string, size int) (f float64, err error) {
var s string var s string
s, err = c.SelectString(sql, params...) s, err = c.SelectString(sql)
if err != nil { if err != nil {
return return
} }
@ -188,28 +187,28 @@ func (c *conn) selectFloat(sql string, size int, params []interface{}) (f float6
return return
} }
func (c *conn) SelectFloat64(sql string, params ...interface{}) (f float64, err error) { func (c *conn) SelectFloat64(sql string) (f float64, err error) {
return c.selectFloat(sql, 64, params) return c.selectFloat(sql, 64)
} }
func (c *conn) SelectFloat32(sql string, params ...interface{}) (f float32, err error) { func (c *conn) SelectFloat32(sql string) (f float32, err error) {
var f64 float64 var f64 float64
f64, err = c.selectFloat(sql, 32, params) f64, err = c.selectFloat(sql, 32)
f = float32(f64) f = float32(f64)
return return
} }
func (c *conn) SelectAllString(sql string, params ...interface{}) (strings []string, err error) { func (c *conn) SelectAllString(sql string) (strings []string, err error) {
strings = make([]string, 0, 8) strings = make([]string, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) error { onDataRow := func(r *messageReader, _ []fieldDescription) error {
strings = append(strings, c.rxDataRowFirstValue(r)) strings = append(strings, c.rxDataRowFirstValue(r))
return nil return nil
} }
err = c.query(sql, params, onDataRow) err = c.query(sql, onDataRow)
return return
} }
func (c *conn) SelectAllInt64(sql string, params ...interface{}) (ints []int64, err error) { func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) {
ints = make([]int64, 0, 8) ints = make([]int64, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
var i int64 var i int64
@ -217,11 +216,11 @@ func (c *conn) SelectAllInt64(sql string, params ...interface{}) (ints []int64,
ints = append(ints, i) ints = append(ints, i)
return return
} }
err = c.query(sql, params, onDataRow) err = c.query(sql, onDataRow)
return return
} }
func (c *conn) SelectAllInt32(sql string, params ...interface{}) (ints []int32, err error) { func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) {
ints = make([]int32, 0, 8) ints = make([]int32, 0, 8)
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) { onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) {
var i int64 var i int64
@ -229,11 +228,11 @@ func (c *conn) SelectAllInt32(sql string, params ...interface{}) (ints []int32,
ints = append(ints, int32(i)) ints = append(ints, int32(i))
return return
} }
err = c.query(sql, params, onDataRow) err = c.query(sql, onDataRow)
return return
} }
func (c *conn) SelectAllInt16(sql string, params ...interface{}) (ints []int16, err error) { func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) {
ints = make([]int16, 0, 8) ints = make([]int16, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
var i int64 var i int64
@ -241,11 +240,11 @@ func (c *conn) SelectAllInt16(sql string, params ...interface{}) (ints []int16,
ints = append(ints, int16(i)) ints = append(ints, int16(i))
return return
} }
err = c.query(sql, params, onDataRow) err = c.query(sql, onDataRow)
return return
} }
func (c *conn) SelectAllFloat64(sql string, params ...interface{}) (floats []float64, err error) { func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) {
floats = make([]float64, 0, 8) floats = make([]float64, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
var f float64 var f float64
@ -253,11 +252,11 @@ func (c *conn) SelectAllFloat64(sql string, params ...interface{}) (floats []flo
floats = append(floats, f) floats = append(floats, f)
return return
} }
err = c.query(sql, params, onDataRow) err = c.query(sql, onDataRow)
return return
} }
func (c *conn) SelectAllFloat32(sql string, params ...interface{}) (floats []float32, err error) { func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) {
floats = make([]float32, 0, 8) floats = make([]float32, 0, 8)
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
var f float64 var f float64
@ -265,7 +264,7 @@ func (c *conn) SelectAllFloat32(sql string, params ...interface{}) (floats []flo
floats = append(floats, float32(f)) floats = append(floats, float32(f))
return return
} }
err = c.query(sql, params, onDataRow) err = c.query(sql, onDataRow)
return return
} }

View File

@ -89,7 +89,7 @@ func TestConnectWithMD5Password(t *testing.T) {
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
rows, err := conn.Query("select $1 as name", "Jack") rows, err := conn.Query("select 'Jack' as name")
if err != nil { if err != nil {
t.Fatal("Query failed") t.Fatal("Query failed")
} }
@ -106,7 +106,7 @@ func TestQuery(t *testing.T) {
func TestSelectString(t *testing.T) { func TestSelectString(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
s, err := conn.SelectString("select $1", "foo") s, err := conn.SelectString("select 'foo'")
if err != nil { if err != nil {
t.Fatal("Unable to select string: " + err.Error()) t.Fatal("Unable to select string: " + err.Error())
} }
@ -119,7 +119,7 @@ func TestSelectString(t *testing.T) {
func TestSelectInt64(t *testing.T) { func TestSelectInt64(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
i, err := conn.SelectInt64("select $1", 1) i, err := conn.SelectInt64("select 1")
if err != nil { if err != nil {
t.Fatal("Unable to select int64: " + err.Error()) t.Fatal("Unable to select int64: " + err.Error())
} }
@ -142,7 +142,7 @@ func TestSelectInt64(t *testing.T) {
func TestSelectInt32(t *testing.T) { func TestSelectInt32(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
i, err := conn.SelectInt32("select $1", 1) i, err := conn.SelectInt32("select 1")
if err != nil { if err != nil {
t.Fatal("Unable to select int32: " + err.Error()) t.Fatal("Unable to select int32: " + err.Error())
} }
@ -165,7 +165,7 @@ func TestSelectInt32(t *testing.T) {
func TestSelectInt16(t *testing.T) { func TestSelectInt16(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
i, err := conn.SelectInt16("select $1", 1) i, err := conn.SelectInt16("select 1")
if err != nil { if err != nil {
t.Fatal("Unable to select int16: " + err.Error()) t.Fatal("Unable to select int16: " + err.Error())
} }
@ -188,7 +188,7 @@ func TestSelectInt16(t *testing.T) {
func TestSelectFloat64(t *testing.T) { func TestSelectFloat64(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
f, err := conn.SelectFloat64("select $1", 1.23) f, err := conn.SelectFloat64("select 1.23")
if err != nil { if err != nil {
t.Fatal("Unable to select float64: " + err.Error()) t.Fatal("Unable to select float64: " + err.Error())
} }
@ -201,7 +201,7 @@ func TestSelectFloat64(t *testing.T) {
func TestSelectFloat32(t *testing.T) { func TestSelectFloat32(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
f, err := conn.SelectFloat32("select $1", 1.23) f, err := conn.SelectFloat32("select 1.23")
if err != nil { if err != nil {
t.Fatal("Unable to select float32: " + err.Error()) t.Fatal("Unable to select float32: " + err.Error())
} }
@ -214,7 +214,7 @@ func TestSelectFloat32(t *testing.T) {
func TestSelectAllString(t *testing.T) { func TestSelectAllString(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
s, err := conn.SelectAllString("select * from (values ($1), ($2), ($3), ($4)) t", "Matthew", "Mark", "Luke", "John") s, err := conn.SelectAllString("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t")
if err != nil { if err != nil {
t.Fatal("Unable to select all strings: " + err.Error()) t.Fatal("Unable to select all strings: " + err.Error())
} }
@ -227,7 +227,7 @@ func TestSelectAllString(t *testing.T) {
func TestSelectAllInt64(t *testing.T) { func TestSelectAllInt64(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
i, err := conn.SelectAllInt64("select * from (values ($1), ($2)) t", 1, 2) i, err := conn.SelectAllInt64("select * from (values (1), (2)) t")
if err != nil { if err != nil {
t.Fatal("Unable to select all int64: " + err.Error()) t.Fatal("Unable to select all int64: " + err.Error())
} }
@ -250,7 +250,7 @@ func TestSelectAllInt64(t *testing.T) {
func TestSelectAllInt32(t *testing.T) { func TestSelectAllInt32(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
i, err := conn.SelectAllInt32("select * from (values ($1), ($2)) t", 1, 2) i, err := conn.SelectAllInt32("select * from (values (1), (2)) t")
if err != nil { if err != nil {
t.Fatal("Unable to select all int32: " + err.Error()) t.Fatal("Unable to select all int32: " + err.Error())
} }
@ -273,7 +273,7 @@ func TestSelectAllInt32(t *testing.T) {
func TestSelectAllInt16(t *testing.T) { func TestSelectAllInt16(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
i, err := conn.SelectAllInt16("select * from (values ($1), ($2)) t", 1, 2) i, err := conn.SelectAllInt16("select * from (values (1), (2)) t")
if err != nil { if err != nil {
t.Fatal("Unable to select all int16: " + err.Error()) t.Fatal("Unable to select all int16: " + err.Error())
} }
@ -296,7 +296,7 @@ func TestSelectAllInt16(t *testing.T) {
func TestSelectAllFloat64(t *testing.T) { func TestSelectAllFloat64(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
f, err := conn.SelectAllFloat64("select * from (values ($1), ($2)) t", 1.23, 4.56) f, err := conn.SelectAllFloat64("select * from (values (1.23), (4.56)) t")
if err != nil { if err != nil {
t.Fatal("Unable to select all float64: " + err.Error()) t.Fatal("Unable to select all float64: " + err.Error())
} }
@ -309,7 +309,7 @@ func TestSelectAllFloat64(t *testing.T) {
func TestSelectAllFloat32(t *testing.T) { func TestSelectAllFloat32(t *testing.T) {
conn := getSharedConn() conn := getSharedConn()
f, err := conn.SelectAllFloat32("select * from (values ($1), ($2)) t", 1.23, 4.56) f, err := conn.SelectAllFloat32("select * from (values (1.23), (4.56)) t")
if err != nil { if err != nil {
t.Fatal("Unable to select all float32: " + err.Error()) t.Fatal("Unable to select all float32: " + err.Error())
} }