Add Status to pgtype.Int2

pgxtype-experiment2
Jack Christensen 2017-02-25 16:15:51 -06:00
parent 720451f06d
commit 001647c1da
4 changed files with 78 additions and 96 deletions

View File

@ -279,14 +279,13 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.doneChan = make(chan struct{})
c.closedChan = make(chan error)
i2 := pgtype.Int2(0)
i4 := pgtype.Int4(0)
i8 := pgtype.Int8(0)
c.oidPgtypeValues = map[OID]pgtype.Value{
BoolOID: &pgtype.Bool{},
DateOID: &pgtype.Date{},
Int2OID: &i2,
Int2OID: &pgtype.Int2{},
Int4OID: &i4,
Int8OID: &i8,
}

View File

@ -9,23 +9,26 @@ import (
"github.com/jackc/pgx/pgio"
)
type Int2 int16
type Int2 struct {
Int int16
Status Status
}
func (i *Int2) ConvertFrom(src interface{}) error {
switch value := src.(type) {
case Int2:
*i = value
case int8:
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case uint8:
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case int16:
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case uint16:
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case int32:
if value < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
@ -33,12 +36,12 @@ func (i *Int2) ConvertFrom(src interface{}) error {
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case uint32:
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case int64:
if value < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
@ -46,12 +49,12 @@ func (i *Int2) ConvertFrom(src interface{}) error {
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case uint64:
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case int:
if value < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
@ -59,18 +62,18 @@ func (i *Int2) ConvertFrom(src interface{}) error {
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case uint:
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*i = Int2(value)
*i = Int2{Int: int16(value), Status: Present}
case string:
num, err := strconv.ParseInt(value, 10, 16)
if err != nil {
return err
}
*i = Int2(num)
*i = Int2{Int: int16(num), Status: Present}
default:
if originalSrc, ok := underlyingIntType(src); ok {
return i.ConvertFrom(originalSrc)
@ -92,7 +95,8 @@ func (i *Int2) DecodeText(r io.Reader) error {
}
if size == -1 {
return fmt.Errorf("invalid length for int2: %v", size)
*i = Int2{Status: Null}
return nil
}
buf := make([]byte, int(size))
@ -106,7 +110,7 @@ func (i *Int2) DecodeText(r io.Reader) error {
return err
}
*i = Int2(n)
*i = Int2{Int: int16(n), Status: Present}
return nil
}
@ -116,6 +120,11 @@ func (i *Int2) DecodeBinary(r io.Reader) error {
return err
}
if size == -1 {
*i = Int2{Status: Null}
return nil
}
if size != 2 {
return fmt.Errorf("invalid length for int2: %v", size)
}
@ -125,12 +134,16 @@ func (i *Int2) DecodeBinary(r io.Reader) error {
return err
}
*i = Int2(n)
*i = Int2{Int: int16(n), Status: Present}
return nil
}
func (i Int2) EncodeText(w io.Writer) error {
s := strconv.FormatInt(int64(i), 10)
if done, err := encodeNotPresent(w, i.Status); done {
return err
}
s := strconv.FormatInt(int64(i.Int), 10)
_, err := pgio.WriteInt32(w, int32(len(s)))
if err != nil {
return nil
@ -140,11 +153,15 @@ func (i Int2) EncodeText(w io.Writer) error {
}
func (i Int2) EncodeBinary(w io.Writer) error {
if done, err := encodeNotPresent(w, i.Status); done {
return err
}
_, err := pgio.WriteInt32(w, 2)
if err != nil {
return err
}
_, err = pgio.WriteInt16(w, int16(i))
_, err = pgio.WriteInt16(w, i.Int)
return err
}

View File

@ -1,12 +1,10 @@
package pgtype_test
import (
"bytes"
"math"
"testing"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgtype"
)
@ -22,66 +20,33 @@ func TestInt2Transcode(t *testing.T) {
tests := []struct {
result pgtype.Int2
}{
{result: pgtype.Int2(math.MinInt16)},
{result: pgtype.Int2(-1)},
{result: pgtype.Int2(0)},
{result: pgtype.Int2(1)},
{result: pgtype.Int2(math.MaxInt16)},
{result: pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}},
{result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
{result: pgtype.Int2{Int: 0, Status: pgtype.Present}},
{result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{result: pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}},
}
ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode
for i, tt := range tests {
inputBuf := &bytes.Buffer{}
err = tt.result.EncodeText(inputBuf)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
var s string
err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
outputBuf := &bytes.Buffer{}
pgio.WriteInt32(outputBuf, int32(len(s)))
outputBuf.WriteString(s)
var r pgtype.Int2
err = r.DecodeText(outputBuf)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
if r != tt.result {
t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r)
}
formats := []struct {
name string
formatCode int16
}{
{name: "TextFormat", formatCode: pgx.TextFormatCode},
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
}
ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode
for i, tt := range tests {
inputBuf := &bytes.Buffer{}
err = tt.result.EncodeBinary(inputBuf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
for _, fc := range formats {
ps.FieldDescriptions[0].FormatCode = fc.formatCode
for i, tt := range tests {
var r pgtype.Int2
err := conn.QueryRow("test", tt.result).Scan(&r)
if err != nil {
t.Errorf("%v %d: %v", fc.name, i, err)
}
var buf []byte
err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
outputBuf := &bytes.Buffer{}
pgio.WriteInt32(outputBuf, int32(len(buf)))
outputBuf.Write(buf)
var r pgtype.Int2
err = r.DecodeBinary(outputBuf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
if r != tt.result {
t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r)
if r != tt.result {
t.Errorf("%v %d: expected %v, got %v", fc.name, i, tt.result, r)
}
}
}
}
@ -93,20 +58,20 @@ func TestInt2ConvertFrom(t *testing.T) {
source interface{}
result pgtype.Int2
}{
{source: int8(1), result: pgtype.Int2(1)},
{source: int16(1), result: pgtype.Int2(1)},
{source: int32(1), result: pgtype.Int2(1)},
{source: int64(1), result: pgtype.Int2(1)},
{source: int8(-1), result: pgtype.Int2(-1)},
{source: int16(-1), result: pgtype.Int2(-1)},
{source: int32(-1), result: pgtype.Int2(-1)},
{source: int64(-1), result: pgtype.Int2(-1)},
{source: uint8(1), result: pgtype.Int2(1)},
{source: uint16(1), result: pgtype.Int2(1)},
{source: uint32(1), result: pgtype.Int2(1)},
{source: uint64(1), result: pgtype.Int2(1)},
{source: "1", result: pgtype.Int2(1)},
{source: _int8(1), result: pgtype.Int2(1)},
{source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
{source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
{source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
{source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}},
{source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
{source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}},
}
for i, tt := range successfulTests {

View File

@ -503,7 +503,7 @@ func (n NullInt16) Encode(w *WriteBuf, oid OID) error {
return nil
}
return pgtype.Int2(n.Int16).EncodeBinary(w)
return pgtype.Int2{Int: n.Int16, Status: pgtype.Present}.EncodeBinary(w)
}
// NullInt32 represents an integer that may be null. NullInt32 implements the
@ -1515,10 +1515,6 @@ func decodeChar(vr *ValueReader) Char {
}
func decodeInt2(vr *ValueReader) int16 {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into int16"))
return 0
}
if vr.Type().DataType != Int2OID {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType)))
@ -1544,7 +1540,12 @@ func decodeInt2(vr *ValueReader) int16 {
return 0
}
return int16(n)
if n.Status == pgtype.Null {
vr.Fatal(ProtocolError("Cannot decode null into int16"))
return 0
}
return n.Int
}
func encodeChar(w *WriteBuf, oid OID, value Char) error {