diff --git a/conn.go b/conn.go index 6b210043..a976e2a5 100644 --- a/conn.go +++ b/conn.go @@ -90,8 +90,9 @@ func (c *conn) Close() (err error) { return } -func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) { - if err = c.sendSimpleQuery(sql); err != nil { +func (c *conn) query(sql string, params []interface{}, onDataRow func(*messageReader, []fieldDescription) error) (err error) { + sanitized_sql := c.SanitizeSql(sql, params...) + if err = c.sendSimpleQuery(sanitized_sql); err != nil { return } @@ -128,28 +129,28 @@ func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescripti panic("Unreachable") } -func (c *conn) Query(sql string) (rows []map[string]string, err error) { +func (c *conn) Query(sql string, params ...interface{}) (rows []map[string]string, err error) { rows = make([]map[string]string, 0, 8) onDataRow := func(r *messageReader, fields []fieldDescription) error { rows = append(rows, c.rxDataRow(r, fields)) return nil } - err = c.query(sql, onDataRow) + err = c.query(sql, params, onDataRow) return } -func (c *conn) SelectString(sql string) (s string, err error) { +func (c *conn) SelectString(sql string, params ...interface{}) (s string, err error) { onDataRow := func(r *messageReader, _ []fieldDescription) error { s = c.rxDataRowFirstValue(r) return nil } - err = c.query(sql, onDataRow) + err = c.query(sql, params, onDataRow) return } -func (c *conn) selectInt(sql string, size int) (i int64, err error) { +func (c *conn) selectInt(sql string, size int, params []interface{}) (i int64, err error) { var s string - s, err = c.SelectString(sql) + s, err = c.SelectString(sql, params...) if err != nil { return } @@ -158,27 +159,27 @@ func (c *conn) selectInt(sql string, size int) (i int64, err error) { return } -func (c *conn) SelectInt64(sql string) (i int64, err error) { - return c.selectInt(sql, 64) +func (c *conn) SelectInt64(sql string, params ...interface{}) (i int64, err error) { + return c.selectInt(sql, 64, params) } -func (c *conn) SelectInt32(sql string) (i int32, err error) { +func (c *conn) SelectInt32(sql string, params ...interface{}) (i int32, err error) { var i64 int64 - i64, err = c.selectInt(sql, 32) + i64, err = c.selectInt(sql, 32, params) i = int32(i64) return } -func (c *conn) SelectInt16(sql string) (i int16, err error) { +func (c *conn) SelectInt16(sql string, params ...interface{}) (i int16, err error) { var i64 int64 - i64, err = c.selectInt(sql, 16) + i64, err = c.selectInt(sql, 16, params) i = int16(i64) return } -func (c *conn) selectFloat(sql string, size int) (f float64, err error) { +func (c *conn) selectFloat(sql string, size int, params []interface{}) (f float64, err error) { var s string - s, err = c.SelectString(sql) + s, err = c.SelectString(sql, params...) if err != nil { return } @@ -187,28 +188,28 @@ func (c *conn) selectFloat(sql string, size int) (f float64, err error) { return } -func (c *conn) SelectFloat64(sql string) (f float64, err error) { - return c.selectFloat(sql, 64) +func (c *conn) SelectFloat64(sql string, params ...interface{}) (f float64, err error) { + return c.selectFloat(sql, 64, params) } -func (c *conn) SelectFloat32(sql string) (f float32, err error) { +func (c *conn) SelectFloat32(sql string, params ...interface{}) (f float32, err error) { var f64 float64 - f64, err = c.selectFloat(sql, 32) + f64, err = c.selectFloat(sql, 32, params) f = float32(f64) return } -func (c *conn) SelectAllString(sql string) (strings []string, err error) { +func (c *conn) SelectAllString(sql string, params ...interface{}) (strings []string, err error) { strings = make([]string, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) error { strings = append(strings, c.rxDataRowFirstValue(r)) return nil } - err = c.query(sql, onDataRow) + err = c.query(sql, params, onDataRow) return } -func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) { +func (c *conn) SelectAllInt64(sql string, params ...interface{}) (ints []int64, err error) { ints = make([]int64, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { var i int64 @@ -216,11 +217,11 @@ func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) { ints = append(ints, i) return } - err = c.query(sql, onDataRow) + err = c.query(sql, params, onDataRow) return } -func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) { +func (c *conn) SelectAllInt32(sql string, params ...interface{}) (ints []int32, err error) { ints = make([]int32, 0, 8) onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) { var i int64 @@ -228,11 +229,11 @@ func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) { ints = append(ints, int32(i)) return } - err = c.query(sql, onDataRow) + err = c.query(sql, params, onDataRow) return } -func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) { +func (c *conn) SelectAllInt16(sql string, params ...interface{}) (ints []int16, err error) { ints = make([]int16, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { var i int64 @@ -240,11 +241,11 @@ func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) { ints = append(ints, int16(i)) return } - err = c.query(sql, onDataRow) + err = c.query(sql, params, onDataRow) return } -func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) { +func (c *conn) SelectAllFloat64(sql string, params ...interface{}) (floats []float64, err error) { floats = make([]float64, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { var f float64 @@ -252,11 +253,11 @@ func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) { floats = append(floats, f) return } - err = c.query(sql, onDataRow) + err = c.query(sql, params, onDataRow) return } -func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) { +func (c *conn) SelectAllFloat32(sql string, params ...interface{}) (floats []float32, err error) { floats = make([]float32, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { var f float64 @@ -264,7 +265,7 @@ func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) { floats = append(floats, float32(f)) return } - err = c.query(sql, onDataRow) + err = c.query(sql, params, onDataRow) return } diff --git a/conn_test.go b/conn_test.go index 4fc4ea9a..3ddb6ac3 100644 --- a/conn_test.go +++ b/conn_test.go @@ -89,7 +89,7 @@ func TestConnectWithMD5Password(t *testing.T) { func TestQuery(t *testing.T) { conn := getSharedConn() - rows, err := conn.Query("select 'Jack' as name") + rows, err := conn.Query("select $1 as name", "Jack") if err != nil { t.Fatal("Query failed") } @@ -106,7 +106,7 @@ func TestQuery(t *testing.T) { func TestSelectString(t *testing.T) { conn := getSharedConn() - s, err := conn.SelectString("select 'foo'") + s, err := conn.SelectString("select $1", "foo") if err != nil { t.Fatal("Unable to select string: " + err.Error()) } @@ -119,7 +119,7 @@ func TestSelectString(t *testing.T) { func TestSelectInt64(t *testing.T) { conn := getSharedConn() - i, err := conn.SelectInt64("select 1") + i, err := conn.SelectInt64("select $1", 1) if err != nil { t.Fatal("Unable to select int64: " + err.Error()) } @@ -142,7 +142,7 @@ func TestSelectInt64(t *testing.T) { func TestSelectInt32(t *testing.T) { conn := getSharedConn() - i, err := conn.SelectInt32("select 1") + i, err := conn.SelectInt32("select $1", 1) if err != nil { t.Fatal("Unable to select int32: " + err.Error()) } @@ -165,7 +165,7 @@ func TestSelectInt32(t *testing.T) { func TestSelectInt16(t *testing.T) { conn := getSharedConn() - i, err := conn.SelectInt16("select 1") + i, err := conn.SelectInt16("select $1", 1) if err != nil { t.Fatal("Unable to select int16: " + err.Error()) } @@ -188,7 +188,7 @@ func TestSelectInt16(t *testing.T) { func TestSelectFloat64(t *testing.T) { conn := getSharedConn() - f, err := conn.SelectFloat64("select 1.23") + f, err := conn.SelectFloat64("select $1", 1.23) if err != nil { t.Fatal("Unable to select float64: " + err.Error()) } @@ -201,7 +201,7 @@ func TestSelectFloat64(t *testing.T) { func TestSelectFloat32(t *testing.T) { conn := getSharedConn() - f, err := conn.SelectFloat32("select 1.23") + f, err := conn.SelectFloat32("select $1", 1.23) if err != nil { t.Fatal("Unable to select float32: " + err.Error()) } @@ -214,7 +214,7 @@ func TestSelectFloat32(t *testing.T) { func TestSelectAllString(t *testing.T) { conn := getSharedConn() - s, err := conn.SelectAllString("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t") + s, err := conn.SelectAllString("select * from (values ($1), ($2), ($3), ($4)) t", "Matthew", "Mark", "Luke", "John") if err != nil { t.Fatal("Unable to select all strings: " + err.Error()) } @@ -227,7 +227,7 @@ func TestSelectAllString(t *testing.T) { func TestSelectAllInt64(t *testing.T) { conn := getSharedConn() - i, err := conn.SelectAllInt64("select * from (values (1), (2)) t") + i, err := conn.SelectAllInt64("select * from (values ($1), ($2)) t", 1, 2) if err != nil { t.Fatal("Unable to select all int64: " + err.Error()) } @@ -250,7 +250,7 @@ func TestSelectAllInt64(t *testing.T) { func TestSelectAllInt32(t *testing.T) { conn := getSharedConn() - i, err := conn.SelectAllInt32("select * from (values (1), (2)) t") + i, err := conn.SelectAllInt32("select * from (values ($1), ($2)) t", 1, 2) if err != nil { t.Fatal("Unable to select all int32: " + err.Error()) } @@ -273,7 +273,7 @@ func TestSelectAllInt32(t *testing.T) { func TestSelectAllInt16(t *testing.T) { conn := getSharedConn() - i, err := conn.SelectAllInt16("select * from (values (1), (2)) t") + i, err := conn.SelectAllInt16("select * from (values ($1), ($2)) t", 1, 2) if err != nil { t.Fatal("Unable to select all int16: " + err.Error()) } @@ -296,7 +296,7 @@ func TestSelectAllInt16(t *testing.T) { func TestSelectAllFloat64(t *testing.T) { conn := getSharedConn() - f, err := conn.SelectAllFloat64("select * from (values (1.23), (4.56)) t") + f, err := conn.SelectAllFloat64("select * from (values ($1), ($2)) t", 1.23, 4.56) if err != nil { t.Fatal("Unable to select all float64: " + err.Error()) } @@ -309,7 +309,7 @@ func TestSelectAllFloat64(t *testing.T) { func TestSelectAllFloat32(t *testing.T) { conn := getSharedConn() - f, err := conn.SelectAllFloat32("select * from (values (1.23), (4.56)) t") + f, err := conn.SelectAllFloat32("select * from (values ($1), ($2)) t", 1.23, 4.56) if err != nil { t.Fatal("Unable to select all float32: " + err.Error()) } diff --git a/sanitize.go b/sanitize.go new file mode 100644 index 00000000..871b00cd --- /dev/null +++ b/sanitize.go @@ -0,0 +1,55 @@ +package pgx + +import ( + "reflect" + "regexp" + "strconv" + "strings" +) + +var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`) + +func (c *conn) QuoteString(input string) (output string) { + output = "'" + strings.Replace(input, "'", "''", -1) + "'" + return +} + +func (c *conn) SanitizeSql(sql string, args ...interface{}) (output string) { + replacer := func(match string) (replacement string) { + n, _ := strconv.ParseInt(match[1:], 10, 0) + switch arg := args[n-1].(type) { + case string: + return c.QuoteString(arg) + case int: + return strconv.FormatInt(int64(arg), 10) + case int8: + return strconv.FormatInt(int64(arg), 10) + case int16: + return strconv.FormatInt(int64(arg), 10) + case int32: + return strconv.FormatInt(int64(arg), 10) + case int64: + return strconv.FormatInt(int64(arg), 10) + case uint: + return strconv.FormatUint(uint64(arg), 10) + case uint8: + return strconv.FormatUint(uint64(arg), 10) + case uint16: + return strconv.FormatUint(uint64(arg), 10) + case uint32: + return strconv.FormatUint(uint64(arg), 10) + case uint64: + return strconv.FormatUint(uint64(arg), 10) + case float32: + return strconv.FormatFloat(float64(arg), 'f', -1, 32) + case float64: + return strconv.FormatFloat(arg, 'f', -1, 64) + default: + panic("Unable to sanitize type: " + reflect.TypeOf(arg).String()) + } + panic("Unreachable") + } + + output = literalPattern.ReplaceAllStringFunc(sql, replacer) + return +} diff --git a/sanitize_test.go b/sanitize_test.go new file mode 100644 index 00000000..06282eb7 --- /dev/null +++ b/sanitize_test.go @@ -0,0 +1,25 @@ +package pgx + +import ( + "testing" +) + +func TestQuoteString(t *testing.T) { + conn := getSharedConn() + + if conn.QuoteString("test") != "'test'" { + t.Error("Failed to quote string") + } + + if conn.QuoteString("Jack's") != "'Jack''s'" { + t.Error("Failed to quote and escape string with embedded quote") + } +} + +func TestSanitizeSql(t *testing.T) { + conn := getSharedConn() + + if conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23) != "select 'Jack''s', 42, 1.23" { + t.Error("Failed to sanitize sql") + } +}