From 001647c1da03f796014cf21f41c9a7fd2cfadfde Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Sat, 25 Feb 2017 16:15:51 -0600
Subject: [PATCH] Add Status to pgtype.Int2

---
 conn.go             |   3 +-
 pgtype/int2.go      |  51 ++++++++++++++-------
 pgtype/int2_test.go | 107 +++++++++++++++-----------------------------
 values.go           |  13 +++---
 4 files changed, 78 insertions(+), 96 deletions(-)

diff --git a/conn.go b/conn.go
index 750aa7f5..794e6427 100644
--- a/conn.go
+++ b/conn.go
@@ -279,14 +279,13 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
 	c.doneChan = make(chan struct{})
 	c.closedChan = make(chan error)
 
-	i2 := pgtype.Int2(0)
 	i4 := pgtype.Int4(0)
 	i8 := pgtype.Int8(0)
 
 	c.oidPgtypeValues = map[OID]pgtype.Value{
 		BoolOID: &pgtype.Bool{},
 		DateOID: &pgtype.Date{},
-		Int2OID: &i2,
+		Int2OID: &pgtype.Int2{},
 		Int4OID: &i4,
 		Int8OID: &i8,
 	}
diff --git a/pgtype/int2.go b/pgtype/int2.go
index 38dac534..636ea1f1 100644
--- a/pgtype/int2.go
+++ b/pgtype/int2.go
@@ -9,23 +9,26 @@ import (
 	"github.com/jackc/pgx/pgio"
 )
 
-type Int2 int16
+type Int2 struct {
+	Int    int16
+	Status Status
+}
 
 func (i *Int2) ConvertFrom(src interface{}) error {
 	switch value := src.(type) {
 	case Int2:
 		*i = value
 	case int8:
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case uint8:
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case int16:
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case uint16:
 		if value > math.MaxInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
 		}
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case int32:
 		if value < math.MinInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
@@ -33,12 +36,12 @@ func (i *Int2) ConvertFrom(src interface{}) error {
 		if value > math.MaxInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
 		}
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case uint32:
 		if value > math.MaxInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
 		}
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case int64:
 		if value < math.MinInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
@@ -46,12 +49,12 @@ func (i *Int2) ConvertFrom(src interface{}) error {
 		if value > math.MaxInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
 		}
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case uint64:
 		if value > math.MaxInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
 		}
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case int:
 		if value < math.MinInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
@@ -59,18 +62,18 @@ func (i *Int2) ConvertFrom(src interface{}) error {
 		if value > math.MaxInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
 		}
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case uint:
 		if value > math.MaxInt16 {
 			return fmt.Errorf("%d is greater than maximum value for Int2", value)
 		}
-		*i = Int2(value)
+		*i = Int2{Int: int16(value), Status: Present}
 	case string:
 		num, err := strconv.ParseInt(value, 10, 16)
 		if err != nil {
 			return err
 		}
-		*i = Int2(num)
+		*i = Int2{Int: int16(num), Status: Present}
 	default:
 		if originalSrc, ok := underlyingIntType(src); ok {
 			return i.ConvertFrom(originalSrc)
@@ -92,7 +95,8 @@ func (i *Int2) DecodeText(r io.Reader) error {
 	}
 
 	if size == -1 {
-		return fmt.Errorf("invalid length for int2: %v", size)
+		*i = Int2{Status: Null}
+		return nil
 	}
 
 	buf := make([]byte, int(size))
@@ -106,7 +110,7 @@ func (i *Int2) DecodeText(r io.Reader) error {
 		return err
 	}
 
-	*i = Int2(n)
+	*i = Int2{Int: int16(n), Status: Present}
 	return nil
 }
 
@@ -116,6 +120,11 @@ func (i *Int2) DecodeBinary(r io.Reader) error {
 		return err
 	}
 
+	if size == -1 {
+		*i = Int2{Status: Null}
+		return nil
+	}
+
 	if size != 2 {
 		return fmt.Errorf("invalid length for int2: %v", size)
 	}
@@ -125,12 +134,16 @@ func (i *Int2) DecodeBinary(r io.Reader) error {
 		return err
 	}
 
