diff --git a/conn.go b/conn.go index f347a7d9..c284c0f9 100644 --- a/conn.go +++ b/conn.go @@ -280,11 +280,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.closedChan = make(chan error) c.oidPgtypeValues = map[OID]pgtype.Value{ - BoolOID: &pgtype.Bool{}, - DateOID: &pgtype.Date{}, - Int2OID: &pgtype.Int2{}, - Int4OID: &pgtype.Int4{}, - Int8OID: &pgtype.Int8{}, + BoolOID: &pgtype.Bool{}, + DateOID: &pgtype.Date{}, + Int2OID: &pgtype.Int2{}, + Int2ArrayOID: &pgtype.Int2Array{}, + Int4OID: &pgtype.Int4{}, + Int8OID: &pgtype.Int8{}, } if tlsConfig != nil { diff --git a/pgtype/convert.go b/pgtype/convert.go index 26f827cf..5d26669d 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -93,3 +93,25 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { return time.Time{}, false } + +// underlyingSliceType gets the underlying slice type +func underlyingSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + baseSliceType := reflect.SliceOf(refVal.Type().Elem()) + if refVal.Type().ConvertibleTo(baseSliceType) { + convVal := refVal.Convert(baseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + } + + return nil, false +} diff --git a/pgtype/int2array.go b/pgtype/int2array.go index 7345305f..43bbccbd 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "fmt" "io" "github.com/jackc/pgx/pgio" @@ -14,6 +15,52 @@ type Int2Array struct { } func (a *Int2Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int2Array: + *a = value + case []int16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return a.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + return nil } diff --git a/values.go b/values.go index 90391f29..a9c4c209 100644 --- a/values.go +++ b/values.go @@ -1087,10 +1087,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case []int16: - return encodeInt16Slice(wbuf, oid, arg) - case []uint16: - return encodeUInt16Slice(wbuf, oid, arg) case []int32: return encodeInt32Slice(wbuf, oid, arg) case []uint32: @@ -2410,42 +2406,45 @@ func encodeByteSliceSlice(w *WriteBuf, oid OID, value [][]byte) error { } func decodeInt2Array(vr *ValueReader) []int16 { - if vr.Len() == -1 { - return nil - } - if vr.Type().DataType != Int2ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType))) return nil } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var a pgtype.Int2Array + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = a.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = a.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } - numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } - a := make([]int16, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 2: - a[i] = vr.ReadInt16() - case -1: + if a.Status == pgtype.Null { + return nil + } + + rawArray := make([]int16, len(a.Elements)) + for i := range a.Elements { + if a.Elements[i].Status == pgtype.Present { + rawArray[i] = a.Elements[i].Int + } else { vr.Fatal(ProtocolError("Cannot decode null element")) return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) - return nil } } - return a + return rawArray } func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { @@ -2492,38 +2491,6 @@ func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { return a } -func encodeInt16Slice(w *WriteBuf, oid OID, slice []int16) error { - if oid != Int2ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid) - } - - encodeArrayHeader(w, Int2OID, len(slice), 6) - for _, v := range slice { - w.WriteInt32(2) - w.WriteInt16(v) - } - - return nil -} - -func encodeUInt16Slice(w *WriteBuf, oid OID, slice []uint16) error { - if oid != Int2ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid) - } - - encodeArrayHeader(w, Int2OID, len(slice), 6) - for _, v := range slice { - if v <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(v)) - } else { - return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16) - } - } - - return nil -} - func decodeInt4Array(vr *ValueReader) []int32 { if vr.Len() == -1 { return nil