diff --git a/pgtype/box.go b/pgtype/box.go index 868b40a2..438a4f21 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -11,32 +11,189 @@ import ( "github.com/jackc/pgio" ) +type BoxScanner interface { + ScanBox(v Box) error +} + +type BoxValuer interface { + BoxValue() (Box, error) +} + type Box struct { P [2]Vec2 Valid bool } -func (dst *Box) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Box", src) +func (b *Box) ScanBox(v Box) error { + *b = v + return nil } -func (dst Box) Get() interface{} { - if !dst.Valid { - return nil - } - return dst +func (b Box) BoxValue() (Box, error) { + return b, nil } -func (src *Box) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the database/sql Scanner interface. +func (dst *Box) Scan(src interface{}) error { if src == nil { *dst = Box{} return nil } + switch src := src.(type) { + case string: + return scanPlanTextAnyToBoxScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Box) Value() (driver.Value, error) { + buf, err := BoxCodec{}.Encode(nil, 0, TextFormatCode, src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type BoxCodec struct{} + +func (BoxCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BoxCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BoxCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + if value == nil { + return nil, nil + } + + var box Box + if v, ok := value.(BoxValuer); ok { + b, err := v.BoxValue() + if err != nil { + return nil, err + } + box = b + } else { + return nil, fmt.Errorf("cannot convert %v to box: %v", value, err) + } + + if !box.Valid { + return nil, nil + } + + switch format { + case BinaryFormatCode: + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].Y)) + return buf, nil + case TextFormatCode: + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(box.P[0].X, 'f', -1, 64), + strconv.FormatFloat(box.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(box.P[1].X, 'f', -1, 64), + strconv.FormatFloat(box.P[1].Y, 'f', -1, 64), + )...) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (BoxCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case BoxScanner: + return scanPlanBinaryBoxToBoxScanner{} + } + case TextFormatCode: + switch target.(type) { + case BoxScanner: + return scanPlanTextAnyToBoxScanner{} + } + } + + return nil +} + +func (c BoxCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if format == TextFormatCode { + return string(src), nil + } else { + box, err := c.DecodeValue(ci, oid, format, src) + if err != nil { + return nil, err + } + buf, err := c.Encode(ci, oid, TextFormatCode, box, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } +} + +func (c BoxCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var box Box + scanPlan := c.PlanScan(ci, oid, format, &box, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &box) + if err != nil { + return nil, err + } + return box, nil +} + +type scanPlanBinaryBoxToBoxScanner struct{} + +func (scanPlanBinaryBoxToBoxScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BoxScanner) + + if src == nil { + return scanner.ScanBox(Box{}) + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + return scanner.ScanBox(Box{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Valid: true, + }) +} + +type scanPlanTextAnyToBoxScanner struct{} + +func (scanPlanTextAnyToBoxScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BoxScanner) + + if src == nil { + return scanner.ScanBox(Box{}) + } + if len(src) < 11 { return fmt.Errorf("invalid length for Box: %v", len(src)) } @@ -74,82 +231,5 @@ func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true} - return nil -} - -func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Box{} - return nil - } - - if len(src) != 32 { - return fmt.Errorf("invalid length for Box: %v", len(src)) - } - - x1 := binary.BigEndian.Uint64(src) - y1 := binary.BigEndian.Uint64(src[8:]) - x2 := binary.BigEndian.Uint64(src[16:]) - y2 := binary.BigEndian.Uint64(src[24:]) - - *dst = Box{ - P: [2]Vec2{ - {math.Float64frombits(x1), math.Float64frombits(y1)}, - {math.Float64frombits(x2), math.Float64frombits(y2)}, - }, - Valid: true, - } - return nil -} - -func (src Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, - strconv.FormatFloat(src.P[0].X, 'f', -1, 64), - strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), - strconv.FormatFloat(src.P[1].X, 'f', -1, 64), - strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), - )...) - return buf, nil -} - -func (src Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Box) Scan(src interface{}) error { - if src == nil { - *dst = Box{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Box) Value() (driver.Value, error) { - return EncodeValueText(src) + return scanner.ScanBox(Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) } diff --git a/pgtype/box_test.go b/pgtype/box_test.go index 481723b5..f4e26370 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -7,17 +7,31 @@ import ( "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestBoxTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "box", []interface{}{ - &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, - Valid: true, +func TestBoxCodec(t *testing.T) { + testPgxCodec(t, "box", []PgxTranscodeTestCase{ + { + pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, + }, + new(pgtype.Box), + isExpectedEq(pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, + }), }, - &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Valid: true, + { + pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, + Valid: true, + }, + new(pgtype.Box), + isExpectedEq(pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, + Valid: true, + }), }, - &pgtype.Box{}, + {nil, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 372f755b..b72255ce 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -312,7 +312,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) - ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID}) + ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) diff --git a/pgtype/zzz.box.go b/pgtype/zzz.box.go deleted file mode 100644 index 5ca2df43..00000000 --- a/pgtype/zzz.box.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Box) BinaryFormatSupported() bool { - return true -} - -func (Box) TextFormatSupported() bool { - return true -} - -func (Box) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Box) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Box) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -}