diff --git a/conn.go b/conn.go index 33b7d410..750aa7f5 100644 --- a/conn.go +++ b/conn.go @@ -279,13 +279,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.doneChan = make(chan struct{}) c.closedChan = make(chan error) - b := pgtype.Bool(false) i2 := pgtype.Int2(0) i4 := pgtype.Int4(0) i8 := pgtype.Int8(0) c.oidPgtypeValues = map[OID]pgtype.Value{ - BoolOID: &b, + BoolOID: &pgtype.Bool{}, DateOID: &pgtype.Date{}, Int2OID: &i2, Int4OID: &i4, @@ -978,6 +977,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} switch arg := arguments[i].(type) { case Encoder: wbuf.WriteInt16(arg.FormatCode()) + case pgtype.BinaryEncoder: + wbuf.WriteInt16(BinaryFormatCode) + case pgtype.TextEncoder: + wbuf.WriteInt16(TextFormatCode) case string, *string: wbuf.WriteInt16(TextFormatCode) default: diff --git a/pgtype/bool.go b/pgtype/bool.go index 2e44e84c..d645780d 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -8,20 +8,23 @@ import ( "github.com/jackc/pgx/pgio" ) -type Bool bool +type Bool struct { + Bool bool + Status Status +} func (b *Bool) ConvertFrom(src interface{}) error { switch value := src.(type) { case Bool: *b = value case bool: - *b = Bool(value) + *b = Bool{Bool: value, Status: Present} case string: bb, err := strconv.ParseBool(value) if err != nil { return err } - *b = Bool(bb) + *b = Bool{Bool: bb, Status: Present} default: if originalSrc, ok := underlyingBoolType(src); ok { return b.ConvertFrom(originalSrc) @@ -42,6 +45,11 @@ func (b *Bool) DecodeText(r io.Reader) error { return err } + if size == -1 { + *b = Bool{Status: Null} + return nil + } + if size != 1 { return fmt.Errorf("invalid length for bool: %v", size) } @@ -51,7 +59,7 @@ func (b *Bool) DecodeText(r io.Reader) error { return err } - *b = Bool(byt == 't') + *b = Bool{Bool: byt == 't', Status: Present} return nil } @@ -61,6 +69,11 @@ func (b *Bool) DecodeBinary(r io.Reader) error { return err } + if size == -1 { + *b = Bool{Status: Null} + return nil + } + if size != 1 { return fmt.Errorf("invalid length for bool: %v", size) } @@ -70,18 +83,22 @@ func (b *Bool) DecodeBinary(r io.Reader) error { return err } - *b = Bool(byt == 1) + *b = Bool{Bool: byt == 1, Status: Present} return nil } func (b Bool) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + _, err := pgio.WriteInt32(w, 1) if err != nil { return nil } var buf []byte - if b { + if b.Bool { buf = []byte{'t'} } else { buf = []byte{'f'} @@ -92,13 +109,17 @@ func (b Bool) EncodeText(w io.Writer) error { } func (b Bool) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + _, err := pgio.WriteInt32(w, 1) if err != nil { return nil } var buf []byte - if b { + if b.Bool { buf = []byte{1} } else { buf = []byte{0} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index c68067b4..a10ee3ca 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -1,11 +1,9 @@ package pgtype_test import ( - "bytes" "testing" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgtype" ) @@ -21,63 +19,32 @@ func TestBoolTranscode(t *testing.T) { tests := []struct { result pgtype.Bool }{ - {result: pgtype.Bool(false)}, - {result: pgtype.Bool(true)}, + {result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {result: pgtype.Bool{Bool: false, Status: pgtype.Null}}, } - ps.FieldDescriptions[0].FormatCode = pgx.TextFormatCode - for i, tt := range tests { - inputBuf := &bytes.Buffer{} - err = tt.result.EncodeText(inputBuf) - if err != nil { - t.Errorf("TextFormat %d: %v", i, err) - } - - var s string - err := conn.QueryRow("test", string(inputBuf.Bytes()[4:])).Scan(&s) - if err != nil { - t.Errorf("TextFormat %d: %v", i, err) - } - - outputBuf := &bytes.Buffer{} - pgio.WriteInt32(outputBuf, int32(len(s))) - outputBuf.WriteString(s) - var r pgtype.Bool - err = r.DecodeText(outputBuf) - if err != nil { - t.Errorf("TextFormat %d: %v", i, err) - } - - if r != tt.result { - t.Errorf("TextFormat %d: expected %v, got %v", i, tt.result, r) - } + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, } - ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode - for i, tt := range tests { - inputBuf := &bytes.Buffer{} - err = tt.result.EncodeBinary(inputBuf) - if err != nil { - t.Errorf("BinaryFormat %d: %v", i, err) - } + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode - var buf []byte - err := conn.QueryRow("test", inputBuf.Bytes()[4:]).Scan(&buf) - if err != nil { - t.Errorf("BinaryFormat %d: %v", i, err) - } + for i, tt := range tests { + var r pgtype.Bool + err := conn.QueryRow("test", tt.result).Scan(&r) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } - outputBuf := &bytes.Buffer{} - pgio.WriteInt32(outputBuf, int32(len(buf))) - outputBuf.Write(buf) - var r pgtype.Bool - err = r.DecodeBinary(outputBuf) - if err != nil { - t.Errorf("BinaryFormat %d: %v", i, err) - } - - if r != tt.result { - t.Errorf("BinaryFormat %d: expected %v, got %v", i, tt.result, r) + if r != tt.result { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, tt.result, r) + } } } } @@ -89,12 +56,12 @@ func TestBoolConvertFrom(t *testing.T) { source interface{} result pgtype.Bool }{ - {source: true, result: pgtype.Bool(true)}, - {source: false, result: pgtype.Bool(false)}, - {source: "true", result: pgtype.Bool(true)}, - {source: "false", result: pgtype.Bool(false)}, - {source: "t", result: pgtype.Bool(true)}, - {source: "f", result: pgtype.Bool(false)}, + {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: false, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "f", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, } for i, tt := range successfulTests { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index a89ea2bb..84e35b21 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1,7 +1,18 @@ package pgtype import ( + "errors" "io" + + "github.com/jackc/pgx/pgio" +) + +type Status byte + +const ( + Undefined Status = iota + Null + Present ) type Value interface { @@ -24,3 +35,16 @@ type BinaryEncoder interface { type TextEncoder interface { EncodeText(w io.Writer) error } + +var errUndefined = errors.New("cannot encode status undefined") + +func encodeNotPresent(w io.Writer, status Status) (done bool, err error) { + switch status { + case Undefined: + return true, errUndefined + case Null: + _, err = pgio.WriteInt32(w, -1) + return true, err + } + return false, nil +} diff --git a/query.go b/query.go index 3bf1a87c..db99cddd 100644 --- a/query.go +++ b/query.go @@ -6,6 +6,8 @@ import ( "fmt" "golang.org/x/net/context" "time" + + "github.com/jackc/pgx/pgtype" ) // Row is a convenience wrapper over Rows that is returned by QueryRow. @@ -228,6 +230,18 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } + } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { + vr.err = errRewoundLen + err = s.DecodeBinary(&valueReader2{vr}) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode { + vr.err = errRewoundLen + err = s.DecodeText(&valueReader2{vr}) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } } else if s, ok := d.(sql.Scanner); ok { var val interface{} if 0 <= vr.Len() { diff --git a/values.go b/values.go index 49daa3b4..0dea734a 100644 --- a/values.go +++ b/values.go @@ -1026,6 +1026,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { switch arg := arg.(type) { case Encoder: return arg.Encode(wbuf, oid) + case pgtype.BinaryEncoder: + return arg.EncodeBinary(wbuf) + case pgtype.TextEncoder: + return arg.EncodeText(wbuf) case driver.Valuer: v, err := arg.Value() if err != nil { @@ -1398,21 +1402,11 @@ func Decode(vr *ValueReader, d interface{}) error { } func decodeBool(vr *ValueReader) bool { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into bool")) - return false - } - if vr.Type().DataType != BoolOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) return false } - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return false - } - vr.err = errRewoundLen var b pgtype.Bool @@ -1432,7 +1426,12 @@ func decodeBool(vr *ValueReader) bool { return false } - return bool(b) + if b.Status != pgtype.Present { + vr.Fatal(fmt.Errorf("Cannot decode null into bool")) + return false + } + + return b.Bool } func encodeBool(w *WriteBuf, oid OID, value bool) error { @@ -1440,7 +1439,7 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error { return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) } - b := pgtype.Bool(value) + b := pgtype.Bool{Bool: value, Status: pgtype.Present} return b.EncodeBinary(w) }