mirror of https://github.com/jackc/pgx.git
Use pgtype.Int2Array in pgx
parent
36da5cc217
commit
cc3d1e4af8
11
conn.go
11
conn.go
|
@ -280,11 +280,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
||||||
c.closedChan = make(chan error)
|
c.closedChan = make(chan error)
|
||||||
|
|
||||||
c.oidPgtypeValues = map[OID]pgtype.Value{
|
c.oidPgtypeValues = map[OID]pgtype.Value{
|
||||||
BoolOID: &pgtype.Bool{},
|
BoolOID: &pgtype.Bool{},
|
||||||
DateOID: &pgtype.Date{},
|
DateOID: &pgtype.Date{},
|
||||||
Int2OID: &pgtype.Int2{},
|
Int2OID: &pgtype.Int2{},
|
||||||
Int4OID: &pgtype.Int4{},
|
Int2ArrayOID: &pgtype.Int2Array{},
|
||||||
Int8OID: &pgtype.Int8{},
|
Int4OID: &pgtype.Int4{},
|
||||||
|
Int8OID: &pgtype.Int8{},
|
||||||
}
|
}
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
|
|
|
@ -93,3 +93,25 @@ func underlyingTimeType(val interface{}) (interface{}, bool) {
|
||||||
|
|
||||||
return time.Time{}, false
|
return time.Time{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// underlyingSliceType gets the underlying slice type
|
||||||
|
func underlyingSliceType(val interface{}) (interface{}, bool) {
|
||||||
|
refVal := reflect.ValueOf(val)
|
||||||
|
|
||||||
|
switch refVal.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
if refVal.IsNil() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
convVal := refVal.Elem().Interface()
|
||||||
|
return convVal, true
|
||||||
|
case reflect.Slice:
|
||||||
|
baseSliceType := reflect.SliceOf(refVal.Type().Elem())
|
||||||
|
if refVal.Type().ConvertibleTo(baseSliceType) {
|
||||||
|
convVal := refVal.Convert(baseSliceType)
|
||||||
|
return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package pgtype
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgx/pgio"
|
||||||
|
@ -14,6 +15,52 @@ type Int2Array struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Int2Array) ConvertFrom(src interface{}) error {
|
func (a *Int2Array) ConvertFrom(src interface{}) error {
|
||||||
|
switch value := src.(type) {
|
||||||
|
case Int2Array:
|
||||||
|
*a = value
|
||||||
|
case []int16:
|
||||||
|
if value == nil {
|
||||||
|
*a = Int2Array{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*a = Int2Array{Status: Present}
|
||||||
|
} else {
|
||||||
|
elements := make([]Int2, len(value))
|
||||||
|
for i := range value {
|
||||||
|
if err := elements[i].ConvertFrom(value[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*a = Int2Array{
|
||||||
|
Elements: elements,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []uint16:
|
||||||
|
if value == nil {
|
||||||
|
*a = Int2Array{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*a = Int2Array{Status: Present}
|
||||||
|
} else {
|
||||||
|
elements := make([]Int2, len(value))
|
||||||
|
for i := range value {
|
||||||
|
if err := elements[i].ConvertFrom(value[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*a = Int2Array{
|
||||||
|
Elements: elements,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||||
|
return a.ConvertFrom(originalSrc)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("cannot convert %v to Int2", value)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
73
values.go
73
values.go
|
@ -1087,10 +1087,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error {
|
||||||
// The name data type goes over the wire using the same format as string,
|
// The name data type goes over the wire using the same format as string,
|
||||||
// so just cast to string and use encodeString
|
// so just cast to string and use encodeString
|
||||||
return encodeString(wbuf, oid, string(arg))
|
return encodeString(wbuf, oid, string(arg))
|
||||||
case []int16:
|
|
||||||
return encodeInt16Slice(wbuf, oid, arg)
|
|
||||||
case []uint16:
|
|
||||||
return encodeUInt16Slice(wbuf, oid, arg)
|
|
||||||
case []int32:
|
case []int32:
|
||||||
return encodeInt32Slice(wbuf, oid, arg)
|
return encodeInt32Slice(wbuf, oid, arg)
|
||||||
case []uint32:
|
case []uint32:
|
||||||
|
@ -2410,42 +2406,45 @@ func encodeByteSliceSlice(w *WriteBuf, oid OID, value [][]byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeInt2Array(vr *ValueReader) []int16 {
|
func decodeInt2Array(vr *ValueReader) []int16 {
|
||||||
if vr.Len() == -1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if vr.Type().DataType != Int2ArrayOID {
|
if vr.Type().DataType != Int2ArrayOID {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType)))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if vr.Type().FormatCode != BinaryFormatCode {
|
vr.err = errRewoundLen
|
||||||
|
|
||||||
|
var a pgtype.Int2Array
|
||||||
|
var err error
|
||||||
|
switch vr.Type().FormatCode {
|
||||||
|
case TextFormatCode:
|
||||||
|
err = a.DecodeText(&valueReader2{vr})
|
||||||
|
case BinaryFormatCode:
|
||||||
|
err = a.DecodeBinary(&valueReader2{vr})
|
||||||
|
default:
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
numElems, err := decode1dArrayHeader(vr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
vr.Fatal(err)
|
vr.Fatal(err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
a := make([]int16, int(numElems))
|
if a.Status == pgtype.Null {
|
||||||
for i := 0; i < len(a); i++ {
|
return nil
|
||||||
elSize := vr.ReadInt32()
|
}
|
||||||
switch elSize {
|
|
||||||
case 2:
|
rawArray := make([]int16, len(a.Elements))
|
||||||
a[i] = vr.ReadInt16()
|
for i := range a.Elements {
|
||||||
case -1:
|
if a.Elements[i].Status == pgtype.Present {
|
||||||
|
rawArray[i] = a.Elements[i].Int
|
||||||
|
} else {
|
||||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||||
return nil
|
return nil
|
||||||
default:
|
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize)))
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return a
|
return rawArray
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 {
|
func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 {
|
||||||
|
@ -2492,38 +2491,6 @@ func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeInt16Slice(w *WriteBuf, oid OID, slice []int16) error {
|
|
||||||
if oid != Int2ArrayOID {
|
|
||||||
return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid)
|
|
||||||
}
|
|
||||||
|
|
||||||
encodeArrayHeader(w, Int2OID, len(slice), 6)
|
|
||||||
for _, v := range slice {
|
|
||||||
w.WriteInt32(2)
|
|
||||||
w.WriteInt16(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeUInt16Slice(w *WriteBuf, oid OID, slice []uint16) error {
|
|
||||||
if oid != Int2ArrayOID {
|
|
||||||
return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid)
|
|
||||||
}
|
|
||||||
|
|
||||||
encodeArrayHeader(w, Int2OID, len(slice), 6)
|
|
||||||
for _, v := range slice {
|
|
||||||
if v <= math.MaxInt16 {
|
|
||||||
w.WriteInt32(2)
|
|
||||||
w.WriteInt16(int16(v))
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeInt4Array(vr *ValueReader) []int32 {
|
func decodeInt4Array(vr *ValueReader) []int32 {
|
||||||
if vr.Len() == -1 {
|
if vr.Len() == -1 {
|
||||||
return nil
|
return nil
|
||||||
|
|
Loading…
Reference in New Issue