From aac1fd86a422b9c94cbd9d0fc93bd993bf59c585 Mon Sep 17 00:00:00 2001 From: "m1kc (Max Musatov)" Date: Thu, 17 Mar 2016 10:30:30 +0300 Subject: [PATCH] []uint16, []uint32, and []uint64 encoding and decoding. --- values.go | 198 +++++++++++++++++++++++++++++++++++++++++++++++++ values_test.go | 40 ++++++++++ 2 files changed, 238 insertions(+) diff --git a/values.go b/values.go index 5a3a7d7a..dd5da146 100644 --- a/values.go +++ b/values.go @@ -632,18 +632,24 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return encodeInt16Slice(wbuf, oid, arg) case uint16: return encodeUInt16(wbuf, oid, arg) + case []uint16: + return encodeUInt16Slice(wbuf, oid, arg) case int32: return encodeInt32(wbuf, oid, arg) case []int32: return encodeInt32Slice(wbuf, oid, arg) case uint32: return encodeUInt32(wbuf, oid, arg) + case []uint32: + return encodeUInt32Slice(wbuf, oid, arg) case int64: return encodeInt64(wbuf, oid, arg) case []int64: return encodeInt64Slice(wbuf, oid, arg) case uint64: return encodeUInt64(wbuf, oid, arg) + case []uint64: + return encodeUInt64Slice(wbuf, oid, arg) case int: return encodeInt(wbuf, oid, arg) case float32: @@ -728,10 +734,16 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeBoolArray(vr) case *[]int16: *v = decodeInt2Array(vr) + case *[]uint16: + *v = decodeInt2ArrayToUInt(vr) case *[]int32: *v = decodeInt4Array(vr) + case *[]uint32: + *v = decodeInt4ArrayToUInt(vr) case *[]int64: *v = decodeInt8Array(vr) + case *[]uint64: + *v = decodeInt8ArrayToUInt(vr) case *[]float32: *v = decodeFloat4Array(vr) case *[]float64: @@ -1607,6 +1619,50 @@ func decodeInt2Array(vr *ValueReader) []int16 { return a } +func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int2ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint16", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]uint16, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 2: + tmp := vr.ReadInt16() + if tmp < 0 { + vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint16", tmp))) + return nil + } + a[i] = uint16(tmp) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) + return nil + } + } + + return a +} + func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error { if oid != Int2ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid) @@ -1621,6 +1677,24 @@ func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error { return nil } +func encodeUInt16Slice(w *WriteBuf, oid Oid, slice []uint16) error { + if oid != Int2ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid) + } + + encodeArrayHeader(w, Int2Oid, len(slice), 6) + for _, v := range slice { + if v <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(v)) + } else { + return fmt.Errorf("%d is larger than max smallint %d", v, math.MaxInt16) + } + } + + return nil +} + func decodeInt4Array(vr *ValueReader) []int32 { if vr.Len() == -1 { return nil @@ -1660,6 +1734,50 @@ func decodeInt4Array(vr *ValueReader) []int32 { return a } +func decodeInt4ArrayToUInt(vr *ValueReader) []uint32 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int4ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint32", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]uint32, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 4: + tmp := vr.ReadInt32() + if tmp < 0 { + vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint32", tmp))) + return nil + } + a[i] = uint32(tmp) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) + return nil + } + } + + return a +} + func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error { if oid != Int4ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]int32", oid) @@ -1674,6 +1792,24 @@ func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error { return nil } +func encodeUInt32Slice(w *WriteBuf, oid Oid, slice []uint32) error { + if oid != Int4ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint32", oid) + } + + encodeArrayHeader(w, Int4Oid, len(slice), 8) + for _, v := range slice { + if v <= math.MaxInt32 { + w.WriteInt32(4) + w.WriteInt32(int32(v)) + } else { + return fmt.Errorf("%d is larger than max integer %d", v, math.MaxInt32) + } + } + + return nil +} + func decodeInt8Array(vr *ValueReader) []int64 { if vr.Len() == -1 { return nil @@ -1713,6 +1849,50 @@ func decodeInt8Array(vr *ValueReader) []int64 { return a } +func decodeInt8ArrayToUInt(vr *ValueReader) []uint64 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int8ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint64", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]uint64, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + tmp := vr.ReadInt64() + if tmp < 0 { + vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint64", tmp))) + return nil + } + a[i] = uint64(tmp) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) + return nil + } + } + + return a +} + func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error { if oid != Int8ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]int64", oid) @@ -1727,6 +1907,24 @@ func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error { return nil } +func encodeUInt64Slice(w *WriteBuf, oid Oid, slice []uint64) error { + if oid != Int8ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint64", oid) + } + + encodeArrayHeader(w, Int8Oid, len(slice), 12) + for _, v := range slice { + if v <= math.MaxInt64 { + w.WriteInt32(8) + w.WriteInt64(int64(v)) + } else { + return fmt.Errorf("%d is larger than max bigint %d", v, math.MaxInt64) + } + } + + return nil +} + func decodeFloat4Array(vr *ValueReader) []float32 { if vr.Len() == -1 { return nil diff --git a/values_test.go b/values_test.go index 8a54421e..f6ddc623 100644 --- a/values_test.go +++ b/values_test.go @@ -492,6 +492,22 @@ func TestArrayDecoding(t *testing.T) { } }, }, + { + "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]int16))) == false { + t.Errorf("failed to encode smallint[]") + } + }, + }, + { + "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]uint16))) == false { + t.Errorf("failed to encode smallint[]") + } + }, + }, { "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, func(t *testing.T, query, scan interface{}) { @@ -500,6 +516,30 @@ func TestArrayDecoding(t *testing.T) { } }, }, + { + "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]uint32))) == false { + t.Errorf("failed to encode int[]") + } + }, + }, + { + "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]int64))) == false { + t.Errorf("failed to encode bigint[]") + } + }, + }, + { + "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]uint64))) == false { + t.Errorf("failed to encode bigint[]") + } + }, + }, { "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, func(t *testing.T, query, scan interface{}) {