Move bool to pgtype

pgxtype-experiment2
Jack Christensen 2017-02-24 14:10:56 -06:00
parent bb764d2129
commit e5707023ca
3 changed files with 184 additions and 14 deletions

82
pgtype/bool.go Normal file
View File

@ -0,0 +1,82 @@
package pgtype
import (
"fmt"
"io"
"github.com/jackc/pgx/pgio"
)
type Bool bool
func (b *Bool) DecodeText(r io.Reader) error {
size, err := pgio.ReadInt32(r)
if err != nil {
return err
}
if size != 1 {
return fmt.Errorf("invalid length for bool: %v", size)
}
byt, err := pgio.ReadByte(r)
if err != nil {
return err
}
*b = Bool(byt == 't')
return nil
}
func (b *Bool) DecodeBinary(r io.Reader) error {
size, err := pgio.ReadInt32(r)
if err != nil {
return err
}
if size != 1 {
return fmt.Errorf("invalid length for bool: %v", size)
}
byt, err := pgio.ReadByte(r)
if err != nil {
return err
}
*b = Bool(byt == 1)
return nil
}
func (b Bool) EncodeText(w io.Writer) error {
_, err := pgio.WriteInt32(w, 1)
if err != nil {
return nil
}
var buf []byte
if b {
buf = []byte{'t'}
} else {
buf = []byte{'f'}
}
_, err = w.Write(buf)
return err
}
func (b Bool) EncodeBinary(w io.Writer) error {
_, err := pgio.WriteInt32(w, 1)
if err != nil {
return nil
}
var buf []byte
if b {
buf = []byte{1}
} else {
buf = []byte{0}
}
_, err = w.Write(buf)
return err
}

83
pgtype/bool_test.go Normal file
View File

@ -0,0 +1,83 @@
package pgtype_test
import (
"bytes"
"testing"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgtype"
)
func TestBoolTranscode(t *testing.T) {
conn := mustConnectPgx(t)
defer mustClose(t, conn)
ps, err := conn.Prepare("test", "select $1::bool")
if err != nil {
t.Fatal(err)
}
tests := []struct {
result pgtype.Bool
}{
{result: pgtype.Bool(false)},
{result: pgtype.Bool(true)},
}
ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode
for i, tt := range tests {
inputBuf := &bytes.Buffer{}
err = tt.result.EncodeText(inputBuf)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
var s string
err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
outputBuf := &bytes.Buffer{}
pgio.WriteInt32(outputBuf, int32(len(s)))
outputBuf.WriteString(s)
var r pgtype.Bool
err = r.DecodeText(outputBuf)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
if r != tt.result {
t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r)
}
}
ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode
for i, tt := range tests {
inputBuf := &bytes.Buffer{}
err = tt.result.EncodeBinary(inputBuf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
var buf []byte
err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
outputBuf := &bytes.Buffer{}
pgio.WriteInt32(outputBuf, int32(len(buf)))
outputBuf.Write(buf)
var r pgtype.Bool
err = r.DecodeBinary(outputBuf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
if r != tt.result {
t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r)
}
}
}

View File

@ -1428,13 +1428,26 @@ func decodeBool(vr *ValueReader) bool {
return false
}
if vr.Len() != 1 {
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len())))
vr.err = errRewoundLen
var b pgtype.Bool
var err error
switch vr.Type().FormatCode {
case TextFormatCode:
err = b.DecodeText(&valueReader2{vr})
case BinaryFormatCode:
err = b.DecodeBinary(&valueReader2{vr})
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return false
}
b := vr.ReadByte()
return b != 0
if err != nil {
vr.Fatal(err)
return false
}
return bool(b)
}
func encodeBool(w *WriteBuf, oid OID, value bool) error {
@ -1442,16 +1455,8 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error {
return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid)
}
w.WriteInt32(1)
var n byte
if value {
n = 1
}
w.WriteByte(n)
return nil
b := pgtype.Bool(value)
return b.EncodeBinary(w)
}
func decodeInt(vr *ValueReader) int64 {