mirror of https://github.com/jackc/pgx.git
Move bool to pgtype
parent
bb764d2129
commit
e5707023ca
|
@ -0,0 +1,82 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Bool bool
|
||||
|
||||
func (b *Bool) DecodeText(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size != 1 {
|
||||
return fmt.Errorf("invalid length for bool: %v", size)
|
||||
}
|
||||
|
||||
byt, err := pgio.ReadByte(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*b = Bool(byt == 't')
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bool) DecodeBinary(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size != 1 {
|
||||
return fmt.Errorf("invalid length for bool: %v", size)
|
||||
}
|
||||
|
||||
byt, err := pgio.ReadByte(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*b = Bool(byt == 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b Bool) EncodeText(w io.Writer) error {
|
||||
_, err := pgio.WriteInt32(w, 1)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
if b {
|
||||
buf = []byte{'t'}
|
||||
} else {
|
||||
buf = []byte{'f'}
|
||||
}
|
||||
|
||||
_, err = w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b Bool) EncodeBinary(w io.Writer) error {
|
||||
_, err := pgio.WriteInt32(w, 1)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
if b {
|
||||
buf = []byte{1}
|
||||
} else {
|
||||
buf = []byte{0}
|
||||
}
|
||||
|
||||
_, err = w.Write(buf)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func TestBoolTranscode(t *testing.T) {
|
||||
conn := mustConnectPgx(t)
|
||||
defer mustClose(t, conn)
|
||||
|
||||
ps, err := conn.Prepare("test", "select $1::bool")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
result pgtype.Bool
|
||||
}{
|
||||
{result: pgtype.Bool(false)},
|
||||
{result: pgtype.Bool(true)},
|
||||
}
|
||||
|
||||
ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode
|
||||
for i, tt := range tests {
|
||||
inputBuf := &bytes.Buffer{}
|
||||
err = tt.result.EncodeText(inputBuf)
|
||||
if err != nil {
|
||||
t.Errorf("TextFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
var s string
|
||||
err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("TextFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
outputBuf := &bytes.Buffer{}
|
||||
pgio.WriteInt32(outputBuf, int32(len(s)))
|
||||
outputBuf.WriteString(s)
|
||||
var r pgtype.Bool
|
||||
err = r.DecodeText(outputBuf)
|
||||
if err != nil {
|
||||
t.Errorf("TextFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
if r != tt.result {
|
||||
t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r)
|
||||
}
|
||||
}
|
||||
|
||||
ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode
|
||||
for i, tt := range tests {
|
||||
inputBuf := &bytes.Buffer{}
|
||||
err = tt.result.EncodeBinary(inputBuf)
|
||||
if err != nil {
|
||||
t.Errorf("BinaryFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf)
|
||||
if err != nil {
|
||||
t.Errorf("BinaryFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
outputBuf := &bytes.Buffer{}
|
||||
pgio.WriteInt32(outputBuf, int32(len(buf)))
|
||||
outputBuf.Write(buf)
|
||||
var r pgtype.Bool
|
||||
err = r.DecodeBinary(outputBuf)
|
||||
if err != nil {
|
||||
t.Errorf("BinaryFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
if r != tt.result {
|
||||
t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
33
values.go
33
values.go
|
@ -1428,13 +1428,26 @@ func decodeBool(vr *ValueReader) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
if vr.Len() != 1 {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len())))
|
||||
vr.err = errRewoundLen
|
||||
|
||||
var b pgtype.Bool
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = b.DecodeText(&valueReader2{vr})
|
||||
case BinaryFormatCode:
|
||||
err = b.DecodeBinary(&valueReader2{vr})
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return false
|
||||
}
|
||||
|
||||
b := vr.ReadByte()
|
||||
return b != 0
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return false
|
||||
}
|
||||
|
||||
return bool(b)
|
||||
}
|
||||
|
||||
func encodeBool(w *WriteBuf, oid OID, value bool) error {
|
||||
|
@ -1442,16 +1455,8 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error {
|
|||
return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid)
|
||||
}
|
||||
|
||||
w.WriteInt32(1)
|
||||
|
||||
var n byte
|
||||
if value {
|
||||
n = 1
|
||||
}
|
||||
|
||||
w.WriteByte(n)
|
||||
|
||||
return nil
|
||||
b := pgtype.Bool(value)
|
||||
return b.EncodeBinary(w)
|
||||
}
|
||||
|
||||
func decodeInt(vr *ValueReader) int64 {
|
||||
|
|
Loading…
Reference in New Issue