mirror of https://github.com/jackc/pgx.git
Add pgtype.Point
parent
c09c356b19
commit
5a2feadf11
|
@ -2,7 +2,6 @@ package pgx_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
|
@ -18,6 +17,25 @@ type Point struct {
|
|||
Status pgtype.Status
|
||||
}
|
||||
|
||||
func (dst *Point) Set(src interface{}) error {
|
||||
return fmt.Errorf("cannot convert %v to Point", src)
|
||||
}
|
||||
|
||||
func (dst *Point) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case pgtype.Present:
|
||||
return dst
|
||||
case pgtype.Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Point) AssignTo(dst interface{}) error {
|
||||
return fmt.Errorf("cannot assign %v to %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Point{Status: pgtype.Null}
|
||||
|
@ -44,23 +62,12 @@ func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (src Point) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
|
||||
switch src.Status {
|
||||
case pgtype.Null:
|
||||
return true, nil
|
||||
case pgtype.Undefined:
|
||||
return false, fmt.Errorf("undefined")
|
||||
func (src *Point) String() string {
|
||||
if src.Status == pgtype.Null {
|
||||
return "null point"
|
||||
}
|
||||
|
||||
_, err := io.WriteString(w, fmt.Sprintf("point(%v,%v)", src.X, src.Y))
|
||||
return false, err
|
||||
}
|
||||
|
||||
func (p Point) String() string {
|
||||
if p.Status == pgtype.Present {
|
||||
return fmt.Sprintf("%v, %v", p.X, p.Y)
|
||||
}
|
||||
return "null point"
|
||||
return fmt.Sprintf("%.1f, %.1f", src.X, src.Y)
|
||||
}
|
||||
|
||||
func Example_CustomType() {
|
||||
|
@ -70,15 +77,22 @@ func Example_CustomType() {
|
|||
return
|
||||
}
|
||||
|
||||
var p Point
|
||||
err = conn.QueryRow("select null::point").Scan(&p)
|
||||
// Override registered handler for point
|
||||
conn.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||
Value: &Point{},
|
||||
Name: "point",
|
||||
Oid: 600,
|
||||
})
|
||||
|
||||
p := &Point{}
|
||||
err = conn.QueryRow("select null::point").Scan(p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(p)
|
||||
|
||||
err = conn.QueryRow("select point(1.5,2.5)").Scan(&p)
|
||||
err = conn.QueryRow("select point(1.5,2.5)").Scan(p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
|
|
|
@ -245,6 +245,7 @@ func init() {
|
|||
"numeric": &Numeric{},
|
||||
"numrange": &Numrange{},
|
||||
"oid": &OidValue{},
|
||||
"point": &Point{},
|
||||
"record": &Record{},
|
||||
"text": &Text{},
|
||||
"tid": &Tid{},
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Point struct {
|
||||
X float64
|
||||
Y float64
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *Point) Set(src interface{}) error {
|
||||
return fmt.Errorf("cannot convert %v to Point", src)
|
||||
}
|
||||
|
||||
func (dst *Point) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Point) AssignTo(dst interface{}) error {
|
||||
return fmt.Errorf("cannot assign %v to %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Point{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src) < 5 {
|
||||
return fmt.Errorf("invalid length for point: %v", len(src))
|
||||
}
|
||||
|
||||
parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2)
|
||||
if len(parts) < 2 {
|
||||
return fmt.Errorf("invalid format for point")
|
||||
}
|
||||
|
||||
x, err := strconv.ParseFloat(parts[0], 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
y, err := strconv.ParseFloat(parts[1], 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*dst = Point{X: x, Y: y, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Point{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src) != 16 {
|
||||
return fmt.Errorf("invalid length for point: %v", len(src))
|
||||
}
|
||||
|
||||
x := binary.BigEndian.Uint64(src)
|
||||
y := binary.BigEndian.Uint64(src[8:])
|
||||
|
||||
*dst = Point{
|
||||
X: math.Float64frombits(x),
|
||||
Y: math.Float64frombits(y),
|
||||
Status: Present,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return true, nil
|
||||
case Undefined:
|
||||
return false, errUndefined
|
||||
}
|
||||
|
||||
_, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.X, src.Y))
|
||||
return false, err
|
||||
}
|
||||
|
||||
func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return true, nil
|
||||
case Undefined:
|
||||
return false, errUndefined
|
||||
}
|
||||
|
||||
_, err := pgio.WriteUint64(w, math.Float64bits(src.X))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = pgio.WriteUint64(w, math.Float64bits(src.Y))
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *Point) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = Point{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
return dst.DecodeText(nil, src)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *Point) Value() (driver.Value, error) {
|
||||
return encodeValueText(src)
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func TestPointTranscode(t *testing.T) {
|
||||
testSuccessfulTranscode(t, "point", []interface{}{
|
||||
&pgtype.Point{X: 1.234, Y: 5.6789, Status: pgtype.Present},
|
||||
&pgtype.Point{X: -1.234, Y: -5.6789, Status: pgtype.Present},
|
||||
&pgtype.Point{Status: pgtype.Null},
|
||||
})
|
||||
}
|
|
@ -710,6 +710,19 @@ func TestQueryRowUnknownType(t *testing.T) {
|
|||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// Clear existing type mappings
|
||||
conn.ConnInfo = pgtype.NewConnInfo()
|
||||
conn.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||
Value: &pgtype.GenericText{},
|
||||
Name: "point",
|
||||
Oid: 600,
|
||||
})
|
||||
conn.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||
Value: &pgtype.Int4{},
|
||||
Name: "int4",
|
||||
Oid: pgtype.Int4Oid,
|
||||
})
|
||||
|
||||
sql := "select $1::point"
|
||||
expected := "(1,0)"
|
||||
var actual string
|
||||
|
@ -751,7 +764,7 @@ func TestQueryRowErrors(t *testing.T) {
|
|||
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
||||
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
||||
{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"},
|
||||
{"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Text"},
|
||||
{"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
|
Loading…
Reference in New Issue