diff --git a/aclitem.go b/aclitem.go index e8386ae7..77e385e6 100644 --- a/aclitem.go +++ b/aclitem.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -93,3 +94,32 @@ func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { _, err := io.WriteString(w, src.String) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Aclitem) Scan(src interface{}) error { + if src == nil { + *dst = Aclitem{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Aclitem) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/aclitem_array.go b/aclitem_array.go index 1c97e74f..20a7636a 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "fmt" "io" @@ -194,3 +195,33 @@ func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } + +// Scan implements the database/sql Scanner interface. +func (dst *AclitemArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *AclitemArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/bool.go b/bool.go index 608a6f95..736d19cf 100644 --- a/bool.go +++ b/bool.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" "strconv" @@ -126,3 +127,35 @@ func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(buf) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Bool) Scan(src interface{}) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + switch src := src.(type) { + case bool: + *dst = Bool{Bool: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bool) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bool, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/bool_array.go b/bool_array.go index cdfe9685..4705d734 100644 --- a/bool_array.go +++ b/bool_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *BoolArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *BoolArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/bytea.go b/bytea.go index 00bed8e8..9f0266e7 100644 --- a/bytea.go +++ b/bytea.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/hex" "fmt" "io" @@ -12,6 +13,11 @@ type Bytea struct { } func (dst *Bytea) Set(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + switch value := src.(type) { case []byte: if value != nil { @@ -124,3 +130,35 @@ func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(src.Bytes) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Bytea) Scan(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + buf := make([]byte, len(src)) + copy(buf, src) + *dst = Bytea{Bytes: buf, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bytea) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/bytea_array.go b/bytea_array.go index 175ca2f6..268364c1 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *ByteaArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *ByteaArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/cid.go b/cid.go index d86e8063..63ba6a2f 100644 --- a/cid.go +++ b/cid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -49,3 +50,13 @@ func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Cid) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Cid) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/cidr_array.go b/cidr_array.go index 49a2728b..6643bb47 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -325,3 +326,33 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *CidrArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *CidrArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/database_sql.go b/database_sql.go index 969d6542..2ddd842d 100644 --- a/database_sql.go +++ b/database_sql.go @@ -2,47 +2,13 @@ package pgtype import ( "bytes" + "database/sql/driver" "errors" ) func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { - switch src := src.(type) { - case *Bool: - return src.Bool, nil - case *Bytea: - return src.Bytes, nil - case *Date: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Float4: - return float64(src.Float), nil - case *Float8: - return src.Float, nil - case *GenericBinary: - return src.Bytes, nil - case *GenericText: - return src.String, nil - case *Int2: - return int64(src.Int), nil - case *Int4: - return int64(src.Int), nil - case *Int8: - return int64(src.Int), nil - case *Text: - return src.String, nil - case *Timestamp: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Timestamptz: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Unknown: - return src.String, nil - case *Varchar: - return src.String, nil + if valuer, ok := src.(driver.Valuer); ok { + return valuer.Value() } buf := &bytes.Buffer{} @@ -64,3 +30,15 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { return nil, errors.New("cannot convert to database/sql compatible value") } + +func encodeValueText(src TextEncoder) (interface{}, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + return buf.String(), err +} diff --git a/date.go b/date.go index ab854eb2..7dd2c4f0 100644 --- a/date.go +++ b/date.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -10,9 +11,9 @@ import ( ) type Date struct { - Time time.Time - Status Status - InfinityModifier + Time time.Time + Status Status + InfinityModifier InfinityModifier } const ( @@ -21,6 +22,11 @@ const ( ) func (dst *Date) Set(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Date{Time: value, Status: Present} @@ -167,3 +173,38 @@ func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, daysSinceDateEpoch) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Date) Scan(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Date{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Date) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/date_array.go b/date_array.go index bf791677..f58de011 100644 --- a/date_array.go +++ b/date_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *DateArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *DateArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/date_test.go b/date_test.go index cfc3dd70..1832b5b4 100644 --- a/date_test.go +++ b/date_test.go @@ -9,7 +9,7 @@ import ( ) func TestDateTranscode(t *testing.T) { - testSuccessfulTranscode(t, "date", []interface{}{ + testSuccessfulTranscodeEqFunc(t, "date", []interface{}{ pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, @@ -19,6 +19,11 @@ func TestDateTranscode(t *testing.T) { pgtype.Date{Status: pgtype.Null}, pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Date) + bt := b.(pgtype.Date) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier }) } diff --git a/float4.go b/float4.go index 94b7b7a1..e92149a6 100644 --- a/float4.go +++ b/float4.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Float4 struct { } func (dst *Float4) Set(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + switch value := src.(type) { case float32: *dst = Float4{Float: value, Status: Present} @@ -156,3 +162,35 @@ func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float4) Scan(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float4{Float: float32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return float64(src.Float), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/float4_array.go b/float4_array.go index b4d05c55..b9ee4b9e 100644 --- a/float4_array.go +++ b/float4_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float4Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/float8.go b/float8.go index dd2d592d..4d094757 100644 --- a/float8.go +++ b/float8.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Float8 struct { } func (dst *Float8) Set(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + switch value := src.(type) { case float32: *dst = Float8{Float: float64(value), Status: Present} @@ -146,3 +152,35 @@ func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float8) Scan(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float8{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Float, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/float8_array.go b/float8_array.go index e000807e..d49f18a7 100644 --- a/float8_array.go +++ b/float8_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float8Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/generic_binary.go b/generic_binary.go index aa28bb62..f834bfb2 100644 --- a/generic_binary.go +++ b/generic_binary.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -27,3 +28,13 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Bytea)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *GenericBinary) Scan(src interface{}) error { + return (*Bytea)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericBinary) Value() (driver.Value, error) { + return (Bytea)(src).Value() +} diff --git a/generic_text.go b/generic_text.go index bd75e0d0..053ec504 100644 --- a/generic_text.go +++ b/generic_text.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -27,3 +28,13 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *GenericText) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericText) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/hstore.go b/hstore.go index 8dc5b4d8..b8b0c6f3 100644 --- a/hstore.go +++ b/hstore.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "errors" "fmt" @@ -21,6 +22,11 @@ type Hstore struct { } func (dst *Hstore) Set(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + switch value := src.(type) { case map[string]string: m := make(map[string]Text, len(value)) @@ -437,3 +443,25 @@ func parseHstore(s string) (k []string, v []Text, err error) { v = values return } + +// Scan implements the database/sql Scanner interface. +func (dst *Hstore) Scan(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Hstore) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/hstore_array.go b/hstore_array.go index 9bd0ed3b..097fec7b 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *HstoreArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *HstoreArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/inet.go b/inet.go index 13764814..0ca3ee7a 100644 --- a/inet.go +++ b/inet.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" "net" @@ -23,6 +24,11 @@ type Inet struct { } func (dst *Inet) Set(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + switch value := src.(type) { case net.IPNet: *dst = Inet{IPNet: &value, Status: Present} @@ -189,3 +195,25 @@ func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(src.IPNet.IP) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Inet) Scan(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Inet) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/inet_array.go b/inet_array.go index 1988a145..a108d75b 100644 --- a/inet_array.go +++ b/inet_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -325,3 +326,33 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *InetArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *InetArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/int2.go b/int2.go index 6996cd4f..3bcac63c 100644 --- a/int2.go +++ b/int2.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int2 struct { } func (dst *Int2) Set(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int2{Int: int16(value), Status: Present} @@ -151,3 +157,41 @@ func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt16(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + if src > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + *dst = Int2{Int: int16(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/int2_array.go b/int2_array.go index 531e7dd6..bddb5ac2 100644 --- a/int2_array.go +++ b/int2_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int2Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int2Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/int4.go b/int4.go index 62ee366f..5069dab4 100644 --- a/int4.go +++ b/int4.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int4 struct { } func (dst *Int4) Set(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int4{Int: int32(value), Status: Present} @@ -68,7 +74,7 @@ func (dst *Int4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int8", value) + return fmt.Errorf("cannot convert %v to Int4", value) } return nil @@ -142,3 +148,41 @@ func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + if src > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + *dst = Int4{Int: int32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/int4_array.go b/int4_array.go index 3617050f..d5c8f911 100644 --- a/int4_array.go +++ b/int4_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int4Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/int8.go b/int8.go index 7ed54f8e..cf701dc6 100644 --- a/int8.go +++ b/int8.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int8 struct { } func (dst *Int8) Set(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int8{Int: int64(value), Status: Present} @@ -134,3 +140,35 @@ func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + *dst = Int8{Int: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/int8_array.go b/int8_array.go index 4f04b660..ae2521fa 100644 --- a/int8_array.go +++ b/int8_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int8Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/json.go b/json.go index bfffae14..05d965ca 100644 --- a/json.go +++ b/json.go @@ -1,7 +1,9 @@ package pgtype import ( + "database/sql/driver" "encoding/json" + "fmt" "io" ) @@ -11,6 +13,11 @@ type Json struct { } func (dst *Json) Set(src interface{}) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + switch value := src.(type) { case string: *dst = Json{Bytes: []byte(value), Status: Present} @@ -116,3 +123,32 @@ func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Json) Scan(src interface{}) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Json) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/jsonb.go b/jsonb.go index e44f3c41..f47476d6 100644 --- a/jsonb.go +++ b/jsonb.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -66,3 +67,13 @@ func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err = w.Write(src.Bytes) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Jsonb) Scan(src interface{}) error { + return (*Json)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Jsonb) Value() (driver.Value, error) { + return (Json)(src).Value() +} diff --git a/name.go b/name.go index 9ebf63d3..cc4ae23b 100644 --- a/name.go +++ b/name.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -46,3 +47,13 @@ func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Name) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Name) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/oid.go b/oid.go index 3edd7f3c..339dee0f 100644 --- a/oid.go +++ b/oid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -55,3 +56,27 @@ func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Oid) Scan(src interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", src) + } + + switch src := src.(type) { + case int64: + *dst = Oid(src) + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Oid) Value() (driver.Value, error) { + return int64(src), nil +} diff --git a/oid_value.go b/oid_value.go index 1bce6e11..cb03802e 100644 --- a/oid_value.go +++ b/oid_value.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -43,3 +44,13 @@ func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *OidValue) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src OidValue) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/pgtype.go b/pgtype.go index 674c0db7..7e6633d9 100644 --- a/pgtype.go +++ b/pgtype.go @@ -67,6 +67,19 @@ const ( NegativeInfinity InfinityModifier = -Infinity ) +func (im InfinityModifier) String() string { + switch im { + case None: + return "none" + case Infinity: + return "infinity" + case NegativeInfinity: + return "-infinity" + default: + return "invalid" + } +} + type Value interface { // Set converts and assigns src to itself. Set(src interface{}) error diff --git a/pgtype_test.go b/pgtype_test.go index 391fed57..16cabfd1 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "database/sql" "fmt" "io" "net" @@ -10,6 +11,8 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + _ "github.com/jackc/pgx/stdlib" + _ "github.com/lib/pq" ) // Test for renamed types @@ -24,6 +27,25 @@ type _float32Slice []float32 type _float64Slice []float64 type _byteSlice []byte +func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/lib/pq": + sqlDriverName = "postgres" + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + return db +} + func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) if err != nil { @@ -93,6 +115,13 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface } func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := mustConnectPgx(t) defer mustClose(t, conn) @@ -114,7 +143,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int ps.FieldDescriptions[0].FormatCode = fc.formatCode vEncoder := forceEncoder(v, fc.formatCode) if vEncoder == nil { - t.Logf("%#v does not implement %v", v, fc.name) + t.Logf("Skipping: %#v does not implement %v", v, fc.name) continue } // Derefence value if it is a pointer @@ -136,3 +165,33 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int } } } + +func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } + +} diff --git a/pguint32.go b/pguint32.go index 3f9e7bf7..7138a409 100644 --- a/pguint32.go +++ b/pguint32.go @@ -1,9 +1,11 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" + "math" "strconv" "github.com/jackc/pgx/pgio" @@ -21,6 +23,14 @@ type pguint32 struct { // types do. func (dst *pguint32) Set(src interface{}) error { switch value := src.(type) { + case int64: + if value < 0 { + return fmt.Errorf("%d is less than minimum value for pguint32", value) + } + if value > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for pguint32", value) + } + *dst = pguint32{Uint: uint32(value), Status: Present} case uint32: *dst = pguint32{Uint: value, Status: Present} default: @@ -116,3 +126,38 @@ func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, src.Uint) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *pguint32) Scan(src interface{}) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + switch src := src.(type) { + case uint32: + *dst = pguint32{Uint: src, Status: Present} + return nil + case int64: + *dst = pguint32{Uint: uint32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src pguint32) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Uint), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/qchar.go b/qchar.go index 4b32ee4a..49475bd3 100644 --- a/qchar.go +++ b/qchar.go @@ -17,13 +17,20 @@ import ( // standard type char. // // Not all possible values of QChar are representable in the text format. -// Therefore, QChar does not implement TextEncoder and TextDecoder. +// Therefore, QChar does not implement TextEncoder and TextDecoder. In +// addition, database/sql Scanner and database/sql/driver Value are not +// implemented. type QChar struct { Int int8 Status Status } func (dst *QChar) Set(src interface{}) error { + if src == nil { + *dst = QChar{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = QChar{Int: value, Status: Present} diff --git a/qchar_test.go b/qchar_test.go index a1b6d22e..afac5016 100644 --- a/qchar_test.go +++ b/qchar_test.go @@ -9,13 +9,15 @@ import ( ) func TestQCharTranscode(t *testing.T) { - testSuccessfulTranscode(t, `"char"`, []interface{}{ + testPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, pgtype.QChar{Int: -1, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Present}, pgtype.QChar{Int: 1, Status: pgtype.Present}, pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Null}, + }, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) }) } diff --git a/record.go b/record.go index 89e081ca..9c42c907 100644 --- a/record.go +++ b/record.go @@ -16,6 +16,11 @@ type Record struct { } func (dst *Record) Set(src interface{}) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + switch value := src.(type) { case []Value: *dst = Record{Fields: value, Status: Present} diff --git a/text.go b/text.go index dbc9362b..482c9023 100644 --- a/text.go +++ b/text.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -11,6 +12,11 @@ type Text struct { } func (dst *Text) Set(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + switch value := src.(type) { case string: *dst = Text{String: value, Status: Present} @@ -20,6 +26,12 @@ func (dst *Text) Set(src interface{}) error { } else { *dst = Text{String: *value, Status: Present} } + case []byte: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: string(value), Status: Present} + } default: if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) @@ -93,3 +105,32 @@ func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/text_array.go b/text_array.go index 6e8ead26..64728048 100644 --- a/text_array.go +++ b/text_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TextArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TextArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/tid.go b/tid.go index b91711d3..b363c1f9 100644 --- a/tid.go +++ b/tid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -121,3 +122,25 @@ func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err = pgio.WriteUint16(w, src.OffsetNumber) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Tid) Scan(src interface{}) error { + if src == nil { + *dst = Tid{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tid) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/timestamp.go b/timestamp.go index 4b42f3cf..78c6355e 100644 --- a/timestamp.go +++ b/timestamp.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -17,14 +18,19 @@ const pgTimestampFormat = "2006-01-02 15:04:05.999999999" // recommended to use timestamptz whenever possible. Timestamp methods either // convert to UTC or return an error on non-UTC times. type Timestamp struct { - Time time.Time // Time must always be in UTC. - Status Status - InfinityModifier + Time time.Time // Time must always be in UTC. + Status Status + InfinityModifier InfinityModifier } // Set converts src into a Timestamp and stores in dst. If src is a // time.Time in a non-UTC time zone, the time zone is discarded. func (dst *Timestamp) Set(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} @@ -183,3 +189,38 @@ func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, microsecSinceY2K) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Timestamp{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamp) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/timestamp_array.go b/timestamp_array.go index 6a6950c7..5d08f9cc 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TimestampArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestampArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/timestamptz.go b/timestamptz.go index ba849ac8..50370335 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -20,12 +21,17 @@ const ( ) type Timestamptz struct { - Time time.Time - Status Status - InfinityModifier + Time time.Time + Status Status + InfinityModifier InfinityModifier } func (dst *Timestamptz) Set(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Timestamptz{Time: value, Status: Present} @@ -179,3 +185,38 @@ func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, microsecSinceY2K) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Timestamptz{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamptz) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/timestamptz_array.go b/timestamptz_array.go index 347d0b8b..107be06a 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TimestamptzArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestamptzArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/typed_array.go.erb b/typed_array.go.erb index 0e5725ce..4b8f1a28 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -299,3 +299,33 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool return false, err } <% end %> + +// Scan implements the database/sql Scanner interface. +func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/unknown.go b/unknown.go index b951ad99..2dca0f87 100644 --- a/unknown.go +++ b/unknown.go @@ -1,5 +1,7 @@ package pgtype +import "database/sql/driver" + // Unknown represents the PostgreSQL unknown type. It is either a string literal // or NULL. It is used when PostgreSQL does not know the type of a value. In // general, this will only be used in pgx when selecting a null value without @@ -30,3 +32,13 @@ func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } + +// Scan implements the database/sql Scanner interface. +func (dst *Unknown) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Unknown) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/varchar.go b/varchar.go index adda6c49..f25ada5d 100644 --- a/varchar.go +++ b/varchar.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -38,3 +39,13 @@ func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Varchar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Varchar) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/varchar_array.go b/varchar_array.go index e1dd3910..2712b4d2 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *VarcharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *VarcharArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/xid.go b/xid.go index c76548a4..0a7fc7d9 100644 --- a/xid.go +++ b/xid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -52,3 +53,13 @@ func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Xid) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Xid) Value() (driver.Value, error) { + return (pguint32)(src).Value() +}