From 5a2feadf1128e1a3217691013b7d98ad0eb324d7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Apr 2017 17:53:32 -0500 Subject: [PATCH] Add pgtype.Point --- example_custom_type_test.go | 52 +++++++++----- pgtype/pgtype.go | 1 + pgtype/point.go | 139 ++++++++++++++++++++++++++++++++++++ pgtype/point_test.go | 15 ++++ query_test.go | 15 +++- v3.md | 1 - 6 files changed, 202 insertions(+), 21 deletions(-) create mode 100644 pgtype/point.go create mode 100644 pgtype/point_test.go diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 1c21c7e6..647b97e6 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -2,7 +2,6 @@ package pgx_test import ( "fmt" - "io" "regexp" "strconv" @@ -18,6 +17,25 @@ type Point struct { Status pgtype.Status } +func (dst *Point) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Point", src) +} + +func (dst *Point) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *Point) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { *dst = Point{Status: pgtype.Null} @@ -44,23 +62,12 @@ func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return nil } -func (src Point) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - switch src.Status { - case pgtype.Null: - return true, nil - case pgtype.Undefined: - return false, fmt.Errorf("undefined") +func (src *Point) String() string { + if src.Status == pgtype.Null { + return "null point" } - _, err := io.WriteString(w, fmt.Sprintf("point(%v,%v)", src.X, src.Y)) - return false, err -} - -func (p Point) String() string { - if p.Status == pgtype.Present { - return fmt.Sprintf("%v, %v", p.X, p.Y) - } - return "null point" + return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) } func Example_CustomType() { @@ -70,15 +77,22 @@ func Example_CustomType() { return } - var p Point - err = conn.QueryRow("select null::point").Scan(&p) + // Override registered handler for point + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &Point{}, + Name: "point", + Oid: 600, + }) + + p := &Point{} + err = conn.QueryRow("select null::point").Scan(p) if err != nil { fmt.Println(err) return } fmt.Println(p) - err = conn.QueryRow("select point(1.5,2.5)").Scan(&p) + err = conn.QueryRow("select point(1.5,2.5)").Scan(p) if err != nil { fmt.Println(err) return diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 208b1f00..911ab70e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -245,6 +245,7 @@ func init() { "numeric": &Numeric{}, "numrange": &Numrange{}, "oid": &OidValue{}, + "point": &Point{}, "record": &Record{}, "text": &Text{}, "tid": &Tid{}, diff --git a/pgtype/point.go b/pgtype/point.go new file mode 100644 index 00000000..1b40bc44 --- /dev/null +++ b/pgtype/point.go @@ -0,0 +1,139 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Point struct { + X float64 + Y float64 + Status Status +} + +func (dst *Point) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Point", src) +} + +func (dst *Point) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Point) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + *dst = Point{X: x, Y: y, Status: Present} + return nil +} + +func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + *dst = Point{ + X: math.Float64frombits(x), + Y: math.Float64frombits(y), + Status: Present, + } + return nil +} + +func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.X, src.Y)) + return false, err +} + +func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.X)) + if err != nil { + return false, err + } + + _, err = pgio.WriteUint64(w, math.Float64bits(src.Y)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Point) Scan(src interface{}) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Point) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/point_test.go b/pgtype/point_test.go new file mode 100644 index 00000000..4ddb8009 --- /dev/null +++ b/pgtype/point_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestPointTranscode(t *testing.T) { + testSuccessfulTranscode(t, "point", []interface{}{ + &pgtype.Point{X: 1.234, Y: 5.6789, Status: pgtype.Present}, + &pgtype.Point{X: -1.234, Y: -5.6789, Status: pgtype.Present}, + &pgtype.Point{Status: pgtype.Null}, + }) +} diff --git a/query_test.go b/query_test.go index 25347ec5..d0fcb706 100644 --- a/query_test.go +++ b/query_test.go @@ -710,6 +710,19 @@ func TestQueryRowUnknownType(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) + // Clear existing type mappings + conn.ConnInfo = pgtype.NewConnInfo() + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &pgtype.GenericText{}, + Name: "point", + Oid: 600, + }) + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &pgtype.Int4{}, + Name: "int4", + Oid: pgtype.Int4Oid, + }) + sql := "select $1::point" expected := "(1,0)" var actual string @@ -751,7 +764,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Text"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"}, } for i, tt := range tests { diff --git a/v3.md b/v3.md index f1ec1990..70a378ad 100644 --- a/v3.md +++ b/v3.md @@ -68,7 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -point line lseg box