From 623ba1eeb14c0b61c90e032c8ebe93c6b448c575 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 27 Apr 2016 08:26:59 -0500 Subject: [PATCH] Add scan to uint16 refs #138 --- query_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ values.go | 12 ++++++++++++ 2 files changed, 53 insertions(+) diff --git a/query_test.go b/query_test.go index 664d0bb6..8c5b191b 100644 --- a/query_test.go +++ b/query_test.go @@ -474,6 +474,47 @@ func TestQueryRowCoreTypes(t *testing.T) { } } +func TestQueryRowCoreUnsignedIntTypes(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + ui16 uint16 + ui32 uint32 + ui64 uint64 + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + expected allTypes + }{ + {"select $1::int2", []interface{}{uint16(42)}, []interface{}{&actual.ui16}, allTypes{ui16: 42}}, + {"select $1::int4", []interface{}{uint32(42)}, []interface{}{&actual.ui32}, allTypes{ui32: 42}}, + {"select $1::int8", []interface{}{uint64(42)}, []interface{}{&actual.ui64}, allTypes{ui64: 42}}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) + } + + if actual != tt.expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } +} + func TestQueryRowCoreByteSlice(t *testing.T) { t.Parallel() diff --git a/values.go b/values.go index 06f49a9c..80c136a8 100644 --- a/values.go +++ b/values.go @@ -693,6 +693,18 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeInt2(vr) case *int32: *v = decodeInt4(vr) + case *uint16: + var valInt int16 + switch vr.Type().DataType { + case Int2Oid: + valInt = int16(decodeInt2(vr)) + default: + return fmt.Errorf("Can't convert OID %v to uint16", vr.Type().DataType) + } + if valInt < 0 { + return fmt.Errorf("%d is less than zero for uint16", valInt) + } + *v = uint16(valInt) case *uint32: var valInt int32 switch vr.Type().DataType {