mirror of https://github.com/jackc/pgx.git
Use pgtype.Int2Array in pgx
parent
36da5cc217
commit
cc3d1e4af8
11
conn.go
11
conn.go
|
@ -280,11 +280,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
c.closedChan = make(chan error)
|
||||
|
||||
c.oidPgtypeValues = map[OID]pgtype.Value{
|
||||
BoolOID: &pgtype.Bool{},
|
||||
DateOID: &pgtype.Date{},
|
||||
Int2OID: &pgtype.Int2{},
|
||||
Int4OID: &pgtype.Int4{},
|
||||
Int8OID: &pgtype.Int8{},
|
||||
BoolOID: &pgtype.Bool{},
|
||||
DateOID: &pgtype.Date{},
|
||||
Int2OID: &pgtype.Int2{},
|
||||
Int2ArrayOID: &pgtype.Int2Array{},
|
||||
Int4OID: &pgtype.Int4{},
|
||||
Int8OID: &pgtype.Int8{},
|
||||
}
|
||||
|
||||
if tlsConfig != nil {
|
||||
|
|
|
@ -93,3 +93,25 @@ func underlyingTimeType(val interface{}) (interface{}, bool) {
|
|||
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// underlyingSliceType gets the underlying slice type
|
||||
func underlyingSliceType(val interface{}) (interface{}, bool) {
|
||||
refVal := reflect.ValueOf(val)
|
||||
|
||||
switch refVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
if refVal.IsNil() {
|
||||
return nil, false
|
||||
}
|
||||
convVal := refVal.Elem().Interface()
|
||||
return convVal, true
|
||||
case reflect.Slice:
|
||||
baseSliceType := reflect.SliceOf(refVal.Type().Elem())
|
||||
if refVal.Type().ConvertibleTo(baseSliceType) {
|
||||
convVal := refVal.Convert(baseSliceType)
|
||||
return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package pgtype
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
|
@ -14,6 +15,52 @@ type Int2Array struct {
|
|||
}
|
||||
|
||||
func (a *Int2Array) ConvertFrom(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
case Int2Array:
|
||||
*a = value
|
||||
case []int16:
|
||||
if value == nil {
|
||||
*a = Int2Array{Status: Null}
|
||||
} else if len(value) == 0 {
|
||||
*a = Int2Array{Status: Present}
|
||||
} else {
|
||||
elements := make([]Int2, len(value))
|
||||
for i := range value {
|
||||
if err := elements[i].ConvertFrom(value[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*a = Int2Array{
|
||||
Elements: elements,
|
||||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||
Status: Present,
|
||||
}
|
||||
}
|
||||
case []uint16:
|
||||
if value == nil {
|
||||
*a = Int2Array{Status: Null}
|
||||
} else if len(value) == 0 {
|
||||
*a = Int2Array{Status: Present}
|
||||
} else {
|
||||
elements := make([]Int2, len(value))
|
||||
for i := range value {
|
||||
if err := elements[i].ConvertFrom(value[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*a = Int2Array{
|
||||
Elements: elements,
|
||||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||
Status: Present,
|
||||
}
|
||||
}
|
||||
default:
|
||||
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||
return a.ConvertFrom(originalSrc)
|
||||
}
|
||||
return fmt.Errorf("cannot convert %v to Int2", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
73
values.go
73
values.go
|
@ -1087,10 +1087,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error {
|
|||
// The name data type goes over the wire using the same format as string,
|
||||
// so just cast to string and use encodeString
|
||||
return encodeString(wbuf, oid, string(arg))
|
||||
case []int16:
|
||||
return encodeInt16Slice(wbuf, oid, arg)
|
||||
case []uint16:
|
||||
return encodeUInt16Slice(wbuf, oid, arg)
|
||||
case []int32:
|
||||
return encodeInt32Slice(wbuf, oid, arg)
|
||||
case []uint32:
|
||||
|
@ -2410,42 +2406,45 @@ func encodeByteSliceSlice(w *WriteBuf, oid OID, value [][]byte) error {
|
|||
}
|
||||
|
||||
func decodeInt2Array(vr *ValueReader) []int16 {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Int2ArrayOID {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.err = errRewoundLen
|
||||
|
||||
var a pgtype.Int2Array
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = a.DecodeText(&valueReader2{vr})
|
||||
case BinaryFormatCode:
|
||||
err = a.DecodeBinary(&valueReader2{vr})
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([]int16, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
switch elSize {
|
||||
case 2:
|
||||
a[i] = vr.ReadInt16()
|
||||
case -1:
|
||||
if a.Status == pgtype.Null {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawArray := make([]int16, len(a.Elements))
|
||||
for i := range a.Elements {
|
||||
if a.Elements[i].Status == pgtype.Present {
|
||||
rawArray[i] = a.Elements[i].Int
|
||||
} else {
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return a
|
||||
return rawArray
|
||||
}
|
||||
|
||||
func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 {
|
||||
|
@ -2492,38 +2491,6 @@ func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 {
|
|||
return a
|
||||
}
|
||||
|
||||
func encodeInt16Slice(w *WriteBuf, oid OID, slice []int16) error {
|
||||
if oid != Int2ArrayOID {
|
||||
return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid)
|
||||
}
|
||||
|
||||
encodeArrayHeader(w, Int2OID, len(slice), 6)
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(2)
|
||||
w.WriteInt16(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeUInt16Slice(w *WriteBuf, oid OID, slice []uint16) error {
|
||||
if oid != Int2ArrayOID {
|
||||
return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid)
|
||||
}
|
||||
|
||||
encodeArrayHeader(w, Int2OID, len(slice), 6)
|
||||
for _, v := range slice {
|
||||
if v <= math.MaxInt16 {
|
||||
w.WriteInt32(2)
|
||||
w.WriteInt16(int16(v))
|
||||
} else {
|
||||
return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeInt4Array(vr *ValueReader) []int32 {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
|
|
Loading…
Reference in New Issue