diff --git a/conn.go b/conn.go index d97942aa..7bb26677 100644 --- a/conn.go +++ b/conn.go @@ -270,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.oidPgtypeValues = map[OID]pgtype.Value{ BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, + CIDOID: &pgtype.CID{}, CidrArrayOID: &pgtype.CidrArray{}, CidrOID: &pgtype.Inet{}, DateArrayOID: &pgtype.DateArray{}, diff --git a/pgtype/cid.go b/pgtype/cid.go new file mode 100644 index 00000000..9f8c87d8 --- /dev/null +++ b/pgtype/cid.go @@ -0,0 +1,141 @@ +package pgtype + +import ( + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +// CID is PostgreSQL's Command Identifier type. +// +// When one does +// +// select cmin, cmax, * from some_table; +// +// it is the data type of the cmin and cmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/c.h as CommandId +// in the PostgreSQL sources. +type CID struct { + Uint uint32 + Status Status +} + +// ConvertFrom converts from src to dst. Note that as CID is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. +func (dst *CID) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case CID: + *dst = value + case uint32: + *dst = CID{Uint: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to CID", value) + } + + return nil +} + +// AssignTo assigns from src to dst. Note that as CID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *CID) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *uint32: + if src.Status == Present { + *v = src.Uint + } else { + return fmt.Errorf("cannot assign %v into %T", src, dst) + } + case **uint32: + if src.Status == Present { + n := src.Uint + *v = &n + } else { + *v = nil + } + } + + return nil +} + +func (dst *CID) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = CID{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseUint(string(buf), 10, 32) + if err != nil { + return err + } + + *dst = CID{Uint: uint32(n), Status: Present} + return nil +} + +func (dst *CID) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = CID{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for cid: %v", size) + } + + n, err := pgio.ReadUint32(r) + if err != nil { + return err + } + + *dst = CID{Uint: n, Status: Present} + return nil +} + +func (src CID) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := strconv.FormatUint(uint64(src.Uint), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (src CID) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteUint32(w, src.Uint) + return err +} diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go new file mode 100644 index 00000000..72f5dfea --- /dev/null +++ b/pgtype/cid_test.go @@ -0,0 +1,94 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "cid", []interface{}{ + pgtype.CID{Uint: 42, Status: pgtype.Present}, + pgtype.CID{Status: pgtype.Null}, + }) +} + +func TestCIDConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CID + }{ + {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.CID + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.CID + dst interface{} + }{ + {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5722c8ab..1200bf12 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -20,7 +20,7 @@ const ( OIDOID = 26 TidOID = 27 XidOID = 28 - CidOID = 29 + CIDOID = 29 JSONOID = 114 CidrOID = 650 CidrArrayOID = 651 diff --git a/values.go b/values.go index f050726e..b143ac1a 100644 --- a/values.go +++ b/values.go @@ -29,7 +29,7 @@ const ( OIDOID = 26 TidOID = 27 XidOID = 28 - CidOID = 29 + CIDOID = 29 JSONOID = 114 CidrOID = 650 CidrArrayOID = 651 @@ -645,58 +645,6 @@ func (n NullXid) Encode(w *WriteBuf, oid OID) error { return encodeXid(w, oid, n.Xid) } -// Cid is PostgreSQL's Command Identifier type. -// -// When one does -// -// select cmin, cmax, * from some_table; -// -// it is the data type of the cmin and cmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/c.h as CommandId -// in the PostgreSQL sources. -type Cid uint32 - -// NullCid represents a Command Identifier (Cid) that may be null. NullCid implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullCid struct { - Cid Cid - Valid bool // Valid is true if Cid is not NULL -} - -func (n *NullCid) Scan(vr *ValueReader) error { - if vr.Type().DataType != CidOID { - return SerializationError(fmt.Sprintf("NullCid.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Cid, n.Valid = 0, false - return nil - } - n.Valid = true - n.Cid = decodeCid(vr) - return vr.Err() -} - -func (n NullCid) FormatCode() int16 { return BinaryFormatCode } - -func (n NullCid) Encode(w *WriteBuf, oid OID) error { - if oid != CidOID { - return SerializationError(fmt.Sprintf("NullCid.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeCid(w, oid, n.Cid) -} - // Tid is PostgreSQL's Tuple Identifier type. // // When one does @@ -1087,8 +1035,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return encodeOID(wbuf, oid, arg) case Xid: return encodeXid(wbuf, oid, arg) - case Cid: - return encodeCid(wbuf, oid, arg) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -1170,8 +1116,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeXid(vr) case *Tid: *v = decodeTid(vr) - case *Cid: - *v = decodeCid(vr) case *string: *v = decodeText(vr) case *[]AclItem: @@ -1493,49 +1437,6 @@ func encodeXid(w *WriteBuf, oid OID, value Xid) error { return nil } -func decodeCid(vr *ValueReader) Cid { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Cid")) - return Cid(0) - } - - if vr.Type().DataType != CidOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Cid", vr.Type().DataType))) - return Cid(0) - } - - // Unlikely Cid will ever go over the wire as text format, but who knows? - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseUint(s, 10, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) - } - return Cid(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) - return Cid(0) - } - return Cid(vr.ReadUint32()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Cid(0) - } -} - -func encodeCid(w *WriteBuf, oid OID, value Cid) error { - if oid != CidOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Cid", oid) - } - - w.WriteInt32(4) - w.WriteUint32(uint32(value)) - - return nil -} - // Note that we do not match negative numbers, because neither the // BlockNumber nor OffsetNumber of a Tid can be negative. var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) diff --git a/values_test.go b/values_test.go index d6ce705a..ae3ecc84 100644 --- a/values_test.go +++ b/values_test.go @@ -573,7 +573,6 @@ func TestNullX(t *testing.T) { n pgx.NullName oid pgx.NullOID xid pgx.NullXid - cid pgx.NullCid tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 @@ -611,9 +610,6 @@ func TestNullX(t *testing.T) { {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, - {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, - {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, - {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}},