-	*i = Int2(n)
+	*i = Int2{Int: int16(n), Status: Present}
 	return nil
 }
 
 func (i Int2) EncodeText(w io.Writer) error {
-	s := strconv.FormatInt(int64(i), 10)
+	if done, err := encodeNotPresent(w, i.Status); done {
+		return err
+	}
+
+	s := strconv.FormatInt(int64(i.Int), 10)
 	_, err := pgio.WriteInt32(w, int32(len(s)))
 	if err != nil {
 		return nil
@@ -140,11 +153,15 @@ func (i Int2) EncodeText(w io.Writer) error {
 }
 
 func (i Int2) EncodeBinary(w io.Writer) error {
+	if done, err := encodeNotPresent(w, i.Status); done {
+		return err
+	}
+
 	_, err := pgio.WriteInt32(w, 2)
 	if err != nil {
 		return err
 	}
 
-	_, err = pgio.WriteInt16(w, int16(i))
+	_, err = pgio.WriteInt16(w, i.Int)
 	return err
 }
diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go
index b40262e5..0abde920 100644
--- a/pgtype/int2_test.go
+++ b/pgtype/int2_test.go
@@ -1,12 +1,10 @@
 package pgtype_test
 
 import (
-	"bytes"
 	"math"
 	"testing"
 
 	"github.com/jackc/pgx"
-	"github.com/jackc/pgx/pgio"
 	"github.com/jackc/pgx/pgtype"
 )
 
@@ -22,66 +20,33 @@ func TestInt2Transcode(t *testing.T) {
 	tests := []struct {
 		result pgtype.Int2
 	}{
-		{result: pgtype.Int2(math.MinInt16)},
-		{result: pgtype.Int2(-1)},
-		{result: pgtype.Int2(0)},
-		{result: pgtype.Int2(1)},
-		{result: pgtype.Int2(math.MaxInt16)},
+		{result: pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}},
+		{result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
+		{result: pgtype.Int2{Int: 0, Status: pgtype.Present}},
+		{result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{result: pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}},
 	}
 
-	ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode
-	for i, tt := range tests {
-		inputBuf := &bytes.Buffer{}
-		err = tt.result.EncodeText(inputBuf)
-		if err != nil {
-			t.Errorf("TextFormat %d: %v", i, err)
-		}
-
-		var s string
-		err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s)
-		if err != nil {
-			t.Errorf("TextFormat %d: %v", i, err)
-		}
-
-		outputBuf := &bytes.Buffer{}
-		pgio.WriteInt32(outputBuf, int32(len(s)))
-		outputBuf.WriteString(s)
-		var r pgtype.Int2
-		err = r.DecodeText(outputBuf)
-		if err != nil {
-			t.Errorf("TextFormat %d: %v", i, err)
-		}
-
-		if r != tt.result {
-			t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r)
-		}
+	formats := []struct {
+		name       string
+		formatCode int16
+	}{
+		{name: "TextFormat", formatCode: pgx.TextFormatCode},
+		{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
 	}
 
-	ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode
-	for i, tt := range tests {
-		inputBuf := &bytes.Buffer{}
-		err = tt.result.EncodeBinary(inputBuf)
-		if err != nil {
-			t.Errorf("BinaryFormat %d: %v", i, err)
-		}
+	for _, fc := range formats {
+		ps.FieldDescriptions[0].FormatCode = fc.formatCode
+		for i, tt := range tests {
+			var r pgtype.Int2
+			err := conn.QueryRow("test", tt.result).Scan(&r)
+			if err != nil {
+				t.Errorf("%v %d: %v", fc.name, i, err)
+			}
 
-		var buf []byte
-		err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf)
-		if err != nil {
-			t.Errorf("BinaryFormat %d: %v", i, err)
-		}
-
-		outputBuf := &bytes.Buffer{}
-		pgio.WriteInt32(outputBuf, int32(len(buf)))
-		outputBuf.Write(buf)
-		var r pgtype.Int2
-		err = r.DecodeBinary(outputBuf)
-		if err != nil {
-			t.Errorf("BinaryFormat %d: %v", i, err)
-		}
-
-		if r != tt.result {
-			t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r)
+			if r != tt.result {
+				t.Errorf("%v %d: expected %v, got %v", fc.name, i, tt.result, r)
+			}
 		}
 	}
 }
@@ -93,20 +58,20 @@ func TestInt2ConvertFrom(t *testing.T) {
 		source interface{}
 		result pgtype.Int2
 	}{
-		{source: int8(1), result: pgtype.Int2(1)},
-		{source: int16(1), result: pgtype.Int2(1)},
-		{source: int32(1), result: pgtype.Int2(1)},
-		{source: int64(1), result: pgtype.Int2(1)},
-		{source: int8(-1), result: pgtype.Int2(-1)},
-		{source: int16(-1), result: pgtype.Int2(-1)},
-		{source: int32(-1), result: pgtype.Int2(-1)},
-		{source: int64(-1), result: pgtype.Int2(-1)},
-		{source: uint8(1), result: pgtype.Int2(1)},
-		{source: uint16(1), result: pgtype.Int2(1)},
-		{source: uint32(1), result: pgtype.Int2(1)},
-		{source: uint64(1), result: pgtype.Int2(1)},
-		{source: "1", result: pgtype.Int2(1)},
-		{source: _int8(1), result: pgtype.Int2(1)},
+		{source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
+		{source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
+		{source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
+		{source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
+		{source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
+		{source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
 	}
 
 	for i, tt := range successfulTests {
diff --git a/values.go b/values.go
index 0dea734a..469df006 100644
--- a/values.go
+++ b/values.go
@@ -503,7 +503,7 @@ func (n NullInt16) Encode(w *WriteBuf, oid OID) error {
 		return nil
 	}
 
-	return pgtype.Int2(n.Int16).EncodeBinary(w)
+	return pgtype.Int2{Int: n.Int16, Status: pgtype.Present}.EncodeBinary(w)
 }
 
 // NullInt32 represents an integer that may be null. NullInt32 implements the
@@ -1515,10 +1515,6 @@ func decodeChar(vr *ValueReader) Char {
 }
 
 func decodeInt2(vr *ValueReader) int16 {
-	if vr.Len() == -1 {
-		vr.Fatal(ProtocolError("Cannot decode null into int16"))
-		return 0
-	}
 
 	if vr.Type().DataType != Int2OID {
 		vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType)))
@@ -1544,7 +1540,12 @@ func decodeInt2(vr *ValueReader) int16 {
 		return 0
 	}
 
-	return int16(n)
+	if n.Status == pgtype.Null {
+		vr.Fatal(ProtocolError("Cannot decode null into int16"))
+		return 0
+	}
+
+	return n.Int
 }
 
 func encodeChar(w *WriteBuf, oid OID, value Char) error {