diff --git a/conn.go b/conn.go index 78a810a2..8d27b84a 100644 --- a/conn.go +++ b/conn.go @@ -319,11 +319,9 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { case rowDescription: ps.FieldDescriptions = c.rxRowDescription(r) for i := range ps.FieldDescriptions { - oid := ps.FieldDescriptions[i].DataType - vt := ValueTranscoders[oid] - - if vt != nil { - ps.FieldDescriptions[i].FormatCode = vt.DecodeFormat + switch ps.FieldDescriptions[i].DataType { + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, DateOid, TimestampTzOid: + ps.FieldDescriptions[i].FormatCode = BinaryFormatCode } } case noData: @@ -342,7 +340,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // Deallocate released a prepared statement func (c *Conn) Deallocate(name string) (err error) { delete(c.preparedStatements, name) - _, err = c.Exec("deallocate " + c.QuoteIdentifier(name)) + _, err = c.Exec("deallocate " + QuoteIdentifier(name)) return } @@ -601,6 +599,14 @@ func (qr *QueryResult) Scan(dest ...interface{}) (err error) { } else { *d = decodeTimestampTz(qr, fd, size) } + + case Scanner: + err = d.Scan(qr, fd, size) + if err != nil { + return err + } + default: + return errors.New("Unknown type") } } @@ -708,7 +714,7 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { func (c *Conn) sendSimpleQuery(sql string, arguments ...interface{}) (err error) { if len(arguments) > 0 { - sql, err = c.SanitizeSql(sql, arguments...) + sql, err = SanitizeSql(sql, arguments...) if err != nil { return } @@ -734,38 +740,71 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteCString(ps.Name) wbuf.WriteInt16(int16(len(ps.ParameterOids))) - for _, oid := range ps.ParameterOids { - transcoder := ValueTranscoders[oid] - if transcoder == nil { - transcoder = defaultTranscoder + for i, oid := range ps.ParameterOids { + switch oid { + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid: + wbuf.WriteInt16(BinaryFormatCode) + case TextOid, VarcharOid, DateOid, TimestampTzOid: + wbuf.WriteInt16(TextFormatCode) + default: + if _, ok := arguments[i].(BinaryEncoder); ok { + wbuf.WriteInt16(BinaryFormatCode) + } else { + wbuf.WriteInt16(TextFormatCode) + } } - wbuf.WriteInt16(transcoder.EncodeFormat) } wbuf.WriteInt16(int16(len(arguments))) for i, oid := range ps.ParameterOids { - if arguments[i] != nil { - transcoder := ValueTranscoders[oid] - if transcoder == nil { - transcoder = defaultTranscoder + if arguments[i] == nil { + wbuf.WriteInt32(-1) + continue + } + + switch oid { + case BoolOid: + err = encodeBool(wbuf, arguments[i]) + case ByteaOid: + err = encodeBytea(wbuf, arguments[i]) + case Int2Oid: + err = encodeInt2(wbuf, arguments[i]) + case Int4Oid: + err = encodeInt4(wbuf, arguments[i]) + case Int8Oid: + err = encodeInt8(wbuf, arguments[i]) + case Float4Oid: + err = encodeFloat4(wbuf, arguments[i]) + case Float8Oid: + err = encodeFloat8(wbuf, arguments[i]) + case TextOid, VarcharOid: + err = encodeText(wbuf, arguments[i]) + case DateOid: + err = encodeDate(wbuf, arguments[i]) + case TimestampTzOid: + err = encodeTimestampTz(wbuf, arguments[i]) + default: + switch arg := arguments[i].(type) { + case BinaryEncoder: + err = arg.EncodeBinary(wbuf) + case TextEncoder: + var s string + s, err = arg.EncodeText() + wbuf.WriteInt32(int32(len(s))) + wbuf.WriteBytes([]byte(s)) + default: + return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder or BinaryEncoder", arg)) } - err = transcoder.EncodeTo(wbuf, arguments[i]) - if err != nil { - return err - } - } else { - wbuf.WriteInt32(int32(-1)) + } + + if err != nil { + return err } } wbuf.WriteInt16(int16(len(ps.FieldDescriptions))) for _, fd := range ps.FieldDescriptions { - transcoder := ValueTranscoders[fd.DataType] - if transcoder != nil { - wbuf.WriteInt16(transcoder.DecodeFormat) - } else { - wbuf.WriteInt16(0) - } + wbuf.WriteInt16(fd.FormatCode) } // execute diff --git a/conn_test.go b/conn_test.go index bddf49c7..4a5e3a29 100644 --- a/conn_test.go +++ b/conn_test.go @@ -450,6 +450,120 @@ func TestConnQueryReadTooManyValues(t *testing.T) { ensureConnValid(t, conn) } +func TestConnQueryUnpreparedScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + qr, err := conn.Query("select null::int8, 1::int8") + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + ok := qr.NextRow() + if !ok { + t.Fatal("qr.NextRow terminated early") + } + + var n, m pgx.NullInt64 + err = qr.Scan(&n, &m) + if err != nil { + t.Fatalf("qr.Scan failed: ", err) + } + qr.Close() + + if n.Valid { + t.Error("Null should not be valid, but it was") + } + + if !m.Valid { + t.Error("1 should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryPreparedScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustPrepare(t, conn, "scannerTest", "select null::int8, 1::int8") + + qr, err := conn.Query("scannerTest") + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + ok := qr.NextRow() + if !ok { + t.Fatal("qr.NextRow terminated early") + } + + var n, m pgx.NullInt64 + err = qr.Scan(&n, &m) + if err != nil { + t.Fatalf("qr.Scan failed: ", err) + } + qr.Close() + + if n.Valid { + t.Error("Null should not be valid, but it was") + } + + if !m.Valid { + t.Error("1 should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryUnpreparedEncoder(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + n := pgx.NullInt64{Int64: 1, Valid: true} + + qr, err := conn.Query("select $1::int8", &n) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + ok := qr.NextRow() + if !ok { + t.Fatal("qr.NextRow terminated early") + } + + var m pgx.NullInt64 + err = qr.Scan(&m) + if err != nil { + t.Fatalf("qr.Scan failed: ", err) + } + qr.Close() + + if !m.Valid { + t.Error("m should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + func TestPrepare(t *testing.T) { t.Parallel() diff --git a/msg_reader.go b/msg_reader.go index 5bc1170d..32baac75 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -178,3 +178,26 @@ func (r *MsgReader) ReadString(count int32) string { return string(b) } + +// ReadBytes reads count bytes and returns as []byte +func (r *MsgReader) ReadBytes(count int32) []byte { + if r.err != nil { + return nil + } + + r.msgBytesRemaining -= count + if r.msgBytesRemaining < 0 { + r.Fatal(errors.New("read past end of message")) + return nil + } + + b := make([]byte, int(count)) + + _, err := io.ReadFull(r.reader, b) + if err != nil { + r.Fatal(err) + return nil + } + + return b +} diff --git a/sanitize.go b/sanitize.go index b7cc95f0..320af55b 100644 --- a/sanitize.go +++ b/sanitize.go @@ -9,18 +9,34 @@ import ( "time" ) +type SerializationError string + +func (e SerializationError) Error() string { + return string(e) +} + +// TextEncoder is an interface used to encode values in text format for +// transmission to the PostgreSQL server. It is used by unprepared +// queries and for prepared queries when the type does not implement +// BinaryEncoder +type TextEncoder interface { + // EncodeText MUST sanitize (and quote, if necessary) the returned string. + // It will be interpolated directly into the SQL string. + EncodeText() (string, error) +} + var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`) // QuoteString escapes and quotes a string making it safe for interpolation // into an SQL string. -func (c *Conn) QuoteString(input string) (output string) { +func QuoteString(input string) (output string) { output = "'" + strings.Replace(input, "'", "''", -1) + "'" return } // QuoteIdentifier escapes and quotes an identifier making it safe for // interpolation into an SQL string -func (c *Conn) QuoteIdentifier(input string) (output string) { +func QuoteIdentifier(input string) (output string) { output = `"` + strings.Replace(input, `"`, `""`, -1) + `"` return } @@ -28,12 +44,21 @@ func (c *Conn) QuoteIdentifier(input string) (output string) { // SanitizeSql substitutely args positionaly into sql. Placeholder values are // $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as // appropriate. -func (c *Conn) SanitizeSql(sql string, args ...interface{}) (output string, err error) { +func SanitizeSql(sql string, args ...interface{}) (output string, err error) { replacer := func(match string) (replacement string) { + if err != nil { + return "" + } + n, _ := strconv.ParseInt(match[1:], 10, 0) + if int(n-1) >= len(args) { + err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args)) + return + } + switch arg := args[n-1].(type) { case string: - return c.QuoteString(arg) + return QuoteString(arg) case int: return strconv.FormatInt(int64(arg), 10) case int8: @@ -45,7 +70,7 @@ func (c *Conn) SanitizeSql(sql string, args ...interface{}) (output string, err case int64: return strconv.FormatInt(int64(arg), 10) case time.Time: - return c.QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")) + return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")) case uint: return strconv.FormatUint(uint64(arg), 10) case uint8: @@ -64,22 +89,14 @@ func (c *Conn) SanitizeSql(sql string, args ...interface{}) (output string, err return strconv.FormatBool(arg) case []byte: return `E'\\x` + hex.EncodeToString(arg) + `'` - case []int16: - var s string - s, err = int16SliceToArrayString(arg) - return c.QuoteString(s) - case []int32: - var s string - s, err = int32SliceToArrayString(arg) - return c.QuoteString(s) - case []int64: - var s string - s, err = int64SliceToArrayString(arg) - return c.QuoteString(s) case nil: return "null" + case TextEncoder: + var s string + s, err = arg.EncodeText() + return s default: - err = fmt.Errorf("Unable to sanitize type: %T", arg) + err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg)) return "" } } diff --git a/sanitize_test.go b/sanitize_test.go index fd5f035b..d3134dc3 100644 --- a/sanitize_test.go +++ b/sanitize_test.go @@ -1,20 +1,19 @@ package pgx_test import ( + "github.com/jackc/pgx" + "strings" "testing" ) func TestQuoteString(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - if conn.QuoteString("test") != "'test'" { + if pgx.QuoteString("test") != "'test'" { t.Error("Failed to quote string") } - if conn.QuoteString("Jack's") != "'Jack''s'" { + if pgx.QuoteString("Jack's") != "'Jack''s'" { t.Error("Failed to quote and escape string with embedded quote") } } @@ -22,70 +21,47 @@ func TestQuoteString(t *testing.T) { func TestSanitizeSql(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - if san, err := conn.SanitizeSql("select $1", nil); err != nil || san != "select null" { - t.Errorf("Failed to translate nil to null: %v - %v", san, err) + successTests := []struct { + sql string + args []interface{} + output string + }{ + {"select $1", []interface{}{nil}, "select null"}, + {"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"}, + {"select $1", []interface{}{42}, "select 42"}, + {"select $1", []interface{}{1.23}, "select 1.23"}, + {"select $1", []interface{}{true}, "select true"}, + {"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"}, + {"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`}, + {"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"}, } - if san, err := conn.SanitizeSql("select $1", "Jack's"); err != nil || san != "select 'Jack''s'" { - t.Errorf("Failed to sanitize string: %v - %v", san, err) + for i, tt := range successTests { + san, err := pgx.SanitizeSql(tt.sql, tt.args...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args) + } + if san != tt.output { + t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args) + } } - if san, err := conn.SanitizeSql("select $1", 42); err != nil || san != "select 42" { - t.Errorf("Failed to pass through integer: %v - %v", san, err) + errorTests := []struct { + sql string + args []interface{} + err string + }{ + {"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"}, + {"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"}, } - if san, err := conn.SanitizeSql("select $1", 1.23); err != nil || san != "select 1.23" { - t.Errorf("Failed to pass through float: %v - %v", san, err) - } - - if san, err := conn.SanitizeSql("select $1", true); err != nil || san != "select true" { - t.Errorf("Failed to pass through bool: %v - %v", san, err) - } - - if san, err := conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23); err != nil || san != "select 'Jack''s', 42, 1.23" { - t.Errorf("Failed to sanitize multiple params: %v - %v", san, err) - } - - bytea := make([]byte, 4) - bytea[0] = 0 // 0x00 - bytea[1] = 15 // 0x0F - bytea[2] = 255 // 0xFF - bytea[3] = 17 // 0x11 - - if san, err := conn.SanitizeSql("select $1", bytea); err != nil || san != `select E'\\x000fff11'` { - t.Errorf("Failed to sanitize []byte: %v - %v", san, err) - } - - int2a := make([]int16, 4) - int2a[0] = 42 - int2a[1] = 0 - int2a[2] = -1 - int2a[3] = 32123 - - if san, err := conn.SanitizeSql("select $1::int2[]", int2a); err != nil || san != `select '{42,0,-1,32123}'::int2[]` { - t.Errorf("Failed to sanitize []int16: %v - %v", san, err) - } - - int4a := make([]int32, 4) - int4a[0] = 42 - int4a[1] = 0 - int4a[2] = -1 - int4a[3] = 32123 - - if san, err := conn.SanitizeSql("select $1::int4[]", int4a); err != nil || san != `select '{42,0,-1,32123}'::int4[]` { - t.Errorf("Failed to sanitize []int32: %v - %v", san, err) - } - - int8a := make([]int64, 4) - int8a[0] = 42 - int8a[1] = 0 - int8a[2] = -1 - int8a[3] = 32123 - - if san, err := conn.SanitizeSql("select $1::int8[]", int8a); err != nil || san != `select '{42,0,-1,32123}'::int8[]` { - t.Errorf("Failed to sanitize []int64: %v - %v", san, err) + for i, tt := range errorTests { + _, err := pgx.SanitizeSql(tt.sql, tt.args...) + if err == nil { + t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args) + } } } diff --git a/value_transcoder.go b/value_transcoder.go index 357fd822..3a9afc17 100644 --- a/value_transcoder.go +++ b/value_transcoder.go @@ -1,7 +1,6 @@ package pgx import ( - "bytes" "encoding/hex" "fmt" "math" @@ -20,9 +19,6 @@ const ( TextOid = 25 Float4Oid = 700 Float8Oid = 701 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - Int8ArrayOid = 1016 VarcharOid = 1043 DateOid = 1082 TimestampTzOid = 1184 @@ -33,134 +29,39 @@ const ( BinaryFormatCode = 1 ) -// ValueTranscoder stores all the data necessary to encode and decode values from -// a PostgreSQL server -type ValueTranscoder struct { - // Decode decodes values returned from the server - Decode func(qr *QueryResult, fd *FieldDescription, size int32) interface{} - // DecodeFormat is the preferred response format. - // Allowed values: TextFormatCode, BinaryFormatCode - DecodeFormat int16 - // EncodeTo encodes values to send to the server - EncodeTo func(*WriteBuf, interface{}) error - // EncodeFormat is the format values are encoded for transmission. - // Allowed values: TextFormatCode, BinaryFormatCode - EncodeFormat int16 +type Scanner interface { + Scan(qr *QueryResult, fd *FieldDescription, size int32) error } -// ValueTranscoders is used to transcode values being sent to and received from -// the PostgreSQL server. Additional types can be transcoded by adding a -// *ValueTranscoder for the appropriate Oid to the map. -var ValueTranscoders map[Oid]*ValueTranscoder +// BinaryEncoder is an interface used to encode values in binary format for +// transmission to the PostgreSQL server. It is used by prepared queries. +type BinaryEncoder interface { + // EncodeText MUST sanitize (and quote, if necessary) the returned string. + // It will be interpolated directly into the SQL string. + EncodeBinary(w *WriteBuf) error +} -var defaultTranscoder *ValueTranscoder +type NullInt64 struct { + Int64 int64 + Valid bool // Valid is true if Int64 is not NULL +} -func init() { - ValueTranscoders = make(map[Oid]*ValueTranscoder) +func (n *NullInt64) Scan(qr *QueryResult, fd *FieldDescription, size int32) error { + if size == -1 { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + n.Int64 = decodeInt8(qr, fd, size) + return qr.Err() +} - // bool - ValueTranscoders[BoolOid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeBool(qr, fd, size) }, - DecodeFormat: BinaryFormatCode, - EncodeTo: encodeBool, - EncodeFormat: BinaryFormatCode} - - // bytea - ValueTranscoders[ByteaOid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeBytea(qr, fd, size) }, - DecodeFormat: TextFormatCode, - EncodeTo: encodeBytea, - EncodeFormat: BinaryFormatCode} - - // int8 - ValueTranscoders[Int8Oid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt8(qr, fd, size) }, - DecodeFormat: BinaryFormatCode, - EncodeTo: encodeInt8, - EncodeFormat: BinaryFormatCode} - - // int2 - ValueTranscoders[Int2Oid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt2(qr, fd, size) }, - DecodeFormat: BinaryFormatCode, - EncodeTo: encodeInt2, - EncodeFormat: BinaryFormatCode} - - // int4 - ValueTranscoders[Int4Oid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt4(qr, fd, size) }, - DecodeFormat: BinaryFormatCode, - EncodeTo: encodeInt4, - EncodeFormat: BinaryFormatCode} - - // text - ValueTranscoders[TextOid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeText(qr, fd, size) }, - DecodeFormat: TextFormatCode, - EncodeTo: encodeText, - EncodeFormat: TextFormatCode} - - // float4 - ValueTranscoders[Float4Oid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeFloat4(qr, fd, size) }, - EncodeTo: encodeFloat4, - EncodeFormat: BinaryFormatCode} - - // float8 - ValueTranscoders[Float8Oid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeFloat8(qr, fd, size) }, - DecodeFormat: BinaryFormatCode, - EncodeTo: encodeFloat8, - EncodeFormat: BinaryFormatCode} - - // int2[] - ValueTranscoders[Int2ArrayOid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { - return decodeInt2Array(qr, fd, size) - }, - DecodeFormat: TextFormatCode, - EncodeTo: encodeInt2Array, - EncodeFormat: TextFormatCode} - - // int4[] - ValueTranscoders[Int4ArrayOid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { - return decodeInt4Array(qr, fd, size) - }, - DecodeFormat: TextFormatCode, - EncodeTo: encodeInt4Array, - EncodeFormat: TextFormatCode} - - // int8[] - ValueTranscoders[Int8ArrayOid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { - return decodeInt8Array(qr, fd, size) - }, - DecodeFormat: TextFormatCode, - EncodeTo: encodeInt8Array, - EncodeFormat: TextFormatCode} - - // varchar -- same as text - ValueTranscoders[VarcharOid] = ValueTranscoders[Oid(25)] - - // date - ValueTranscoders[DateOid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeDate(qr, fd, size) }, - DecodeFormat: BinaryFormatCode, - EncodeTo: encodeDate, - EncodeFormat: TextFormatCode} - - // timestamptz - ValueTranscoders[TimestampTzOid] = &ValueTranscoder{ - Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { - return decodeTimestampTz(qr, fd, size) - }, - DecodeFormat: BinaryFormatCode, - EncodeTo: encodeTimestampTz, - EncodeFormat: TextFormatCode} - - // use text transcoder for anything we don't understand - defaultTranscoder = ValueTranscoders[TextOid] +func (n *NullInt64) EncodeText() (string, error) { + if n.Valid { + return strconv.FormatInt(int64(n.Int64), 10), nil + } else { + return "null", nil + } } var arrayEl *regexp.Regexp = regexp.MustCompile(`[{,](?:"((?:[^"\\]|\\.)*)"|(NULL)|([^,}]+))`) @@ -645,224 +546,3 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error { s := t.Format("2006-01-02 15:04:05.999999 -0700") return encodeText(w, s) } - -func decodeInt2Array(qr *QueryResult, fd *FieldDescription, size int32) []int16 { - if fd.DataType != Int2ArrayOid { - qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int2[] but received: %v", fd.DataType))) - return nil - } - - switch fd.FormatCode { - case TextFormatCode: - s := qr.mr.ReadString(size) - - elements := SplitArrayText(s) - - numbers := make([]int16, 0, len(elements)) - - for _, e := range elements { - n, err := strconv.ParseInt(e, 10, 16) - if err != nil { - qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2[]: %v", s))) - return nil - } - numbers = append(numbers, int16(n)) - } - - return numbers - default: - qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) - return nil - } -} - -func int16SliceToArrayString(nums []int16) (string, error) { - w := &bytes.Buffer{} - _, err := w.WriteString("{") - if err != nil { - return "", err - } - - for i, n := range nums { - if i > 0 { - _, err = w.WriteString(",") - if err != nil { - return "", err - } - } - - _, err = w.WriteString(strconv.FormatInt(int64(n), 10)) - if err != nil { - return "", err - } - } - - _, err = w.WriteString("}") - if err != nil { - return "", err - } - - return w.String(), nil -} - -func encodeInt2Array(w *WriteBuf, value interface{}) error { - v, ok := value.([]int16) - if !ok { - return fmt.Errorf("Expected []int16, received %T", value) - } - - s, err := int16SliceToArrayString(v) - if err != nil { - return fmt.Errorf("Failed to encode []int16: %v", err) - } - - return encodeText(w, s) -} - -func decodeInt4Array(qr *QueryResult, fd *FieldDescription, size int32) []int32 { - if fd.DataType != Int4ArrayOid { - qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int4[] but received: %v", fd.DataType))) - return nil - } - - switch fd.FormatCode { - case TextFormatCode: - s := qr.mr.ReadString(size) - - elements := SplitArrayText(s) - - numbers := make([]int32, 0, len(elements)) - - for _, e := range elements { - n, err := strconv.ParseInt(e, 10, 32) - if err != nil { - qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4[]: %v", s))) - return nil - } - numbers = append(numbers, int32(n)) - } - - return numbers - default: - qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) - return nil - } -} - -func int32SliceToArrayString(nums []int32) (string, error) { - w := &bytes.Buffer{} - - _, err := w.WriteString("{") - if err != nil { - return "", err - } - - for i, n := range nums { - if i > 0 { - _, err = w.WriteString(",") - if err != nil { - return "", err - } - } - - _, err = w.WriteString(strconv.FormatInt(int64(n), 10)) - if err != nil { - return "", err - } - } - - _, err = w.WriteString("}") - if err != nil { - return "", err - } - - return w.String(), nil -} - -func encodeInt4Array(w *WriteBuf, value interface{}) error { - v, ok := value.([]int32) - if !ok { - return fmt.Errorf("Expected []int32, received %T", value) - } - - s, err := int32SliceToArrayString(v) - if err != nil { - return fmt.Errorf("Failed to encode []int32: %v", err) - } - - return encodeText(w, s) -} - -func decodeInt8Array(qr *QueryResult, fd *FieldDescription, size int32) []int64 { - if fd.DataType != Int8ArrayOid { - qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int8[] but received: %v", fd.DataType))) - return nil - } - - switch fd.FormatCode { - case TextFormatCode: - s := qr.mr.ReadString(size) - - elements := SplitArrayText(s) - - numbers := make([]int64, 0, len(elements)) - - for _, e := range elements { - n, err := strconv.ParseInt(e, 10, 64) - if err != nil { - qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8[]: %v", s))) - return nil - } - numbers = append(numbers, int64(n)) - } - - return numbers - default: - qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) - return nil - } -} - -func int64SliceToArrayString(nums []int64) (string, error) { - w := &bytes.Buffer{} - - _, err := w.WriteString("{") - if err != nil { - return "", err - } - - for i, n := range nums { - if i > 0 { - _, err = w.WriteString(",") - if err != nil { - return "", err - } - } - - _, err = w.WriteString(strconv.FormatInt(int64(n), 10)) - if err != nil { - return "", err - } - } - - _, err = w.WriteString("}") - if err != nil { - return "", err - } - - return w.String(), nil -} - -func encodeInt8Array(w *WriteBuf, value interface{}) error { - v, ok := value.([]int64) - if !ok { - return fmt.Errorf("Expected []int64, received %T", value) - } - - s, err := int64SliceToArrayString(v) - if err != nil { - return fmt.Errorf("Failed to encode []int64: %v", err) - } - - return encodeText(w, s) -}