Add database/sql support to pgtype

v3-numeric-wip
Jack Christensen 2017-03-18 21:11:43 -05:00
parent 5572c002dc
commit bec9bd261b
55 changed files with 1459 additions and 201 deletions

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"fmt"
"io"
)
@ -93,3 +94,32 @@ func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := io.WriteString(w, src.String)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Aclitem) Scan(src interface{}) error {
if src == nil {
*dst = Aclitem{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Aclitem) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.String, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"fmt"
"io"
@ -194,3 +195,33 @@ func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
return false, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *AclitemArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *AclitemArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"fmt"
"io"
"strconv"
@ -126,3 +127,35 @@ func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := w.Write(buf)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Bool) Scan(src interface{}) error {
if src == nil {
*dst = Bool{Status: Null}
return nil
}
switch src := src.(type) {
case bool:
*dst = Bool{Bool: src, Status: Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Bool) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.Bool, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -296,3 +297,33 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *BoolArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *BoolArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/hex"
"fmt"
"io"
@ -12,6 +13,11 @@ type Bytea struct {
}
func (dst *Bytea) Set(src interface{}) error {
if src == nil {
*dst = Bytea{Status: Null}
return nil
}
switch value := src.(type) {
case []byte:
if value != nil {
@ -124,3 +130,35 @@ func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := w.Write(src.Bytes)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Bytea) Scan(src interface{}) error {
if src == nil {
*dst = Bytea{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
buf := make([]byte, len(src))
copy(buf, src)
*dst = Bytea{Bytes: buf, Status: Present}
return nil
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Bytea) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.Bytes, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -296,3 +297,33 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *ByteaArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *ByteaArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"io"
)
@ -49,3 +50,13 @@ func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return (pguint32)(src).EncodeBinary(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *Cid) Scan(src interface{}) error {
return (*pguint32)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Cid) Value() (driver.Value, error) {
return (pguint32)(src).Value()
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -325,3 +326,33 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *CidrArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *CidrArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -2,47 +2,13 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"errors"
)
func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) {
switch src := src.(type) {
case *Bool:
return src.Bool, nil
case *Bytea:
return src.Bytes, nil
case *Date:
if src.InfinityModifier == None {
return src.Time, nil
}
case *Float4:
return float64(src.Float), nil
case *Float8:
return src.Float, nil
case *GenericBinary:
return src.Bytes, nil
case *GenericText:
return src.String, nil
case *Int2:
return int64(src.Int), nil
case *Int4:
return int64(src.Int), nil
case *Int8:
return int64(src.Int), nil
case *Text:
return src.String, nil
case *Timestamp:
if src.InfinityModifier == None {
return src.Time, nil
}
case *Timestamptz:
if src.InfinityModifier == None {
return src.Time, nil
}
case *Unknown:
return src.String, nil
case *Varchar:
return src.String, nil
if valuer, ok := src.(driver.Valuer); ok {
return valuer.Value()
}
buf := &bytes.Buffer{}
@ -64,3 +30,15 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) {
return nil, errors.New("cannot convert to database/sql compatible value")
}
func encodeValueText(src TextEncoder) (interface{}, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), err
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -10,9 +11,9 @@ import (
)
type Date struct {
Time time.Time
Status Status
InfinityModifier
Time time.Time
Status Status
InfinityModifier InfinityModifier
}
const (
@ -21,6 +22,11 @@ const (
)
func (dst *Date) Set(src interface{}) error {
if src == nil {
*dst = Date{Status: Null}
return nil
}
switch value := src.(type) {
case time.Time:
*dst = Date{Time: value, Status: Present}
@ -167,3 +173,38 @@ func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteInt32(w, daysSinceDateEpoch)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Date) Scan(src interface{}) error {
if src == nil {
*dst = Date{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
case time.Time:
*dst = Date{Time: src, Status: Present}
return nil
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Date) Value() (driver.Value, error) {
switch src.Status {
case Present:
if src.InfinityModifier != None {
return src.InfinityModifier.String(), nil
}
return src.Time, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -297,3 +298,33 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *DateArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *DateArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -9,7 +9,7 @@ import (
)
func TestDateTranscode(t *testing.T) {
testSuccessfulTranscode(t, "date", []interface{}{
testSuccessfulTranscodeEqFunc(t, "date", []interface{}{
pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present},
pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present},
pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present},
@ -19,6 +19,11 @@ func TestDateTranscode(t *testing.T) {
pgtype.Date{Status: pgtype.Null},
pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity},
pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity},
}, func(a, b interface{}) bool {
at := a.(pgtype.Date)
bt := b.(pgtype.Date)
return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier
})
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -16,6 +17,11 @@ type Float4 struct {
}
func (dst *Float4) Set(src interface{}) error {
if src == nil {
*dst = Float4{Status: Null}
return nil
}
switch value := src.(type) {
case float32:
*dst = Float4{Float: value, Status: Present}
@ -156,3 +162,35 @@ func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteInt32(w, int32(math.Float32bits(src.Float)))
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Float4) Scan(src interface{}) error {
if src == nil {
*dst = Float4{Status: Null}
return nil
}
switch src := src.(type) {
case float64:
*dst = Float4{Float: float32(src), Status: Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Float4) Value() (driver.Value, error) {
switch src.Status {
case Present:
return float64(src.Float), nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -296,3 +297,33 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Float4Array) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Float4Array) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -16,6 +17,11 @@ type Float8 struct {
}
func (dst *Float8) Set(src interface{}) error {
if src == nil {
*dst = Float8{Status: Null}
return nil
}
switch value := src.(type) {
case float32:
*dst = Float8{Float: float64(value), Status: Present}
@ -146,3 +152,35 @@ func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteInt64(w, int64(math.Float64bits(src.Float)))
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Float8) Scan(src interface{}) error {
if src == nil {
*dst = Float8{Status: Null}
return nil
}
switch src := src.(type) {
case float64:
*dst = Float8{Float: src, Status: Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Float8) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.Float, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -296,3 +297,33 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Float8Array) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Float8Array) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"io"
)
@ -27,3 +28,13 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error {
func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return (Bytea)(src).EncodeBinary(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *GenericBinary) Scan(src interface{}) error {
return (*Bytea)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src GenericBinary) Value() (driver.Value, error) {
return (Bytea)(src).Value()
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"io"
)
@ -27,3 +28,13 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error {
func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
return (Text)(src).EncodeText(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *GenericText) Scan(src interface{}) error {
return (*Text)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src GenericText) Value() (driver.Value, error) {
return (Text)(src).Value()
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
@ -21,6 +22,11 @@ type Hstore struct {
}
func (dst *Hstore) Set(src interface{}) error {
if src == nil {
*dst = Hstore{Status: Null}
return nil
}
switch value := src.(type) {
case map[string]string:
m := make(map[string]Text, len(value))
@ -437,3 +443,25 @@ func parseHstore(s string) (k []string, v []Text, err error) {
v = values
return
}
// Scan implements the database/sql Scanner interface.
func (dst *Hstore) Scan(src interface{}) error {
if src == nil {
*dst = Hstore{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Hstore) Value() (driver.Value, error) {
return encodeValueText(src)
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -296,3 +297,33 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *HstoreArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *HstoreArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"fmt"
"io"
"net"
@ -23,6 +24,11 @@ type Inet struct {
}
func (dst *Inet) Set(src interface{}) error {
if src == nil {
*dst = Inet{Status: Null}
return nil
}
switch value := src.(type) {
case net.IPNet:
*dst = Inet{IPNet: &value, Status: Present}
@ -189,3 +195,25 @@ func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := w.Write(src.IPNet.IP)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Inet) Scan(src interface{}) error {
if src == nil {
*dst = Inet{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Inet) Value() (driver.Value, error) {
return encodeValueText(src)
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -325,3 +326,33 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *InetArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *InetArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -16,6 +17,11 @@ type Int2 struct {
}
func (dst *Int2) Set(src interface{}) error {
if src == nil {
*dst = Int2{Status: Null}
return nil
}
switch value := src.(type) {
case int8:
*dst = Int2{Int: int16(value), Status: Present}
@ -151,3 +157,41 @@ func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteInt16(w, src.Int)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Int2) Scan(src interface{}) error {
if src == nil {
*dst = Int2{Status: Null}
return nil
}
switch src := src.(type) {
case int64:
if src < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", src)
}
if src > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", src)
}
*dst = Int2{Int: int16(src), Status: Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Int2) Value() (driver.Value, error) {
switch src.Status {
case Present:
return int64(src.Int), nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -324,3 +325,33 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Int2Array) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Int2Array) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -16,6 +17,11 @@ type Int4 struct {
}
func (dst *Int4) Set(src interface{}) error {
if src == nil {
*dst = Int4{Status: Null}
return nil
}
switch value := src.(type) {
case int8:
*dst = Int4{Int: int32(value), Status: Present}
@ -68,7 +74,7 @@ func (dst *Int4) Set(src interface{}) error {
if originalSrc, ok := underlyingNumberType(src); ok {
return dst.Set(originalSrc)
}
return fmt.Errorf("cannot convert %v to Int8", value)
return fmt.Errorf("cannot convert %v to Int4", value)
}
return nil
@ -142,3 +148,41 @@ func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteInt32(w, src.Int)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Int4) Scan(src interface{}) error {
if src == nil {
*dst = Int4{Status: Null}
return nil
}
switch src := src.(type) {
case int64:
if src < math.MinInt32 {
return fmt.Errorf("%d is greater than maximum value for Int4", src)
}
if src > math.MaxInt32 {
return fmt.Errorf("%d is greater than maximum value for Int4", src)
}
*dst = Int4{Int: int32(src), Status: Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Int4) Value() (driver.Value, error) {
switch src.Status {
case Present:
return int64(src.Int), nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -324,3 +325,33 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Int4Array) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Int4Array) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -16,6 +17,11 @@ type Int8 struct {
}
func (dst *Int8) Set(src interface{}) error {
if src == nil {
*dst = Int8{Status: Null}
return nil
}
switch value := src.(type) {
case int8:
*dst = Int8{Int: int64(value), Status: Present}
@ -134,3 +140,35 @@ func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteInt64(w, src.Int)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Int8) Scan(src interface{}) error {
if src == nil {
*dst = Int8{Status: Null}
return nil
}
switch src := src.(type) {
case int64:
*dst = Int8{Int: src, Status: Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Int8) Value() (driver.Value, error) {
switch src.Status {
case Present:
return int64(src.Int), nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -324,3 +325,33 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Int8Array) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Int8Array) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,7 +1,9 @@
package pgtype
import (
"database/sql/driver"
"encoding/json"
"fmt"
"io"
)
@ -11,6 +13,11 @@ type Json struct {
}
func (dst *Json) Set(src interface{}) error {
if src == nil {
*dst = Json{Status: Null}
return nil
}
switch value := src.(type) {
case string:
*dst = Json{Bytes: []byte(value), Status: Present}
@ -116,3 +123,32 @@ func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return src.EncodeText(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *Json) Scan(src interface{}) error {
if src == nil {
*dst = Json{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Json) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.Bytes, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"fmt"
"io"
)
@ -66,3 +67,13 @@ func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err = w.Write(src.Bytes)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Jsonb) Scan(src interface{}) error {
return (*Json)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Jsonb) Value() (driver.Value, error) {
return (Json)(src).Value()
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"io"
)
@ -46,3 +47,13 @@ func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return (Text)(src).EncodeBinary(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *Name) Scan(src interface{}) error {
return (*Text)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Name) Value() (driver.Value, error) {
return (Text)(src).Value()
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -55,3 +56,27 @@ func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteUint32(w, uint32(src))
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Oid) Scan(src interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan NULL into %T", src)
}
switch src := src.(type) {
case int64:
*dst = Oid(src)
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Oid) Value() (driver.Value, error) {
return int64(src), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"io"
)
@ -43,3 +44,13 @@ func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return (pguint32)(src).EncodeBinary(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *OidValue) Scan(src interface{}) error {
return (*pguint32)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src OidValue) Value() (driver.Value, error) {
return (pguint32)(src).Value()
}

View File

@ -67,6 +67,19 @@ const (
NegativeInfinity InfinityModifier = -Infinity
)
func (im InfinityModifier) String() string {
switch im {
case None:
return "none"
case Infinity:
return "infinity"
case NegativeInfinity:
return "-infinity"
default:
return "invalid"
}
}
type Value interface {
// Set converts and assigns src to itself.
Set(src interface{}) error

View File

@ -1,6 +1,7 @@
package pgtype_test
import (
"database/sql"
"fmt"
"io"
"net"
@ -10,6 +11,8 @@ import (
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
_ "github.com/jackc/pgx/stdlib"
_ "github.com/lib/pq"
)
// Test for renamed types
@ -24,6 +27,25 @@ type _float32Slice []float32
type _float64Slice []float64
type _byteSlice []byte
func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
var sqlDriverName string
switch driverName {
case "github.com/lib/pq":
sqlDriverName = "postgres"
case "github.com/jackc/pgx/stdlib":
sqlDriverName = "pgx"
default:
t.Fatalf("Unknown driver %v", driverName)
}
db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL"))
if err != nil {
t.Fatal(err)
}
return db
}
func mustConnectPgx(t testing.TB) *pgx.Conn {
config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))
if err != nil {
@ -93,6 +115,13 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface
}
func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
}
}
func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
conn := mustConnectPgx(t)
defer mustClose(t, conn)
@ -114,7 +143,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int
ps.FieldDescriptions[0].FormatCode = fc.formatCode
vEncoder := forceEncoder(v, fc.formatCode)
if vEncoder == nil {
t.Logf("%#v does not implement %v", v, fc.name)
t.Logf("Skipping: %#v does not implement %v", v, fc.name)
continue
}
// Derefence value if it is a pointer
@ -136,3 +165,33 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int
}
}
}
func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
conn := mustConnectDatabaseSQL(t, driverName)
defer mustClose(t, conn)
ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
if err != nil {
t.Fatal(err)
}
for i, v := range values {
// Derefence value if it is a pointer
derefV := v
refVal := reflect.ValueOf(v)
if refVal.Kind() == reflect.Ptr {
derefV = refVal.Elem().Interface()
}
result := reflect.New(reflect.TypeOf(derefV))
err := ps.QueryRow(v).Scan(result.Interface())
if err != nil {
t.Errorf("%v %d: %v", driverName, i, err)
}
if !eqFunc(result.Elem().Interface(), derefV) {
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
}
}
}

View File

@ -1,9 +1,11 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"math"
"strconv"
"github.com/jackc/pgx/pgio"
@ -21,6 +23,14 @@ type pguint32 struct {
// types do.
func (dst *pguint32) Set(src interface{}) error {
switch value := src.(type) {
case int64:
if value < 0 {
return fmt.Errorf("%d is less than minimum value for pguint32", value)
}
if value > math.MaxUint32 {
return fmt.Errorf("%d is greater than maximum value for pguint32", value)
}
*dst = pguint32{Uint: uint32(value), Status: Present}
case uint32:
*dst = pguint32{Uint: value, Status: Present}
default:
@ -116,3 +126,38 @@ func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteUint32(w, src.Uint)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *pguint32) Scan(src interface{}) error {
if src == nil {
*dst = pguint32{Status: Null}
return nil
}
switch src := src.(type) {
case uint32:
*dst = pguint32{Uint: src, Status: Present}
return nil
case int64:
*dst = pguint32{Uint: uint32(src), Status: Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src pguint32) Value() (driver.Value, error) {
switch src.Status {
case Present:
return int64(src.Uint), nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -17,13 +17,20 @@ import (
// standard type char.
//
// Not all possible values of QChar are representable in the text format.
// Therefore, QChar does not implement TextEncoder and TextDecoder.
// Therefore, QChar does not implement TextEncoder and TextDecoder. In
// addition, database/sql Scanner and database/sql/driver Value are not
// implemented.
type QChar struct {
Int int8
Status Status
}
func (dst *QChar) Set(src interface{}) error {
if src == nil {
*dst = QChar{Status: Null}
return nil
}
switch value := src.(type) {
case int8:
*dst = QChar{Int: value, Status: Present}

View File

@ -9,13 +9,15 @@ import (
)
func TestQCharTranscode(t *testing.T) {
testSuccessfulTranscode(t, `"char"`, []interface{}{
testPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{
pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present},
pgtype.QChar{Int: -1, Status: pgtype.Present},
pgtype.QChar{Int: 0, Status: pgtype.Present},
pgtype.QChar{Int: 1, Status: pgtype.Present},
pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present},
pgtype.QChar{Int: 0, Status: pgtype.Null},
}, func(a, b interface{}) bool {
return reflect.DeepEqual(a, b)
})
}

View File

@ -16,6 +16,11 @@ type Record struct {
}
func (dst *Record) Set(src interface{}) error {
if src == nil {
*dst = Record{Status: Null}
return nil
}
switch value := src.(type) {
case []Value:
*dst = Record{Fields: value, Status: Present}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"fmt"
"io"
)
@ -11,6 +12,11 @@ type Text struct {
}
func (dst *Text) Set(src interface{}) error {
if src == nil {
*dst = Text{Status: Null}
return nil
}
switch value := src.(type) {
case string:
*dst = Text{String: value, Status: Present}
@ -20,6 +26,12 @@ func (dst *Text) Set(src interface{}) error {
} else {
*dst = Text{String: *value, Status: Present}
}
case []byte:
if value == nil {
*dst = Text{Status: Null}
} else {
*dst = Text{String: string(value), Status: Present}
}
default:
if originalSrc, ok := underlyingStringType(src); ok {
return dst.Set(originalSrc)
@ -93,3 +105,32 @@ func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return src.EncodeText(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *Text) Scan(src interface{}) error {
if src == nil {
*dst = Text{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Text) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.String, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -296,3 +297,33 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *TextArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *TextArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -121,3 +122,25 @@ func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err = pgio.WriteUint16(w, src.OffsetNumber)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Tid) Scan(src interface{}) error {
if src == nil {
*dst = Tid{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Tid) Value() (driver.Value, error) {
return encodeValueText(src)
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -17,14 +18,19 @@ const pgTimestampFormat = "2006-01-02 15:04:05.999999999"
// recommended to use timestamptz whenever possible. Timestamp methods either
// convert to UTC or return an error on non-UTC times.
type Timestamp struct {
Time time.Time // Time must always be in UTC.
Status Status
InfinityModifier
Time time.Time // Time must always be in UTC.
Status Status
InfinityModifier InfinityModifier
}
// Set converts src into a Timestamp and stores in dst. If src is a
// time.Time in a non-UTC time zone, the time zone is discarded.
func (dst *Timestamp) Set(src interface{}) error {
if src == nil {
*dst = Timestamp{Status: Null}
return nil
}
switch value := src.(type) {
case time.Time:
*dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present}
@ -183,3 +189,38 @@ func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteInt64(w, microsecSinceY2K)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Timestamp) Scan(src interface{}) error {
if src == nil {
*dst = Timestamp{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
case time.Time:
*dst = Timestamp{Time: src, Status: Present}
return nil
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Timestamp) Value() (driver.Value, error) {
switch src.Status {
case Present:
if src.InfinityModifier != None {
return src.InfinityModifier.String(), nil
}
return src.Time, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -297,3 +298,33 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *TimestampArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *TimestampArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -20,12 +21,17 @@ const (
)
type Timestamptz struct {
Time time.Time
Status Status
InfinityModifier
Time time.Time
Status Status
InfinityModifier InfinityModifier
}
func (dst *Timestamptz) Set(src interface{}) error {
if src == nil {
*dst = Timestamptz{Status: Null}
return nil
}
switch value := src.(type) {
case time.Time:
*dst = Timestamptz{Time: value, Status: Present}
@ -179,3 +185,38 @@ func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
_, err := pgio.WriteInt64(w, microsecSinceY2K)
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *Timestamptz) Scan(src interface{}) error {
if src == nil {
*dst = Timestamptz{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
case time.Time:
*dst = Timestamptz{Time: src, Status: Present}
return nil
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Timestamptz) Value() (driver.Value, error) {
switch src.Status {
case Present:
if src.InfinityModifier != None {
return src.InfinityModifier.String(), nil
}
return src.Time, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -297,3 +298,33 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *TimestamptzArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *TimestamptzArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -299,3 +299,33 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool
return false, err
}
<% end %>
// Scan implements the database/sql Scanner interface.
func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,5 +1,7 @@
package pgtype
import "database/sql/driver"
// Unknown represents the PostgreSQL unknown type. It is either a string literal
// or NULL. It is used when PostgreSQL does not know the type of a value. In
// general, this will only be used in pgx when selecting a null value without
@ -30,3 +32,13 @@ func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error {
func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error {
return (*Text)(dst).DecodeBinary(ci, src)
}
// Scan implements the database/sql Scanner interface.
func (dst *Unknown) Scan(src interface{}) error {
return (*Text)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Unknown) Value() (driver.Value, error) {
return (Text)(src).Value()
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"io"
)
@ -38,3 +39,13 @@ func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return (Text)(src).EncodeBinary(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *Varchar) Scan(src interface{}) error {
return (*Text)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Varchar) Value() (driver.Value, error) {
return (Text)(src).Value()
}

View File

@ -2,6 +2,7 @@ package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
@ -296,3 +297,33 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return false, err
}
// Scan implements the database/sql Scanner interface.
func (dst *VarcharArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *VarcharArray) Value() (driver.Value, error) {
buf := &bytes.Buffer{}
null, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql/driver"
"io"
)
@ -52,3 +53,13 @@ func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return (pguint32)(src).EncodeBinary(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *Xid) Scan(src interface{}) error {
return (*pguint32)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Xid) Value() (driver.Value, error) {
return (pguint32)(src).Value()
}

117
query.go
View File

@ -208,47 +208,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if s, ok := d.(sql.Scanner); ok {
var sqlSrc interface{}
if 0 <= vr.Len() {
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok {
value := dt.Value
switch vr.Type().FormatCode {
case TextFormatCode:
decoder := value.(pgtype.TextDecoder)
if decoder == nil {
decoder = &pgtype.GenericText{}
}
err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes())
if err != nil {
rows.Fatal(err)
}
case BinaryFormatCode:
decoder := value.(pgtype.BinaryDecoder)
if decoder == nil {
decoder = &pgtype.GenericBinary{}
}
err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes())
if err != nil {
rows.Fatal(err)
}
default:
rows.Fatal(errors.New("Unknown format code"))
}
sqlSrc, err = pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value)
if err != nil {
rows.Fatal(err)
}
} else {
rows.Fatal(errors.New("Unknown type"))
}
}
err = s.Scan(sqlSrc)
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else {
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok {
value := dt.Value
@ -276,7 +235,16 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
}
if vr.Err() == nil {
if err := value.AssignTo(d); err != nil {
if scanner, ok := d.(sql.Scanner); ok {
sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value)
if err != nil {
rows.Fatal(err)
}
err = scanner.Scan(sqlSrc)
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if err := value.AssignTo(d); err != nil {
vr.Fatal(err)
}
}
@ -355,71 +323,6 @@ func (rows *Rows) Values() ([]interface{}, error) {
return values, rows.Err()
}
// ValuesForStdlib is a temporary function to keep all systems operational
// while refactoring. Do not use.
func (rows *Rows) ValuesForStdlib() ([]interface{}, error) {
if rows.closed {
return nil, errors.New("rows is closed")
}
values := make([]interface{}, 0, len(rows.fields))
for range rows.fields {
vr, _ := rows.nextColumn()
if vr.Len() == -1 {
values = append(values, nil)
continue
}
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok {
value := dt.Value
switch vr.Type().FormatCode {
case TextFormatCode:
decoder := value.(pgtype.TextDecoder)
if decoder == nil {
decoder = &pgtype.GenericText{}
}
err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes())
if err != nil {
rows.Fatal(err)
}
case BinaryFormatCode:
decoder := value.(pgtype.BinaryDecoder)
if decoder == nil {
decoder = &pgtype.GenericBinary{}
}
err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes())
if err != nil {
rows.Fatal(err)
}
default:
rows.Fatal(errors.New("Unknown format code"))
}
sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value)
if err != nil {
rows.Fatal(err)
}
values = append(values, sqlSrc)
} else {
rows.Fatal(errors.New("Unknown type"))
}
if vr.Err() != nil {
rows.Fatal(vr.Err())
}
if rows.Err() != nil {
return nil, rows.Err()
}
}
return values, rows.Err()
}
// AfterClose adds f to a LILO queue of functions that will be called when
// rows is closed.
func (rows *Rows) AfterClose(f func(*Rows)) {

View File

@ -704,30 +704,6 @@ func TestQueryRowCoreByteSlice(t *testing.T) {
}
}
func TestQueryRowByteSliceArgument(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := "select $1::int4"
queryArg := []byte{14, 63, 53, 49}
expected := int32(239023409)
var actual int32
err := conn.QueryRow(sql, queryArg).Scan(&actual)
if err != nil {
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
}
if expected != actual {
t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql)
}
ensureConnValid(t, conn)
}
func TestQueryRowUnknownType(t *testing.T) {
t.Parallel()

View File

@ -68,14 +68,17 @@ func init() {
databaseSqlOids = make(map[pgtype.Oid]bool)
databaseSqlOids[pgtype.BoolOid] = true
databaseSqlOids[pgtype.ByteaOid] = true
databaseSqlOids[pgtype.CidOid] = true
databaseSqlOids[pgtype.DateOid] = true
databaseSqlOids[pgtype.Float4Oid] = true
databaseSqlOids[pgtype.Float8Oid] = true
databaseSqlOids[pgtype.Int2Oid] = true
databaseSqlOids[pgtype.Int4Oid] = true
databaseSqlOids[pgtype.Int8Oid] = true
databaseSqlOids[pgtype.Float4Oid] = true
databaseSqlOids[pgtype.Float8Oid] = true
databaseSqlOids[pgtype.DateOid] = true
databaseSqlOids[pgtype.TimestamptzOid] = true
databaseSqlOids[pgtype.OidOid] = true
databaseSqlOids[pgtype.TimestampOid] = true
databaseSqlOids[pgtype.TimestamptzOid] = true
databaseSqlOids[pgtype.XidOid] = true
}
type Driver struct {
@ -292,9 +295,9 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
return s.conn.queryPrepared(s.ps.Name, argsV)
}
// TODO - rename to avoid alloc
type Rows struct {
rows *pgx.Rows
rows *pgx.Rows
values []interface{}
}
func (r *Rows) Columns() []string {
@ -312,6 +315,42 @@ func (r *Rows) Close() error {
}
func (r *Rows) Next(dest []driver.Value) error {
if r.values == nil {
r.values = make([]interface{}, len(r.rows.FieldDescriptions()))
for i, fd := range r.rows.FieldDescriptions() {
switch fd.DataType {
case pgtype.BoolOid:
r.values[i] = &pgtype.Bool{}
case pgtype.ByteaOid:
r.values[i] = &pgtype.Bytea{}
case pgtype.CidOid:
r.values[i] = &pgtype.Cid{}
case pgtype.DateOid:
r.values[i] = &pgtype.Date{}
case pgtype.Float4Oid:
r.values[i] = &pgtype.Float4{}
case pgtype.Float8Oid:
r.values[i] = &pgtype.Float8{}
case pgtype.Int2Oid:
r.values[i] = &pgtype.Int2{}
case pgtype.Int4Oid:
r.values[i] = &pgtype.Int4{}
case pgtype.Int8Oid:
r.values[i] = &pgtype.Int8{}
case pgtype.OidOid:
r.values[i] = &pgtype.OidValue{}
case pgtype.TimestampOid:
r.values[i] = &pgtype.Timestamp{}
case pgtype.TimestamptzOid:
r.values[i] = &pgtype.Timestamptz{}
case pgtype.XidOid:
r.values[i] = &pgtype.Xid{}
default:
r.values[i] = &pgtype.GenericText{}
}
}
}
more := r.rows.Next()
if !more {
if r.rows.Err() == nil {
@ -321,19 +360,16 @@ func (r *Rows) Next(dest []driver.Value) error {
}
}
values, err := r.rows.ValuesForStdlib()
err := r.rows.Scan(r.values...)
if err != nil {
return err
}
if len(dest) < len(values) {
fmt.Printf("%d: %#v\n", len(dest), dest)
fmt.Printf("%d: %#v\n", len(values), values)
return errors.New("expected more values than were received")
}
for i, v := range values {
dest[i] = driver.Value(v)
for i, v := range r.values {
dest[i], err = v.(driver.Valuer).Value()
if err != nil {
return err
}
}
return nil

View File

@ -65,10 +65,6 @@ func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interfa
wbuf.WriteInt32(int32(len(arg)))
wbuf.WriteBytes([]byte(arg))
return nil
case []byte:
wbuf.WriteInt32(int32(len(arg)))
wbuf.WriteBytes(arg)
return nil
}
refVal := reflect.ValueOf(arg)