From 6694e0e61876db7827791ef1af197847dee9b2e3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 16:48:37 -0600 Subject: [PATCH] Move Tid to pgtype --- conn.go | 3 +- pgtype/pgtype.go | 9 +++- pgtype/tid.go | 104 ++++++++++++++++++++++++++++++++++++ pgtype/tid_test.go | 15 ++++++ query.go | 8 ++- values.go | 130 +++------------------------------------------ values_test.go | 4 -- 7 files changed, 142 insertions(+), 131 deletions(-) create mode 100644 pgtype/tid.go create mode 100644 pgtype/tid_test.go diff --git a/conn.go b/conn.go index f55dd82a..c2cc5d3c 100644 --- a/conn.go +++ b/conn.go @@ -268,8 +268,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.closedChan = make(chan error) c.oidPgtypeValues = map[OID]pgtype.Value{ - ACLItemOID: &pgtype.ACLItem{}, ACLItemArrayOID: &pgtype.ACLItemArray{}, + ACLItemOID: &pgtype.ACLItem{}, BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, ByteaArrayOID: &pgtype.ByteaArray{}, @@ -296,6 +296,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl OIDOID: &pgtype.OID{}, TextArrayOID: &pgtype.TextArray{}, TextOID: &pgtype.Text{}, + TIDOID: &pgtype.TID{}, TimestampArrayOID: &pgtype.TimestampArray{}, TimestampOID: &pgtype.Timestamp{}, TimestampTzArrayOID: &pgtype.TimestamptzArray{}, diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d72217ac..8c67c630 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -16,7 +16,7 @@ const ( Int4OID = 23 TextOID = 25 OIDOID = 26 - TidOID = 27 + TIDOID = 27 XIDOID = 28 CIDOID = 29 JSONOID = 114 @@ -66,8 +66,13 @@ const ( NegativeInfinity InfinityModifier = -Infinity ) -type Value interface { +type Value interface{} + +type ConverterFrom interface { ConvertFrom(src interface{}) error +} + +type AssignerTo interface { AssignTo(dst interface{}) error } diff --git a/pgtype/tid.go b/pgtype/tid.go new file mode 100644 index 00000000..804cced2 --- /dev/null +++ b/pgtype/tid.go @@ -0,0 +1,104 @@ +package pgtype + +import ( + "encoding/binary" + "fmt" + "io" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +// TID is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type TID struct { + BlockNumber uint32 + OffsetNumber uint16 + Status Status +} + +func (dst *TID) DecodeText(src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for tid") + } + + blockNumber, err := strconv.ParseUint(parts[0], 10, 32) + if err != nil { + return err + } + + offsetNumber, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return err + } + + *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + return nil +} + +func (dst *TID) DecodeBinary(src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + *dst = TID{ + BlockNumber: binary.BigEndian.Uint32(src), + OffsetNumber: binary.BigEndian.Uint16(src[4:]), + Status: Present, + } + return nil +} + +func (src TID) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)) + return false, err +} + +func (src TID) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteUint32(w, src.BlockNumber) + if err != nil { + return false, err + } + + _, err = pgio.WriteUint16(w, src.OffsetNumber) + return false, err +} diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go new file mode 100644 index 00000000..a5aab8a3 --- /dev/null +++ b/pgtype/tid_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestTIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "tid", []interface{}{ + pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + pgtype.TID{Status: pgtype.Null}, + }) +} diff --git a/query.go b/query.go index d1191c7c..5730f1c6 100644 --- a/query.go +++ b/query.go @@ -299,8 +299,12 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) } - if err := pgVal.AssignTo(d); err != nil { - vr.Fatal(err) + if assignerTo, ok := pgVal.(pgtype.AssignerTo); ok { + if err := assignerTo.AssignTo(d); err != nil { + vr.Fatal(err) + } + } else { + vr.Fatal(fmt.Errorf("cannot assign %T", pgVal)) } } else { if err := Decode(vr, d); err != nil { diff --git a/values.go b/values.go index f34735d4..72f836bb 100644 --- a/values.go +++ b/values.go @@ -9,7 +9,6 @@ import ( "io" "math" "reflect" - "regexp" "strconv" "strings" "time" @@ -29,7 +28,7 @@ const ( Int4OID = 23 TextOID = 25 OIDOID = 26 - TidOID = 27 + TIDOID = 27 XIDOID = 28 CIDOID = 29 JSONOID = 114 @@ -444,61 +443,6 @@ func (src OID) EncodeBinary(w io.Writer) (bool, error) { return false, err } -// Tid is PostgreSQL's Tuple Identifier type. -// -// When one does -// -// select ctid, * from some_table; -// -// it is the data type of the ctid hidden system column. -// -// It is currently implemented as a pair unsigned two byte integers. -// Its conversion functions can be found in src/backend/utils/adt/tid.c -// in the PostgreSQL sources. -type Tid struct { - BlockNumber uint32 - OffsetNumber uint16 -} - -// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullTid struct { - Tid Tid - Valid bool // Valid is true if Tid is not NULL -} - -func (n *NullTid) Scan(vr *ValueReader) error { - if vr.Type().DataType != TidOID { - return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Tid, n.Valid = Tid{BlockNumber: 0, OffsetNumber: 0}, false - return nil - } - n.Valid = true - n.Tid = decodeTid(vr) - return vr.Err() -} - -func (n NullTid) FormatCode() int16 { return BinaryFormatCode } - -func (n NullTid) Encode(w *WriteBuf, oid OID) error { - if oid != TidOID { - return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeTid(w, oid, n.Tid) -} - // NullInt64 represents an bigint that may be null. NullInt64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -836,9 +780,13 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { } if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok { - err := value.ConvertFrom(arg) - if err != nil { - return err + if converterFrom, ok := value.(pgtype.ConverterFrom); ok { + err := converterFrom.ConvertFrom(arg) + if err != nil { + return err + } + } else { + return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } buf := &bytes.Buffer{} @@ -906,8 +854,6 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { // decoding to the built-in functionality. func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { - case *Tid: - *v = decodeTid(vr) case *string: *v = decodeText(vr) case *[]interface{}: @@ -1092,66 +1038,6 @@ func decodeInt4(vr *ValueReader) int32 { return n.Int } -// Note that we do not match negative numbers, because neither the -// BlockNumber nor OffsetNumber of a Tid can be negative. -var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) - -func decodeTid(vr *ValueReader) Tid { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Tid")) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - - if vr.Type().DataType != TidOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - - // Unlikely Tid will ever go over the wire as text format, but who knows? - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - - match := tidRegexp.FindStringSubmatch(s) - if match == nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - - blockNumber, err := strconv.ParseUint(s, 10, 16) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid BlockNumber part of a Tid: %v", s))) - } - - offsetNumber, err := strconv.ParseUint(s, 10, 16) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s))) - } - return Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber)} - case BinaryFormatCode: - if vr.Len() != 6 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - return Tid{BlockNumber: vr.ReadUint32(), OffsetNumber: vr.ReadUint16()} - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } -} - -func encodeTid(w *WriteBuf, oid OID, value Tid) error { - if oid != TidOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid) - } - - w.WriteInt32(6) - w.WriteUint32(value.BlockNumber) - w.WriteUint16(value.OffsetNumber) - - return nil -} - func decodeFloat4(vr *ValueReader) float32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float32")) diff --git a/values_test.go b/values_test.go index 9cf2b219..eb570fe6 100644 --- a/values_test.go +++ b/values_test.go @@ -568,7 +568,6 @@ func TestNullX(t *testing.T) { s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 - tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 f64 pgx.NullFloat64 @@ -590,9 +589,6 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, - {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, - {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}},