Use pgtype.Int2Array in pgx

pgxtype-experiment2
Jack Christensen 2017-03-02 21:19:07 -06:00
parent 36da5cc217
commit cc3d1e4af8
4 changed files with 95 additions and 58 deletions

11
conn.go
View File

@ -280,11 +280,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.closedChan = make(chan error)
c.oidPgtypeValues = map[OID]pgtype.Value{
BoolOID: &pgtype.Bool{},
DateOID: &pgtype.Date{},
Int2OID: &pgtype.Int2{},
Int4OID: &pgtype.Int4{},
Int8OID: &pgtype.Int8{},
BoolOID: &pgtype.Bool{},
DateOID: &pgtype.Date{},
Int2OID: &pgtype.Int2{},
Int2ArrayOID: &pgtype.Int2Array{},
Int4OID: &pgtype.Int4{},
Int8OID: &pgtype.Int8{},
}
if tlsConfig != nil {

View File

@ -93,3 +93,25 @@ func underlyingTimeType(val interface{}) (interface{}, bool) {
return time.Time{}, false
}
// underlyingSliceType gets the underlying slice type
func underlyingSliceType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Slice:
baseSliceType := reflect.SliceOf(refVal.Type().Elem())
if refVal.Type().ConvertibleTo(baseSliceType) {
convVal := refVal.Convert(baseSliceType)
return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
}
}
return nil, false
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"fmt"
"io"
"github.com/jackc/pgx/pgio"
@ -14,6 +15,52 @@ type Int2Array struct {
}
func (a *Int2Array) ConvertFrom(src interface{}) error {
switch value := src.(type) {
case Int2Array:
*a = value
case []int16:
if value == nil {
*a = Int2Array{Status: Null}
} else if len(value) == 0 {
*a = Int2Array{Status: Present}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].ConvertFrom(value[i]); err != nil {
return err
}
}
*a = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
case []uint16:
if value == nil {
*a = Int2Array{Status: Null}
} else if len(value) == 0 {
*a = Int2Array{Status: Present}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].ConvertFrom(value[i]); err != nil {
return err
}
}
*a = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
default:
if originalSrc, ok := underlyingSliceType(src); ok {
return a.ConvertFrom(originalSrc)
}
return fmt.Errorf("cannot convert %v to Int2", value)
}
return nil
}

View File

@ -1087,10 +1087,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error {
// The name data type goes over the wire using the same format as string,
// so just cast to string and use encodeString
return encodeString(wbuf, oid, string(arg))
case []int16:
return encodeInt16Slice(wbuf, oid, arg)
case []uint16:
return encodeUInt16Slice(wbuf, oid, arg)
case []int32:
return encodeInt32Slice(wbuf, oid, arg)
case []uint32:
@ -2410,42 +2406,45 @@ func encodeByteSliceSlice(w *WriteBuf, oid OID, value [][]byte) error {
}
func decodeInt2Array(vr *ValueReader) []int16 {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != Int2ArrayOID {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType)))
return nil
}
if vr.Type().FormatCode != BinaryFormatCode {
vr.err = errRewoundLen
var a pgtype.Int2Array
var err error
switch vr.Type().FormatCode {
case TextFormatCode:
err = a.DecodeText(&valueReader2{vr})
case BinaryFormatCode:
err = a.DecodeBinary(&valueReader2{vr})
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
numElems, err := decode1dArrayHeader(vr)
if err != nil {
vr.Fatal(err)
return nil
}
a := make([]int16, int(numElems))
for i := 0; i < len(a); i++ {
elSize := vr.ReadInt32()
switch elSize {
case 2:
a[i] = vr.ReadInt16()
case -1:
if a.Status == pgtype.Null {
return nil
}
rawArray := make([]int16, len(a.Elements))
for i := range a.Elements {
if a.Elements[i].Status == pgtype.Present {
rawArray[i] = a.Elements[i].Int
} else {
vr.Fatal(ProtocolError("Cannot decode null element"))
return nil
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize)))
return nil
}
}
return a
return rawArray
}
func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 {
@ -2492,38 +2491,6 @@ func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 {
return a
}
func encodeInt16Slice(w *WriteBuf, oid OID, slice []int16) error {
if oid != Int2ArrayOID {
return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid)
}
encodeArrayHeader(w, Int2OID, len(slice), 6)
for _, v := range slice {
w.WriteInt32(2)
w.WriteInt16(v)
}
return nil
}
func encodeUInt16Slice(w *WriteBuf, oid OID, slice []uint16) error {
if oid != Int2ArrayOID {
return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid)
}
encodeArrayHeader(w, Int2OID, len(slice), 6)
for _, v := range slice {
if v <= math.MaxInt16 {
w.WriteInt32(2)
w.WriteInt16(int16(v))
} else {
return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16)
}
}
return nil
}
func decodeInt4Array(vr *ValueReader) []int32 {
if vr.Len() == -1 {
return nil