Add status to pgtype.Bool

pgxtype-experiment2
Jack Christensen 2017-02-25 15:56:44 -06:00
parent 325f700b6e
commit 720451f06d
6 changed files with 108 additions and 80 deletions

View File

@ -279,13 +279,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.doneChan = make(chan struct{})
c.closedChan = make(chan error)
b := pgtype.Bool(false)
i2 := pgtype.Int2(0)
i4 := pgtype.Int4(0)
i8 := pgtype.Int8(0)
c.oidPgtypeValues = map[OID]pgtype.Value{
BoolOID: &b,
BoolOID: &pgtype.Bool{},
DateOID: &pgtype.Date{},
Int2OID: &i2,
Int4OID: &i4,
@ -978,6 +977,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
switch arg := arguments[i].(type) {
case Encoder:
wbuf.WriteInt16(arg.FormatCode())
case pgtype.BinaryEncoder:
wbuf.WriteInt16(BinaryFormatCode)
case pgtype.TextEncoder:
wbuf.WriteInt16(TextFormatCode)
case string, *string:
wbuf.WriteInt16(TextFormatCode)
default:

View File

@ -8,20 +8,23 @@ import (
"github.com/jackc/pgx/pgio"
)
type Bool bool
type Bool struct {
Bool bool
Status Status
}
func (b *Bool) ConvertFrom(src interface{}) error {
switch value := src.(type) {
case Bool:
*b = value
case bool:
*b = Bool(value)
*b = Bool{Bool: value, Status: Present}
case string:
bb, err := strconv.ParseBool(value)
if err != nil {
return err
}
*b = Bool(bb)
*b = Bool{Bool: bb, Status: Present}
default:
if originalSrc, ok := underlyingBoolType(src); ok {
return b.ConvertFrom(originalSrc)
@ -42,6 +45,11 @@ func (b *Bool) DecodeText(r io.Reader) error {
return err
}
if size == -1 {
*b = Bool{Status: Null}
return nil
}
if size != 1 {
return fmt.Errorf("invalid length for bool: %v", size)
}
@ -51,7 +59,7 @@ func (b *Bool) DecodeText(r io.Reader) error {
return err
}
*b = Bool(byt == 't')
*b = Bool{Bool: byt == 't', Status: Present}
return nil
}
@ -61,6 +69,11 @@ func (b *Bool) DecodeBinary(r io.Reader) error {
return err
}
if size == -1 {
*b = Bool{Status: Null}
return nil
}
if size != 1 {
return fmt.Errorf("invalid length for bool: %v", size)
}
@ -70,18 +83,22 @@ func (b *Bool) DecodeBinary(r io.Reader) error {
return err
}
*b = Bool(byt == 1)
*b = Bool{Bool: byt == 1, Status: Present}
return nil
}
func (b Bool) EncodeText(w io.Writer) error {
if done, err := encodeNotPresent(w, b.Status); done {
return err
}
_, err := pgio.WriteInt32(w, 1)
if err != nil {
return nil
}
var buf []byte
if b {
if b.Bool {
buf = []byte{'t'}
} else {
buf = []byte{'f'}
@ -92,13 +109,17 @@ func (b Bool) EncodeText(w io.Writer) error {
}
func (b Bool) EncodeBinary(w io.Writer) error {
if done, err := encodeNotPresent(w, b.Status); done {
return err
}
_, err := pgio.WriteInt32(w, 1)
if err != nil {
return nil
}
var buf []byte
if b {
if b.Bool {
buf = []byte{1}
} else {
buf = []byte{0}

View File

@ -1,11 +1,9 @@
package pgtype_test
import (
"bytes"
"testing"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgtype"
)
@ -21,63 +19,32 @@ func TestBoolTranscode(t *testing.T) {
tests := []struct {
result pgtype.Bool
}{
{result: pgtype.Bool(false)},
{result: pgtype.Bool(true)},
{result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
{result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{result: pgtype.Bool{Bool: false, Status: pgtype.Null}},
}
ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode
for i, tt := range tests {
inputBuf := &bytes.Buffer{}
err = tt.result.EncodeText(inputBuf)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
var s string
err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
outputBuf := &bytes.Buffer{}
pgio.WriteInt32(outputBuf, int32(len(s)))
outputBuf.WriteString(s)
var r pgtype.Bool
err = r.DecodeText(outputBuf)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
if r != tt.result {
t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r)
}
formats := []struct {
name string
formatCode int16
}{
{name: "TextFormat", formatCode: pgx.TextFormatCode},
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
}
ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode
for i, tt := range tests {
inputBuf := &bytes.Buffer{}
err = tt.result.EncodeBinary(inputBuf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
for _, fc := range formats {
ps.FieldDescriptions[0].FormatCode = fc.formatCode
var buf []byte
err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
for i, tt := range tests {
var r pgtype.Bool
err := conn.QueryRow("test", tt.result).Scan(&r)
if err != nil {
t.Errorf("%v %d: %v", fc.name, i, err)
}
outputBuf := &bytes.Buffer{}
pgio.WriteInt32(outputBuf, int32(len(buf)))
outputBuf.Write(buf)
var r pgtype.Bool
err = r.DecodeBinary(outputBuf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
if r != tt.result {
t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r)
if r != tt.result {
t.Errorf("%v %d: expected %v, got %v", fc.name, i, tt.result, r)
}
}
}
}
@ -89,12 +56,12 @@ func TestBoolConvertFrom(t *testing.T) {
source interface{}
result pgtype.Bool
}{
{source: true, result: pgtype.Bool(true)},
{source: false, result: pgtype.Bool(false)},
{source: "true", result: pgtype.Bool(true)},
{source: "false", result: pgtype.Bool(false)},
{source: "t", result: pgtype.Bool(true)},
{source: "f", result: pgtype.Bool(false)},
{source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: false, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "false", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "f", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
}
for i, tt := range successfulTests {

View File

@ -1,7 +1,18 @@
package pgtype
import (
"errors"
"io"
"github.com/jackc/pgx/pgio"
)
type Status byte
const (
Undefined Status = iota
Null
Present
)
type Value interface {
@ -24,3 +35,16 @@ type BinaryEncoder interface {
type TextEncoder interface {
EncodeText(w io.Writer) error
}
var errUndefined = errors.New("cannot encode status undefined")
func encodeNotPresent(w io.Writer, status Status) (done bool, err error) {
switch status {
case Undefined:
return true, errUndefined
case Null:
_, err = pgio.WriteInt32(w, -1)
return true, err
}
return false, nil
}

View File

@ -6,6 +6,8 @@ import (
"fmt"
"golang.org/x/net/context"
"time"
"github.com/jackc/pgx/pgtype"
)
// Row is a convenience wrapper over Rows that is returned by QueryRow.
@ -228,6 +230,18 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode {
vr.err = errRewoundLen
err = s.DecodeBinary(&valueReader2{vr})
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode {
vr.err = errRewoundLen
err = s.DecodeText(&valueReader2{vr})
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if s, ok := d.(sql.Scanner); ok {
var val interface{}
if 0 <= vr.Len() {

View File

@ -1026,6 +1026,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error {
switch arg := arg.(type) {
case Encoder:
return arg.Encode(wbuf, oid)
case pgtype.BinaryEncoder:
return arg.EncodeBinary(wbuf)
case pgtype.TextEncoder:
return arg.EncodeText(wbuf)
case driver.Valuer:
v, err := arg.Value()
if err != nil {
@ -1398,21 +1402,11 @@ func Decode(vr *ValueReader, d interface{}) error {
}
func decodeBool(vr *ValueReader) bool {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into bool"))
return false
}
if vr.Type().DataType != BoolOID {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType)))
return false
}
if vr.Type().FormatCode != BinaryFormatCode {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return false
}
vr.err = errRewoundLen
var b pgtype.Bool
@ -1432,7 +1426,12 @@ func decodeBool(vr *ValueReader) bool {
return false
}
return bool(b)
if b.Status != pgtype.Present {
vr.Fatal(fmt.Errorf("Cannot decode null into bool"))
return false
}
return b.Bool
}
func encodeBool(w *WriteBuf, oid OID, value bool) error {
@ -1440,7 +1439,7 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error {
return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid)
}
b := pgtype.Bool(value)
b := pgtype.Bool{Bool: value, Status: pgtype.Present}
return b.EncodeBinary(w)
}