From e5707023cac7c07342b8c910e480d09a1caaf6ee Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Feb 2017 14:10:56 -0600 Subject: [PATCH] Move bool to pgtype --- pgtype/bool.go | 82 ++++++++++++++++++++++++++++++++++++++++++++ pgtype/bool_test.go | 83 +++++++++++++++++++++++++++++++++++++++++++++ values.go | 33 ++++++++++-------- 3 files changed, 184 insertions(+), 14 deletions(-) create mode 100644 pgtype/bool.go create mode 100644 pgtype/bool_test.go diff --git a/pgtype/bool.go b/pgtype/bool.go new file mode 100644 index 00000000..6ab07485 --- /dev/null +++ b/pgtype/bool.go @@ -0,0 +1,82 @@ +package pgtype + +import ( + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Bool bool + +func (b *Bool) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool(byt == 't') + return nil +} + +func (b *Bool) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool(byt == 1) + return nil +} + +func (b Bool) EncodeText(w io.Writer) error { + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b { + buf = []byte{'t'} + } else { + buf = []byte{'f'} + } + + _, err = w.Write(buf) + return err +} + +func (b Bool) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b { + buf = []byte{1} + } else { + buf = []byte{0} + } + + _, err = w.Write(buf) + return err +} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go new file mode 100644 index 00000000..2ed5d75d --- /dev/null +++ b/pgtype/bool_test.go @@ -0,0 +1,83 @@ +package pgtype_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/pgtype" +) + +func TestBoolTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + ps, err := conn.Prepare("test", "select $1::bool") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + result pgtype.Bool + }{ + {result: pgtype.Bool(false)}, + {result: pgtype.Bool(true)}, + } + + ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode + for i, tt := range tests { + inputBuf := &bytes.Buffer{} + err = tt.result.EncodeText(inputBuf) + if err != nil { + t.Errorf("TextFormat %d: %v", i, err) + } + + var s string + err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s) + if err != nil { + t.Errorf("TextFormat %d: %v", i, err) + } + + outputBuf := &bytes.Buffer{} + pgio.WriteInt32(outputBuf, int32(len(s))) + outputBuf.WriteString(s) + var r pgtype.Bool + err = r.DecodeText(outputBuf) + if err != nil { + t.Errorf("TextFormat %d: %v", i, err) + } + + if r != tt.result { + t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r) + } + } + + ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode + for i, tt := range tests { + inputBuf := &bytes.Buffer{} + err = tt.result.EncodeBinary(inputBuf) + if err != nil { + t.Errorf("BinaryFormat %d: %v", i, err) + } + + var buf []byte + err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf) + if err != nil { + t.Errorf("BinaryFormat %d: %v", i, err) + } + + outputBuf := &bytes.Buffer{} + pgio.WriteInt32(outputBuf, int32(len(buf))) + outputBuf.Write(buf) + var r pgtype.Bool + err = r.DecodeBinary(outputBuf) + if err != nil { + t.Errorf("BinaryFormat %d: %v", i, err) + } + + if r != tt.result { + t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r) + } + } +} diff --git a/values.go b/values.go index 69a5daee..081d4670 100644 --- a/values.go +++ b/values.go @@ -1428,13 +1428,26 @@ func decodeBool(vr *ValueReader) bool { return false } - if vr.Len() != 1 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len()))) + vr.err = errRewoundLen + + var b pgtype.Bool + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = b.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = b.DecodeBinary(&valueReader2{vr}) + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return false } - b := vr.ReadByte() - return b != 0 + if err != nil { + vr.Fatal(err) + return false + } + + return bool(b) } func encodeBool(w *WriteBuf, oid OID, value bool) error { @@ -1442,16 +1455,8 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error { return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) } - w.WriteInt32(1) - - var n byte - if value { - n = 1 - } - - w.WriteByte(n) - - return nil + b := pgtype.Bool(value) + return b.EncodeBinary(w) } func decodeInt(vr *ValueReader) int64 {