diff --git a/pgtype/line.go b/pgtype/line.go index c3192b2a..db584862 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -11,34 +11,177 @@ import ( "github.com/jackc/pgio" ) +type LineScanner interface { + ScanLine(v Line) error +} + +type LineValuer interface { + LineValue() (Line, error) +} + type Line struct { A, B, C float64 Valid bool } -func (dst *Line) Set(src interface{}) error { +func (line *Line) ScanLine(v Line) error { + *line = v + return nil +} + +func (line Line) LineValue() (Line, error) { + return line, nil +} + +func (line *Line) Set(src interface{}) error { return fmt.Errorf("cannot convert %v to Line", src) } -func (dst Line) Get() interface{} { - if !dst.Valid { +// Scan implements the database/sql Scanner interface. +func (line *Line) Scan(src interface{}) error { + if src == nil { + *line = Line{} return nil } - return dst + + switch src := src.(type) { + case string: + return scanPlanTextAnyToLineScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), line) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *Line) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) +// Value implements the database/sql/driver Valuer interface. +func (line Line) Value() (driver.Value, error) { + if !line.Valid { + return nil, nil + } + + buf, err := LineCodec{}.PlanEncode(nil, 0, TextFormatCode, line).Encode(line, nil) + if err != nil { + return nil, err + } + return string(buf), err } -func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Line{} +type LineCodec struct{} + +func (LineCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (LineCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (LineCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(LineValuer); !ok { return nil } + switch format { + case BinaryFormatCode: + return encodePlanLineCodecBinary{} + case TextFormatCode: + return encodePlanLineCodecText{} + } + + return nil +} + +type encodePlanLineCodecBinary struct{} + +func (encodePlanLineCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + line, err := value.(LineValuer).LineValue() + if err != nil { + return nil, err + } + + if !line.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(line.A)) + buf = pgio.AppendUint64(buf, math.Float64bits(line.B)) + buf = pgio.AppendUint64(buf, math.Float64bits(line.C)) + return buf, nil +} + +type encodePlanLineCodecText struct{} + +func (encodePlanLineCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + line, err := value.(LineValuer).LineValue() + if err != nil { + return nil, err + } + + if !line.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, + strconv.FormatFloat(line.A, 'f', -1, 64), + strconv.FormatFloat(line.B, 'f', -1, 64), + strconv.FormatFloat(line.C, 'f', -1, 64), + )...) + return buf, nil +} + +func (LineCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case LineScanner: + return scanPlanBinaryLineToLineScanner{} + } + case TextFormatCode: + switch target.(type) { + case LineScanner: + return scanPlanTextAnyToLineScanner{} + } + } + + return nil +} + +type scanPlanBinaryLineToLineScanner struct{} + +func (scanPlanBinaryLineToLineScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(LineScanner) + + if src == nil { + return scanner.ScanLine(Line{}) + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for line: %v", len(src)) + } + + a := binary.BigEndian.Uint64(src) + b := binary.BigEndian.Uint64(src[8:]) + c := binary.BigEndian.Uint64(src[16:]) + + return scanner.ScanLine(Line{ + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Valid: true, + }) +} + +type scanPlanTextAnyToLineScanner struct{} + +func (scanPlanTextAnyToLineScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(LineScanner) + + if src == nil { + return scanner.ScanLine(Line{}) + } + if len(src) < 7 { - return fmt.Errorf("invalid length for Line: %v", len(src)) + return fmt.Errorf("invalid length for line: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) @@ -61,78 +204,22 @@ func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Line{A: a, B: b, C: c, Valid: true} - return nil + return scanner.ScanLine(Line{A: a, B: b, C: c, Valid: true}) } -func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c LineCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c LineCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Line{} - return nil - } - - if len(src) != 24 { - return fmt.Errorf("invalid length for Line: %v", len(src)) - } - - a := binary.BigEndian.Uint64(src) - b := binary.BigEndian.Uint64(src[8:]) - c := binary.BigEndian.Uint64(src[16:]) - - *dst = Line{ - A: math.Float64frombits(a), - B: math.Float64frombits(b), - C: math.Float64frombits(c), - Valid: true, - } - return nil -} - -func (src Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, - strconv.FormatFloat(src.A, 'f', -1, 64), - strconv.FormatFloat(src.B, 'f', -1, 64), - strconv.FormatFloat(src.C, 'f', -1, 64), - )...) - - return buf, nil -} - -func (src Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil + var line Line + err := codecScan(c, ci, oid, format, src, &line) + if err != nil { + return nil, err } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.A)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.B)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.C)) - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Line) Scan(src interface{}) error { - if src == nil { - *dst = Line{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Line) Value() (driver.Value, error) { - return EncodeValueText(src) + return line, nil } diff --git a/pgtype/line_test.go b/pgtype/line_test.go index b171a7a5..669d9b8d 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -10,6 +10,7 @@ import ( func TestLineTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) + defer conn.Close(context.Background()) if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { t.Skip("Skipping due to no line type") } @@ -24,15 +25,30 @@ func TestLineTranscode(t *testing.T) { t.Skip("Skipping due to unimplemented line type in PG 9.3") } - testutil.TestSuccessfulTranscode(t, "line", []interface{}{ - &pgtype.Line{ - A: 1.23, B: 4.56, C: 7.89012345, - Valid: true, + testPgxCodec(t, "line", []PgxTranscodeTestCase{ + { + pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Valid: true, + }, + new(pgtype.Line), + isExpectedEq(pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Valid: true, + }), }, - &pgtype.Line{ - A: -1.23, B: -4.56, C: -7.89, - Valid: true, + { + pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Valid: true, + }, + new(pgtype.Line), + isExpectedEq(pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Valid: true, + }), }, - &pgtype.Line{}, + {pgtype.Line{}, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, + {nil, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 300037df..605d9132 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -311,7 +311,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) - ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID}) + ci.RegisterDataType(DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}})