mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
parent
2ed4f46454
commit
8392883350
65
conn.go
65
conn.go
@ -90,8 +90,9 @@ func (c *conn) Close() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
|
func (c *conn) query(sql string, params []interface{}, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
|
||||||
if err = c.sendSimpleQuery(sql); err != nil {
|
sanitized_sql := c.SanitizeSql(sql, params...)
|
||||||
|
if err = c.sendSimpleQuery(sanitized_sql); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,28 +129,28 @@ func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescripti
|
|||||||
panic("Unreachable")
|
panic("Unreachable")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) Query(sql string) (rows []map[string]string, err error) {
|
func (c *conn) Query(sql string, params ...interface{}) (rows []map[string]string, err error) {
|
||||||
rows = make([]map[string]string, 0, 8)
|
rows = make([]map[string]string, 0, 8)
|
||||||
onDataRow := func(r *messageReader, fields []fieldDescription) error {
|
onDataRow := func(r *messageReader, fields []fieldDescription) error {
|
||||||
rows = append(rows, c.rxDataRow(r, fields))
|
rows = append(rows, c.rxDataRow(r, fields))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err = c.query(sql, onDataRow)
|
err = c.query(sql, params, onDataRow)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectString(sql string) (s string, err error) {
|
func (c *conn) SelectString(sql string, params ...interface{}) (s string, err error) {
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||||
s = c.rxDataRowFirstValue(r)
|
s = c.rxDataRowFirstValue(r)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err = c.query(sql, onDataRow)
|
err = c.query(sql, params, onDataRow)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) selectInt(sql string, size int) (i int64, err error) {
|
func (c *conn) selectInt(sql string, size int, params []interface{}) (i int64, err error) {
|
||||||
var s string
|
var s string
|
||||||
s, err = c.SelectString(sql)
|
s, err = c.SelectString(sql, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -158,27 +159,27 @@ func (c *conn) selectInt(sql string, size int) (i int64, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectInt64(sql string) (i int64, err error) {
|
func (c *conn) SelectInt64(sql string, params ...interface{}) (i int64, err error) {
|
||||||
return c.selectInt(sql, 64)
|
return c.selectInt(sql, 64, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectInt32(sql string) (i int32, err error) {
|
func (c *conn) SelectInt32(sql string, params ...interface{}) (i int32, err error) {
|
||||||
var i64 int64
|
var i64 int64
|
||||||
i64, err = c.selectInt(sql, 32)
|
i64, err = c.selectInt(sql, 32, params)
|
||||||
i = int32(i64)
|
i = int32(i64)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectInt16(sql string) (i int16, err error) {
|
func (c *conn) SelectInt16(sql string, params ...interface{}) (i int16, err error) {
|
||||||
var i64 int64
|
var i64 int64
|
||||||
i64, err = c.selectInt(sql, 16)
|
i64, err = c.selectInt(sql, 16, params)
|
||||||
i = int16(i64)
|
i = int16(i64)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) selectFloat(sql string, size int) (f float64, err error) {
|
func (c *conn) selectFloat(sql string, size int, params []interface{}) (f float64, err error) {
|
||||||
var s string
|
var s string
|
||||||
s, err = c.SelectString(sql)
|
s, err = c.SelectString(sql, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -187,28 +188,28 @@ func (c *conn) selectFloat(sql string, size int) (f float64, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectFloat64(sql string) (f float64, err error) {
|
func (c *conn) SelectFloat64(sql string, params ...interface{}) (f float64, err error) {
|
||||||
return c.selectFloat(sql, 64)
|
return c.selectFloat(sql, 64, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectFloat32(sql string) (f float32, err error) {
|
func (c *conn) SelectFloat32(sql string, params ...interface{}) (f float32, err error) {
|
||||||
var f64 float64
|
var f64 float64
|
||||||
f64, err = c.selectFloat(sql, 32)
|
f64, err = c.selectFloat(sql, 32, params)
|
||||||
f = float32(f64)
|
f = float32(f64)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectAllString(sql string) (strings []string, err error) {
|
func (c *conn) SelectAllString(sql string, params ...interface{}) (strings []string, err error) {
|
||||||
strings = make([]string, 0, 8)
|
strings = make([]string, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||||
strings = append(strings, c.rxDataRowFirstValue(r))
|
strings = append(strings, c.rxDataRowFirstValue(r))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err = c.query(sql, onDataRow)
|
err = c.query(sql, params, onDataRow)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) {
|
func (c *conn) SelectAllInt64(sql string, params ...interface{}) (ints []int64, err error) {
|
||||||
ints = make([]int64, 0, 8)
|
ints = make([]int64, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||||
var i int64
|
var i int64
|
||||||
@ -216,11 +217,11 @@ func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) {
|
|||||||
ints = append(ints, i)
|
ints = append(ints, i)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = c.query(sql, onDataRow)
|
err = c.query(sql, params, onDataRow)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) {
|
func (c *conn) SelectAllInt32(sql string, params ...interface{}) (ints []int32, err error) {
|
||||||
ints = make([]int32, 0, 8)
|
ints = make([]int32, 0, 8)
|
||||||
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) {
|
||||||
var i int64
|
var i int64
|
||||||
@ -228,11 +229,11 @@ func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) {
|
|||||||
ints = append(ints, int32(i))
|
ints = append(ints, int32(i))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = c.query(sql, onDataRow)
|
err = c.query(sql, params, onDataRow)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) {
|
func (c *conn) SelectAllInt16(sql string, params ...interface{}) (ints []int16, err error) {
|
||||||
ints = make([]int16, 0, 8)
|
ints = make([]int16, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||||
var i int64
|
var i int64
|
||||||
@ -240,11 +241,11 @@ func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) {
|
|||||||
ints = append(ints, int16(i))
|
ints = append(ints, int16(i))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = c.query(sql, onDataRow)
|
err = c.query(sql, params, onDataRow)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) {
|
func (c *conn) SelectAllFloat64(sql string, params ...interface{}) (floats []float64, err error) {
|
||||||
floats = make([]float64, 0, 8)
|
floats = make([]float64, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||||
var f float64
|
var f float64
|
||||||
@ -252,11 +253,11 @@ func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) {
|
|||||||
floats = append(floats, f)
|
floats = append(floats, f)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = c.query(sql, onDataRow)
|
err = c.query(sql, params, onDataRow)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) {
|
func (c *conn) SelectAllFloat32(sql string, params ...interface{}) (floats []float32, err error) {
|
||||||
floats = make([]float32, 0, 8)
|
floats = make([]float32, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||||
var f float64
|
var f float64
|
||||||
@ -264,7 +265,7 @@ func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) {
|
|||||||
floats = append(floats, float32(f))
|
floats = append(floats, float32(f))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = c.query(sql, onDataRow)
|
err = c.query(sql, params, onDataRow)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
26
conn_test.go
26
conn_test.go
@ -89,7 +89,7 @@ func TestConnectWithMD5Password(t *testing.T) {
|
|||||||
func TestQuery(t *testing.T) {
|
func TestQuery(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
rows, err := conn.Query("select 'Jack' as name")
|
rows, err := conn.Query("select $1 as name", "Jack")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Query failed")
|
t.Fatal("Query failed")
|
||||||
}
|
}
|
||||||
@ -106,7 +106,7 @@ func TestQuery(t *testing.T) {
|
|||||||
func TestSelectString(t *testing.T) {
|
func TestSelectString(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
s, err := conn.SelectString("select 'foo'")
|
s, err := conn.SelectString("select $1", "foo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select string: " + err.Error())
|
t.Fatal("Unable to select string: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -119,7 +119,7 @@ func TestSelectString(t *testing.T) {
|
|||||||
func TestSelectInt64(t *testing.T) {
|
func TestSelectInt64(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
i, err := conn.SelectInt64("select 1")
|
i, err := conn.SelectInt64("select $1", 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select int64: " + err.Error())
|
t.Fatal("Unable to select int64: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -142,7 +142,7 @@ func TestSelectInt64(t *testing.T) {
|
|||||||
func TestSelectInt32(t *testing.T) {
|
func TestSelectInt32(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
i, err := conn.SelectInt32("select 1")
|
i, err := conn.SelectInt32("select $1", 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select int32: " + err.Error())
|
t.Fatal("Unable to select int32: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -165,7 +165,7 @@ func TestSelectInt32(t *testing.T) {
|
|||||||
func TestSelectInt16(t *testing.T) {
|
func TestSelectInt16(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
i, err := conn.SelectInt16("select 1")
|
i, err := conn.SelectInt16("select $1", 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select int16: " + err.Error())
|
t.Fatal("Unable to select int16: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -188,7 +188,7 @@ func TestSelectInt16(t *testing.T) {
|
|||||||
func TestSelectFloat64(t *testing.T) {
|
func TestSelectFloat64(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
f, err := conn.SelectFloat64("select 1.23")
|
f, err := conn.SelectFloat64("select $1", 1.23)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select float64: " + err.Error())
|
t.Fatal("Unable to select float64: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -201,7 +201,7 @@ func TestSelectFloat64(t *testing.T) {
|
|||||||
func TestSelectFloat32(t *testing.T) {
|
func TestSelectFloat32(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
f, err := conn.SelectFloat32("select 1.23")
|
f, err := conn.SelectFloat32("select $1", 1.23)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select float32: " + err.Error())
|
t.Fatal("Unable to select float32: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -214,7 +214,7 @@ func TestSelectFloat32(t *testing.T) {
|
|||||||
func TestSelectAllString(t *testing.T) {
|
func TestSelectAllString(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
s, err := conn.SelectAllString("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t")
|
s, err := conn.SelectAllString("select * from (values ($1), ($2), ($3), ($4)) t", "Matthew", "Mark", "Luke", "John")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select all strings: " + err.Error())
|
t.Fatal("Unable to select all strings: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -227,7 +227,7 @@ func TestSelectAllString(t *testing.T) {
|
|||||||
func TestSelectAllInt64(t *testing.T) {
|
func TestSelectAllInt64(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
i, err := conn.SelectAllInt64("select * from (values (1), (2)) t")
|
i, err := conn.SelectAllInt64("select * from (values ($1), ($2)) t", 1, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select all int64: " + err.Error())
|
t.Fatal("Unable to select all int64: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -250,7 +250,7 @@ func TestSelectAllInt64(t *testing.T) {
|
|||||||
func TestSelectAllInt32(t *testing.T) {
|
func TestSelectAllInt32(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
i, err := conn.SelectAllInt32("select * from (values (1), (2)) t")
|
i, err := conn.SelectAllInt32("select * from (values ($1), ($2)) t", 1, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select all int32: " + err.Error())
|
t.Fatal("Unable to select all int32: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -273,7 +273,7 @@ func TestSelectAllInt32(t *testing.T) {
|
|||||||
func TestSelectAllInt16(t *testing.T) {
|
func TestSelectAllInt16(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
i, err := conn.SelectAllInt16("select * from (values (1), (2)) t")
|
i, err := conn.SelectAllInt16("select * from (values ($1), ($2)) t", 1, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select all int16: " + err.Error())
|
t.Fatal("Unable to select all int16: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -296,7 +296,7 @@ func TestSelectAllInt16(t *testing.T) {
|
|||||||
func TestSelectAllFloat64(t *testing.T) {
|
func TestSelectAllFloat64(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
f, err := conn.SelectAllFloat64("select * from (values (1.23), (4.56)) t")
|
f, err := conn.SelectAllFloat64("select * from (values ($1), ($2)) t", 1.23, 4.56)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select all float64: " + err.Error())
|
t.Fatal("Unable to select all float64: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -309,7 +309,7 @@ func TestSelectAllFloat64(t *testing.T) {
|
|||||||
func TestSelectAllFloat32(t *testing.T) {
|
func TestSelectAllFloat32(t *testing.T) {
|
||||||
conn := getSharedConn()
|
conn := getSharedConn()
|
||||||
|
|
||||||
f, err := conn.SelectAllFloat32("select * from (values (1.23), (4.56)) t")
|
f, err := conn.SelectAllFloat32("select * from (values ($1), ($2)) t", 1.23, 4.56)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select all float32: " + err.Error())
|
t.Fatal("Unable to select all float32: " + err.Error())
|
||||||
}
|
}
|
||||||
|
55
sanitize.go
Normal file
55
sanitize.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||||
|
|
||||||
|
func (c *conn) QuoteString(input string) (output string) {
|
||||||
|
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) SanitizeSql(sql string, args ...interface{}) (output string) {
|
||||||
|
replacer := func(match string) (replacement string) {
|
||||||
|
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
||||||
|
switch arg := args[n-1].(type) {
|
||||||
|
case string:
|
||||||
|
return c.QuoteString(arg)
|
||||||
|
case int:
|
||||||
|
return strconv.FormatInt(int64(arg), 10)
|
||||||
|
case int8:
|
||||||
|
return strconv.FormatInt(int64(arg), 10)
|
||||||
|
case int16:
|
||||||
|
return strconv.FormatInt(int64(arg), 10)
|
||||||
|
case int32:
|
||||||
|
return strconv.FormatInt(int64(arg), 10)
|
||||||
|
case int64:
|
||||||
|
return strconv.FormatInt(int64(arg), 10)
|
||||||
|
case uint:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10)
|
||||||
|
case uint8:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10)
|
||||||
|
case uint16:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10)
|
||||||
|
case uint32:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10)
|
||||||
|
case uint64:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10)
|
||||||
|
case float32:
|
||||||
|
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
||||||
|
case float64:
|
||||||
|
return strconv.FormatFloat(arg, 'f', -1, 64)
|
||||||
|
default:
|
||||||
|
panic("Unable to sanitize type: " + reflect.TypeOf(arg).String())
|
||||||
|
}
|
||||||
|
panic("Unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
||||||
|
return
|
||||||
|
}
|
25
sanitize_test.go
Normal file
25
sanitize_test.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQuoteString(t *testing.T) {
|
||||||
|
conn := getSharedConn()
|
||||||
|
|
||||||
|
if conn.QuoteString("test") != "'test'" {
|
||||||
|
t.Error("Failed to quote string")
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.QuoteString("Jack's") != "'Jack''s'" {
|
||||||
|
t.Error("Failed to quote and escape string with embedded quote")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeSql(t *testing.T) {
|
||||||
|
conn := getSharedConn()
|
||||||
|
|
||||||
|
if conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23) != "select 'Jack''s', 42, 1.23" {
|
||||||
|
t.Error("Failed to sanitize sql")
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user