mirror of https://github.com/jackc/pgx.git
Merge branch 'karlseguin-time_array'
commit
8b7af157a3
6
conn.go
6
conn.go
|
@ -549,7 +549,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
switch oid {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid:
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
default:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
|
@ -593,6 +593,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
err = encodeTimestampTz(wbuf, arguments[i])
|
||||
case TimestampOid:
|
||||
err = encodeTimestamp(wbuf, arguments[i])
|
||||
case BoolArrayOid:
|
||||
err = encodeBoolArray(wbuf, arguments[i])
|
||||
case Int2ArrayOid:
|
||||
err = encodeInt2Array(wbuf, arguments[i])
|
||||
case Int4ArrayOid:
|
||||
|
@ -607,6 +609,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
err = encodeTextArray(wbuf, arguments[i], TextOid)
|
||||
case VarcharArrayOid:
|
||||
err = encodeTextArray(wbuf, arguments[i], VarcharOid)
|
||||
case TimestampArrayOid:
|
||||
err = encodeTimestampArray(wbuf, arguments[i], VarcharOid)
|
||||
case OidOid:
|
||||
err = encodeOid(wbuf, arguments[i])
|
||||
default:
|
||||
|
|
32
conn_test.go
32
conn_test.go
|
@ -515,3 +515,35 @@ func TestCommandTag(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertBoolArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Accept parameters
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertTimestampArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Accept parameters
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
}
|
||||
|
|
6
query.go
6
query.go
|
@ -243,6 +243,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
*d = decodeFloat4(vr)
|
||||
case *float64:
|
||||
*d = decodeFloat8(vr)
|
||||
case *[]bool:
|
||||
*d = decodeBoolArray(vr)
|
||||
case *[]int16:
|
||||
*d = decodeInt2Array(vr)
|
||||
case *[]int32:
|
||||
|
@ -255,6 +257,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
*d = decodeFloat8Array(vr)
|
||||
case *[]string:
|
||||
*d = decodeTextArray(vr)
|
||||
case *[]time.Time:
|
||||
*d = decodeTimestampArray(vr)
|
||||
case *time.Time:
|
||||
switch vr.Type().DataType {
|
||||
case DateOid:
|
||||
|
@ -324,6 +328,8 @@ func (rows *Rows) Values() ([]interface{}, error) {
|
|||
values = append(values, decodeFloat4(vr))
|
||||
case Float8Oid:
|
||||
values = append(values, decodeFloat8(vr))
|
||||
case BoolArrayOid:
|
||||
values = append(values, decodeBoolArray(vr))
|
||||
case Int2ArrayOid:
|
||||
values = append(values, decodeInt2Array(vr))
|
||||
case Int4ArrayOid:
|
||||
|
|
224
values.go
224
values.go
|
@ -12,26 +12,28 @@ import (
|
|||
|
||||
// PostgreSQL oids for common types
|
||||
const (
|
||||
BoolOid = 16
|
||||
ByteaOid = 17
|
||||
Int8Oid = 20
|
||||
Int2Oid = 21
|
||||
Int4Oid = 23
|
||||
TextOid = 25
|
||||
OidOid = 26
|
||||
Float4Oid = 700
|
||||
Float8Oid = 701
|
||||
Int2ArrayOid = 1005
|
||||
Int4ArrayOid = 1007
|
||||
TextArrayOid = 1009
|
||||
VarcharArrayOid = 1015
|
||||
Int8ArrayOid = 1016
|
||||
Float4ArrayOid = 1021
|
||||
Float8ArrayOid = 1022
|
||||
VarcharOid = 1043
|
||||
DateOid = 1082
|
||||
TimestampOid = 1114
|
||||
TimestampTzOid = 1184
|
||||
BoolOid = 16
|
||||
ByteaOid = 17
|
||||
Int8Oid = 20
|
||||
Int2Oid = 21
|
||||
Int4Oid = 23
|
||||
TextOid = 25
|
||||
OidOid = 26
|
||||
Float4Oid = 700
|
||||
Float8Oid = 701
|
||||
BoolArrayOid = 1000
|
||||
Int2ArrayOid = 1005
|
||||
Int4ArrayOid = 1007
|
||||
TextArrayOid = 1009
|
||||
VarcharArrayOid = 1015
|
||||
Int8ArrayOid = 1016
|
||||
Float4ArrayOid = 1021
|
||||
Float8ArrayOid = 1022
|
||||
VarcharOid = 1043
|
||||
DateOid = 1082
|
||||
TimestampOid = 1114
|
||||
TimestampArrayOid = 1115
|
||||
TimestampTzOid = 1184
|
||||
)
|
||||
|
||||
// PostgreSQL format codes
|
||||
|
@ -51,11 +53,13 @@ func init() {
|
|||
DefaultTypeFormats = make(map[string]int16)
|
||||
DefaultTypeFormats["_float4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_float8"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_bool"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_int2"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_int4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_int8"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_text"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_varchar"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_timestamp"] = BinaryFormatCode
|
||||
DefaultTypeFormats["bool"] = BinaryFormatCode
|
||||
DefaultTypeFormats["bytea"] = BinaryFormatCode
|
||||
DefaultTypeFormats["date"] = BinaryFormatCode
|
||||
|
@ -1193,6 +1197,66 @@ func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
|
|||
return length, nil
|
||||
}
|
||||
|
||||
func decodeBoolArray(vr *ValueReader) []bool {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != BoolArrayOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([]bool, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
switch elSize {
|
||||
case 1:
|
||||
if vr.ReadByte() == 1 {
|
||||
a[i] = true
|
||||
}
|
||||
case -1:
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func encodeBoolArray(w *WriteBuf, value interface{}) error {
|
||||
slice, ok := value.([]bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []bool, received %T", value)
|
||||
}
|
||||
|
||||
encodeArrayHeader(w, BoolOid, len(slice), 5)
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(1)
|
||||
var b byte
|
||||
if v {
|
||||
b = 1
|
||||
}
|
||||
w.WriteByte(b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeInt2Array(vr *ValueReader) []int16 {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
|
@ -1238,15 +1302,7 @@ func encodeInt2Array(w *WriteBuf, value interface{}) error {
|
|||
return fmt.Errorf("Expected []int16, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*6
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Int2Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
encodeArrayHeader(w, Int2Oid, len(slice), 6)
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(2)
|
||||
w.WriteInt16(v)
|
||||
|
@ -1300,15 +1356,7 @@ func encodeInt4Array(w *WriteBuf, value interface{}) error {
|
|||
return fmt.Errorf("Expected []int32, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*8
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Int4Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
encodeArrayHeader(w, Int4Oid, len(slice), 8)
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(4)
|
||||
w.WriteInt32(v)
|
||||
|
@ -1362,15 +1410,7 @@ func encodeInt8Array(w *WriteBuf, value interface{}) error {
|
|||
return fmt.Errorf("Expected []int64, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*12
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Int8Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
encodeArrayHeader(w, Int8Oid, len(slice), 12)
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(8)
|
||||
w.WriteInt64(v)
|
||||
|
@ -1424,19 +1464,9 @@ func encodeFloat4Array(w *WriteBuf, value interface{}) error {
|
|||
if !ok {
|
||||
return fmt.Errorf("Expected []float32, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*8
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Float4Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
encodeArrayHeader(w, Float4Oid, len(slice), 8)
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(4)
|
||||
|
||||
w.WriteInt32(int32(math.Float32bits(v)))
|
||||
}
|
||||
|
||||
|
@ -1489,18 +1519,9 @@ func encodeFloat8Array(w *WriteBuf, value interface{}) error {
|
|||
return fmt.Errorf("Expected []float64, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*12
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Float8Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
encodeArrayHeader(w, Float8Oid, len(slice), 12)
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(8)
|
||||
|
||||
w.WriteInt64(int64(math.Float64bits(v)))
|
||||
}
|
||||
|
||||
|
@ -1569,3 +1590,70 @@ func encodeTextArray(w *WriteBuf, value interface{}, elOid Oid) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeTimestampArray(vr *ValueReader) []time.Time {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != TimestampArrayOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([]time.Time, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
switch elSize {
|
||||
case 8:
|
||||
microsecSinceY2K := vr.ReadInt64()
|
||||
microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
|
||||
a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
|
||||
case -1:
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error {
|
||||
slice, ok := value.([]time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []time.Time, received %T", value)
|
||||
}
|
||||
|
||||
encodeArrayHeader(w, TimestampOid, len(slice), 12)
|
||||
for _, t := range slice {
|
||||
w.WriteInt32(8)
|
||||
microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000
|
||||
microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K
|
||||
w.WriteInt64(microsecSinceY2K)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) {
|
||||
w.WriteInt32(int32(20 + length*sizePerItem))
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(int32(oid)) // type of elements
|
||||
w.WriteInt32(int32(length)) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package pgx_test
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/jackc/pgx"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -159,6 +160,65 @@ func TestNullXMismatch(t *testing.T) {
|
|||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
query interface{}
|
||||
scan interface{}
|
||||
assert func(*testing.T, interface{}, interface{})
|
||||
}{
|
||||
{
|
||||
"select $1::bool[]", []bool{true, false, true}, &[]bool{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]bool))) == false {
|
||||
t.Errorf("failed to encode bool[]")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"select $1::int[]", []int32{2, 4, 484}, &[]int32{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]int32))) == false {
|
||||
t.Errorf("failed to encode int[]")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]string))) == false {
|
||||
t.Errorf("failed to encode text[]")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
|
||||
t.Errorf("failed to encode time.Time[]")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
psName := fmt.Sprintf("ps%d", i)
|
||||
mustPrepare(t, conn, psName, tt.sql)
|
||||
|
||||
err := conn.QueryRow(psName, tt.query).Scan(tt.scan)
|
||||
if err != nil {
|
||||
t.Errorf(`error reading array: %v`, err)
|
||||
}
|
||||
tt.assert(t, tt.query, tt.scan)
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArrayDecoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type allTypes struct {
|
||||
s pgx.NullString
|
||||
i16 pgx.NullInt16
|
||||
|
|
Loading…
Reference in New Issue