mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
Add pgtype.Point
This commit is contained in:
parent
c09c356b19
commit
5a2feadf11
@ -2,7 +2,6 @@ package pgx_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@ -18,6 +17,25 @@ type Point struct {
|
|||||||
Status pgtype.Status
|
Status pgtype.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dst *Point) Set(src interface{}) error {
|
||||||
|
return fmt.Errorf("cannot convert %v to Point", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Point) Get() interface{} {
|
||||||
|
switch dst.Status {
|
||||||
|
case pgtype.Present:
|
||||||
|
return dst
|
||||||
|
case pgtype.Null:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return dst.Status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *Point) AssignTo(dst interface{}) error {
|
||||||
|
return fmt.Errorf("cannot assign %v to %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Point{Status: pgtype.Null}
|
*dst = Point{Status: pgtype.Null}
|
||||||
@ -44,23 +62,12 @@ func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (src Point) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
|
func (src *Point) String() string {
|
||||||
switch src.Status {
|
if src.Status == pgtype.Null {
|
||||||
case pgtype.Null:
|
return "null point"
|
||||||
return true, nil
|
|
||||||
case pgtype.Undefined:
|
|
||||||
return false, fmt.Errorf("undefined")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := io.WriteString(w, fmt.Sprintf("point(%v,%v)", src.X, src.Y))
|
return fmt.Sprintf("%.1f, %.1f", src.X, src.Y)
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p Point) String() string {
|
|
||||||
if p.Status == pgtype.Present {
|
|
||||||
return fmt.Sprintf("%v, %v", p.X, p.Y)
|
|
||||||
}
|
|
||||||
return "null point"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Example_CustomType() {
|
func Example_CustomType() {
|
||||||
@ -70,15 +77,22 @@ func Example_CustomType() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var p Point
|
// Override registered handler for point
|
||||||
err = conn.QueryRow("select null::point").Scan(&p)
|
conn.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||||
|
Value: &Point{},
|
||||||
|
Name: "point",
|
||||||
|
Oid: 600,
|
||||||
|
})
|
||||||
|
|
||||||
|
p := &Point{}
|
||||||
|
err = conn.QueryRow("select null::point").Scan(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Println(p)
|
fmt.Println(p)
|
||||||
|
|
||||||
err = conn.QueryRow("select point(1.5,2.5)").Scan(&p)
|
err = conn.QueryRow("select point(1.5,2.5)").Scan(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return
|
return
|
||||||
|
@ -245,6 +245,7 @@ func init() {
|
|||||||
"numeric": &Numeric{},
|
"numeric": &Numeric{},
|
||||||
"numrange": &Numrange{},
|
"numrange": &Numrange{},
|
||||||
"oid": &OidValue{},
|
"oid": &OidValue{},
|
||||||
|
"point": &Point{},
|
||||||
"record": &Record{},
|
"record": &Record{},
|
||||||
"text": &Text{},
|
"text": &Text{},
|
||||||
"tid": &Tid{},
|
"tid": &Tid{},
|
||||||
|
139
pgtype/point.go
Normal file
139
pgtype/point.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Point struct {
|
||||||
|
X float64
|
||||||
|
Y float64
|
||||||
|
Status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Point) Set(src interface{}) error {
|
||||||
|
return fmt.Errorf("cannot convert %v to Point", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Point) Get() interface{} {
|
||||||
|
switch dst.Status {
|
||||||
|
case Present:
|
||||||
|
return dst
|
||||||
|
case Null:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return dst.Status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *Point) AssignTo(dst interface{}) error {
|
||||||
|
return fmt.Errorf("cannot assign %v to %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Point{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src) < 5 {
|
||||||
|
return fmt.Errorf("invalid length for point: %v", len(src))
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2)
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return fmt.Errorf("invalid format for point")
|
||||||
|
}
|
||||||
|
|
||||||
|
x, err := strconv.ParseFloat(parts[0], 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
y, err := strconv.ParseFloat(parts[1], 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = Point{X: x, Y: y, Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Point{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src) != 16 {
|
||||||
|
return fmt.Errorf("invalid length for point: %v", len(src))
|
||||||
|
}
|
||||||
|
|
||||||
|
x := binary.BigEndian.Uint64(src)
|
||||||
|
y := binary.BigEndian.Uint64(src[8:])
|
||||||
|
|
||||||
|
*dst = Point{
|
||||||
|
X: math.Float64frombits(x),
|
||||||
|
Y: math.Float64frombits(y),
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return true, nil
|
||||||
|
case Undefined:
|
||||||
|
return false, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.X, src.Y))
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return true, nil
|
||||||
|
case Undefined:
|
||||||
|
return false, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := pgio.WriteUint64(w, math.Float64bits(src.X))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = pgio.WriteUint64(w, math.Float64bits(src.Y))
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *Point) Scan(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Point{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch src := src.(type) {
|
||||||
|
case string:
|
||||||
|
return dst.DecodeText(nil, []byte(src))
|
||||||
|
case []byte:
|
||||||
|
return dst.DecodeText(nil, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src *Point) Value() (driver.Value, error) {
|
||||||
|
return encodeValueText(src)
|
||||||
|
}
|
15
pgtype/point_test.go
Normal file
15
pgtype/point_test.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package pgtype_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgtype"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPointTranscode(t *testing.T) {
|
||||||
|
testSuccessfulTranscode(t, "point", []interface{}{
|
||||||
|
&pgtype.Point{X: 1.234, Y: 5.6789, Status: pgtype.Present},
|
||||||
|
&pgtype.Point{X: -1.234, Y: -5.6789, Status: pgtype.Present},
|
||||||
|
&pgtype.Point{Status: pgtype.Null},
|
||||||
|
})
|
||||||
|
}
|
@ -710,6 +710,19 @@ func TestQueryRowUnknownType(t *testing.T) {
|
|||||||
conn := mustConnect(t, *defaultConnConfig)
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
// Clear existing type mappings
|
||||||
|
conn.ConnInfo = pgtype.NewConnInfo()
|
||||||
|
conn.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||||
|
Value: &pgtype.GenericText{},
|
||||||
|
Name: "point",
|
||||||
|
Oid: 600,
|
||||||
|
})
|
||||||
|
conn.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||||
|
Value: &pgtype.Int4{},
|
||||||
|
Name: "int4",
|
||||||
|
Oid: pgtype.Int4Oid,
|
||||||
|
})
|
||||||
|
|
||||||
sql := "select $1::point"
|
sql := "select $1::point"
|
||||||
expected := "(1,0)"
|
expected := "(1,0)"
|
||||||
var actual string
|
var actual string
|
||||||
@ -751,7 +764,7 @@ func TestQueryRowErrors(t *testing.T) {
|
|||||||
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
||||||
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
||||||
{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"},
|
{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"},
|
||||||
{"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Text"},
|
{"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user