mirror of https://github.com/jackc/pgx.git
Add status to pgtype.Bool
parent
325f700b6e
commit
720451f06d
7
conn.go
7
conn.go
|
@ -279,13 +279,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
c.doneChan = make(chan struct{})
|
||||
c.closedChan = make(chan error)
|
||||
|
||||
b := pgtype.Bool(false)
|
||||
i2 := pgtype.Int2(0)
|
||||
i4 := pgtype.Int4(0)
|
||||
i8 := pgtype.Int8(0)
|
||||
|
||||
c.oidPgtypeValues = map[OID]pgtype.Value{
|
||||
BoolOID: &b,
|
||||
BoolOID: &pgtype.Bool{},
|
||||
DateOID: &pgtype.Date{},
|
||||
Int2OID: &i2,
|
||||
Int4OID: &i4,
|
||||
|
@ -978,6 +977,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
switch arg := arguments[i].(type) {
|
||||
case Encoder:
|
||||
wbuf.WriteInt16(arg.FormatCode())
|
||||
case pgtype.BinaryEncoder:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
case pgtype.TextEncoder:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
case string, *string:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
|
|
|
@ -8,20 +8,23 @@ import (
|
|||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Bool bool
|
||||
type Bool struct {
|
||||
Bool bool
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (b *Bool) ConvertFrom(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
case Bool:
|
||||
*b = value
|
||||
case bool:
|
||||
*b = Bool(value)
|
||||
*b = Bool{Bool: value, Status: Present}
|
||||
case string:
|
||||
bb, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*b = Bool(bb)
|
||||
*b = Bool{Bool: bb, Status: Present}
|
||||
default:
|
||||
if originalSrc, ok := underlyingBoolType(src); ok {
|
||||
return b.ConvertFrom(originalSrc)
|
||||
|
@ -42,6 +45,11 @@ func (b *Bool) DecodeText(r io.Reader) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
*b = Bool{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if size != 1 {
|
||||
return fmt.Errorf("invalid length for bool: %v", size)
|
||||
}
|
||||
|
@ -51,7 +59,7 @@ func (b *Bool) DecodeText(r io.Reader) error {
|
|||
return err
|
||||
}
|
||||
|
||||
*b = Bool(byt == 't')
|
||||
*b = Bool{Bool: byt == 't', Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -61,6 +69,11 @@ func (b *Bool) DecodeBinary(r io.Reader) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
*b = Bool{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if size != 1 {
|
||||
return fmt.Errorf("invalid length for bool: %v", size)
|
||||
}
|
||||
|
@ -70,18 +83,22 @@ func (b *Bool) DecodeBinary(r io.Reader) error {
|
|||
return err
|
||||
}
|
||||
|
||||
*b = Bool(byt == 1)
|
||||
*b = Bool{Bool: byt == 1, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b Bool) EncodeText(w io.Writer) error {
|
||||
if done, err := encodeNotPresent(w, b.Status); done {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := pgio.WriteInt32(w, 1)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
if b {
|
||||
if b.Bool {
|
||||
buf = []byte{'t'}
|
||||
} else {
|
||||
buf = []byte{'f'}
|
||||
|
@ -92,13 +109,17 @@ func (b Bool) EncodeText(w io.Writer) error {
|
|||
}
|
||||
|
||||
func (b Bool) EncodeBinary(w io.Writer) error {
|
||||
if done, err := encodeNotPresent(w, b.Status); done {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := pgio.WriteInt32(w, 1)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
if b {
|
||||
if b.Bool {
|
||||
buf = []byte{1}
|
||||
} else {
|
||||
buf = []byte{0}
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
@ -21,63 +19,32 @@ func TestBoolTranscode(t *testing.T) {
|
|||
tests := []struct {
|
||||
result pgtype.Bool
|
||||
}{
|
||||
{result: pgtype.Bool(false)},
|
||||
{result: pgtype.Bool(true)},
|
||||
{result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
|
||||
{result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{result: pgtype.Bool{Bool: false, Status: pgtype.Null}},
|
||||
}
|
||||
|
||||
ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode
|
||||
for i, tt := range tests {
|
||||
inputBuf := &bytes.Buffer{}
|
||||
err = tt.result.EncodeText(inputBuf)
|
||||
if err != nil {
|
||||
t.Errorf("TextFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
var s string
|
||||
err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("TextFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
outputBuf := &bytes.Buffer{}
|
||||
pgio.WriteInt32(outputBuf, int32(len(s)))
|
||||
outputBuf.WriteString(s)
|
||||
var r pgtype.Bool
|
||||
err = r.DecodeText(outputBuf)
|
||||
if err != nil {
|
||||
t.Errorf("TextFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
if r != tt.result {
|
||||
t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r)
|
||||
}
|
||||
formats := []struct {
|
||||
name string
|
||||
formatCode int16
|
||||
}{
|
||||
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
||||
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
||||
}
|
||||
|
||||
ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode
|
||||
for i, tt := range tests {
|
||||
inputBuf := &bytes.Buffer{}
|
||||
err = tt.result.EncodeBinary(inputBuf)
|
||||
if err != nil {
|
||||
t.Errorf("BinaryFormat %d: %v", i, err)
|
||||
}
|
||||
for _, fc := range formats {
|
||||
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
||||
|
||||
var buf []byte
|
||||
err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf)
|
||||
if err != nil {
|
||||
t.Errorf("BinaryFormat %d: %v", i, err)
|
||||
}
|
||||
for i, tt := range tests {
|
||||
var r pgtype.Bool
|
||||
err := conn.QueryRow("test", tt.result).Scan(&r)
|
||||
if err != nil {
|
||||
t.Errorf("%v %d: %v", fc.name, i, err)
|
||||
}
|
||||
|
||||
outputBuf := &bytes.Buffer{}
|
||||
pgio.WriteInt32(outputBuf, int32(len(buf)))
|
||||
outputBuf.Write(buf)
|
||||
var r pgtype.Bool
|
||||
err = r.DecodeBinary(outputBuf)
|
||||
if err != nil {
|
||||
t.Errorf("BinaryFormat %d: %v", i, err)
|
||||
}
|
||||
|
||||
if r != tt.result {
|
||||
t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r)
|
||||
if r != tt.result {
|
||||
t.Errorf("%v %d: expected %v, got %v", fc.name, i, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -89,12 +56,12 @@ func TestBoolConvertFrom(t *testing.T) {
|
|||
source interface{}
|
||||
result pgtype.Bool
|
||||
}{
|
||||
{source: true, result: pgtype.Bool(true)},
|
||||
{source: false, result: pgtype.Bool(false)},
|
||||
{source: "true", result: pgtype.Bool(true)},
|
||||
{source: "false", result: pgtype.Bool(false)},
|
||||
{source: "t", result: pgtype.Bool(true)},
|
||||
{source: "f", result: pgtype.Bool(false)},
|
||||
{source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{source: false, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{source: "false", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{source: "f", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
|
|
|
@ -1,7 +1,18 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Status byte
|
||||
|
||||
const (
|
||||
Undefined Status = iota
|
||||
Null
|
||||
Present
|
||||
)
|
||||
|
||||
type Value interface {
|
||||
|
@ -24,3 +35,16 @@ type BinaryEncoder interface {
|
|||
type TextEncoder interface {
|
||||
EncodeText(w io.Writer) error
|
||||
}
|
||||
|
||||
var errUndefined = errors.New("cannot encode status undefined")
|
||||
|
||||
func encodeNotPresent(w io.Writer, status Status) (done bool, err error) {
|
||||
switch status {
|
||||
case Undefined:
|
||||
return true, errUndefined
|
||||
case Null:
|
||||
_, err = pgio.WriteInt32(w, -1)
|
||||
return true, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
|
14
query.go
14
query.go
|
@ -6,6 +6,8 @@ import (
|
|||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
// Row is a convenience wrapper over Rows that is returned by QueryRow.
|
||||
|
@ -228,6 +230,18 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
if err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode {
|
||||
vr.err = errRewoundLen
|
||||
err = s.DecodeBinary(&valueReader2{vr})
|
||||
if err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode {
|
||||
vr.err = errRewoundLen
|
||||
err = s.DecodeText(&valueReader2{vr})
|
||||
if err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if s, ok := d.(sql.Scanner); ok {
|
||||
var val interface{}
|
||||
if 0 <= vr.Len() {
|
||||
|
|
23
values.go
23
values.go
|
@ -1026,6 +1026,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error {
|
|||
switch arg := arg.(type) {
|
||||
case Encoder:
|
||||
return arg.Encode(wbuf, oid)
|
||||
case pgtype.BinaryEncoder:
|
||||
return arg.EncodeBinary(wbuf)
|
||||
case pgtype.TextEncoder:
|
||||
return arg.EncodeText(wbuf)
|
||||
case driver.Valuer:
|
||||
v, err := arg.Value()
|
||||
if err != nil {
|
||||
|
@ -1398,21 +1402,11 @@ func Decode(vr *ValueReader, d interface{}) error {
|
|||
}
|
||||
|
||||
func decodeBool(vr *ValueReader) bool {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into bool"))
|
||||
return false
|
||||
}
|
||||
|
||||
if vr.Type().DataType != BoolOID {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType)))
|
||||
return false
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return false
|
||||
}
|
||||
|
||||
vr.err = errRewoundLen
|
||||
|
||||
var b pgtype.Bool
|
||||
|
@ -1432,7 +1426,12 @@ func decodeBool(vr *ValueReader) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
return bool(b)
|
||||
if b.Status != pgtype.Present {
|
||||
vr.Fatal(fmt.Errorf("Cannot decode null into bool"))
|
||||
return false
|
||||
}
|
||||
|
||||
return b.Bool
|
||||
}
|
||||
|
||||
func encodeBool(w *WriteBuf, oid OID, value bool) error {
|
||||
|
@ -1440,7 +1439,7 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error {
|
|||
return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid)
|
||||
}
|
||||
|
||||
b := pgtype.Bool(value)
|
||||
b := pgtype.Bool{Bool: value, Status: pgtype.Present}
|
||||
return b.EncodeBinary(w)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue