mirror of
https://github.com/jackc/pgx.git
synced 2025-04-27 21:25:53 +00:00
Use Scan and Encode* instead of ValueTranscoders
This commit is contained in:
parent
43dcd47a92
commit
009cdfa0b1
95
conn.go
95
conn.go
@ -319,11 +319,9 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||||||
case rowDescription:
|
case rowDescription:
|
||||||
ps.FieldDescriptions = c.rxRowDescription(r)
|
ps.FieldDescriptions = c.rxRowDescription(r)
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.FieldDescriptions {
|
||||||
oid := ps.FieldDescriptions[i].DataType
|
switch ps.FieldDescriptions[i].DataType {
|
||||||
vt := ValueTranscoders[oid]
|
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, DateOid, TimestampTzOid:
|
||||||
|
ps.FieldDescriptions[i].FormatCode = BinaryFormatCode
|
||||||
if vt != nil {
|
|
||||||
ps.FieldDescriptions[i].FormatCode = vt.DecodeFormat
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case noData:
|
case noData:
|
||||||
@ -342,7 +340,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||||||
// Deallocate released a prepared statement
|
// Deallocate released a prepared statement
|
||||||
func (c *Conn) Deallocate(name string) (err error) {
|
func (c *Conn) Deallocate(name string) (err error) {
|
||||||
delete(c.preparedStatements, name)
|
delete(c.preparedStatements, name)
|
||||||
_, err = c.Exec("deallocate " + c.QuoteIdentifier(name))
|
_, err = c.Exec("deallocate " + QuoteIdentifier(name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -601,6 +599,14 @@ func (qr *QueryResult) Scan(dest ...interface{}) (err error) {
|
|||||||
} else {
|
} else {
|
||||||
*d = decodeTimestampTz(qr, fd, size)
|
*d = decodeTimestampTz(qr, fd, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case Scanner:
|
||||||
|
err = d.Scan(qr, fd, size)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors.New("Unknown type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -708,7 +714,7 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
|
|||||||
|
|
||||||
func (c *Conn) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
|
func (c *Conn) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
|
||||||
if len(arguments) > 0 {
|
if len(arguments) > 0 {
|
||||||
sql, err = c.SanitizeSql(sql, arguments...)
|
sql, err = SanitizeSql(sql, arguments...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -734,38 +740,71 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
wbuf.WriteCString(ps.Name)
|
wbuf.WriteCString(ps.Name)
|
||||||
|
|
||||||
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
||||||
for _, oid := range ps.ParameterOids {
|
for i, oid := range ps.ParameterOids {
|
||||||
transcoder := ValueTranscoders[oid]
|
switch oid {
|
||||||
if transcoder == nil {
|
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid:
|
||||||
transcoder = defaultTranscoder
|
wbuf.WriteInt16(BinaryFormatCode)
|
||||||
|
case TextOid, VarcharOid, DateOid, TimestampTzOid:
|
||||||
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
|
default:
|
||||||
|
if _, ok := arguments[i].(BinaryEncoder); ok {
|
||||||
|
wbuf.WriteInt16(BinaryFormatCode)
|
||||||
|
} else {
|
||||||
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
wbuf.WriteInt16(transcoder.EncodeFormat)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
wbuf.WriteInt16(int16(len(arguments)))
|
wbuf.WriteInt16(int16(len(arguments)))
|
||||||
for i, oid := range ps.ParameterOids {
|
for i, oid := range ps.ParameterOids {
|
||||||
if arguments[i] != nil {
|
if arguments[i] == nil {
|
||||||
transcoder := ValueTranscoders[oid]
|
wbuf.WriteInt32(-1)
|
||||||
if transcoder == nil {
|
continue
|
||||||
transcoder = defaultTranscoder
|
}
|
||||||
|
|
||||||
|
switch oid {
|
||||||
|
case BoolOid:
|
||||||
|
err = encodeBool(wbuf, arguments[i])
|
||||||
|
case ByteaOid:
|
||||||
|
err = encodeBytea(wbuf, arguments[i])
|
||||||
|
case Int2Oid:
|
||||||
|
err = encodeInt2(wbuf, arguments[i])
|
||||||
|
case Int4Oid:
|
||||||
|
err = encodeInt4(wbuf, arguments[i])
|
||||||
|
case Int8Oid:
|
||||||
|
err = encodeInt8(wbuf, arguments[i])
|
||||||
|
case Float4Oid:
|
||||||
|
err = encodeFloat4(wbuf, arguments[i])
|
||||||
|
case Float8Oid:
|
||||||
|
err = encodeFloat8(wbuf, arguments[i])
|
||||||
|
case TextOid, VarcharOid:
|
||||||
|
err = encodeText(wbuf, arguments[i])
|
||||||
|
case DateOid:
|
||||||
|
err = encodeDate(wbuf, arguments[i])
|
||||||
|
case TimestampTzOid:
|
||||||
|
err = encodeTimestampTz(wbuf, arguments[i])
|
||||||
|
default:
|
||||||
|
switch arg := arguments[i].(type) {
|
||||||
|
case BinaryEncoder:
|
||||||
|
err = arg.EncodeBinary(wbuf)
|
||||||
|
case TextEncoder:
|
||||||
|
var s string
|
||||||
|
s, err = arg.EncodeText()
|
||||||
|
wbuf.WriteInt32(int32(len(s)))
|
||||||
|
wbuf.WriteBytes([]byte(s))
|
||||||
|
default:
|
||||||
|
return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder or BinaryEncoder", arg))
|
||||||
}
|
}
|
||||||
err = transcoder.EncodeTo(wbuf, arguments[i])
|
}
|
||||||
if err != nil {
|
|
||||||
return err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
} else {
|
|
||||||
wbuf.WriteInt32(int32(-1))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
wbuf.WriteInt16(int16(len(ps.FieldDescriptions)))
|
wbuf.WriteInt16(int16(len(ps.FieldDescriptions)))
|
||||||
for _, fd := range ps.FieldDescriptions {
|
for _, fd := range ps.FieldDescriptions {
|
||||||
transcoder := ValueTranscoders[fd.DataType]
|
wbuf.WriteInt16(fd.FormatCode)
|
||||||
if transcoder != nil {
|
|
||||||
wbuf.WriteInt16(transcoder.DecodeFormat)
|
|
||||||
} else {
|
|
||||||
wbuf.WriteInt16(0)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute
|
// execute
|
||||||
|
114
conn_test.go
114
conn_test.go
@ -450,6 +450,120 @@ func TestConnQueryReadTooManyValues(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnQueryUnpreparedScanner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
qr, err := conn.Query("select null::int8, 1::int8")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := qr.NextRow()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("qr.NextRow terminated early")
|
||||||
|
}
|
||||||
|
|
||||||
|
var n, m pgx.NullInt64
|
||||||
|
err = qr.Scan(&n, &m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("qr.Scan failed: ", err)
|
||||||
|
}
|
||||||
|
qr.Close()
|
||||||
|
|
||||||
|
if n.Valid {
|
||||||
|
t.Error("Null should not be valid, but it was")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.Valid {
|
||||||
|
t.Error("1 should be valid, but it wasn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Int64 != 1 {
|
||||||
|
t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnQueryPreparedScanner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustPrepare(t, conn, "scannerTest", "select null::int8, 1::int8")
|
||||||
|
|
||||||
|
qr, err := conn.Query("scannerTest")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := qr.NextRow()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("qr.NextRow terminated early")
|
||||||
|
}
|
||||||
|
|
||||||
|
var n, m pgx.NullInt64
|
||||||
|
err = qr.Scan(&n, &m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("qr.Scan failed: ", err)
|
||||||
|
}
|
||||||
|
qr.Close()
|
||||||
|
|
||||||
|
if n.Valid {
|
||||||
|
t.Error("Null should not be valid, but it was")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.Valid {
|
||||||
|
t.Error("1 should be valid, but it wasn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Int64 != 1 {
|
||||||
|
t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnQueryUnpreparedEncoder(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
n := pgx.NullInt64{Int64: 1, Valid: true}
|
||||||
|
|
||||||
|
qr, err := conn.Query("select $1::int8", &n)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := qr.NextRow()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("qr.NextRow terminated early")
|
||||||
|
}
|
||||||
|
|
||||||
|
var m pgx.NullInt64
|
||||||
|
err = qr.Scan(&m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("qr.Scan failed: ", err)
|
||||||
|
}
|
||||||
|
qr.Close()
|
||||||
|
|
||||||
|
if !m.Valid {
|
||||||
|
t.Error("m should be valid, but it wasn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Int64 != 1 {
|
||||||
|
t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPrepare(t *testing.T) {
|
func TestPrepare(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -178,3 +178,26 @@ func (r *MsgReader) ReadString(count int32) string {
|
|||||||
|
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadBytes reads count bytes and returns as []byte
|
||||||
|
func (r *MsgReader) ReadBytes(count int32) []byte {
|
||||||
|
if r.err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.msgBytesRemaining -= count
|
||||||
|
if r.msgBytesRemaining < 0 {
|
||||||
|
r.Fatal(errors.New("read past end of message"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, int(count))
|
||||||
|
|
||||||
|
_, err := io.ReadFull(r.reader, b)
|
||||||
|
if err != nil {
|
||||||
|
r.Fatal(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
53
sanitize.go
53
sanitize.go
@ -9,18 +9,34 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type SerializationError string
|
||||||
|
|
||||||
|
func (e SerializationError) Error() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextEncoder is an interface used to encode values in text format for
|
||||||
|
// transmission to the PostgreSQL server. It is used by unprepared
|
||||||
|
// queries and for prepared queries when the type does not implement
|
||||||
|
// BinaryEncoder
|
||||||
|
type TextEncoder interface {
|
||||||
|
// EncodeText MUST sanitize (and quote, if necessary) the returned string.
|
||||||
|
// It will be interpolated directly into the SQL string.
|
||||||
|
EncodeText() (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||||
|
|
||||||
// QuoteString escapes and quotes a string making it safe for interpolation
|
// QuoteString escapes and quotes a string making it safe for interpolation
|
||||||
// into an SQL string.
|
// into an SQL string.
|
||||||
func (c *Conn) QuoteString(input string) (output string) {
|
func QuoteString(input string) (output string) {
|
||||||
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuoteIdentifier escapes and quotes an identifier making it safe for
|
// QuoteIdentifier escapes and quotes an identifier making it safe for
|
||||||
// interpolation into an SQL string
|
// interpolation into an SQL string
|
||||||
func (c *Conn) QuoteIdentifier(input string) (output string) {
|
func QuoteIdentifier(input string) (output string) {
|
||||||
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -28,12 +44,21 @@ func (c *Conn) QuoteIdentifier(input string) (output string) {
|
|||||||
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
|
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
|
||||||
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
|
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
|
||||||
// appropriate.
|
// appropriate.
|
||||||
func (c *Conn) SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
||||||
replacer := func(match string) (replacement string) {
|
replacer := func(match string) (replacement string) {
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
||||||
|
if int(n-1) >= len(args) {
|
||||||
|
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch arg := args[n-1].(type) {
|
switch arg := args[n-1].(type) {
|
||||||
case string:
|
case string:
|
||||||
return c.QuoteString(arg)
|
return QuoteString(arg)
|
||||||
case int:
|
case int:
|
||||||
return strconv.FormatInt(int64(arg), 10)
|
return strconv.FormatInt(int64(arg), 10)
|
||||||
case int8:
|
case int8:
|
||||||
@ -45,7 +70,7 @@ func (c *Conn) SanitizeSql(sql string, args ...interface{}) (output string, err
|
|||||||
case int64:
|
case int64:
|
||||||
return strconv.FormatInt(int64(arg), 10)
|
return strconv.FormatInt(int64(arg), 10)
|
||||||
case time.Time:
|
case time.Time:
|
||||||
return c.QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
||||||
case uint:
|
case uint:
|
||||||
return strconv.FormatUint(uint64(arg), 10)
|
return strconv.FormatUint(uint64(arg), 10)
|
||||||
case uint8:
|
case uint8:
|
||||||
@ -64,22 +89,14 @@ func (c *Conn) SanitizeSql(sql string, args ...interface{}) (output string, err
|
|||||||
return strconv.FormatBool(arg)
|
return strconv.FormatBool(arg)
|
||||||
case []byte:
|
case []byte:
|
||||||
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||||
case []int16:
|
|
||||||
var s string
|
|
||||||
s, err = int16SliceToArrayString(arg)
|
|
||||||
return c.QuoteString(s)
|
|
||||||
case []int32:
|
|
||||||
var s string
|
|
||||||
s, err = int32SliceToArrayString(arg)
|
|
||||||
return c.QuoteString(s)
|
|
||||||
case []int64:
|
|
||||||
var s string
|
|
||||||
s, err = int64SliceToArrayString(arg)
|
|
||||||
return c.QuoteString(s)
|
|
||||||
case nil:
|
case nil:
|
||||||
return "null"
|
return "null"
|
||||||
|
case TextEncoder:
|
||||||
|
var s string
|
||||||
|
s, err = arg.EncodeText()
|
||||||
|
return s
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("Unable to sanitize type: %T", arg)
|
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
104
sanitize_test.go
104
sanitize_test.go
@ -1,20 +1,19 @@
|
|||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/jackc/pgx"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestQuoteString(t *testing.T) {
|
func TestQuoteString(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnect(t, *defaultConnConfig)
|
if pgx.QuoteString("test") != "'test'" {
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
if conn.QuoteString("test") != "'test'" {
|
|
||||||
t.Error("Failed to quote string")
|
t.Error("Failed to quote string")
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.QuoteString("Jack's") != "'Jack''s'" {
|
if pgx.QuoteString("Jack's") != "'Jack''s'" {
|
||||||
t.Error("Failed to quote and escape string with embedded quote")
|
t.Error("Failed to quote and escape string with embedded quote")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -22,70 +21,47 @@ func TestQuoteString(t *testing.T) {
|
|||||||
func TestSanitizeSql(t *testing.T) {
|
func TestSanitizeSql(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnect(t, *defaultConnConfig)
|
successTests := []struct {
|
||||||
defer closeConn(t, conn)
|
sql string
|
||||||
|
args []interface{}
|
||||||
if san, err := conn.SanitizeSql("select $1", nil); err != nil || san != "select null" {
|
output string
|
||||||
t.Errorf("Failed to translate nil to null: %v - %v", san, err)
|
}{
|
||||||
|
{"select $1", []interface{}{nil}, "select null"},
|
||||||
|
{"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"},
|
||||||
|
{"select $1", []interface{}{42}, "select 42"},
|
||||||
|
{"select $1", []interface{}{1.23}, "select 1.23"},
|
||||||
|
{"select $1", []interface{}{true}, "select true"},
|
||||||
|
{"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"},
|
||||||
|
{"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`},
|
||||||
|
{"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"},
|
||||||
}
|
}
|
||||||
|
|
||||||
if san, err := conn.SanitizeSql("select $1", "Jack's"); err != nil || san != "select 'Jack''s'" {
|
for i, tt := range successTests {
|
||||||
t.Errorf("Failed to sanitize string: %v - %v", san, err)
|
san, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args)
|
||||||
|
}
|
||||||
|
if san != tt.output {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if san, err := conn.SanitizeSql("select $1", 42); err != nil || san != "select 42" {
|
errorTests := []struct {
|
||||||
t.Errorf("Failed to pass through integer: %v - %v", san, err)
|
sql string
|
||||||
|
args []interface{}
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"},
|
||||||
|
{"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"},
|
||||||
}
|
}
|
||||||
|
|
||||||
if san, err := conn.SanitizeSql("select $1", 1.23); err != nil || san != "select 1.23" {
|
for i, tt := range errorTests {
|
||||||
t.Errorf("Failed to pass through float: %v - %v", san, err)
|
_, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
||||||
}
|
if err == nil {
|
||||||
|
t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err)
|
||||||
if san, err := conn.SanitizeSql("select $1", true); err != nil || san != "select true" {
|
}
|
||||||
t.Errorf("Failed to pass through bool: %v - %v", san, err)
|
if !strings.Contains(err.Error(), tt.err) {
|
||||||
}
|
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args)
|
||||||
|
}
|
||||||
if san, err := conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23); err != nil || san != "select 'Jack''s', 42, 1.23" {
|
|
||||||
t.Errorf("Failed to sanitize multiple params: %v - %v", san, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bytea := make([]byte, 4)
|
|
||||||
bytea[0] = 0 // 0x00
|
|
||||||
bytea[1] = 15 // 0x0F
|
|
||||||
bytea[2] = 255 // 0xFF
|
|
||||||
bytea[3] = 17 // 0x11
|
|
||||||
|
|
||||||
if san, err := conn.SanitizeSql("select $1", bytea); err != nil || san != `select E'\\x000fff11'` {
|
|
||||||
t.Errorf("Failed to sanitize []byte: %v - %v", san, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
int2a := make([]int16, 4)
|
|
||||||
int2a[0] = 42
|
|
||||||
int2a[1] = 0
|
|
||||||
int2a[2] = -1
|
|
||||||
int2a[3] = 32123
|
|
||||||
|
|
||||||
if san, err := conn.SanitizeSql("select $1::int2[]", int2a); err != nil || san != `select '{42,0,-1,32123}'::int2[]` {
|
|
||||||
t.Errorf("Failed to sanitize []int16: %v - %v", san, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
int4a := make([]int32, 4)
|
|
||||||
int4a[0] = 42
|
|
||||||
int4a[1] = 0
|
|
||||||
int4a[2] = -1
|
|
||||||
int4a[3] = 32123
|
|
||||||
|
|
||||||
if san, err := conn.SanitizeSql("select $1::int4[]", int4a); err != nil || san != `select '{42,0,-1,32123}'::int4[]` {
|
|
||||||
t.Errorf("Failed to sanitize []int32: %v - %v", san, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
int8a := make([]int64, 4)
|
|
||||||
int8a[0] = 42
|
|
||||||
int8a[1] = 0
|
|
||||||
int8a[2] = -1
|
|
||||||
int8a[3] = 32123
|
|
||||||
|
|
||||||
if san, err := conn.SanitizeSql("select $1::int8[]", int8a); err != nil || san != `select '{42,0,-1,32123}'::int8[]` {
|
|
||||||
t.Errorf("Failed to sanitize []int64: %v - %v", san, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package pgx
|
package pgx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
@ -20,9 +19,6 @@ const (
|
|||||||
TextOid = 25
|
TextOid = 25
|
||||||
Float4Oid = 700
|
Float4Oid = 700
|
||||||
Float8Oid = 701
|
Float8Oid = 701
|
||||||
Int2ArrayOid = 1005
|
|
||||||
Int4ArrayOid = 1007
|
|
||||||
Int8ArrayOid = 1016
|
|
||||||
VarcharOid = 1043
|
VarcharOid = 1043
|
||||||
DateOid = 1082
|
DateOid = 1082
|
||||||
TimestampTzOid = 1184
|
TimestampTzOid = 1184
|
||||||
@ -33,134 +29,39 @@ const (
|
|||||||
BinaryFormatCode = 1
|
BinaryFormatCode = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValueTranscoder stores all the data necessary to encode and decode values from
|
type Scanner interface {
|
||||||
// a PostgreSQL server
|
Scan(qr *QueryResult, fd *FieldDescription, size int32) error
|
||||||
type ValueTranscoder struct {
|
|
||||||
// Decode decodes values returned from the server
|
|
||||||
Decode func(qr *QueryResult, fd *FieldDescription, size int32) interface{}
|
|
||||||
// DecodeFormat is the preferred response format.
|
|
||||||
// Allowed values: TextFormatCode, BinaryFormatCode
|
|
||||||
DecodeFormat int16
|
|
||||||
// EncodeTo encodes values to send to the server
|
|
||||||
EncodeTo func(*WriteBuf, interface{}) error
|
|
||||||
// EncodeFormat is the format values are encoded for transmission.
|
|
||||||
// Allowed values: TextFormatCode, BinaryFormatCode
|
|
||||||
EncodeFormat int16
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValueTranscoders is used to transcode values being sent to and received from
|
// BinaryEncoder is an interface used to encode values in binary format for
|
||||||
// the PostgreSQL server. Additional types can be transcoded by adding a
|
// transmission to the PostgreSQL server. It is used by prepared queries.
|
||||||
// *ValueTranscoder for the appropriate Oid to the map.
|
type BinaryEncoder interface {
|
||||||
var ValueTranscoders map[Oid]*ValueTranscoder
|
// EncodeText MUST sanitize (and quote, if necessary) the returned string.
|
||||||
|
// It will be interpolated directly into the SQL string.
|
||||||
|
EncodeBinary(w *WriteBuf) error
|
||||||
|
}
|
||||||
|
|
||||||
var defaultTranscoder *ValueTranscoder
|
type NullInt64 struct {
|
||||||
|
Int64 int64
|
||||||
|
Valid bool // Valid is true if Int64 is not NULL
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func (n *NullInt64) Scan(qr *QueryResult, fd *FieldDescription, size int32) error {
|
||||||
ValueTranscoders = make(map[Oid]*ValueTranscoder)
|
if size == -1 {
|
||||||
|
n.Int64, n.Valid = 0, false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
n.Valid = true
|
||||||
|
n.Int64 = decodeInt8(qr, fd, size)
|
||||||
|
return qr.Err()
|
||||||
|
}
|
||||||
|
|
||||||
// bool
|
func (n *NullInt64) EncodeText() (string, error) {
|
||||||
ValueTranscoders[BoolOid] = &ValueTranscoder{
|
if n.Valid {
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeBool(qr, fd, size) },
|
return strconv.FormatInt(int64(n.Int64), 10), nil
|
||||||
DecodeFormat: BinaryFormatCode,
|
} else {
|
||||||
EncodeTo: encodeBool,
|
return "null", nil
|
||||||
EncodeFormat: BinaryFormatCode}
|
}
|
||||||
|
|
||||||
// bytea
|
|
||||||
ValueTranscoders[ByteaOid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeBytea(qr, fd, size) },
|
|
||||||
DecodeFormat: TextFormatCode,
|
|
||||||
EncodeTo: encodeBytea,
|
|
||||||
EncodeFormat: BinaryFormatCode}
|
|
||||||
|
|
||||||
// int8
|
|
||||||
ValueTranscoders[Int8Oid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt8(qr, fd, size) },
|
|
||||||
DecodeFormat: BinaryFormatCode,
|
|
||||||
EncodeTo: encodeInt8,
|
|
||||||
EncodeFormat: BinaryFormatCode}
|
|
||||||
|
|
||||||
// int2
|
|
||||||
ValueTranscoders[Int2Oid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt2(qr, fd, size) },
|
|
||||||
DecodeFormat: BinaryFormatCode,
|
|
||||||
EncodeTo: encodeInt2,
|
|
||||||
EncodeFormat: BinaryFormatCode}
|
|
||||||
|
|
||||||
// int4
|
|
||||||
ValueTranscoders[Int4Oid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt4(qr, fd, size) },
|
|
||||||
DecodeFormat: BinaryFormatCode,
|
|
||||||
EncodeTo: encodeInt4,
|
|
||||||
EncodeFormat: BinaryFormatCode}
|
|
||||||
|
|
||||||
// text
|
|
||||||
ValueTranscoders[TextOid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeText(qr, fd, size) },
|
|
||||||
DecodeFormat: TextFormatCode,
|
|
||||||
EncodeTo: encodeText,
|
|
||||||
EncodeFormat: TextFormatCode}
|
|
||||||
|
|
||||||
// float4
|
|
||||||
ValueTranscoders[Float4Oid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeFloat4(qr, fd, size) },
|
|
||||||
EncodeTo: encodeFloat4,
|
|
||||||
EncodeFormat: BinaryFormatCode}
|
|
||||||
|
|
||||||
// float8
|
|
||||||
ValueTranscoders[Float8Oid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeFloat8(qr, fd, size) },
|
|
||||||
DecodeFormat: BinaryFormatCode,
|
|
||||||
EncodeTo: encodeFloat8,
|
|
||||||
EncodeFormat: BinaryFormatCode}
|
|
||||||
|
|
||||||
// int2[]
|
|
||||||
ValueTranscoders[Int2ArrayOid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} {
|
|
||||||
return decodeInt2Array(qr, fd, size)
|
|
||||||
},
|
|
||||||
DecodeFormat: TextFormatCode,
|
|
||||||
EncodeTo: encodeInt2Array,
|
|
||||||
EncodeFormat: TextFormatCode}
|
|
||||||
|
|
||||||
// int4[]
|
|
||||||
ValueTranscoders[Int4ArrayOid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} {
|
|
||||||
return decodeInt4Array(qr, fd, size)
|
|
||||||
},
|
|
||||||
DecodeFormat: TextFormatCode,
|
|
||||||
EncodeTo: encodeInt4Array,
|
|
||||||
EncodeFormat: TextFormatCode}
|
|
||||||
|
|
||||||
// int8[]
|
|
||||||
ValueTranscoders[Int8ArrayOid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} {
|
|
||||||
return decodeInt8Array(qr, fd, size)
|
|
||||||
},
|
|
||||||
DecodeFormat: TextFormatCode,
|
|
||||||
EncodeTo: encodeInt8Array,
|
|
||||||
EncodeFormat: TextFormatCode}
|
|
||||||
|
|
||||||
// varchar -- same as text
|
|
||||||
ValueTranscoders[VarcharOid] = ValueTranscoders[Oid(25)]
|
|
||||||
|
|
||||||
// date
|
|
||||||
ValueTranscoders[DateOid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeDate(qr, fd, size) },
|
|
||||||
DecodeFormat: BinaryFormatCode,
|
|
||||||
EncodeTo: encodeDate,
|
|
||||||
EncodeFormat: TextFormatCode}
|
|
||||||
|
|
||||||
// timestamptz
|
|
||||||
ValueTranscoders[TimestampTzOid] = &ValueTranscoder{
|
|
||||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} {
|
|
||||||
return decodeTimestampTz(qr, fd, size)
|
|
||||||
},
|
|
||||||
DecodeFormat: BinaryFormatCode,
|
|
||||||
EncodeTo: encodeTimestampTz,
|
|
||||||
EncodeFormat: TextFormatCode}
|
|
||||||
|
|
||||||
// use text transcoder for anything we don't understand
|
|
||||||
defaultTranscoder = ValueTranscoders[TextOid]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var arrayEl *regexp.Regexp = regexp.MustCompile(`[{,](?:"((?:[^"\\]|\\.)*)"|(NULL)|([^,}]+))`)
|
var arrayEl *regexp.Regexp = regexp.MustCompile(`[{,](?:"((?:[^"\\]|\\.)*)"|(NULL)|([^,}]+))`)
|
||||||
@ -645,224 +546,3 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error {
|
|||||||
s := t.Format("2006-01-02 15:04:05.999999 -0700")
|
s := t.Format("2006-01-02 15:04:05.999999 -0700")
|
||||||
return encodeText(w, s)
|
return encodeText(w, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeInt2Array(qr *QueryResult, fd *FieldDescription, size int32) []int16 {
|
|
||||||
if fd.DataType != Int2ArrayOid {
|
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int2[] but received: %v", fd.DataType)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch fd.FormatCode {
|
|
||||||
case TextFormatCode:
|
|
||||||
s := qr.mr.ReadString(size)
|
|
||||||
|
|
||||||
elements := SplitArrayText(s)
|
|
||||||
|
|
||||||
numbers := make([]int16, 0, len(elements))
|
|
||||||
|
|
||||||
for _, e := range elements {
|
|
||||||
n, err := strconv.ParseInt(e, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2[]: %v", s)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
numbers = append(numbers, int16(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
return numbers
|
|
||||||
default:
|
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func int16SliceToArrayString(nums []int16) (string, error) {
|
|
||||||
w := &bytes.Buffer{}
|
|
||||||
_, err := w.WriteString("{")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, n := range nums {
|
|
||||||
if i > 0 {
|
|
||||||
_, err = w.WriteString(",")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.WriteString("}")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeInt2Array(w *WriteBuf, value interface{}) error {
|
|
||||||
v, ok := value.([]int16)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("Expected []int16, received %T", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := int16SliceToArrayString(v)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to encode []int16: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return encodeText(w, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeInt4Array(qr *QueryResult, fd *FieldDescription, size int32) []int32 {
|
|
||||||
if fd.DataType != Int4ArrayOid {
|
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int4[] but received: %v", fd.DataType)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch fd.FormatCode {
|
|
||||||
case TextFormatCode:
|
|
||||||
s := qr.mr.ReadString(size)
|
|
||||||
|
|
||||||
elements := SplitArrayText(s)
|
|
||||||
|
|
||||||
numbers := make([]int32, 0, len(elements))
|
|
||||||
|
|
||||||
for _, e := range elements {
|
|
||||||
n, err := strconv.ParseInt(e, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4[]: %v", s)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
numbers = append(numbers, int32(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
return numbers
|
|
||||||
default:
|
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func int32SliceToArrayString(nums []int32) (string, error) {
|
|
||||||
w := &bytes.Buffer{}
|
|
||||||
|
|
||||||
_, err := w.WriteString("{")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, n := range nums {
|
|
||||||
if i > 0 {
|
|
||||||
_, err = w.WriteString(",")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.WriteString("}")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeInt4Array(w *WriteBuf, value interface{}) error {
|
|
||||||
v, ok := value.([]int32)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("Expected []int32, received %T", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := int32SliceToArrayString(v)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to encode []int32: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return encodeText(w, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeInt8Array(qr *QueryResult, fd *FieldDescription, size int32) []int64 {
|
|
||||||
if fd.DataType != Int8ArrayOid {
|
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int8[] but received: %v", fd.DataType)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch fd.FormatCode {
|
|
||||||
case TextFormatCode:
|
|
||||||
s := qr.mr.ReadString(size)
|
|
||||||
|
|
||||||
elements := SplitArrayText(s)
|
|
||||||
|
|
||||||
numbers := make([]int64, 0, len(elements))
|
|
||||||
|
|
||||||
for _, e := range elements {
|
|
||||||
n, err := strconv.ParseInt(e, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8[]: %v", s)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
numbers = append(numbers, int64(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
return numbers
|
|
||||||
default:
|
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func int64SliceToArrayString(nums []int64) (string, error) {
|
|
||||||
w := &bytes.Buffer{}
|
|
||||||
|
|
||||||
_, err := w.WriteString("{")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, n := range nums {
|
|
||||||
if i > 0 {
|
|
||||||
_, err = w.WriteString(",")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.WriteString("}")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeInt8Array(w *WriteBuf, value interface{}) error {
|
|
||||||
v, ok := value.([]int64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("Expected []int64, received %T", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := int64SliceToArrayString(v)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to encode []int64: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return encodeText(w, s)
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user