mirror of
https://github.com/jackc/pgx.git
synced 2025-05-28 18:22:15 +00:00
parent
2ed4f46454
commit
8392883350
65
conn.go
65
conn.go
@ -90,8 +90,9 @@ func (c *conn) Close() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
|
||||
if err = c.sendSimpleQuery(sql); err != nil {
|
||||
func (c *conn) query(sql string, params []interface{}, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
|
||||
sanitized_sql := c.SanitizeSql(sql, params...)
|
||||
if err = c.sendSimpleQuery(sanitized_sql); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -128,28 +129,28 @@ func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescripti
|
||||
panic("Unreachable")
|
||||
}
|
||||
|
||||
func (c *conn) Query(sql string) (rows []map[string]string, err error) {
|
||||
func (c *conn) Query(sql string, params ...interface{}) (rows []map[string]string, err error) {
|
||||
rows = make([]map[string]string, 0, 8)
|
||||
onDataRow := func(r *messageReader, fields []fieldDescription) error {
|
||||
rows = append(rows, c.rxDataRow(r, fields))
|
||||
return nil
|
||||
}
|
||||
err = c.query(sql, onDataRow)
|
||||
err = c.query(sql, params, onDataRow)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectString(sql string) (s string, err error) {
|
||||
func (c *conn) SelectString(sql string, params ...interface{}) (s string, err error) {
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||
s = c.rxDataRowFirstValue(r)
|
||||
return nil
|
||||
}
|
||||
err = c.query(sql, onDataRow)
|
||||
err = c.query(sql, params, onDataRow)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) selectInt(sql string, size int) (i int64, err error) {
|
||||
func (c *conn) selectInt(sql string, size int, params []interface{}) (i int64, err error) {
|
||||
var s string
|
||||
s, err = c.SelectString(sql)
|
||||
s, err = c.SelectString(sql, params...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -158,27 +159,27 @@ func (c *conn) selectInt(sql string, size int) (i int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectInt64(sql string) (i int64, err error) {
|
||||
return c.selectInt(sql, 64)
|
||||
func (c *conn) SelectInt64(sql string, params ...interface{}) (i int64, err error) {
|
||||
return c.selectInt(sql, 64, params)
|
||||
}
|
||||
|
||||
func (c *conn) SelectInt32(sql string) (i int32, err error) {
|
||||
func (c *conn) SelectInt32(sql string, params ...interface{}) (i int32, err error) {
|
||||
var i64 int64
|
||||
i64, err = c.selectInt(sql, 32)
|
||||
i64, err = c.selectInt(sql, 32, params)
|
||||
i = int32(i64)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectInt16(sql string) (i int16, err error) {
|
||||
func (c *conn) SelectInt16(sql string, params ...interface{}) (i int16, err error) {
|
||||
var i64 int64
|
||||
i64, err = c.selectInt(sql, 16)
|
||||
i64, err = c.selectInt(sql, 16, params)
|
||||
i = int16(i64)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) selectFloat(sql string, size int) (f float64, err error) {
|
||||
func (c *conn) selectFloat(sql string, size int, params []interface{}) (f float64, err error) {
|
||||
var s string
|
||||
s, err = c.SelectString(sql)
|
||||
s, err = c.SelectString(sql, params...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -187,28 +188,28 @@ func (c *conn) selectFloat(sql string, size int) (f float64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectFloat64(sql string) (f float64, err error) {
|
||||
return c.selectFloat(sql, 64)
|
||||
func (c *conn) SelectFloat64(sql string, params ...interface{}) (f float64, err error) {
|
||||
return c.selectFloat(sql, 64, params)
|
||||
}
|
||||
|
||||
func (c *conn) SelectFloat32(sql string) (f float32, err error) {
|
||||
func (c *conn) SelectFloat32(sql string, params ...interface{}) (f float32, err error) {
|
||||
var f64 float64
|
||||
f64, err = c.selectFloat(sql, 32)
|
||||
f64, err = c.selectFloat(sql, 32, params)
|
||||
f = float32(f64)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllString(sql string) (strings []string, err error) {
|
||||
func (c *conn) SelectAllString(sql string, params ...interface{}) (strings []string, err error) {
|
||||
strings = make([]string, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||
strings = append(strings, c.rxDataRowFirstValue(r))
|
||||
return nil
|
||||
}
|
||||
err = c.query(sql, onDataRow)
|
||||
err = c.query(sql, params, onDataRow)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) {
|
||||
func (c *conn) SelectAllInt64(sql string, params ...interface{}) (ints []int64, err error) {
|
||||
ints = make([]int64, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
var i int64
|
||||
@ -216,11 +217,11 @@ func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) {
|
||||
ints = append(ints, i)
|
||||
return
|
||||
}
|
||||
err = c.query(sql, onDataRow)
|
||||
err = c.query(sql, params, onDataRow)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) {
|
||||
func (c *conn) SelectAllInt32(sql string, params ...interface{}) (ints []int32, err error) {
|
||||
ints = make([]int32, 0, 8)
|
||||
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) {
|
||||
var i int64
|
||||
@ -228,11 +229,11 @@ func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) {
|
||||
ints = append(ints, int32(i))
|
||||
return
|
||||
}
|
||||
err = c.query(sql, onDataRow)
|
||||
err = c.query(sql, params, onDataRow)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) {
|
||||
func (c *conn) SelectAllInt16(sql string, params ...interface{}) (ints []int16, err error) {
|
||||
ints = make([]int16, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
var i int64
|
||||
@ -240,11 +241,11 @@ func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) {
|
||||
ints = append(ints, int16(i))
|
||||
return
|
||||
}
|
||||
err = c.query(sql, onDataRow)
|
||||
err = c.query(sql, params, onDataRow)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) {
|
||||
func (c *conn) SelectAllFloat64(sql string, params ...interface{}) (floats []float64, err error) {
|
||||
floats = make([]float64, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
var f float64
|
||||
@ -252,11 +253,11 @@ func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) {
|
||||
floats = append(floats, f)
|
||||
return
|
||||
}
|
||||
err = c.query(sql, onDataRow)
|
||||
err = c.query(sql, params, onDataRow)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) {
|
||||
func (c *conn) SelectAllFloat32(sql string, params ...interface{}) (floats []float32, err error) {
|
||||
floats = make([]float32, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
var f float64
|
||||
@ -264,7 +265,7 @@ func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) {
|
||||
floats = append(floats, float32(f))
|
||||
return
|
||||
}
|
||||
err = c.query(sql, onDataRow)
|
||||
err = c.query(sql, params, onDataRow)
|
||||
return
|
||||
}
|
||||
|
||||
|
26
conn_test.go
26
conn_test.go
@ -89,7 +89,7 @@ func TestConnectWithMD5Password(t *testing.T) {
|
||||
func TestQuery(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
rows, err := conn.Query("select 'Jack' as name")
|
||||
rows, err := conn.Query("select $1 as name", "Jack")
|
||||
if err != nil {
|
||||
t.Fatal("Query failed")
|
||||
}
|
||||
@ -106,7 +106,7 @@ func TestQuery(t *testing.T) {
|
||||
func TestSelectString(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
s, err := conn.SelectString("select 'foo'")
|
||||
s, err := conn.SelectString("select $1", "foo")
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select string: " + err.Error())
|
||||
}
|
||||
@ -119,7 +119,7 @@ func TestSelectString(t *testing.T) {
|
||||
func TestSelectInt64(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
i, err := conn.SelectInt64("select 1")
|
||||
i, err := conn.SelectInt64("select $1", 1)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select int64: " + err.Error())
|
||||
}
|
||||
@ -142,7 +142,7 @@ func TestSelectInt64(t *testing.T) {
|
||||
func TestSelectInt32(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
i, err := conn.SelectInt32("select 1")
|
||||
i, err := conn.SelectInt32("select $1", 1)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select int32: " + err.Error())
|
||||
}
|
||||
@ -165,7 +165,7 @@ func TestSelectInt32(t *testing.T) {
|
||||
func TestSelectInt16(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
i, err := conn.SelectInt16("select 1")
|
||||
i, err := conn.SelectInt16("select $1", 1)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select int16: " + err.Error())
|
||||
}
|
||||
@ -188,7 +188,7 @@ func TestSelectInt16(t *testing.T) {
|
||||
func TestSelectFloat64(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
f, err := conn.SelectFloat64("select 1.23")
|
||||
f, err := conn.SelectFloat64("select $1", 1.23)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select float64: " + err.Error())
|
||||
}
|
||||
@ -201,7 +201,7 @@ func TestSelectFloat64(t *testing.T) {
|
||||
func TestSelectFloat32(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
f, err := conn.SelectFloat32("select 1.23")
|
||||
f, err := conn.SelectFloat32("select $1", 1.23)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select float32: " + err.Error())
|
||||
}
|
||||
@ -214,7 +214,7 @@ func TestSelectFloat32(t *testing.T) {
|
||||
func TestSelectAllString(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
s, err := conn.SelectAllString("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t")
|
||||
s, err := conn.SelectAllString("select * from (values ($1), ($2), ($3), ($4)) t", "Matthew", "Mark", "Luke", "John")
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select all strings: " + err.Error())
|
||||
}
|
||||
@ -227,7 +227,7 @@ func TestSelectAllString(t *testing.T) {
|
||||
func TestSelectAllInt64(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
i, err := conn.SelectAllInt64("select * from (values (1), (2)) t")
|
||||
i, err := conn.SelectAllInt64("select * from (values ($1), ($2)) t", 1, 2)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select all int64: " + err.Error())
|
||||
}
|
||||
@ -250,7 +250,7 @@ func TestSelectAllInt64(t *testing.T) {
|
||||
func TestSelectAllInt32(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
i, err := conn.SelectAllInt32("select * from (values (1), (2)) t")
|
||||
i, err := conn.SelectAllInt32("select * from (values ($1), ($2)) t", 1, 2)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select all int32: " + err.Error())
|
||||
}
|
||||
@ -273,7 +273,7 @@ func TestSelectAllInt32(t *testing.T) {
|
||||
func TestSelectAllInt16(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
i, err := conn.SelectAllInt16("select * from (values (1), (2)) t")
|
||||
i, err := conn.SelectAllInt16("select * from (values ($1), ($2)) t", 1, 2)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select all int16: " + err.Error())
|
||||
}
|
||||
@ -296,7 +296,7 @@ func TestSelectAllInt16(t *testing.T) {
|
||||
func TestSelectAllFloat64(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
f, err := conn.SelectAllFloat64("select * from (values (1.23), (4.56)) t")
|
||||
f, err := conn.SelectAllFloat64("select * from (values ($1), ($2)) t", 1.23, 4.56)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select all float64: " + err.Error())
|
||||
}
|
||||
@ -309,7 +309,7 @@ func TestSelectAllFloat64(t *testing.T) {
|
||||
func TestSelectAllFloat32(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
f, err := conn.SelectAllFloat32("select * from (values (1.23), (4.56)) t")
|
||||
f, err := conn.SelectAllFloat32("select * from (values ($1), ($2)) t", 1.23, 4.56)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to select all float32: " + err.Error())
|
||||
}
|
||||
|
55
sanitize.go
Normal file
55
sanitize.go
Normal file
@ -0,0 +1,55 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||
|
||||
func (c *conn) QuoteString(input string) (output string) {
|
||||
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SanitizeSql(sql string, args ...interface{}) (output string) {
|
||||
replacer := func(match string) (replacement string) {
|
||||
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
||||
switch arg := args[n-1].(type) {
|
||||
case string:
|
||||
return c.QuoteString(arg)
|
||||
case int:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
default:
|
||||
panic("Unable to sanitize type: " + reflect.TypeOf(arg).String())
|
||||
}
|
||||
panic("Unreachable")
|
||||
}
|
||||
|
||||
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
||||
return
|
||||
}
|
25
sanitize_test.go
Normal file
25
sanitize_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
if conn.QuoteString("test") != "'test'" {
|
||||
t.Error("Failed to quote string")
|
||||
}
|
||||
|
||||
if conn.QuoteString("Jack's") != "'Jack''s'" {
|
||||
t.Error("Failed to quote and escape string with embedded quote")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSql(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
|
||||
if conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23) != "select 'Jack''s', 42, 1.23" {
|
||||
t.Error("Failed to sanitize sql")
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user