support decoding of []time.Time and []bool

pull/53/head
Karl Seguin 2014-12-21 14:35:38 +07:00
parent be663f648c
commit 109b55f9de
3 changed files with 169 additions and 19 deletions

View File

@ -243,6 +243,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
*d = decodeFloat4(vr)
case *float64:
*d = decodeFloat8(vr)
case *[]bool:
*d = decodeBoolArray(vr)
case *[]int16:
*d = decodeInt2Array(vr)
case *[]int32:
@ -255,6 +257,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
*d = decodeFloat8Array(vr)
case *[]string:
*d = decodeTextArray(vr)
case *[]time.Time:
*d = decodeTimestampArray(vr)
case *time.Time:
switch vr.Type().DataType {
case DateOid:
@ -324,6 +328,8 @@ func (rows *Rows) Values() ([]interface{}, error) {
values = append(values, decodeFloat4(vr))
case Float8Oid:
values = append(values, decodeFloat8(vr))
case BoolArrayOid:
values = append(values, decodeBoolArray(vr))
case Int2ArrayOid:
values = append(values, decodeInt2Array(vr))
case Int4ArrayOid:

122
values.go
View File

@ -53,11 +53,13 @@ func init() {
DefaultTypeFormats = make(map[string]int16)
DefaultTypeFormats["_float4"] = BinaryFormatCode
DefaultTypeFormats["_float8"] = BinaryFormatCode
DefaultTypeFormats["_bool"] = BinaryFormatCode
DefaultTypeFormats["_int2"] = BinaryFormatCode
DefaultTypeFormats["_int4"] = BinaryFormatCode
DefaultTypeFormats["_int8"] = BinaryFormatCode
DefaultTypeFormats["_text"] = BinaryFormatCode
DefaultTypeFormats["_varchar"] = BinaryFormatCode
DefaultTypeFormats["_timestamp"] = BinaryFormatCode
DefaultTypeFormats["bool"] = BinaryFormatCode
DefaultTypeFormats["bytea"] = BinaryFormatCode
DefaultTypeFormats["date"] = BinaryFormatCode
@ -1195,6 +1197,66 @@ func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
return length, nil
}
func decodeBoolArray(vr *ValueReader) []bool {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != BoolArrayOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType)))
return nil
}
if vr.Type().FormatCode != BinaryFormatCode {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
numElems, err := decode1dArrayHeader(vr)
if err != nil {
vr.Fatal(err)
return nil
}
a := make([]bool, int(numElems))
for i := 0; i < len(a); i++ {
elSize := vr.ReadInt32()
switch elSize {
case 1:
if vr.ReadByte() == 1 {
a[i] = true
}
case -1:
vr.Fatal(ProtocolError("Cannot decode null element"))
return nil
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize)))
return nil
}
}
return a
}
func encodeBoolArray(w *WriteBuf, value interface{}) error {
slice, ok := value.([]bool)
if !ok {
return fmt.Errorf("Expected []bool, received %T", value)
}
encodeArrayHeader(w, BoolOid, len(slice), 5)
for _, v := range slice {
w.WriteInt32(1)
var b byte
if v {
b = 1
}
w.WriteByte(b)
}
return nil
}
func decodeInt2Array(vr *ValueReader) []int16 {
if vr.Len() == -1 {
return nil
@ -1234,25 +1296,6 @@ func decodeInt2Array(vr *ValueReader) []int16 {
return a
}
func encodeBoolArray(w *WriteBuf, value interface{}) error {
slice, ok := value.([]bool)
if !ok {
return fmt.Errorf("Expected []bool, received %T", value)
}
encodeArrayHeader(w, BoolOid, len(slice), 5)
for _, v := range slice {
w.WriteInt32(1)
var b byte
if v {
b = 1
}
w.WriteByte(b)
}
return nil
}
func encodeInt2Array(w *WriteBuf, value interface{}) error {
slice, ok := value.([]int16)
if !ok {
@ -1548,6 +1591,47 @@ func encodeTextArray(w *WriteBuf, value interface{}, elOid Oid) error {
return nil
}
func decodeTimestampArray(vr *ValueReader) []time.Time {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != TimestampArrayOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType)))
return nil
}
if vr.Type().FormatCode != BinaryFormatCode {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
numElems, err := decode1dArrayHeader(vr)
if err != nil {
vr.Fatal(err)
return nil
}
a := make([]time.Time, int(numElems))
for i := 0; i < len(a); i++ {
elSize := vr.ReadInt32()
switch elSize {
case 8:
microsecSinceY2K := vr.ReadInt64()
microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
case -1:
vr.Fatal(ProtocolError("Cannot decode null element"))
return nil
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize)))
return nil
}
}
return a
}
func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error {
slice, ok := value.([]time.Time)
if !ok {

View File

@ -3,6 +3,7 @@ package pgx_test
import (
"fmt"
"github.com/jackc/pgx"
"reflect"
"strings"
"testing"
"time"
@ -159,6 +160,65 @@ func TestNullXMismatch(t *testing.T) {
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
tests := []struct {
sql string
query interface{}
scan interface{}
assert func(*testing.T, interface{}, interface{})
}{
{
"select $1::bool[]", []bool{true, false, true}, &[]bool{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]bool))) == false {
t.Errorf("failed to encode bool[]")
}
},
},
{
"select $1::int[]", []int32{2, 4, 484}, &[]int32{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]int32))) == false {
t.Errorf("failed to encode int[]")
}
},
},
{
"select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]string))) == false {
t.Errorf("failed to encode text[]")
}
},
},
{
"select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
t.Errorf("failed to encode time.Time[]")
}
},
},
}
for i, tt := range tests {
psName := fmt.Sprintf("ps%d", i)
mustPrepare(t, conn, psName, tt.sql)
err := conn.QueryRow(psName, tt.query).Scan(tt.scan)
if err != nil {
t.Errorf(`error reading array: %v`, err)
}
tt.assert(t, tt.query, tt.scan)
ensureConnValid(t, conn)
}
}
func TestArrayDecoding(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
type allTypes struct {
s pgx.NullString
i16 pgx.NullInt16