diff --git a/conn.go b/conn.go index 0c86d169..3414d7cf 100644 --- a/conn.go +++ b/conn.go @@ -31,6 +31,20 @@ const ( connStatusBusy ) +// minimalConnInfo has just enough static type information to establish the +// connection and retrieve the type data. +var minimalConnInfo *pgtype.ConnInfo + +func init() { + minimalConnInfo = pgtype.NewConnInfo() + minimalConnInfo.InitializeDataTypes(map[string]pgtype.Oid{ + "int4": Int4Oid, + "name": NameOid, + "oid": OidOid, + "text": TextOid, + }) +} + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) @@ -74,11 +88,10 @@ type Conn struct { lastActivityTime time.Time // the last time the connection was used wbuf [1024]byte writeBuf WriteBuf - pid int32 // backend pid - secretKey int32 // key to use to send a cancel query message to the server - RuntimeParams map[string]string // parameters that have been reported by the server - PgTypes map[pgtype.Oid]PgType // oids to PgTypes - config ConnConfig // config used when establishing this connection + pid int32 // backend pid + secretKey int32 // key to use to send a cancel query message to the server + RuntimeParams map[string]string // parameters that have been reported by the server + config ConnConfig // config used when establishing this connection txStatus byte preparedStatements map[string]*PreparedStatement channels map[string]struct{} @@ -102,7 +115,7 @@ type Conn struct { doneChan chan struct{} closedChan chan error - oidPgtypeValues map[pgtype.Oid]pgtype.Value + ConnInfo *pgtype.ConnInfo } // PreparedStatement is a description of a prepared statement @@ -125,12 +138,6 @@ type Notification struct { Payload string } -// PgType is information about PostgreSQL type and how to encode and decode it -type PgType struct { - Name string // name of type e.g. int4, text, date - DefaultFormat int16 // default format (text or binary) this type will be requested in -} - // CommandTag is the result of an Exec function type CommandTag string @@ -190,20 +197,14 @@ func (e ProtocolError) Error() string { // config.Host must be specified. config.User will default to the OS user name. // Other config fields are optional. func Connect(config ConnConfig) (c *Conn, err error) { - return connect(config, nil) + return connect(config, minimalConnInfo) } -func connect(config ConnConfig, pgTypes map[pgtype.Oid]PgType) (c *Conn, err error) { +func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { c = new(Conn) c.config = config - - if pgTypes != nil { - c.PgTypes = make(map[pgtype.Oid]PgType, len(pgTypes)) - for k, v := range pgTypes { - c.PgTypes[k] = v - } - } + c.ConnInfo = connInfo if c.config.LogLevel != 0 { c.logLevel = c.config.LogLevel @@ -289,8 +290,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.loadStaticOidPgtypeValues() - c.mr.cr = chunkreader.NewChunkReader(c.conn) msg := newStartupMessage() @@ -344,13 +343,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return nil } - if c.PgTypes == nil { - err = c.loadPgTypes() + if c.ConnInfo == minimalConnInfo { + err = c.initConnInfo() if err != nil { return err } } - c.loadDynamicOidPgtypeValues() return nil default: @@ -361,88 +359,37 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } -func (c *Conn) loadPgTypes() error { +func (c *Conn) initConnInfo() error { + nameOids := make(map[string]pgtype.Oid, 256) + rows, err := c.Query(`select t.oid, t.typname from pg_type t left join pg_type base_type on t.typelem=base_type.oid where ( - t.typtype='b' - and (base_type.oid is null or base_type.typtype='b') - ) - or t.typname in('record');`) + t.typtype in('b', 'p') + and (base_type.oid is null or base_type.typtype in('b', 'p')) + )`) if err != nil { return err } - c.PgTypes = make(map[pgtype.Oid]PgType, 128) - for rows.Next() { - var oid uint32 - var t PgType + var oid pgtype.Oid + var name pgtype.Text + if err := rows.Scan(&oid, &name); err != nil { + return err + } - rows.Scan(&oid, &t.Name) - - // The zero value is text format so we ignore any types without a default type format - t.DefaultFormat, _ = DefaultTypeFormats[t.Name] - - c.PgTypes[pgtype.Oid(oid)] = t + nameOids[name.String] = oid } - return rows.Err() -} - -func (c *Conn) loadStaticOidPgtypeValues() { - c.oidPgtypeValues = map[pgtype.Oid]pgtype.Value{ - AclitemArrayOid: &pgtype.AclitemArray{}, - AclitemOid: &pgtype.Aclitem{}, - BoolArrayOid: &pgtype.BoolArray{}, - BoolOid: &pgtype.Bool{}, - ByteaArrayOid: &pgtype.ByteaArray{}, - ByteaOid: &pgtype.Bytea{}, - CharOid: &pgtype.QChar{}, - CidOid: &pgtype.Cid{}, - CidrArrayOid: &pgtype.CidrArray{}, - CidrOid: &pgtype.Inet{}, - DateArrayOid: &pgtype.DateArray{}, - DateOid: &pgtype.Date{}, - Float4ArrayOid: &pgtype.Float4Array{}, - Float4Oid: &pgtype.Float4{}, - Float8ArrayOid: &pgtype.Float8Array{}, - Float8Oid: &pgtype.Float8{}, - InetArrayOid: &pgtype.InetArray{}, - InetOid: &pgtype.Inet{}, - Int2ArrayOid: &pgtype.Int2Array{}, - Int2Oid: &pgtype.Int2{}, - Int4ArrayOid: &pgtype.Int4Array{}, - Int4Oid: &pgtype.Int4{}, - Int8ArrayOid: &pgtype.Int8Array{}, - Int8Oid: &pgtype.Int8{}, - JsonbOid: &pgtype.Jsonb{}, - JsonOid: &pgtype.Json{}, - NameOid: &pgtype.Name{}, - OidOid: &pgtype.OidValue{}, - TextArrayOid: &pgtype.TextArray{}, - TextOid: &pgtype.Text{}, - TidOid: &pgtype.Tid{}, - TimestampArrayOid: &pgtype.TimestampArray{}, - TimestampOid: &pgtype.Timestamp{}, - TimestampTzArrayOid: &pgtype.TimestamptzArray{}, - TimestampTzOid: &pgtype.Timestamptz{}, - VarcharArrayOid: &pgtype.VarcharArray{}, - VarcharOid: &pgtype.Text{}, - XidOid: &pgtype.Xid{}, - } -} - -func (c *Conn) loadDynamicOidPgtypeValues() { - nameOids := make(map[string]pgtype.Oid, len(c.PgTypes)) - for k, v := range c.PgTypes { - nameOids[v.Name] = k + if rows.Err() != nil { + return rows.Err() } - if oid, ok := nameOids["hstore"]; ok { - c.oidPgtypeValues[oid] = &pgtype.Hstore{} - } + c.ConnInfo = pgtype.NewConnInfo() + c.ConnInfo.InitializeDataTypes(nameOids) + return nil } // PID returns the backend PID for this connection. @@ -805,9 +752,16 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared case rowDescription: ps.FieldDescriptions = c.rxRowDescription(r) for i := range ps.FieldDescriptions { - t, _ := c.PgTypes[ps.FieldDescriptions[i].DataType] - ps.FieldDescriptions[i].DataTypeName = t.Name - ps.FieldDescriptions[i].FormatCode = t.DefaultFormat + if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok { + ps.FieldDescriptions[i].DataTypeName = dt.Name + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + ps.FieldDescriptions[i].FormatCode = BinaryFormatCode + } else { + ps.FieldDescriptions[i].FormatCode = TextFormatCode + } + } else { + return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) + } } case readyForQuery: c.rxReadyForQuery(r) diff --git a/conn_pool.go b/conn_pool.go index 653ed0ba..44559ea8 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -30,7 +30,7 @@ type ConnPool struct { closed bool preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration - pgTypes map[pgtype.Oid]PgType + connInfo *pgtype.ConnInfo txAfterClose func(tx *Tx) rowsAfterClose func(rows *Rows) } @@ -49,6 +49,7 @@ var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool") func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p = new(ConnPool) p.config = config.ConnConfig + p.connInfo = minimalConnInfo p.maxConnections = config.MaxConnections if p.maxConnections == 0 { p.maxConnections = 5 @@ -95,6 +96,7 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { } p.allConnections = append(p.allConnections, c) p.availableConnections = append(p.availableConnections, c) + p.connInfo = c.ConnInfo.DeepCopy() return } @@ -294,7 +296,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.pgTypes) + c, err := connect(p.config, p.connInfo) if err != nil { return nil, err } @@ -329,8 +331,6 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { // afterConnectionCreated executes (if it is) afterConnect() callback and prepares // all the known statements for the new connection. func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { - p.pgTypes = c.PgTypes - if p.afterConnect != nil { err := p.afterConnect(c) if err != nil { diff --git a/copy_from_test.go b/copy_from_test.go index e17575de..6df4ebb1 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" ) func TestConnCopyFromSmall(t *testing.T) { @@ -126,8 +125,8 @@ func TestConnCopyFromJSON(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { + for _, typeName := range []string{"json", "jsonb"} { + if _, ok := conn.ConnInfo.DataTypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 71110f85..1c21c7e6 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -18,7 +18,7 @@ type Point struct { Status pgtype.Status } -func (dst *Point) DecodeText(src []byte) error { +func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { *dst = Point{Status: pgtype.Null} return nil @@ -44,7 +44,7 @@ func (dst *Point) DecodeText(src []byte) error { return nil } -func (src Point) EncodeText(w io.Writer) (bool, error) { +func (src Point) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { switch src.Status { case pgtype.Null: return true, nil diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index b8a1549e..f9faab20 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -90,7 +90,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error { return nil } -func (dst *Aclitem) DecodeText(src []byte) error { +func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Aclitem{Status: Null} return nil @@ -100,7 +100,7 @@ func (dst *Aclitem) DecodeText(src []byte) error { return nil } -func (src Aclitem) EncodeText(w io.Writer) (bool, error) { +func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 5e3647b7..f02d339e 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -82,7 +82,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { return nil } -func (dst *AclitemArray) DecodeText(src []byte) error { +func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = AclitemArray{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,7 +118,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { return nil } -func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { +func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -165,7 +165,7 @@ func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/array.go b/pgtype/array.go index dff0fe81..9561afe5 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -27,7 +27,7 @@ type ArrayDimension struct { LowerBound int32 } -func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { +func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) } @@ -60,7 +60,7 @@ func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { return rp, nil } -func (src *ArrayHeader) EncodeBinary(w io.Writer) error { +func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, w io.Writer) error { _, err := pgio.WriteInt32(w, int32(len(src.Dimensions))) if err != nil { return err diff --git a/pgtype/bool.go b/pgtype/bool.go index a8e9b8e1..87316381 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -79,7 +79,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } -func (dst *Bool) DecodeText(src []byte) error { +func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Bool) DecodeText(src []byte) error { return nil } -func (dst *Bool) DecodeBinary(src []byte) error { +func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Bool) DecodeBinary(src []byte) error { return nil } -func (src Bool) EncodeText(w io.Writer) (bool, error) { +func (src Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -126,7 +126,7 @@ func (src Bool) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bool) EncodeBinary(w io.Writer) (bool, error) { +func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 4c5fc563..1cb46cf6 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -83,7 +83,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return nil } -func (dst *BoolArray) DecodeText(src []byte) error { +func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *BoolArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *BoolArray) DecodeText(src []byte) error { return nil } -func (dst *BoolArray) DecodeBinary(src []byte) error { +func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { return nil } -func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { +func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *BoolArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, BoolOid) +func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, BoolOid) } -func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 5df05360..dc1e9c07 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -78,7 +78,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { // DecodeText only supports the hex format. This has been the default since // PostgreSQL 9.0. -func (dst *Bytea) DecodeText(src []byte) error { +func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -98,7 +98,7 @@ func (dst *Bytea) DecodeText(src []byte) error { return nil } -func (dst *Bytea) DecodeBinary(src []byte) error { +func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Bytea) DecodeBinary(src []byte) error { return nil } -func (src Bytea) EncodeText(w io.Writer) (bool, error) { +func (src Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -128,7 +128,7 @@ func (src Bytea) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bytea) EncodeBinary(w io.Writer) (bool, error) { +func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index c6f676a4..30405509 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -83,7 +83,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return nil } -func (dst *ByteaArray) DecodeText(src []byte) error { +func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *ByteaArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *ByteaArray) DecodeText(src []byte) error { return nil } -func (dst *ByteaArray) DecodeBinary(src []byte) error { +func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { return nil } -func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { +func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, ByteaOid) +func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, ByteaOid) } -func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/cid.go b/pgtype/cid.go index 20957f36..d86e8063 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -34,18 +34,18 @@ func (src *Cid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Cid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Cid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Cid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Cid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Cid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/pgtype/cidr.go b/pgtype/cidr.go new file mode 100644 index 00000000..463b279d --- /dev/null +++ b/pgtype/cidr.go @@ -0,0 +1,35 @@ +package pgtype + +import ( + "io" +) + +type Cidr Inet + +func (dst *Cidr) Set(src interface{}) error { + return (*Inet)(dst).Set(src) +} + +func (dst *Cidr) Get() interface{} { + return (*Inet)(dst).Get() +} + +func (src *Cidr) AssignTo(dst interface{}) error { + return (*Inet)(src).AssignTo(dst) +} + +func (dst *Cidr) DecodeText(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeText(ci, src) +} + +func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeBinary(ci, src) +} + +func (src Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeText(ci, w) +} + +func (src Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeBinary(ci, w) +} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index c30c53d3..32d2e7bf 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -1,35 +1,328 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + "net" + + "github.com/jackc/pgx/pgio" ) -type CidrArray InetArray +type CidrArray struct { + Elements []Cidr + Dimensions []ArrayDimension + Status Status +} func (dst *CidrArray) Set(src interface{}) error { - return (*InetArray)(dst).Set(src) + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Cidr", value) + } + + return nil } func (dst *CidrArray) Get() interface{} { - return (*InetArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *CidrArray) AssignTo(dst interface{}) error { - return (*InetArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]*net.IPNet: + if src.Status == Present { + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + case *[]net.IP: + if src.Status == Present { + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *CidrArray) DecodeText(src []byte) error { - return (*InetArray)(dst).DecodeText(src) +func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Cidr + + if len(uta.Elements) > 0 { + elements = make([]Cidr, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Cidr + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = CidrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *CidrArray) DecodeBinary(src []byte) error { - return (*InetArray)(dst).DecodeBinary(src) +func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = CidrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Cidr, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = CidrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *CidrArray) EncodeText(w io.Writer) (bool, error) { - return (*InetArray)(src).EncodeText(w) +func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *CidrArray) EncodeBinary(w io.Writer) (bool, error) { - return (*InetArray)(src).encodeBinary(w, CidrOid) +func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, CidrOid) +} + +func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go new file mode 100644 index 00000000..ec105914 --- /dev/null +++ b/pgtype/cidr_array_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCidrArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "cidr[]", []interface{}{ + &pgtype.CidrArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{Status: pgtype.Null}, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + pgtype.Cidr{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestCidrArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CidrArray + }{ + { + source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.CidrArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCidrArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.CidrArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go new file mode 100644 index 00000000..969d6542 --- /dev/null +++ b/pgtype/database_sql.go @@ -0,0 +1,66 @@ +package pgtype + +import ( + "bytes" + "errors" +) + +func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { + switch src := src.(type) { + case *Bool: + return src.Bool, nil + case *Bytea: + return src.Bytes, nil + case *Date: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Float4: + return float64(src.Float), nil + case *Float8: + return src.Float, nil + case *GenericBinary: + return src.Bytes, nil + case *GenericText: + return src.String, nil + case *Int2: + return int64(src.Int), nil + case *Int4: + return int64(src.Int), nil + case *Int8: + return int64(src.Int), nil + case *Text: + return src.String, nil + case *Timestamp: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Timestamptz: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Unknown: + return src.String, nil + case *Varchar: + return src.String, nil + } + + buf := &bytes.Buffer{} + if textEncoder, ok := src.(TextEncoder); ok { + _, err := textEncoder.EncodeText(ci, buf) + if err != nil { + return nil, err + } + return buf.String(), nil + } + + if binaryEncoder, ok := src.(BinaryEncoder); ok { + _, err := binaryEncoder.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + return buf.Bytes(), nil + } + + return nil, errors.New("cannot convert to database/sql compatible value") +} diff --git a/pgtype/date.go b/pgtype/date.go index d0481637..b6cc8329 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -38,6 +38,9 @@ func (dst *Date) Set(src interface{}) error { func (dst *Date) Get() interface{} { switch dst.Status { case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } return dst.Time case Null: return nil @@ -76,7 +79,7 @@ func (src *Date) AssignTo(dst interface{}) error { return nil } -func (dst *Date) DecodeText(src []byte) error { +func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -100,7 +103,7 @@ func (dst *Date) DecodeText(src []byte) error { return nil } -func (dst *Date) DecodeBinary(src []byte) error { +func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -125,7 +128,7 @@ func (dst *Date) DecodeBinary(src []byte) error { return nil } -func (src Date) EncodeText(w io.Writer) (bool, error) { +func (src Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -148,7 +151,7 @@ func (src Date) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Date) EncodeBinary(w io.Writer) (bool, error) { +func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/date_array.go b/pgtype/date_array.go index 7f602d83..ba68d561 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -84,7 +84,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return nil } -func (dst *DateArray) DecodeText(src []byte) error { +func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *DateArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *DateArray) DecodeText(src []byte) error { return nil } -func (dst *DateArray) DecodeBinary(src []byte) error { +func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { return nil } -func (src *DateArray) EncodeText(w io.Writer) (bool, error) { +func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *DateArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, DateOid) +func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, DateOid) } -func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/float4.go b/pgtype/float4.go index 053af44b..94b7b7a1 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -102,7 +102,7 @@ func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } -func (dst *Float4) DecodeText(src []byte) error { +func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Float4) DecodeText(src []byte) error { return nil } -func (dst *Float4) DecodeBinary(src []byte) error { +func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -133,7 +133,7 @@ func (dst *Float4) DecodeBinary(src []byte) error { return nil } -func (src Float4) EncodeText(w io.Writer) (bool, error) { +func (src Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -145,7 +145,7 @@ func (src Float4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float4) EncodeBinary(w io.Writer) (bool, error) { +func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 0e815e0b..40152bcf 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -83,7 +83,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float4Array) DecodeText(src []byte) error { +func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float4Array) DecodeText(src []byte) error { return nil } -func (dst *Float4Array) DecodeBinary(src []byte) error { +func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { return nil } -func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float4Oid) +func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float4Oid) } -func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/float8.go b/pgtype/float8.go index 635b7a09..dd2d592d 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -92,7 +92,7 @@ func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } -func (dst *Float8) DecodeText(src []byte) error { +func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Float8) DecodeText(src []byte) error { return nil } -func (dst *Float8) DecodeBinary(src []byte) error { +func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -123,7 +123,7 @@ func (dst *Float8) DecodeBinary(src []byte) error { return nil } -func (src Float8) EncodeText(w io.Writer) (bool, error) { +func (src Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -135,7 +135,7 @@ func (src Float8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float8) EncodeBinary(w io.Writer) (bool, error) { +func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 811c5a1f..d0ee0d70 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -83,7 +83,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float8Array) DecodeText(src []byte) error { +func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float8Array) DecodeText(src []byte) error { return nil } -func (dst *Float8Array) DecodeBinary(src []byte) error { +func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { return nil } -func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float8Oid) +func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float8Oid) } -func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go index ac35ea60..aa28bb62 100644 --- a/pgtype/generic_binary.go +++ b/pgtype/generic_binary.go @@ -20,10 +20,10 @@ func (src *GenericBinary) AssignTo(dst interface{}) error { return (*Bytea)(src).AssignTo(dst) } -func (dst *GenericBinary) DecodeBinary(src []byte) error { - return (*Bytea)(dst).DecodeBinary(src) +func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src GenericBinary) EncodeBinary(w io.Writer) (bool, error) { - return (Bytea)(src).EncodeBinary(w) +func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Bytea)(src).EncodeBinary(ci, w) } diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go index 19f41059..bd75e0d0 100644 --- a/pgtype/generic_text.go +++ b/pgtype/generic_text.go @@ -20,10 +20,10 @@ func (src *GenericText) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *GenericText) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (src GenericText) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } diff --git a/pgtype/hstore.go b/pgtype/hstore.go index c48ae6da..d771d6e6 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -70,7 +70,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { return nil } -func (dst *Hstore) DecodeText(src []byte) error { +func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -90,7 +90,7 @@ func (dst *Hstore) DecodeText(src []byte) error { return nil } -func (dst *Hstore) DecodeBinary(src []byte) error { +func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -132,7 +132,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { rp += valueLen var value Text - err := value.DecodeBinary(valueBuf) + err := value.DecodeBinary(ci, valueBuf) if err != nil { return err } @@ -144,7 +144,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { return nil } -func (src Hstore) EncodeText(w io.Writer) (bool, error) { +func (src Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -175,7 +175,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -196,7 +196,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { +func (src Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -220,7 +220,7 @@ func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { return false, err } - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/inet.go b/pgtype/inet.go index 87d675f9..b83bd1c9 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -100,7 +100,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil } -func (dst *Inet) DecodeText(src []byte) error { +func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Inet) DecodeText(src []byte) error { return nil } -func (dst *Inet) DecodeBinary(src []byte) error { +func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -153,7 +153,7 @@ func (dst *Inet) DecodeBinary(src []byte) error { return nil } -func (src Inet) EncodeText(w io.Writer) (bool, error) { +func (src Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Inet) EncodeText(w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(w io.Writer) (bool, error) { +func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 1d1cf3fd..6cad82e7 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -115,7 +115,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return nil } -func (dst *InetArray) DecodeText(src []byte) error { +func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil @@ -137,7 +137,7 @@ func (dst *InetArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -151,14 +151,14 @@ func (dst *InetArray) DecodeText(src []byte) error { return nil } -func (dst *InetArray) DecodeBinary(src []byte) error { +func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -183,7 +183,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -193,7 +193,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { return nil } -func (src *InetArray) EncodeText(w io.Writer) (bool, error) { +func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -240,7 +240,7 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -269,11 +269,11 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *InetArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, InetOid) +func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, InetOid) } -func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -293,7 +293,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -303,7 +303,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/int2.go b/pgtype/int2.go index 62e1bc69..6996cd4f 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -98,7 +98,7 @@ func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int2) DecodeText(src []byte) error { +func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -113,7 +113,7 @@ func (dst *Int2) DecodeText(src []byte) error { return nil } -func (dst *Int2) DecodeBinary(src []byte) error { +func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Int2) DecodeBinary(src []byte) error { return nil } -func (src Int2) EncodeText(w io.Writer) (bool, error) { +func (src Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -140,7 +140,7 @@ func (src Int2) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int2) EncodeBinary(w io.Writer) (bool, error) { +func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 3d06c018..2bf1c237 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -114,7 +114,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int2Array) DecodeText(src []byte) error { +func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int2Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int2Array) DecodeText(src []byte) error { return nil } -func (dst *Int2Array) DecodeBinary(src []byte) error { +func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { return nil } -func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int2Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int2Oid) +func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int2Oid) } -func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/int4.go b/pgtype/int4.go index 8eaf5094..62ee366f 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -89,7 +89,7 @@ func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int4) DecodeText(src []byte) error { +func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *Int4) DecodeText(src []byte) error { return nil } -func (dst *Int4) DecodeBinary(src []byte) error { +func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -119,7 +119,7 @@ func (dst *Int4) DecodeBinary(src []byte) error { return nil } -func (src Int4) EncodeText(w io.Writer) (bool, error) { +func (src Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -131,7 +131,7 @@ func (src Int4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int4) EncodeBinary(w io.Writer) (bool, error) { +func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 5cd91c04..dda88eaf 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -114,7 +114,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int4Array) DecodeText(src []byte) error { +func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int4Array) DecodeText(src []byte) error { return nil } -func (dst *Int4Array) DecodeBinary(src []byte) error { +func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { return nil } -func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int4Oid) +func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int4Oid) } -func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/int8.go b/pgtype/int8.go index 2416500d..7ed54f8e 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -80,7 +80,7 @@ func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int8) DecodeText(src []byte) error { +func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -95,7 +95,7 @@ func (dst *Int8) DecodeText(src []byte) error { return nil } -func (dst *Int8) DecodeBinary(src []byte) error { +func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Int8) DecodeBinary(src []byte) error { return nil } -func (src Int8) EncodeText(w io.Writer) (bool, error) { +func (src Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -123,7 +123,7 @@ func (src Int8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int8) EncodeBinary(w io.Writer) (bool, error) { +func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 5efc0f45..468c126b 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -114,7 +114,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int8Array) DecodeText(src []byte) error { +func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int8Array) DecodeText(src []byte) error { return nil } -func (dst *Int8Array) DecodeBinary(src []byte) error { +func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { return nil } -func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int8Oid) +func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int8Oid) } -func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/json.go b/pgtype/json.go index ecdb3dab..bfffae14 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -84,7 +84,7 @@ func (src *Json) AssignTo(dst interface{}) error { return nil } -func (dst *Json) DecodeText(src []byte) error { +func (dst *Json) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Json{Status: Null} return nil @@ -97,11 +97,11 @@ func (dst *Json) DecodeText(src []byte) error { return nil } -func (dst *Json) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Json) EncodeText(w io.Writer) (bool, error) { +func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -113,6 +113,6 @@ func (src Json) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Json) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 13062e8e..e44f3c41 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -19,11 +19,11 @@ func (src *Jsonb) AssignTo(dst interface{}) error { return (*Json)(src).AssignTo(dst) } -func (dst *Jsonb) DecodeText(src []byte) error { - return (*Json)(dst).DecodeText(src) +func (dst *Jsonb) DecodeText(ci *ConnInfo, src []byte) error { + return (*Json)(dst).DecodeText(ci, src) } -func (dst *Jsonb) DecodeBinary(src []byte) error { +func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Jsonb{Status: Null} return nil @@ -46,11 +46,11 @@ func (dst *Jsonb) DecodeBinary(src []byte) error { } -func (src Jsonb) EncodeText(w io.Writer) (bool, error) { - return (Json)(src).EncodeText(w) +func (src Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Json)(src).EncodeText(ci, w) } -func (src Jsonb) EncodeBinary(w io.Writer) (bool, error) { +func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/name.go b/pgtype/name.go index 9eb12ece..9ebf63d3 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -31,18 +31,18 @@ func (src *Name) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *Name) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (dst *Name) DecodeBinary(src []byte) error { - return (*Text)(dst).DecodeBinary(src) +func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) } -func (src Name) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } -func (src Name) EncodeBinary(w io.Writer) (bool, error) { - return (Text)(src).EncodeBinary(w) +func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) } diff --git a/pgtype/oid.go b/pgtype/oid.go index eab1fbcb..3edd7f3c 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -18,7 +18,7 @@ import ( // allow for NULL Oids use OidValue. type Oid uint32 -func (dst *Oid) DecodeText(src []byte) error { +func (dst *Oid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -32,7 +32,7 @@ func (dst *Oid) DecodeText(src []byte) error { return nil } -func (dst *Oid) DecodeBinary(src []byte) error { +func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -46,12 +46,12 @@ func (dst *Oid) DecodeBinary(src []byte) error { return nil } -func (src Oid) EncodeText(w io.Writer) (bool, error) { +func (src Oid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) return false, err } -func (src Oid) EncodeBinary(w io.Writer) (bool, error) { +func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go index a2b2dcbe..1bce6e11 100644 --- a/pgtype/oid_value.go +++ b/pgtype/oid_value.go @@ -28,18 +28,18 @@ func (src *OidValue) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OidValue) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *OidValue) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *OidValue) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src OidValue) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src OidValue) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7b1470b7..674c0db7 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -3,6 +3,7 @@ package pgtype import ( "errors" "io" + "reflect" ) // PostgreSQL oids for common types @@ -83,14 +84,14 @@ type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder MUST not retain a reference to // src. It MUST make a copy if it needs to retain the raw bytes. - DecodeBinary(src []byte) error + DecodeBinary(ci *ConnInfo, src []byte) error } type TextDecoder interface { // DecodeText decodes src into TextDecoder. If src is nil then the original // SQL value is NULL. TextDecoder MUST not retain a reference to src. It MUST // make a copy if it needs to retain the raw bytes. - DecodeText(src []byte) error + DecodeText(ci *ConnInfo, src []byte) error } // BinaryEncoder is implemented by types that can encode themselves into the @@ -100,7 +101,7 @@ type BinaryEncoder interface { // SQL value NULL then write nothing and return (true, nil). The caller of // EncodeBinary is responsible for writing the correct NULL value or the // length of the data written. - EncodeBinary(w io.Writer) (null bool, err error) + EncodeBinary(ci *ConnInfo, w io.Writer) (null bool, err error) } // TextEncoder is implemented by types that can encode themselves into the @@ -110,7 +111,127 @@ type TextEncoder interface { // value NULL then write nothing and return (true, nil). The caller of // EncodeText is responsible for writing the correct NULL value or the length // of the data written. - EncodeText(w io.Writer) (null bool, err error) + EncodeText(ci *ConnInfo, w io.Writer) (null bool, err error) } var errUndefined = errors.New("cannot encode status undefined") + +type DataType struct { + Value Value + Name string + Oid Oid +} + +type ConnInfo struct { + oidToDataType map[Oid]*DataType + nameToDataType map[string]*DataType + reflectTypeToDataType map[reflect.Type]*DataType +} + +func NewConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, 256), + nameToDataType: make(map[string]*DataType, 256), + reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), + } +} + +func (ci *ConnInfo) InitializeDataTypes(nameOids map[string]Oid) { + for name, oid := range nameOids { + var value Value + if t, ok := nameValues[name]; ok { + value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) + } else { + value = &GenericText{} + } + ci.RegisterDataType(DataType{Value: value, Name: name, Oid: oid}) + } +} + +func (ci *ConnInfo) RegisterDataType(t DataType) { + ci.oidToDataType[t.Oid] = &t + ci.nameToDataType[t.Name] = &t + ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t +} + +func (ci *ConnInfo) DataTypeForOid(oid Oid) (*DataType, bool) { + dt, ok := ci.oidToDataType[oid] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { + dt, ok := ci.nameToDataType[name] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { + dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] + return dt, ok +} + +// DeepCopy makes a deep copy of the ConnInfo. +func (ci *ConnInfo) DeepCopy() *ConnInfo { + ci2 := &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, len(ci.oidToDataType)), + nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), + reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), + } + + for _, dt := range ci.oidToDataType { + ci2.RegisterDataType(DataType{ + Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), + Name: dt.Name, + Oid: dt.Oid, + }) + } + + return ci2 +} + +var nameValues map[string]Value + +func init() { + nameValues = map[string]Value{ + "_aclitem": &AclitemArray{}, + "_bool": &BoolArray{}, + "_bytea": &ByteaArray{}, + "_cidr": &CidrArray{}, + "_date": &DateArray{}, + "_float4": &Float4Array{}, + "_float8": &Float8Array{}, + "_inet": &InetArray{}, + "_int2": &Int2Array{}, + "_int4": &Int4Array{}, + "_int8": &Int8Array{}, + "_text": &TextArray{}, + "_timestamp": &TimestampArray{}, + "_timestamptz": &TimestamptzArray{}, + "_varchar": &VarcharArray{}, + "aclitem": &Aclitem{}, + "bool": &Bool{}, + "bytea": &Bytea{}, + "char": &QChar{}, + "cid": &Cid{}, + "cidr": &Cidr{}, + "date": &Date{}, + "float4": &Float4{}, + "float8": &Float8{}, + "hstore": &Hstore{}, + "inet": &Inet{}, + "int2": &Int2{}, + "int4": &Int4{}, + "int8": &Int8{}, + "json": &Json{}, + "jsonb": &Jsonb{}, + "name": &Name{}, + "oid": &OidValue{}, + "record": &Record{}, + "text": &Text{}, + "tid": &Tid{}, + "timestamp": &Timestamp{}, + "timestamptz": &Timestamptz{}, + "unknown": &Unknown{}, + "varchar": &Varchar{}, + "xid": &Xid{}, + } +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index f9b6f56d..391fed57 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -60,16 +60,16 @@ type forceTextEncoder struct { e pgtype.TextEncoder } -func (f forceTextEncoder) EncodeText(w io.Writer) (bool, error) { - return f.e.EncodeText(w) +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeText(ci, w) } type forceBinaryEncoder struct { e pgtype.BinaryEncoder } -func (f forceBinaryEncoder) EncodeBinary(w io.Writer) (bool, error) { - return f.e.EncodeBinary(w) +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeBinary(ci, w) } func forceEncoder(e interface{}, formatCode int16) interface{} { @@ -114,7 +114,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int ps.FieldDescriptions[0].FormatCode = fc.formatCode vEncoder := forceEncoder(v, fc.formatCode) if vEncoder == nil { - t.Logf("%v does not implement %v", fc.name) + t.Logf("%#v does not implement %v", v, fc.name) continue } // Derefence value if it is a pointer diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 05c79c0e..3f9e7bf7 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -63,7 +63,7 @@ func (src *pguint32) AssignTo(dst interface{}) error { return nil } -func (dst *pguint32) DecodeText(src []byte) error { +func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -78,7 +78,7 @@ func (dst *pguint32) DecodeText(src []byte) error { return nil } -func (dst *pguint32) DecodeBinary(src []byte) error { +func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *pguint32) DecodeBinary(src []byte) error { return nil } -func (src pguint32) EncodeText(w io.Writer) (bool, error) { +func (src pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src pguint32) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src pguint32) EncodeBinary(w io.Writer) (bool, error) { +func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/qchar.go b/pgtype/qchar.go index d46e716d..4b32ee4a 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -115,7 +115,7 @@ func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *QChar) DecodeBinary(src []byte) error { +func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = QChar{Status: Null} return nil @@ -129,7 +129,7 @@ func (dst *QChar) DecodeBinary(src []byte) error { return nil } -func (src QChar) EncodeBinary(w io.Writer) (bool, error) { +func (src QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/record.go b/pgtype/record.go new file mode 100644 index 00000000..1bfd05b9 --- /dev/null +++ b/pgtype/record.go @@ -0,0 +1,123 @@ +package pgtype + +import ( + "encoding/binary" + "fmt" +) + +// Record is the generic PostgreSQL record type such as is created with the +// "row" function. Record only implements BinaryEncoder and Value. The text +// format output format from PostgreSQL does not include type information and is +// therefore impossible to decode. No encoders are implemented because +// PostgreSQL does not support input of generic records. +type Record struct { + Fields []Value + Status Status +} + +func (dst *Record) Set(src interface{}) error { + switch value := src.(type) { + case []Value: + *dst = Record{Fields: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Record", src) + } + + return nil +} + +func (dst *Record) Get() interface{} { + switch dst.Status { + case Present: + return dst.Fields + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Record) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]Value: + switch src.Status { + case Present: + *v = make([]Value, len(src.Fields)) + copy(*v, src.Fields) + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + case *[]interface{}: + switch src.Status { + case Present: + *v = make([]interface{}, len(src.Fields)) + for i := range *v { + (*v)[i] = src.Fields[i].Get() + } + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + fields := make([]Value, fieldCount) + + for i := 0; i < fieldCount; i++ { + if len(src[rp:]) < 8 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldOid := Oid(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var binaryDecoder BinaryDecoder + if dt, ok := ci.DataTypeForOid(fieldOid); ok { + if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { + return fmt.Errorf("unknown oid while decoding record: %v", fieldOid) + } + } + + var fieldBytes []byte + if fieldLen >= 0 { + if len(src[rp:]) < fieldLen { + return fmt.Errorf("Record incomplete %v", src) + } + fieldBytes = src[rp : rp+fieldLen] + rp += fieldLen + } + + if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + fields[i] = binaryDecoder.(Value) + } + + *dst = Record{Fields: fields, Status: Present} + + return nil +} diff --git a/pgtype/record_test.go b/pgtype/record_test.go new file mode 100644 index 00000000..bc6e5893 --- /dev/null +++ b/pgtype/record_test.go @@ -0,0 +1,150 @@ +package pgtype_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestRecordTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + tests := []struct { + sql string + expected pgtype.Record + }{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{Status: pgtype.Null}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{ + Status: pgtype.Null, + }, + }, + } + + for i, tt := range tests { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.sql) + if err != nil { + t.Fatal(err) + } + ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode + + var result pgtype.Record + if err := conn.QueryRow(psName).Scan(&result); err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("%d: expected %v, got %v", i, tt.expected, result) + } + } +} + +func TestRecordAssignTo(t *testing.T) { + var valueSlice []pgtype.Value + var interfaceSlice []interface{} + + simpleTests := []struct { + src pgtype.Record + dst interface{} + expected interface{} + }{ + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &valueSlice, + expected: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + }, + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &interfaceSlice, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &valueSlice, + expected: (([]pgtype.Value)(nil)), + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &interfaceSlice, + expected: (([]interface{})(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/text.go b/pgtype/text.go index 3dd082c9..f1a76b6e 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -78,7 +78,7 @@ func (src *Text) AssignTo(dst interface{}) error { return nil } -func (dst *Text) DecodeText(src []byte) error { +func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Text{Status: Null} return nil @@ -88,11 +88,11 @@ func (dst *Text) DecodeText(src []byte) error { return nil } -func (dst *Text) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Text) EncodeText(w io.Writer) (bool, error) { +func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -104,6 +104,6 @@ func (src Text) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Text) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 1e6677a9..6e89708f 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -83,7 +83,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return nil } -func (dst *TextArray) DecodeText(src []byte) error { +func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *TextArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *TextArray) DecodeText(src []byte) error { return nil } -func (dst *TextArray) DecodeBinary(src []byte) error { +func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { return nil } -func (src *TextArray) EncodeText(w io.Writer) (bool, error) { +func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TextArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TextOid) +func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TextOid) } -func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/tid.go b/pgtype/tid.go index 20d962df..b91711d3 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -46,7 +46,7 @@ func (src *Tid) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } -func (dst *Tid) DecodeText(src []byte) error { +func (dst *Tid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -75,7 +75,7 @@ func (dst *Tid) DecodeText(src []byte) error { return nil } -func (dst *Tid) DecodeBinary(src []byte) error { +func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Tid) DecodeBinary(src []byte) error { return nil } -func (src Tid) EncodeText(w io.Writer) (bool, error) { +func (src Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src Tid) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Tid) EncodeBinary(w io.Writer) (bool, error) { +func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 3bb8f080..9a9e74ea 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -85,7 +85,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { // DecodeText decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeText(src []byte) error { +func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Timestamp) DecodeText(src []byte) error { // DecodeBinary decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeBinary(src []byte) error { +func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -139,7 +139,7 @@ func (dst *Timestamp) DecodeBinary(src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeText(w io.Writer) (bool, error) { +func (src Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -167,7 +167,7 @@ func (src Timestamp) EncodeText(w io.Writer) (bool, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index c955dc42..064ad483 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -84,7 +84,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestampArray) DecodeText(src []byte) error { +func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestampArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestampArray) DecodeText(src []byte) error { return nil } -func (dst *TimestampArray) DecodeBinary(src []byte) error { +func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestampArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestampOid) +func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestampOid) } -func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 5b9f5038..7f57f4b7 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -84,7 +84,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return nil } -func (dst *Timestamptz) DecodeText(src []byte) error { +func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Timestamptz) DecodeText(src []byte) error { return nil } -func (dst *Timestamptz) DecodeBinary(src []byte) error { +func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -143,7 +143,7 @@ func (dst *Timestamptz) DecodeBinary(src []byte) error { return nil } -func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Timestamptz) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index cd63e02e..4af1460b 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -84,7 +84,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestamptzArray) DecodeText(src []byte) error { +func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { return nil } -func (dst *TimestamptzArray) DecodeBinary(src []byte) error { +func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestamptzArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestamptzOid) +func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestamptzOid) } -func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index a56097c0..2a46a658 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -82,7 +82,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,14 +118,14 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -150,7 +150,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { elemSrc = src[rp:rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -160,7 +160,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { return nil } -func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { +func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -207,7 +207,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -236,11 +236,11 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, <%= element_oid %>) +func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, <%= element_oid %>) } -func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -260,7 +260,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -270,7 +270,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 41c1313f..5fde32aa 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -8,6 +8,8 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_oid=CidrOid text_null=NULL typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_oid=VarcharOid text_null='"NULL"' typed_array.go.erb > varchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go diff --git a/pgtype/unknown.go b/pgtype/unknown.go new file mode 100644 index 00000000..b951ad99 --- /dev/null +++ b/pgtype/unknown.go @@ -0,0 +1,32 @@ +package pgtype + +// Unknown represents the PostgreSQL unknown type. It is either a string literal +// or NULL. It is used when PostgreSQL does not know the type of a value. In +// general, this will only be used in pgx when selecting a null value without +// type information. e.g. SELECT NULL; +type Unknown struct { + String string + Status Status +} + +func (dst *Unknown) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Unknown) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Unknown is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Unknown) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} diff --git a/pgtype/varchar.go b/pgtype/varchar.go new file mode 100644 index 00000000..adda6c49 --- /dev/null +++ b/pgtype/varchar.go @@ -0,0 +1,40 @@ +package pgtype + +import ( + "io" +) + +type Varchar Text + +// Set converts from src to dst. Note that as Varchar is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *Varchar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Varchar) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Varchar is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Varchar) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) +} + +func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) +} diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 693b9a61..21e9ccff 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -1,35 +1,296 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + + "github.com/jackc/pgx/pgio" ) -type VarcharArray TextArray +type VarcharArray struct { + Elements []Varchar + Dimensions []ArrayDimension + Status Status +} func (dst *VarcharArray) Set(src interface{}) error { - return (*TextArray)(dst).Set(src) + switch value := src.(type) { + + case []string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Varchar", value) + } + + return nil } func (dst *VarcharArray) Get() interface{} { - return (*TextArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *VarcharArray) AssignTo(dst interface{}) error { - return (*TextArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *VarcharArray) DecodeText(src []byte) error { - return (*TextArray)(dst).DecodeText(src) +func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Varchar + + if len(uta.Elements) > 0 { + elements = make([]Varchar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Varchar + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *VarcharArray) DecodeBinary(src []byte) error { - return (*TextArray)(dst).DecodeBinary(src) +func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Varchar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *VarcharArray) EncodeText(w io.Writer) (bool, error) { - return (*TextArray)(src).EncodeText(w) +func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `"NULL"`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *VarcharArray) EncodeBinary(w io.Writer) (bool, error) { - return (*TextArray)(src).encodeBinary(w, VarcharOid) +func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, VarcharOid) +} + +func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go new file mode 100644 index 00000000..4a8b09b8 --- /dev/null +++ b/pgtype/varchar_array_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestVarcharArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "varchar[]", []interface{}{ + &pgtype.VarcharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{Status: pgtype.Null}, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar ", Status: pgtype.Present}, + pgtype.Varchar{String: "NuLL", Status: pgtype.Present}, + pgtype.Varchar{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.Varchar{String: "", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + pgtype.Varchar{String: "null", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar", Status: pgtype.Present}, + pgtype.Varchar{String: "baz", Status: pgtype.Present}, + pgtype.Varchar{String: "quz", Status: pgtype.Present}, + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestVarcharArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.VarcharArray + }{ + { + source: []string{"foo"}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.VarcharArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.VarcharArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestVarcharArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.VarcharArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.VarcharArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.VarcharArray + dst interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/xid.go b/pgtype/xid.go index a53120de..c76548a4 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -37,18 +37,18 @@ func (src *Xid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Xid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Xid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Xid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Xid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Xid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/query.go b/query.go index 63ce91ed..48a657f9 100644 --- a/query.go +++ b/query.go @@ -212,74 +212,86 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } } } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { - err = s.DecodeBinary(vr.bytes()) + err = s.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode { - err = s.DecodeText(vr.bytes()) + err = s.DecodeText(rows.conn.ConnInfo, vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(sql.Scanner); ok { - var val interface{} + var sqlSrc interface{} if 0 <= vr.Len() { - switch vr.Type().DataType { - case BoolOid: - val = decodeBool(vr) - case Int8Oid: - val = int64(decodeInt8(vr)) - case Int2Oid: - val = int64(decodeInt2(vr)) - case Int4Oid: - val = int64(decodeInt4(vr)) - case TextOid, VarcharOid: - val = decodeText(vr) - case Float4Oid: - val = float64(decodeFloat4(vr)) - case Float8Oid: - val = decodeFloat8(vr) - case DateOid: - val = decodeDate(vr) - case TimestampOid: - val = decodeTimestamp(vr) - case TimestampTzOid: - val = decodeTimestampTz(vr) - default: - val = vr.ReadBytes(vr.Len()) + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + default: + rows.Fatal(errors.New("Unknown format code")) + } + + sqlSrc, err = pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) + if err != nil { + rows.Fatal(err) + } + } else { + rows.Fatal(errors.New("Unknown type")) } } - err = s.Scan(val) + err = s.Scan(sqlSrc) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else { - if pgVal, present := rows.conn.oidPgtypeValues[vr.Type().DataType]; present { + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value switch vr.Type().FormatCode { case TextFormatCode: - if textDecoder, ok := pgVal.(pgtype.TextDecoder); ok { - err = textDecoder.DecodeText(vr.bytes()) + if textDecoder, ok := value.(pgtype.TextDecoder); ok { + err = textDecoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) if err != nil { vr.Fatal(err) } } else { - vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", pgVal)) + vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", value)) } case BinaryFormatCode: - if binaryDecoder, ok := pgVal.(pgtype.BinaryDecoder); ok { - err = binaryDecoder.DecodeBinary(vr.bytes()) + if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { + err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) if err != nil { vr.Fatal(err) } } else { - vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", pgVal)) + vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", value)) } default: vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) } - if err := pgVal.AssignTo(d); err != nil { - vr.Fatal(err) + if vr.Err() == nil { + if err := value.AssignTo(d); err != nil { + vr.Fatal(err) + } } } else { if err := Decode(vr, d); err != nil { @@ -315,29 +327,35 @@ func (rows *Rows) Values() ([]interface{}, error) { continue } - switch vr.Type().FormatCode { - case TextFormatCode: - decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.TextDecoder) - if decoder == nil { - decoder = &pgtype.GenericText{} + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + values = append(values, value.Get()) + default: + rows.Fatal(errors.New("Unknown format code")) } - err := decoder.DecodeText(vr.bytes()) - if err != nil { - rows.Fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) - case BinaryFormatCode: - decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.BinaryDecoder) - if decoder == nil { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(vr.bytes()) - if err != nil { - rows.Fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) - default: - rows.Fatal(errors.New("Unknown format code")) + } else { + rows.Fatal(errors.New("Unknown type")) } if vr.Err() != nil { @@ -368,49 +386,41 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { values = append(values, nil) continue } - // TODO - consider what are the implications of returning complex types since database/sql uses this method - switch vr.Type().FormatCode { - // All intrinsic types (except string) are encoded with binary - // encoding so anything else should be treated as a string - case TextFormatCode: - values = append(values, vr.ReadString(vr.Len())) - case BinaryFormatCode: - switch vr.Type().DataType { - case TextOid, VarcharOid: - values = append(values, decodeText(vr)) - case BoolOid: - values = append(values, decodeBool(vr)) - case ByteaOid: - values = append(values, decodeBytea(vr)) - case Int8Oid: - values = append(values, decodeInt8(vr)) - case Int2Oid: - values = append(values, decodeInt2(vr)) - case Int4Oid: - values = append(values, decodeInt4(vr)) - case Float4Oid: - values = append(values, decodeFloat4(vr)) - case Float8Oid: - values = append(values, decodeFloat8(vr)) - case DateOid: - values = append(values, decodeDate(vr)) - case TimestampTzOid: - values = append(values, decodeTimestampTz(vr)) - case TimestampOid: - values = append(values, decodeTimestamp(vr)) - case JsonOid: - var d interface{} - decodeJSON(vr, &d) - values = append(values, d) - case JsonbOid: - var d interface{} - decodeJSONB(vr, &d) - values = append(values, d) + + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } default: - rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) + rows.Fatal(errors.New("Unknown format code")) } - default: - rows.Fatal(errors.New("Unknown format code")) + + sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) + if err != nil { + rows.Fatal(err) + } + + values = append(values, sqlSrc) + } else { + rows.Fatal(errors.New("Unknown type")) } if vr.Err() != nil { diff --git a/query_test.go b/query_test.go index 01889444..480959e8 100644 --- a/query_test.go +++ b/query_test.go @@ -776,7 +776,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "Cannot encode int into oid 600"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Text"}, } for i, tt := range tests { diff --git a/values.go b/values.go index d90c363b..4eb24eef 100644 --- a/values.go +++ b/values.go @@ -5,9 +5,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "math" "reflect" - "time" "github.com/jackc/pgx/pgtype" ) @@ -167,7 +165,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { switch arg := arg.(type) { case pgtype.BinaryEncoder: buf := &bytes.Buffer{} - null, err := arg.EncodeBinary(buf) + null, err := arg.EncodeBinary(wbuf.conn.ConnInfo, buf) if err != nil { return err } @@ -180,7 +178,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { return nil case pgtype.TextEncoder: buf := &bytes.Buffer{} - null, err := arg.EncodeText(buf) + null, err := arg.EncodeText(wbuf.conn.ConnInfo, buf) if err != nil { return err } @@ -214,14 +212,15 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { return Encode(wbuf, oid, arg) } - if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok { + if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok { + value := dt.Value err := value.Set(arg) if err != nil { return err } buf := &bytes.Buffer{} - null, err := value.(pgtype.BinaryEncoder).EncodeBinary(buf) + null, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, buf) if err != nil { return err } @@ -287,8 +286,6 @@ func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { case *string: *v = decodeText(vr) - case *[]interface{}: - *v = decodeRecord(vr) default: if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { el := v.Elem() @@ -320,232 +317,6 @@ func Decode(vr *ValueReader, d interface{}) error { return nil } -func decodeBool(vr *ValueReader) bool { - if vr.Type().DataType != BoolOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) - return false - } - - var b pgtype.Bool - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = b.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = b.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return false - } - - if err != nil { - vr.Fatal(err) - return false - } - - if b.Status != pgtype.Present { - vr.Fatal(fmt.Errorf("Cannot decode null into bool")) - return false - } - - return b.Bool -} - -func decodeInt(vr *ValueReader) int64 { - switch vr.Type().DataType { - case Int2Oid: - return int64(decodeInt2(vr)) - case Int4Oid: - return int64(decodeInt4(vr)) - case Int8Oid: - return int64(decodeInt8(vr)) - } - - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into any integer type", vr.Type().DataType))) - return 0 -} - -func decodeInt8(vr *ValueReader) int64 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int64")) - return 0 - } - - if vr.Type().DataType != Int8Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType))) - return 0 - } - - var n pgtype.Int8 - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = n.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = n.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if err != nil { - vr.Fatal(err) - return 0 - } - - if n.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - return n.Int -} - -func decodeInt2(vr *ValueReader) int16 { - - if vr.Type().DataType != Int2Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) - return 0 - } - - var n pgtype.Int2 - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = n.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = n.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if err != nil { - vr.Fatal(err) - return 0 - } - - if n.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - return n.Int -} - -func decodeInt4(vr *ValueReader) int32 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int32")) - return 0 - } - - if vr.Type().DataType != Int4Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType))) - return 0 - } - - var n pgtype.Int4 - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = n.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = n.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if err != nil { - vr.Fatal(err) - return 0 - } - - if n.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - return n.Int -} - -func decodeFloat4(vr *ValueReader) float32 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into float32")) - return 0 - } - - if vr.Type().DataType != Float4Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt32() - return math.Float32frombits(uint32(i)) -} - -func encodeFloat32(w *WriteBuf, oid pgtype.Oid, value float32) error { - switch oid { - case Float4Oid: - w.WriteInt32(4) - w.WriteInt32(int32(math.Float32bits(value))) - case Float8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(float64(value)))) - default: - return fmt.Errorf("cannot encode %s into oid %v", "float32", oid) - } - - return nil -} - -func decodeFloat8(vr *ValueReader) float64 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into float64")) - return 0 - } - - if vr.Type().DataType != Float8Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt64() - return math.Float64frombits(uint64(i)) -} - -func encodeFloat64(w *WriteBuf, oid pgtype.Oid, value float64) error { - switch oid { - case Float8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(value))) - default: - return fmt.Errorf("cannot encode %s into oid %v", "float64", oid) - } - - return nil -} - func decodeText(vr *ValueReader) string { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into string")) @@ -677,215 +448,3 @@ func encodeJSONB(w *WriteBuf, oid pgtype.Oid, value interface{}) error { return nil } - -func decodeDate(vr *ValueReader) time.Time { - if vr.Type().DataType != DateOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return time.Time{} - } - - var d pgtype.Date - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = d.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = d.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return time.Time{} - } - - if err != nil { - vr.Fatal(err) - return time.Time{} - } - - if d.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return time.Time{} - } - - return d.Time -} - -func encodeTime(w *WriteBuf, oid pgtype.Oid, value time.Time) error { - switch oid { - case DateOid: - var d pgtype.Date - err := d.Set(value) - if err != nil { - return err - } - - buf := &bytes.Buffer{} - null, err := d.EncodeBinary(buf) - if err != nil { - return err - } - if null { - w.WriteInt32(-1) - } else { - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - } - return nil - - case TimestampTzOid, TimestampOid: - var t pgtype.Timestamptz - err := t.Set(value) - if err != nil { - return err - } - - buf := &bytes.Buffer{} - null, err := t.EncodeBinary(buf) - if err != nil { - return err - } - if null { - w.WriteInt32(-1) - } else { - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - } - return nil - default: - return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) - } -} - -const microsecFromUnixEpochToY2K = 946684800 * 1000000 - -func decodeTimestampTz(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return zeroTime - } - - if vr.Type().DataType != TimestampTzOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime - } - - var t pgtype.Timestamptz - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = t.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = t.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return time.Time{} - } - - if err != nil { - vr.Fatal(err) - return time.Time{} - } - - if t.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return time.Time{} - } - - return t.Time -} - -func decodeTimestamp(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into timestamp")) - return zeroTime - } - - if vr.Type().DataType != TimestampOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len()))) - return zeroTime - } - - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) -} - -func decodeRecord(vr *ValueReader) []interface{} { - if vr.Len() == -1 { - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - if vr.Type().DataType != RecordOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType))) - return nil - } - - valueCount := vr.ReadInt32() - record := make([]interface{}, 0, int(valueCount)) - - for i := int32(0); i < valueCount; i++ { - fd := FieldDescription{FormatCode: BinaryFormatCode} - fieldVR := ValueReader{mr: vr.mr, fd: &fd} - fd.DataType = vr.ReadOid() - fieldVR.valueBytesRemaining = vr.ReadInt32() - vr.valueBytesRemaining -= fieldVR.valueBytesRemaining - - switch fd.DataType { - case BoolOid: - record = append(record, decodeBool(&fieldVR)) - case ByteaOid: - record = append(record, decodeBytea(&fieldVR)) - case Int8Oid: - record = append(record, decodeInt8(&fieldVR)) - case Int2Oid: - record = append(record, decodeInt2(&fieldVR)) - case Int4Oid: - record = append(record, decodeInt4(&fieldVR)) - case Float4Oid: - record = append(record, decodeFloat4(&fieldVR)) - case Float8Oid: - record = append(record, decodeFloat8(&fieldVR)) - case DateOid: - record = append(record, decodeDate(&fieldVR)) - case TimestampTzOid: - record = append(record, decodeTimestampTz(&fieldVR)) - case TimestampOid: - record = append(record, decodeTimestamp(&fieldVR)) - case TextOid, VarcharOid, UnknownOid: - record = append(record, decodeTextAllowBinary(&fieldVR)) - default: - vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType)) - return nil - } - - // Consume any remaining data - if fieldVR.Len() > 0 { - fieldVR.ReadBytes(fieldVR.Len()) - } - - if fieldVR.Err() != nil { - vr.Fatal(fieldVR.Err()) - return nil - } - } - - return record -} diff --git a/values_test.go b/values_test.go index e7ae7e1d..1d09eb18 100644 --- a/values_test.go +++ b/values_test.go @@ -6,9 +6,6 @@ import ( "reflect" "testing" "time" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" ) func TestDateTranscode(t *testing.T) { @@ -78,159 +75,161 @@ func TestTimestampTzTranscode(t *testing.T) { } } -func TestJSONAndJSONBTranscode(t *testing.T) { - t.Parallel() +// TODO - move these tests to pgtype - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) +// func TestJSONAndJSONBTranscode(t *testing.T) { +// t.Parallel() - for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { - return // No JSON/JSONB type -- must be running against old PostgreSQL - } +// conn := mustConnect(t, *defaultConnConfig) +// defer closeConn(t, conn) - for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { - pgtype := conn.PgTypes[oid] - pgtype.DefaultFormat = format - conn.PgTypes[oid] = pgtype +// for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { +// if _, ok := conn.ConnInfo.DataTypeForOid(oid); !ok { +// return // No JSON/JSONB type -- must be running against old PostgreSQL +// } - typename := conn.PgTypes[oid].Name +// for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { +// pgtype := conn.PgTypes[oid] +// pgtype.DefaultFormat = format +// conn.PgTypes[oid] = pgtype - testJSONString(t, conn, typename, format) - testJSONStringPointer(t, conn, typename, format) - testJSONSingleLevelStringMap(t, conn, typename, format) - testJSONNestedMap(t, conn, typename, format) - testJSONStringArray(t, conn, typename, format) - testJSONInt64Array(t, conn, typename, format) - testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format) - testJSONStruct(t, conn, typename, format) - } - } -} +// typename := conn.PgTypes[oid].Name -func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := `{"key": "value"}` - expectedOutput := map[string]string{"key": "value"} - var output map[string]string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// testJSONString(t, conn, typename, format) +// testJSONStringPointer(t, conn, typename, format) +// testJSONSingleLevelStringMap(t, conn, typename, format) +// testJSONNestedMap(t, conn, typename, format) +// testJSONStringArray(t, conn, typename, format) +// testJSONInt64Array(t, conn, typename, format) +// testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format) +// testJSONStruct(t, conn, typename, format) +// } +// } +// } - if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) - return - } -} +// func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := `{"key": "value"}` +// expectedOutput := map[string]string{"key": "value"} +// var output map[string]string +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := `{"key": "value"}` - expectedOutput := map[string]string{"key": "value"} - var output map[string]string - err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// if !reflect.DeepEqual(expectedOutput, output) { +// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) +// return +// } +// } - if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) - return - } -} +// func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := `{"key": "value"}` +// expectedOutput := map[string]string{"key": "value"} +// var output map[string]string +// err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := map[string]string{"key": "value"} - var output map[string]string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// if !reflect.DeepEqual(expectedOutput, output) { +// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) +// return +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) - return - } -} +// func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := map[string]string{"key": "value"} +// var output map[string]string +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := map[string]interface{}{ - "name": "Uncanny", - "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, - "inventory": []interface{}{"phone", "key"}, - } - var output map[string]interface{} - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) +// return +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) - return - } -} +// func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := map[string]interface{}{ +// "name": "Uncanny", +// "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, +// "inventory": []interface{}{"phone", "key"}, +// } +// var output map[string]interface{} +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := []string{"foo", "bar", "baz"} - var output []string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - } +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) +// return +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) - } -} +// func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := []string{"foo", "bar", "baz"} +// var output []string +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// } -func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := []int64{1, 2, 234432} - var output []int64 - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - } +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) - } -} +// func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := []int64{1, 2, 234432} +// var output []int64 +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// } -func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := []int{1, 2, 234432} - var output []int16 - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { - t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) - } -} +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) +// } +// } -func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { - type person struct { - Name string `json:"name"` - Age int `json:"age"` - } +// func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := []int{1, 2, 234432} +// var output []int16 +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { +// t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) +// } +// } - input := person{ - Name: "John", - Age: 42, - } +// func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// type person struct { +// Name string `json:"name"` +// Age int `json:"age"` +// } - var output person +// input := person{ +// Name: "John", +// Age: 42, +// } - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - } +// var output person - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) - } -} +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// } + +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) +// } +// } func mustParseCidr(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s)