mirror of https://github.com/jackc/pgx.git
Use Scan and Encode* instead of ValueTranscoders
parent
43dcd47a92
commit
009cdfa0b1
95
conn.go
95
conn.go
|
@ -319,11 +319,9 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||
case rowDescription:
|
||||
ps.FieldDescriptions = c.rxRowDescription(r)
|
||||
for i := range ps.FieldDescriptions {
|
||||
oid := ps.FieldDescriptions[i].DataType
|
||||
vt := ValueTranscoders[oid]
|
||||
|
||||
if vt != nil {
|
||||
ps.FieldDescriptions[i].FormatCode = vt.DecodeFormat
|
||||
switch ps.FieldDescriptions[i].DataType {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, DateOid, TimestampTzOid:
|
||||
ps.FieldDescriptions[i].FormatCode = BinaryFormatCode
|
||||
}
|
||||
}
|
||||
case noData:
|
||||
|
@ -342,7 +340,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||
// Deallocate released a prepared statement
|
||||
func (c *Conn) Deallocate(name string) (err error) {
|
||||
delete(c.preparedStatements, name)
|
||||
_, err = c.Exec("deallocate " + c.QuoteIdentifier(name))
|
||||
_, err = c.Exec("deallocate " + QuoteIdentifier(name))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -601,6 +599,14 @@ func (qr *QueryResult) Scan(dest ...interface{}) (err error) {
|
|||
} else {
|
||||
*d = decodeTimestampTz(qr, fd, size)
|
||||
}
|
||||
|
||||
case Scanner:
|
||||
err = d.Scan(qr, fd, size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.New("Unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -708,7 +714,7 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
|
|||
|
||||
func (c *Conn) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
|
||||
if len(arguments) > 0 {
|
||||
sql, err = c.SanitizeSql(sql, arguments...)
|
||||
sql, err = SanitizeSql(sql, arguments...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -734,38 +740,71 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
wbuf.WriteCString(ps.Name)
|
||||
|
||||
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
||||
for _, oid := range ps.ParameterOids {
|
||||
transcoder := ValueTranscoders[oid]
|
||||
if transcoder == nil {
|
||||
transcoder = defaultTranscoder
|
||||
for i, oid := range ps.ParameterOids {
|
||||
switch oid {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
case TextOid, VarcharOid, DateOid, TimestampTzOid:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
if _, ok := arguments[i].(BinaryEncoder); ok {
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
} else {
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
}
|
||||
}
|
||||
wbuf.WriteInt16(transcoder.EncodeFormat)
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(int16(len(arguments)))
|
||||
for i, oid := range ps.ParameterOids {
|
||||
if arguments[i] != nil {
|
||||
transcoder := ValueTranscoders[oid]
|
||||
if transcoder == nil {
|
||||
transcoder = defaultTranscoder
|
||||
if arguments[i] == nil {
|
||||
wbuf.WriteInt32(-1)
|
||||
continue
|
||||
}
|
||||
|
||||
switch oid {
|
||||
case BoolOid:
|
||||
err = encodeBool(wbuf, arguments[i])
|
||||
case ByteaOid:
|
||||
err = encodeBytea(wbuf, arguments[i])
|
||||
case Int2Oid:
|
||||
err = encodeInt2(wbuf, arguments[i])
|
||||
case Int4Oid:
|
||||
err = encodeInt4(wbuf, arguments[i])
|
||||
case Int8Oid:
|
||||
err = encodeInt8(wbuf, arguments[i])
|
||||
case Float4Oid:
|
||||
err = encodeFloat4(wbuf, arguments[i])
|
||||
case Float8Oid:
|
||||
err = encodeFloat8(wbuf, arguments[i])
|
||||
case TextOid, VarcharOid:
|
||||
err = encodeText(wbuf, arguments[i])
|
||||
case DateOid:
|
||||
err = encodeDate(wbuf, arguments[i])
|
||||
case TimestampTzOid:
|
||||
err = encodeTimestampTz(wbuf, arguments[i])
|
||||
default:
|
||||
switch arg := arguments[i].(type) {
|
||||
case BinaryEncoder:
|
||||
err = arg.EncodeBinary(wbuf)
|
||||
case TextEncoder:
|
||||
var s string
|
||||
s, err = arg.EncodeText()
|
||||
wbuf.WriteInt32(int32(len(s)))
|
||||
wbuf.WriteBytes([]byte(s))
|
||||
default:
|
||||
return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder or BinaryEncoder", arg))
|
||||
}
|
||||
err = transcoder.EncodeTo(wbuf, arguments[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
wbuf.WriteInt32(int32(-1))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(int16(len(ps.FieldDescriptions)))
|
||||
for _, fd := range ps.FieldDescriptions {
|
||||
transcoder := ValueTranscoders[fd.DataType]
|
||||
if transcoder != nil {
|
||||
wbuf.WriteInt16(transcoder.DecodeFormat)
|
||||
} else {
|
||||
wbuf.WriteInt16(0)
|
||||
}
|
||||
wbuf.WriteInt16(fd.FormatCode)
|
||||
}
|
||||
|
||||
// execute
|
||||
|
|
114
conn_test.go
114
conn_test.go
|
@ -450,6 +450,120 @@ func TestConnQueryReadTooManyValues(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryUnpreparedScanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
qr, err := conn.Query("select null::int8, 1::int8")
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
ok := qr.NextRow()
|
||||
if !ok {
|
||||
t.Fatal("qr.NextRow terminated early")
|
||||
}
|
||||
|
||||
var n, m pgx.NullInt64
|
||||
err = qr.Scan(&n, &m)
|
||||
if err != nil {
|
||||
t.Fatalf("qr.Scan failed: ", err)
|
||||
}
|
||||
qr.Close()
|
||||
|
||||
if n.Valid {
|
||||
t.Error("Null should not be valid, but it was")
|
||||
}
|
||||
|
||||
if !m.Valid {
|
||||
t.Error("1 should be valid, but it wasn't")
|
||||
}
|
||||
|
||||
if m.Int64 != 1 {
|
||||
t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryPreparedScanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustPrepare(t, conn, "scannerTest", "select null::int8, 1::int8")
|
||||
|
||||
qr, err := conn.Query("scannerTest")
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
ok := qr.NextRow()
|
||||
if !ok {
|
||||
t.Fatal("qr.NextRow terminated early")
|
||||
}
|
||||
|
||||
var n, m pgx.NullInt64
|
||||
err = qr.Scan(&n, &m)
|
||||
if err != nil {
|
||||
t.Fatalf("qr.Scan failed: ", err)
|
||||
}
|
||||
qr.Close()
|
||||
|
||||
if n.Valid {
|
||||
t.Error("Null should not be valid, but it was")
|
||||
}
|
||||
|
||||
if !m.Valid {
|
||||
t.Error("1 should be valid, but it wasn't")
|
||||
}
|
||||
|
||||
if m.Int64 != 1 {
|
||||
t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryUnpreparedEncoder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
n := pgx.NullInt64{Int64: 1, Valid: true}
|
||||
|
||||
qr, err := conn.Query("select $1::int8", &n)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
ok := qr.NextRow()
|
||||
if !ok {
|
||||
t.Fatal("qr.NextRow terminated early")
|
||||
}
|
||||
|
||||
var m pgx.NullInt64
|
||||
err = qr.Scan(&m)
|
||||
if err != nil {
|
||||
t.Fatalf("qr.Scan failed: ", err)
|
||||
}
|
||||
qr.Close()
|
||||
|
||||
if !m.Valid {
|
||||
t.Error("m should be valid, but it wasn't")
|
||||
}
|
||||
|
||||
if m.Int64 != 1 {
|
||||
t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -178,3 +178,26 @@ func (r *MsgReader) ReadString(count int32) string {
|
|||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// ReadBytes reads count bytes and returns as []byte
|
||||
func (r *MsgReader) ReadBytes(count int32) []byte {
|
||||
if r.err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= count
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
return nil
|
||||
}
|
||||
|
||||
b := make([]byte, int(count))
|
||||
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
|
53
sanitize.go
53
sanitize.go
|
@ -9,18 +9,34 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
type SerializationError string
|
||||
|
||||
func (e SerializationError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// TextEncoder is an interface used to encode values in text format for
|
||||
// transmission to the PostgreSQL server. It is used by unprepared
|
||||
// queries and for prepared queries when the type does not implement
|
||||
// BinaryEncoder
|
||||
type TextEncoder interface {
|
||||
// EncodeText MUST sanitize (and quote, if necessary) the returned string.
|
||||
// It will be interpolated directly into the SQL string.
|
||||
EncodeText() (string, error)
|
||||
}
|
||||
|
||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||
|
||||
// QuoteString escapes and quotes a string making it safe for interpolation
|
||||
// into an SQL string.
|
||||
func (c *Conn) QuoteString(input string) (output string) {
|
||||
func QuoteString(input string) (output string) {
|
||||
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
||||
return
|
||||
}
|
||||
|
||||
// QuoteIdentifier escapes and quotes an identifier making it safe for
|
||||
// interpolation into an SQL string
|
||||
func (c *Conn) QuoteIdentifier(input string) (output string) {
|
||||
func QuoteIdentifier(input string) (output string) {
|
||||
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
||||
return
|
||||
}
|
||||
|
@ -28,12 +44,21 @@ func (c *Conn) QuoteIdentifier(input string) (output string) {
|
|||
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
|
||||
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
|
||||
// appropriate.
|
||||
func (c *Conn) SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
||||
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
||||
replacer := func(match string) (replacement string) {
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
||||
if int(n-1) >= len(args) {
|
||||
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
|
||||
return
|
||||
}
|
||||
|
||||
switch arg := args[n-1].(type) {
|
||||
case string:
|
||||
return c.QuoteString(arg)
|
||||
return QuoteString(arg)
|
||||
case int:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int8:
|
||||
|
@ -45,7 +70,7 @@ func (c *Conn) SanitizeSql(sql string, args ...interface{}) (output string, err
|
|||
case int64:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case time.Time:
|
||||
return c.QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
||||
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint8:
|
||||
|
@ -64,22 +89,14 @@ func (c *Conn) SanitizeSql(sql string, args ...interface{}) (output string, err
|
|||
return strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||
case []int16:
|
||||
var s string
|
||||
s, err = int16SliceToArrayString(arg)
|
||||
return c.QuoteString(s)
|
||||
case []int32:
|
||||
var s string
|
||||
s, err = int32SliceToArrayString(arg)
|
||||
return c.QuoteString(s)
|
||||
case []int64:
|
||||
var s string
|
||||
s, err = int64SliceToArrayString(arg)
|
||||
return c.QuoteString(s)
|
||||
case nil:
|
||||
return "null"
|
||||
case TextEncoder:
|
||||
var s string
|
||||
s, err = arg.EncodeText()
|
||||
return s
|
||||
default:
|
||||
err = fmt.Errorf("Unable to sanitize type: %T", arg)
|
||||
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
|
104
sanitize_test.go
104
sanitize_test.go
|
@ -1,20 +1,19 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if conn.QuoteString("test") != "'test'" {
|
||||
if pgx.QuoteString("test") != "'test'" {
|
||||
t.Error("Failed to quote string")
|
||||
}
|
||||
|
||||
if conn.QuoteString("Jack's") != "'Jack''s'" {
|
||||
if pgx.QuoteString("Jack's") != "'Jack''s'" {
|
||||
t.Error("Failed to quote and escape string with embedded quote")
|
||||
}
|
||||
}
|
||||
|
@ -22,70 +21,47 @@ func TestQuoteString(t *testing.T) {
|
|||
func TestSanitizeSql(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1", nil); err != nil || san != "select null" {
|
||||
t.Errorf("Failed to translate nil to null: %v - %v", san, err)
|
||||
successTests := []struct {
|
||||
sql string
|
||||
args []interface{}
|
||||
output string
|
||||
}{
|
||||
{"select $1", []interface{}{nil}, "select null"},
|
||||
{"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"},
|
||||
{"select $1", []interface{}{42}, "select 42"},
|
||||
{"select $1", []interface{}{1.23}, "select 1.23"},
|
||||
{"select $1", []interface{}{true}, "select true"},
|
||||
{"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"},
|
||||
{"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`},
|
||||
{"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"},
|
||||
}
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1", "Jack's"); err != nil || san != "select 'Jack''s'" {
|
||||
t.Errorf("Failed to sanitize string: %v - %v", san, err)
|
||||
for i, tt := range successTests {
|
||||
san, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args)
|
||||
}
|
||||
if san != tt.output {
|
||||
t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args)
|
||||
}
|
||||
}
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1", 42); err != nil || san != "select 42" {
|
||||
t.Errorf("Failed to pass through integer: %v - %v", san, err)
|
||||
errorTests := []struct {
|
||||
sql string
|
||||
args []interface{}
|
||||
err string
|
||||
}{
|
||||
{"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"},
|
||||
{"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"},
|
||||
}
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1", 1.23); err != nil || san != "select 1.23" {
|
||||
t.Errorf("Failed to pass through float: %v - %v", san, err)
|
||||
}
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1", true); err != nil || san != "select true" {
|
||||
t.Errorf("Failed to pass through bool: %v - %v", san, err)
|
||||
}
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23); err != nil || san != "select 'Jack''s', 42, 1.23" {
|
||||
t.Errorf("Failed to sanitize multiple params: %v - %v", san, err)
|
||||
}
|
||||
|
||||
bytea := make([]byte, 4)
|
||||
bytea[0] = 0 // 0x00
|
||||
bytea[1] = 15 // 0x0F
|
||||
bytea[2] = 255 // 0xFF
|
||||
bytea[3] = 17 // 0x11
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1", bytea); err != nil || san != `select E'\\x000fff11'` {
|
||||
t.Errorf("Failed to sanitize []byte: %v - %v", san, err)
|
||||
}
|
||||
|
||||
int2a := make([]int16, 4)
|
||||
int2a[0] = 42
|
||||
int2a[1] = 0
|
||||
int2a[2] = -1
|
||||
int2a[3] = 32123
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1::int2[]", int2a); err != nil || san != `select '{42,0,-1,32123}'::int2[]` {
|
||||
t.Errorf("Failed to sanitize []int16: %v - %v", san, err)
|
||||
}
|
||||
|
||||
int4a := make([]int32, 4)
|
||||
int4a[0] = 42
|
||||
int4a[1] = 0
|
||||
int4a[2] = -1
|
||||
int4a[3] = 32123
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1::int4[]", int4a); err != nil || san != `select '{42,0,-1,32123}'::int4[]` {
|
||||
t.Errorf("Failed to sanitize []int32: %v - %v", san, err)
|
||||
}
|
||||
|
||||
int8a := make([]int64, 4)
|
||||
int8a[0] = 42
|
||||
int8a[1] = 0
|
||||
int8a[2] = -1
|
||||
int8a[3] = 32123
|
||||
|
||||
if san, err := conn.SanitizeSql("select $1::int8[]", int8a); err != nil || san != `select '{42,0,-1,32123}'::int8[]` {
|
||||
t.Errorf("Failed to sanitize []int64: %v - %v", san, err)
|
||||
for i, tt := range errorTests {
|
||||
_, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
||||
if err == nil {
|
||||
t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math"
|
||||
|
@ -20,9 +19,6 @@ const (
|
|||
TextOid = 25
|
||||
Float4Oid = 700
|
||||
Float8Oid = 701
|
||||
Int2ArrayOid = 1005
|
||||
Int4ArrayOid = 1007
|
||||
Int8ArrayOid = 1016
|
||||
VarcharOid = 1043
|
||||
DateOid = 1082
|
||||
TimestampTzOid = 1184
|
||||
|
@ -33,134 +29,39 @@ const (
|
|||
BinaryFormatCode = 1
|
||||
)
|
||||
|
||||
// ValueTranscoder stores all the data necessary to encode and decode values from
|
||||
// a PostgreSQL server
|
||||
type ValueTranscoder struct {
|
||||
// Decode decodes values returned from the server
|
||||
Decode func(qr *QueryResult, fd *FieldDescription, size int32) interface{}
|
||||
// DecodeFormat is the preferred response format.
|
||||
// Allowed values: TextFormatCode, BinaryFormatCode
|
||||
DecodeFormat int16
|
||||
// EncodeTo encodes values to send to the server
|
||||
EncodeTo func(*WriteBuf, interface{}) error
|
||||
// EncodeFormat is the format values are encoded for transmission.
|
||||
// Allowed values: TextFormatCode, BinaryFormatCode
|
||||
EncodeFormat int16
|
||||
type Scanner interface {
|
||||
Scan(qr *QueryResult, fd *FieldDescription, size int32) error
|
||||
}
|
||||
|
||||
// ValueTranscoders is used to transcode values being sent to and received from
|
||||
// the PostgreSQL server. Additional types can be transcoded by adding a
|
||||
// *ValueTranscoder for the appropriate Oid to the map.
|
||||
var ValueTranscoders map[Oid]*ValueTranscoder
|
||||
// BinaryEncoder is an interface used to encode values in binary format for
|
||||
// transmission to the PostgreSQL server. It is used by prepared queries.
|
||||
type BinaryEncoder interface {
|
||||
// EncodeText MUST sanitize (and quote, if necessary) the returned string.
|
||||
// It will be interpolated directly into the SQL string.
|
||||
EncodeBinary(w *WriteBuf) error
|
||||
}
|
||||
|
||||
var defaultTranscoder *ValueTranscoder
|
||||
type NullInt64 struct {
|
||||
Int64 int64
|
||||
Valid bool // Valid is true if Int64 is not NULL
|
||||
}
|
||||
|
||||
func init() {
|
||||
ValueTranscoders = make(map[Oid]*ValueTranscoder)
|
||||
func (n *NullInt64) Scan(qr *QueryResult, fd *FieldDescription, size int32) error {
|
||||
if size == -1 {
|
||||
n.Int64, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Int64 = decodeInt8(qr, fd, size)
|
||||
return qr.Err()
|
||||
}
|
||||
|
||||
// bool
|
||||
ValueTranscoders[BoolOid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeBool(qr, fd, size) },
|
||||
DecodeFormat: BinaryFormatCode,
|
||||
EncodeTo: encodeBool,
|
||||
EncodeFormat: BinaryFormatCode}
|
||||
|
||||
// bytea
|
||||
ValueTranscoders[ByteaOid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeBytea(qr, fd, size) },
|
||||
DecodeFormat: TextFormatCode,
|
||||
EncodeTo: encodeBytea,
|
||||
EncodeFormat: BinaryFormatCode}
|
||||
|
||||
// int8
|
||||
ValueTranscoders[Int8Oid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt8(qr, fd, size) },
|
||||
DecodeFormat: BinaryFormatCode,
|
||||
EncodeTo: encodeInt8,
|
||||
EncodeFormat: BinaryFormatCode}
|
||||
|
||||
// int2
|
||||
ValueTranscoders[Int2Oid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt2(qr, fd, size) },
|
||||
DecodeFormat: BinaryFormatCode,
|
||||
EncodeTo: encodeInt2,
|
||||
EncodeFormat: BinaryFormatCode}
|
||||
|
||||
// int4
|
||||
ValueTranscoders[Int4Oid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt4(qr, fd, size) },
|
||||
DecodeFormat: BinaryFormatCode,
|
||||
EncodeTo: encodeInt4,
|
||||
EncodeFormat: BinaryFormatCode}
|
||||
|
||||
// text
|
||||
ValueTranscoders[TextOid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeText(qr, fd, size) },
|
||||
DecodeFormat: TextFormatCode,
|
||||
EncodeTo: encodeText,
|
||||
EncodeFormat: TextFormatCode}
|
||||
|
||||
// float4
|
||||
ValueTranscoders[Float4Oid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeFloat4(qr, fd, size) },
|
||||
EncodeTo: encodeFloat4,
|
||||
EncodeFormat: BinaryFormatCode}
|
||||
|
||||
// float8
|
||||
ValueTranscoders[Float8Oid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeFloat8(qr, fd, size) },
|
||||
DecodeFormat: BinaryFormatCode,
|
||||
EncodeTo: encodeFloat8,
|
||||
EncodeFormat: BinaryFormatCode}
|
||||
|
||||
// int2[]
|
||||
ValueTranscoders[Int2ArrayOid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} {
|
||||
return decodeInt2Array(qr, fd, size)
|
||||
},
|
||||
DecodeFormat: TextFormatCode,
|
||||
EncodeTo: encodeInt2Array,
|
||||
EncodeFormat: TextFormatCode}
|
||||
|
||||
// int4[]
|
||||
ValueTranscoders[Int4ArrayOid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} {
|
||||
return decodeInt4Array(qr, fd, size)
|
||||
},
|
||||
DecodeFormat: TextFormatCode,
|
||||
EncodeTo: encodeInt4Array,
|
||||
EncodeFormat: TextFormatCode}
|
||||
|
||||
// int8[]
|
||||
ValueTranscoders[Int8ArrayOid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} {
|
||||
return decodeInt8Array(qr, fd, size)
|
||||
},
|
||||
DecodeFormat: TextFormatCode,
|
||||
EncodeTo: encodeInt8Array,
|
||||
EncodeFormat: TextFormatCode}
|
||||
|
||||
// varchar -- same as text
|
||||
ValueTranscoders[VarcharOid] = ValueTranscoders[Oid(25)]
|
||||
|
||||
// date
|
||||
ValueTranscoders[DateOid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeDate(qr, fd, size) },
|
||||
DecodeFormat: BinaryFormatCode,
|
||||
EncodeTo: encodeDate,
|
||||
EncodeFormat: TextFormatCode}
|
||||
|
||||
// timestamptz
|
||||
ValueTranscoders[TimestampTzOid] = &ValueTranscoder{
|
||||
Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} {
|
||||
return decodeTimestampTz(qr, fd, size)
|
||||
},
|
||||
DecodeFormat: BinaryFormatCode,
|
||||
EncodeTo: encodeTimestampTz,
|
||||
EncodeFormat: TextFormatCode}
|
||||
|
||||
// use text transcoder for anything we don't understand
|
||||
defaultTranscoder = ValueTranscoders[TextOid]
|
||||
func (n *NullInt64) EncodeText() (string, error) {
|
||||
if n.Valid {
|
||||
return strconv.FormatInt(int64(n.Int64), 10), nil
|
||||
} else {
|
||||
return "null", nil
|
||||
}
|
||||
}
|
||||
|
||||
var arrayEl *regexp.Regexp = regexp.MustCompile(`[{,](?:"((?:[^"\\]|\\.)*)"|(NULL)|([^,}]+))`)
|
||||
|
@ -645,224 +546,3 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error {
|
|||
s := t.Format("2006-01-02 15:04:05.999999 -0700")
|
||||
return encodeText(w, s)
|
||||
}
|
||||
|
||||
func decodeInt2Array(qr *QueryResult, fd *FieldDescription, size int32) []int16 {
|
||||
if fd.DataType != Int2ArrayOid {
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int2[] but received: %v", fd.DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
switch fd.FormatCode {
|
||||
case TextFormatCode:
|
||||
s := qr.mr.ReadString(size)
|
||||
|
||||
elements := SplitArrayText(s)
|
||||
|
||||
numbers := make([]int16, 0, len(elements))
|
||||
|
||||
for _, e := range elements {
|
||||
n, err := strconv.ParseInt(e, 10, 16)
|
||||
if err != nil {
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2[]: %v", s)))
|
||||
return nil
|
||||
}
|
||||
numbers = append(numbers, int16(n))
|
||||
}
|
||||
|
||||
return numbers
|
||||
default:
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func int16SliceToArrayString(nums []int16) (string, error) {
|
||||
w := &bytes.Buffer{}
|
||||
_, err := w.WriteString("{")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for i, n := range nums {
|
||||
if i > 0 {
|
||||
_, err = w.WriteString(",")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = w.WriteString("}")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return w.String(), nil
|
||||
}
|
||||
|
||||
func encodeInt2Array(w *WriteBuf, value interface{}) error {
|
||||
v, ok := value.([]int16)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []int16, received %T", value)
|
||||
}
|
||||
|
||||
s, err := int16SliceToArrayString(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to encode []int16: %v", err)
|
||||
}
|
||||
|
||||
return encodeText(w, s)
|
||||
}
|
||||
|
||||
func decodeInt4Array(qr *QueryResult, fd *FieldDescription, size int32) []int32 {
|
||||
if fd.DataType != Int4ArrayOid {
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int4[] but received: %v", fd.DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
switch fd.FormatCode {
|
||||
case TextFormatCode:
|
||||
s := qr.mr.ReadString(size)
|
||||
|
||||
elements := SplitArrayText(s)
|
||||
|
||||
numbers := make([]int32, 0, len(elements))
|
||||
|
||||
for _, e := range elements {
|
||||
n, err := strconv.ParseInt(e, 10, 32)
|
||||
if err != nil {
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4[]: %v", s)))
|
||||
return nil
|
||||
}
|
||||
numbers = append(numbers, int32(n))
|
||||
}
|
||||
|
||||
return numbers
|
||||
default:
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func int32SliceToArrayString(nums []int32) (string, error) {
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
_, err := w.WriteString("{")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for i, n := range nums {
|
||||
if i > 0 {
|
||||
_, err = w.WriteString(",")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = w.WriteString("}")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return w.String(), nil
|
||||
}
|
||||
|
||||
func encodeInt4Array(w *WriteBuf, value interface{}) error {
|
||||
v, ok := value.([]int32)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []int32, received %T", value)
|
||||
}
|
||||
|
||||
s, err := int32SliceToArrayString(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to encode []int32: %v", err)
|
||||
}
|
||||
|
||||
return encodeText(w, s)
|
||||
}
|
||||
|
||||
func decodeInt8Array(qr *QueryResult, fd *FieldDescription, size int32) []int64 {
|
||||
if fd.DataType != Int8ArrayOid {
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int8[] but received: %v", fd.DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
switch fd.FormatCode {
|
||||
case TextFormatCode:
|
||||
s := qr.mr.ReadString(size)
|
||||
|
||||
elements := SplitArrayText(s)
|
||||
|
||||
numbers := make([]int64, 0, len(elements))
|
||||
|
||||
for _, e := range elements {
|
||||
n, err := strconv.ParseInt(e, 10, 64)
|
||||
if err != nil {
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8[]: %v", s)))
|
||||
return nil
|
||||
}
|
||||
numbers = append(numbers, int64(n))
|
||||
}
|
||||
|
||||
return numbers
|
||||
default:
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func int64SliceToArrayString(nums []int64) (string, error) {
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
_, err := w.WriteString("{")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for i, n := range nums {
|
||||
if i > 0 {
|
||||
_, err = w.WriteString(",")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = w.WriteString("}")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return w.String(), nil
|
||||
}
|
||||
|
||||
func encodeInt8Array(w *WriteBuf, value interface{}) error {
|
||||
v, ok := value.([]int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []int64, received %T", value)
|
||||
}
|
||||
|
||||
s, err := int64SliceToArrayString(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to encode []int64: %v", err)
|
||||
}
|
||||
|
||||
return encodeText(w, s)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue