From 5be6819a8cae0a9ff30f77b35b65b2738683655d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 20:39:48 -0500 Subject: [PATCH] Add pgtype.Circle Also rename Point.Vec2 to Point.P to conform to rest of geometric types. --- pgtype/circle.go | 150 ++++++++++++++++++++++++++++++++++++++++++ pgtype/circle_test.go | 15 +++++ pgtype/pgtype.go | 1 + pgtype/point.go | 12 ++-- pgtype/point_test.go | 4 +- v3.md | 1 - 6 files changed, 174 insertions(+), 9 deletions(-) create mode 100644 pgtype/circle.go create mode 100644 pgtype/circle_test.go diff --git a/pgtype/circle.go b/pgtype/circle.go new file mode 100644 index 00000000..62e2e8b3 --- /dev/null +++ b/pgtype/circle.go @@ -0,0 +1,150 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Circle struct { + P Vec2 + R float64 + Status Status +} + +func (dst *Circle) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Circle", src) +} + +func (dst *Circle) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Circle) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) < 9 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + str := string(src[2:]) + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+2 : len(str)-1] + + r, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Circle{P: Vec2{x, y}, R: r, Status: Present} + return nil +} + +func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + r := binary.BigEndian.Uint64(src[16:]) + + *dst = Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Status: Present, + } + return nil +} + +func (src *Circle) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)) + return false, err +} + +func (src *Circle) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.Y)); err != nil { + return false, err + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.R)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Circle) Scan(src interface{}) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Circle) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go new file mode 100644 index 00000000..9746dd74 --- /dev/null +++ b/pgtype/circle_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCircleTranscode(t *testing.T) { + testSuccessfulTranscode(t, "circle", []interface{}{ + &pgtype.Circle{P: pgtype.Vec2{1.234, 5.6789}, R: 3.5, Status: pgtype.Present}, + &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, + &pgtype.Circle{Status: pgtype.Null}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index cb0cec2c..52cad561 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -228,6 +228,7 @@ func init() { "char": &QChar{}, "cid": &Cid{}, "cidr": &Cidr{}, + "circle": &Circle{}, "date": &Date{}, "daterange": &Daterange{}, "decimal": &Decimal{}, diff --git a/pgtype/point.go b/pgtype/point.go index 94f753e3..788a76c9 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -18,7 +18,7 @@ type Vec2 struct { } type Point struct { - Vec2 + P Vec2 Status Status } @@ -66,7 +66,7 @@ func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Point{Vec2: Vec2{x, y}, Status: Present} + *dst = Point{P: Vec2{x, y}, Status: Present} return nil } @@ -84,7 +84,7 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { y := binary.BigEndian.Uint64(src[8:]) *dst = Point{ - Vec2: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, Status: Present, } return nil @@ -98,7 +98,7 @@ func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.X, src.Y)) + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)) return false, err } @@ -110,12 +110,12 @@ func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, errUndefined } - _, err := pgio.WriteUint64(w, math.Float64bits(src.X)) + _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)) if err != nil { return false, err } - _, err = pgio.WriteUint64(w, math.Float64bits(src.Y)) + _, err = pgio.WriteUint64(w, math.Float64bits(src.P.Y)) return false, err } diff --git a/pgtype/point_test.go b/pgtype/point_test.go index 723dfa60..c921f794 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -8,8 +8,8 @@ import ( func TestPointTranscode(t *testing.T) { testSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{Vec2: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, - &pgtype.Point{Vec2: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, + &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, + &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, &pgtype.Point{Status: pgtype.Null}, }) } diff --git a/v3.md b/v3.md index a879e384..9a69a2f2 100644 --- a/v3.md +++ b/v3.md @@ -68,6 +68,5 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -circle macaddr varbit