diff --git a/msg_reader.go b/msg_reader.go index 2bcd2d51..59617b73 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -137,6 +137,34 @@ func (r *msgReader) readInt32() int32 { return n } +func (r *msgReader) readUint16() uint16 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 2 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(2) + if err != nil { + r.fatal(err) + return 0 + } + + n := uint16(binary.BigEndian.Uint16(b)) + + r.reader.Discard(2) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + func (r *msgReader) readUint32() uint32 { if r.err != nil { return 0 diff --git a/values.go b/values.go index 04c66317..51adfa1c 100644 --- a/values.go +++ b/values.go @@ -344,7 +344,7 @@ func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { // // it is the data type of the xmin and xmax hidden system columns. // -// It is currently implemented as an unsigned for byte integer. +// It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. type Xid uint32 @@ -396,7 +396,7 @@ func (n NullXid) Encode(w *WriteBuf, oid Oid) error { // // it is the data type of the cmin and cmax hidden system columns. // -// It is currently implemented as an unsigned for byte integer. +// It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/c.h as CommandId // in the PostgreSQL sources. type Cid uint32 @@ -440,6 +440,61 @@ func (n NullCid) Encode(w *WriteBuf, oid Oid) error { return encodeCid(w, oid, n.Cid) } +// Tid is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type Tid struct { + BlockNumber uint16 + OffsetNumber uint16 +} + +// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullTid struct { + Tid Tid + Valid bool // Valid is true if Int32 is not NULL +} + +func (n *NullTid) Scan(vr *ValueReader) error { + if vr.Type().DataType != TidOid { + return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Tid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Tid = decodeTid(vr) + return vr.Err() +} + +func (n NullTid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullTid) Encode(w *WriteBuf, oid Oid) error { + if oid != TidOid { + return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeTid(w, oid, n.Tid) +} + // NullInt64 represents an bigint that may be null. NullInt64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -934,6 +989,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeOid(vr) case *Xid: *v = decodeXid(vr) + case *Tid: + *v = decodeTid(vr) case *Cid: *v = decodeCid(vr) case *string: @@ -1546,6 +1603,49 @@ func encodeCid(w *WriteBuf, oid Oid, value Cid) error { return nil } +func decodeTid(vr *ValueReader) Tid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Tid")) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + if vr.Type().DataType != TidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + // Unlikely Tid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: // XXX: not done yet src/backend/utils/adt/tid.c for hints; s already contains the string, so we just have to parse out (uint16,uint16) + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + } + return Tid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + return Tid{BlockNumber: vr.ReadUint16(), OffsetNumber: vr.ReadUint16()} + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } +} + +func encodeTid(w *WriteBuf, oid Oid, value Tid) error { + if oid != TidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid) + } + + w.WriteInt32(4) + w.WriteUint32(uint32(value)) + + return nil +} + func decodeFloat4(vr *ValueReader) float32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float32"))