Add status to pgtype.Bool

pgxtype-experiment2
Jack Christensen 2017-02-25 15:56:44 -06:00
parent 325f700b6e
commit 720451f06d
6 changed files with 108 additions and 80 deletions

View File

@ -279,13 +279,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.doneChan = make(chan struct{}) c.doneChan = make(chan struct{})
c.closedChan = make(chan error) c.closedChan = make(chan error)
b := pgtype.Bool(false)
i2 := pgtype.Int2(0) i2 := pgtype.Int2(0)
i4 := pgtype.Int4(0) i4 := pgtype.Int4(0)
i8 := pgtype.Int8(0) i8 := pgtype.Int8(0)
c.oidPgtypeValues = map[OID]pgtype.Value{ c.oidPgtypeValues = map[OID]pgtype.Value{
BoolOID: &b, BoolOID: &pgtype.Bool{},
DateOID: &pgtype.Date{}, DateOID: &pgtype.Date{},
Int2OID: &i2, Int2OID: &i2,
Int4OID: &i4, Int4OID: &i4,
@ -978,6 +977,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
switch arg := arguments[i].(type) { switch arg := arguments[i].(type) {
case Encoder: case Encoder:
wbuf.WriteInt16(arg.FormatCode()) wbuf.WriteInt16(arg.FormatCode())
case pgtype.BinaryEncoder:
wbuf.WriteInt16(BinaryFormatCode)
case pgtype.TextEncoder:
wbuf.WriteInt16(TextFormatCode)
case string, *string: case string, *string:
wbuf.WriteInt16(TextFormatCode) wbuf.WriteInt16(TextFormatCode)
default: default:

View File

@ -8,20 +8,23 @@ import (
"github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgio"
) )
type Bool bool type Bool struct {
Bool bool
Status Status
}
func (b *Bool) ConvertFrom(src interface{}) error { func (b *Bool) ConvertFrom(src interface{}) error {
switch value := src.(type) { switch value := src.(type) {
case Bool: case Bool:
*b = value *b = value
case bool: case bool:
*b = Bool(value) *b = Bool{Bool: value, Status: Present}
case string: case string:
bb, err := strconv.ParseBool(value) bb, err := strconv.ParseBool(value)
if err != nil { if err != nil {
return err return err
} }
*b = Bool(bb) *b = Bool{Bool: bb, Status: Present}
default: default:
if originalSrc, ok := underlyingBoolType(src); ok { if originalSrc, ok := underlyingBoolType(src); ok {
return b.ConvertFrom(originalSrc) return b.ConvertFrom(originalSrc)
@ -42,6 +45,11 @@ func (b *Bool) DecodeText(r io.Reader) error {
return err return err
} }
if size == -1 {
*b = Bool{Status: Null}
return nil
}
if size != 1 { if size != 1 {
return fmt.Errorf("invalid length for bool: %v", size) return fmt.Errorf("invalid length for bool: %v", size)
} }
@ -51,7 +59,7 @@ func (b *Bool) DecodeText(r io.Reader) error {
return err return err
} }
*b = Bool(byt == 't') *b = Bool{Bool: byt == 't', Status: Present}
return nil return nil
} }
@ -61,6 +69,11 @@ func (b *Bool) DecodeBinary(r io.Reader) error {
return err return err
} }
if size == -1 {
*b = Bool{Status: Null}
return nil
}
if size != 1 { if size != 1 {
return fmt.Errorf("invalid length for bool: %v", size) return fmt.Errorf("invalid length for bool: %v", size)
} }
@ -70,18 +83,22 @@ func (b *Bool) DecodeBinary(r io.Reader) error {
return err return err
} }
*b = Bool(byt == 1) *b = Bool{Bool: byt == 1, Status: Present}
return nil return nil
} }
func (b Bool) EncodeText(w io.Writer) error { func (b Bool) EncodeText(w io.Writer) error {
if done, err := encodeNotPresent(w, b.Status); done {
return err
}
_, err := pgio.WriteInt32(w, 1) _, err := pgio.WriteInt32(w, 1)
if err != nil { if err != nil {
return nil return nil
} }
var buf []byte var buf []byte
if b { if b.Bool {
buf = []byte{'t'} buf = []byte{'t'}
} else { } else {
buf = []byte{'f'} buf = []byte{'f'}
@ -92,13 +109,17 @@ func (b Bool) EncodeText(w io.Writer) error {
} }
func (b Bool) EncodeBinary(w io.Writer) error { func (b Bool) EncodeBinary(w io.Writer) error {
if done, err := encodeNotPresent(w, b.Status); done {
return err
}
_, err := pgio.WriteInt32(w, 1) _, err := pgio.WriteInt32(w, 1)
if err != nil { if err != nil {
return nil return nil
} }
var buf []byte var buf []byte
if b { if b.Bool {
buf = []byte{1} buf = []byte{1}
} else { } else {
buf = []byte{0} buf = []byte{0}

View File

@ -1,11 +1,9 @@
package pgtype_test package pgtype_test
import ( import (
"bytes"
"testing" "testing"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@ -21,64 +19,33 @@ func TestBoolTranscode(t *testing.T) {
tests := []struct { tests := []struct {
result pgtype.Bool result pgtype.Bool
}{ }{
{result: pgtype.Bool(false)}, {result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
{result: pgtype.Bool(true)}, {result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{result: pgtype.Bool{Bool: false, Status: pgtype.Null}},
} }
ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode formats := []struct {
name string
formatCode int16
}{
{name: "TextFormat", formatCode: pgx.TextFormatCode},
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
}
for _, fc := range formats {
ps.FieldDescriptions[0].FormatCode = fc.formatCode
for i, tt := range tests { for i, tt := range tests {
inputBuf := &bytes.Buffer{}
err = tt.result.EncodeText(inputBuf)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
var s string
err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s)
if err != nil {
t.Errorf("TextFormat %d: %v", i, err)
}
outputBuf := &bytes.Buffer{}
pgio.WriteInt32(outputBuf, int32(len(s)))
outputBuf.WriteString(s)
var r pgtype.Bool var r pgtype.Bool
err = r.DecodeText(outputBuf) err := conn.QueryRow("test", tt.result).Scan(&r)
if err != nil { if err != nil {
t.Errorf("TextFormat %d: %v", i, err) t.Errorf("%v %d: %v", fc.name, i, err)
} }
if r != tt.result { if r != tt.result {
t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r) t.Errorf("%v %d: expected %v, got %v", fc.name, i, tt.result, r)
} }
} }
ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode
for i, tt := range tests {
inputBuf := &bytes.Buffer{}
err = tt.result.EncodeBinary(inputBuf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
var buf []byte
err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
outputBuf := &bytes.Buffer{}
pgio.WriteInt32(outputBuf, int32(len(buf)))
outputBuf.Write(buf)
var r pgtype.Bool
err = r.DecodeBinary(outputBuf)
if err != nil {
t.Errorf("BinaryFormat %d: %v", i, err)
}
if r != tt.result {
t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r)
}
} }
} }
@ -89,12 +56,12 @@ func TestBoolConvertFrom(t *testing.T) {
source interface{} source interface{}
result pgtype.Bool result pgtype.Bool
}{ }{
{source: true, result: pgtype.Bool(true)}, {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: false, result: pgtype.Bool(false)}, {source: false, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "true", result: pgtype.Bool(true)}, {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "false", result: pgtype.Bool(false)}, {source: "false", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "t", result: pgtype.Bool(true)}, {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "f", result: pgtype.Bool(false)}, {source: "f", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
} }
for i, tt := range successfulTests { for i, tt := range successfulTests {

View File

@ -1,7 +1,18 @@
package pgtype package pgtype
import ( import (
"errors"
"io" "io"
"github.com/jackc/pgx/pgio"
)
type Status byte
const (
Undefined Status = iota
Null
Present
) )
type Value interface { type Value interface {
@ -24,3 +35,16 @@ type BinaryEncoder interface {
type TextEncoder interface { type TextEncoder interface {
EncodeText(w io.Writer) error EncodeText(w io.Writer) error
} }
var errUndefined = errors.New("cannot encode status undefined")
func encodeNotPresent(w io.Writer, status Status) (done bool, err error) {
switch status {
case Undefined:
return true, errUndefined
case Null:
_, err = pgio.WriteInt32(w, -1)
return true, err
}
return false, nil
}

View File

@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"golang.org/x/net/context" "golang.org/x/net/context"
"time" "time"
"github.com/jackc/pgx/pgtype"
) )
// Row is a convenience wrapper over Rows that is returned by QueryRow. // Row is a convenience wrapper over Rows that is returned by QueryRow.
@ -228,6 +230,18 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
if err != nil { if err != nil {
rows.Fatal(scanArgError{col: i, err: err}) rows.Fatal(scanArgError{col: i, err: err})
} }
} else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode {
vr.err = errRewoundLen
err = s.DecodeBinary(&valueReader2{vr})
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode {
vr.err = errRewoundLen
err = s.DecodeText(&valueReader2{vr})
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if s, ok := d.(sql.Scanner); ok { } else if s, ok := d.(sql.Scanner); ok {
var val interface{} var val interface{}
if 0 <= vr.Len() { if 0 <= vr.Len() {

View File

@ -1026,6 +1026,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error {
switch arg := arg.(type) { switch arg := arg.(type) {
case Encoder: case Encoder:
return arg.Encode(wbuf, oid) return arg.Encode(wbuf, oid)
case pgtype.BinaryEncoder:
return arg.EncodeBinary(wbuf)
case pgtype.TextEncoder:
return arg.EncodeText(wbuf)
case driver.Valuer: case driver.Valuer:
v, err := arg.Value() v, err := arg.Value()
if err != nil { if err != nil {
@ -1398,21 +1402,11 @@ func Decode(vr *ValueReader, d interface{}) error {
} }
func decodeBool(vr *ValueReader) bool { func decodeBool(vr *ValueReader) bool {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into bool"))
return false
}
if vr.Type().DataType != BoolOID { if vr.Type().DataType != BoolOID {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType)))
return false return false
} }
if vr.Type().FormatCode != BinaryFormatCode {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return false
}
vr.err = errRewoundLen vr.err = errRewoundLen
var b pgtype.Bool var b pgtype.Bool
@ -1432,7 +1426,12 @@ func decodeBool(vr *ValueReader) bool {
return false return false
} }
return bool(b) if b.Status != pgtype.Present {
vr.Fatal(fmt.Errorf("Cannot decode null into bool"))
return false
}
return b.Bool
} }
func encodeBool(w *WriteBuf, oid OID, value bool) error { func encodeBool(w *WriteBuf, oid OID, value bool) error {
@ -1440,7 +1439,7 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error {
return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid)
} }
b := pgtype.Bool(value) b := pgtype.Bool{Bool: value, Status: pgtype.Present}
return b.EncodeBinary(w) return b.EncodeBinary(w)
} }