mirror of https://github.com/jackc/pgx.git
Add status to pgtype.Bool
parent
325f700b6e
commit
720451f06d
7
conn.go
7
conn.go
|
@ -279,13 +279,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
||||||
c.doneChan = make(chan struct{})
|
c.doneChan = make(chan struct{})
|
||||||
c.closedChan = make(chan error)
|
c.closedChan = make(chan error)
|
||||||
|
|
||||||
b := pgtype.Bool(false)
|
|
||||||
i2 := pgtype.Int2(0)
|
i2 := pgtype.Int2(0)
|
||||||
i4 := pgtype.Int4(0)
|
i4 := pgtype.Int4(0)
|
||||||
i8 := pgtype.Int8(0)
|
i8 := pgtype.Int8(0)
|
||||||
|
|
||||||
c.oidPgtypeValues = map[OID]pgtype.Value{
|
c.oidPgtypeValues = map[OID]pgtype.Value{
|
||||||
BoolOID: &b,
|
BoolOID: &pgtype.Bool{},
|
||||||
DateOID: &pgtype.Date{},
|
DateOID: &pgtype.Date{},
|
||||||
Int2OID: &i2,
|
Int2OID: &i2,
|
||||||
Int4OID: &i4,
|
Int4OID: &i4,
|
||||||
|
@ -978,6 +977,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||||
switch arg := arguments[i].(type) {
|
switch arg := arguments[i].(type) {
|
||||||
case Encoder:
|
case Encoder:
|
||||||
wbuf.WriteInt16(arg.FormatCode())
|
wbuf.WriteInt16(arg.FormatCode())
|
||||||
|
case pgtype.BinaryEncoder:
|
||||||
|
wbuf.WriteInt16(BinaryFormatCode)
|
||||||
|
case pgtype.TextEncoder:
|
||||||
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
case string, *string:
|
case string, *string:
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -8,20 +8,23 @@ import (
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgx/pgio"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Bool bool
|
type Bool struct {
|
||||||
|
Bool bool
|
||||||
|
Status Status
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Bool) ConvertFrom(src interface{}) error {
|
func (b *Bool) ConvertFrom(src interface{}) error {
|
||||||
switch value := src.(type) {
|
switch value := src.(type) {
|
||||||
case Bool:
|
case Bool:
|
||||||
*b = value
|
*b = value
|
||||||
case bool:
|
case bool:
|
||||||
*b = Bool(value)
|
*b = Bool{Bool: value, Status: Present}
|
||||||
case string:
|
case string:
|
||||||
bb, err := strconv.ParseBool(value)
|
bb, err := strconv.ParseBool(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
*b = Bool(bb)
|
*b = Bool{Bool: bb, Status: Present}
|
||||||
default:
|
default:
|
||||||
if originalSrc, ok := underlyingBoolType(src); ok {
|
if originalSrc, ok := underlyingBoolType(src); ok {
|
||||||
return b.ConvertFrom(originalSrc)
|
return b.ConvertFrom(originalSrc)
|
||||||
|
@ -42,6 +45,11 @@ func (b *Bool) DecodeText(r io.Reader) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if size == -1 {
|
||||||
|
*b = Bool{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if size != 1 {
|
if size != 1 {
|
||||||
return fmt.Errorf("invalid length for bool: %v", size)
|
return fmt.Errorf("invalid length for bool: %v", size)
|
||||||
}
|
}
|
||||||
|
@ -51,7 +59,7 @@ func (b *Bool) DecodeText(r io.Reader) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
*b = Bool(byt == 't')
|
*b = Bool{Bool: byt == 't', Status: Present}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,6 +69,11 @@ func (b *Bool) DecodeBinary(r io.Reader) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if size == -1 {
|
||||||
|
*b = Bool{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if size != 1 {
|
if size != 1 {
|
||||||
return fmt.Errorf("invalid length for bool: %v", size)
|
return fmt.Errorf("invalid length for bool: %v", size)
|
||||||
}
|
}
|
||||||
|
@ -70,18 +83,22 @@ func (b *Bool) DecodeBinary(r io.Reader) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
*b = Bool(byt == 1)
|
*b = Bool{Bool: byt == 1, Status: Present}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b Bool) EncodeText(w io.Writer) error {
|
func (b Bool) EncodeText(w io.Writer) error {
|
||||||
|
if done, err := encodeNotPresent(w, b.Status); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
_, err := pgio.WriteInt32(w, 1)
|
_, err := pgio.WriteInt32(w, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf []byte
|
var buf []byte
|
||||||
if b {
|
if b.Bool {
|
||||||
buf = []byte{'t'}
|
buf = []byte{'t'}
|
||||||
} else {
|
} else {
|
||||||
buf = []byte{'f'}
|
buf = []byte{'f'}
|
||||||
|
@ -92,13 +109,17 @@ func (b Bool) EncodeText(w io.Writer) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b Bool) EncodeBinary(w io.Writer) error {
|
func (b Bool) EncodeBinary(w io.Writer) error {
|
||||||
|
if done, err := encodeNotPresent(w, b.Status); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
_, err := pgio.WriteInt32(w, 1)
|
_, err := pgio.WriteInt32(w, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf []byte
|
var buf []byte
|
||||||
if b {
|
if b.Bool {
|
||||||
buf = []byte{1}
|
buf = []byte{1}
|
||||||
} else {
|
} else {
|
||||||
buf = []byte{0}
|
buf = []byte{0}
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
package pgtype_test
|
package pgtype_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,64 +19,33 @@ func TestBoolTranscode(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
result pgtype.Bool
|
result pgtype.Bool
|
||||||
}{
|
}{
|
||||||
{result: pgtype.Bool(false)},
|
{result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
|
||||||
{result: pgtype.Bool(true)},
|
{result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||||
|
{result: pgtype.Bool{Bool: false, Status: pgtype.Null}},
|
||||||
}
|
}
|
||||||
|
|
||||||
ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode
|
formats := []struct {
|
||||||
|
name string
|
||||||
|
formatCode int16
|
||||||
|
}{
|
||||||
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
||||||
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, fc := range formats {
|
||||||
|
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
inputBuf := &bytes.Buffer{}
|
|
||||||
err = tt.result.EncodeText(inputBuf)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("TextFormat %d: %v", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var s string
|
|
||||||
err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("TextFormat %d: %v", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
outputBuf := &bytes.Buffer{}
|
|
||||||
pgio.WriteInt32(outputBuf, int32(len(s)))
|
|
||||||
outputBuf.WriteString(s)
|
|
||||||
var r pgtype.Bool
|
var r pgtype.Bool
|
||||||
err = r.DecodeText(outputBuf)
|
err := conn.QueryRow("test", tt.result).Scan(&r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("TextFormat %d: %v", i, err)
|
t.Errorf("%v %d: %v", fc.name, i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r != tt.result {
|
if r != tt.result {
|
||||||
t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r)
|
t.Errorf("%v %d: expected %v, got %v", fc.name, i, tt.result, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode
|
|
||||||
for i, tt := range tests {
|
|
||||||
inputBuf := &bytes.Buffer{}
|
|
||||||
err = tt.result.EncodeBinary(inputBuf)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("BinaryFormat %d: %v", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf []byte
|
|
||||||
err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("BinaryFormat %d: %v", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
outputBuf := &bytes.Buffer{}
|
|
||||||
pgio.WriteInt32(outputBuf, int32(len(buf)))
|
|
||||||
outputBuf.Write(buf)
|
|
||||||
var r pgtype.Bool
|
|
||||||
err = r.DecodeBinary(outputBuf)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("BinaryFormat %d: %v", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r != tt.result {
|
|
||||||
t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,12 +56,12 @@ func TestBoolConvertFrom(t *testing.T) {
|
||||||
source interface{}
|
source interface{}
|
||||||
result pgtype.Bool
|
result pgtype.Bool
|
||||||
}{
|
}{
|
||||||
{source: true, result: pgtype.Bool(true)},
|
{source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||||
{source: false, result: pgtype.Bool(false)},
|
{source: false, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||||
{source: "true", result: pgtype.Bool(true)},
|
{source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||||
{source: "false", result: pgtype.Bool(false)},
|
{source: "false", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||||
{source: "t", result: pgtype.Bool(true)},
|
{source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||||
{source: "f", result: pgtype.Bool(false)},
|
{source: "f", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range successfulTests {
|
for i, tt := range successfulTests {
|
||||||
|
|
|
@ -1,7 +1,18 @@
|
||||||
package pgtype
|
package pgtype
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Status byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
Undefined Status = iota
|
||||||
|
Null
|
||||||
|
Present
|
||||||
)
|
)
|
||||||
|
|
||||||
type Value interface {
|
type Value interface {
|
||||||
|
@ -24,3 +35,16 @@ type BinaryEncoder interface {
|
||||||
type TextEncoder interface {
|
type TextEncoder interface {
|
||||||
EncodeText(w io.Writer) error
|
EncodeText(w io.Writer) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errUndefined = errors.New("cannot encode status undefined")
|
||||||
|
|
||||||
|
func encodeNotPresent(w io.Writer, status Status) (done bool, err error) {
|
||||||
|
switch status {
|
||||||
|
case Undefined:
|
||||||
|
return true, errUndefined
|
||||||
|
case Null:
|
||||||
|
_, err = pgio.WriteInt32(w, -1)
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
14
query.go
14
query.go
|
@ -6,6 +6,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Row is a convenience wrapper over Rows that is returned by QueryRow.
|
// Row is a convenience wrapper over Rows that is returned by QueryRow.
|
||||||
|
@ -228,6 +230,18 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.Fatal(scanArgError{col: i, err: err})
|
rows.Fatal(scanArgError{col: i, err: err})
|
||||||
}
|
}
|
||||||
|
} else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode {
|
||||||
|
vr.err = errRewoundLen
|
||||||
|
err = s.DecodeBinary(&valueReader2{vr})
|
||||||
|
if err != nil {
|
||||||
|
rows.Fatal(scanArgError{col: i, err: err})
|
||||||
|
}
|
||||||
|
} else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode {
|
||||||
|
vr.err = errRewoundLen
|
||||||
|
err = s.DecodeText(&valueReader2{vr})
|
||||||
|
if err != nil {
|
||||||
|
rows.Fatal(scanArgError{col: i, err: err})
|
||||||
|
}
|
||||||
} else if s, ok := d.(sql.Scanner); ok {
|
} else if s, ok := d.(sql.Scanner); ok {
|
||||||
var val interface{}
|
var val interface{}
|
||||||
if 0 <= vr.Len() {
|
if 0 <= vr.Len() {
|
||||||
|
|
23
values.go
23
values.go
|
@ -1026,6 +1026,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error {
|
||||||
switch arg := arg.(type) {
|
switch arg := arg.(type) {
|
||||||
case Encoder:
|
case Encoder:
|
||||||
return arg.Encode(wbuf, oid)
|
return arg.Encode(wbuf, oid)
|
||||||
|
case pgtype.BinaryEncoder:
|
||||||
|
return arg.EncodeBinary(wbuf)
|
||||||
|
case pgtype.TextEncoder:
|
||||||
|
return arg.EncodeText(wbuf)
|
||||||
case driver.Valuer:
|
case driver.Valuer:
|
||||||
v, err := arg.Value()
|
v, err := arg.Value()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1398,21 +1402,11 @@ func Decode(vr *ValueReader, d interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeBool(vr *ValueReader) bool {
|
func decodeBool(vr *ValueReader) bool {
|
||||||
if vr.Len() == -1 {
|
|
||||||
vr.Fatal(ProtocolError("Cannot decode null into bool"))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if vr.Type().DataType != BoolOID {
|
if vr.Type().DataType != BoolOID {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType)))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if vr.Type().FormatCode != BinaryFormatCode {
|
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
vr.err = errRewoundLen
|
vr.err = errRewoundLen
|
||||||
|
|
||||||
var b pgtype.Bool
|
var b pgtype.Bool
|
||||||
|
@ -1432,7 +1426,12 @@ func decodeBool(vr *ValueReader) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return bool(b)
|
if b.Status != pgtype.Present {
|
||||||
|
vr.Fatal(fmt.Errorf("Cannot decode null into bool"))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeBool(w *WriteBuf, oid OID, value bool) error {
|
func encodeBool(w *WriteBuf, oid OID, value bool) error {
|
||||||
|
@ -1440,7 +1439,7 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error {
|
||||||
return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid)
|
return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid)
|
||||||
}
|
}
|
||||||
|
|
||||||
b := pgtype.Bool(value)
|
b := pgtype.Bool{Bool: value, Status: pgtype.Present}
|
||||||
return b.EncodeBinary(w)
|
return b.EncodeBinary(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue