From 9f9a9779ac794e885f6fbb04935724f03149f092 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 31 Dec 2015 14:46:43 -0600 Subject: [PATCH] Add compatibility with database/sql custom types Support database/sql.Scanner Support database/sql/driver.Valuer --- CHANGELOG.md | 1 + README.md | 1 + conn.go | 8 +++- doc.go | 3 ++ query.go | 35 ++++++++++++++++ query_test.go | 113 ++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 160 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88c0a74d..f946d67f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Tip +* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces * Go float64 can no longer be encoded to a PostgreSQL float4 * Add ConnPool.Reset method * []byte skips encoding/decoding diff --git a/README.md b/README.md index 31434a02..51b0ad26 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Pgx supports many additional features beyond what is available through database/ * Maps inet and cidr PostgreSQL types to net.IPNet * Large object support * Null mapping to Null* struct or pointer to pointer. +* Supports database/sql.Scanner and database/sql/driver/Valuer interfaces for custom types ## Performance diff --git a/conn.go b/conn.go index 66eb9fc8..1c10d449 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ import ( "bufio" "crypto/md5" "crypto/tls" + "database/sql/driver" "encoding/binary" "encoding/hex" "errors" @@ -851,15 +852,20 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(int16(len(arguments))) for i, oid := range ps.ParameterOids { + encode: if arguments[i] == nil { wbuf.WriteInt32(-1) continue } - encode: switch arg := arguments[i].(type) { case Encoder: err = arg.Encode(wbuf, oid) + case driver.Valuer: + arguments[i], err = arg.Value() + if err == nil { + goto encode + } case string: err = encodeText(wbuf, arguments[i]) case []byte: diff --git a/doc.go b/doc.go index 2d54ab82..0fd3d2f6 100644 --- a/doc.go +++ b/doc.go @@ -181,6 +181,9 @@ Conn.PgTypes. See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. +pgx also includes support for custom types implementing the database/sql.Scanner +and database/sql/driver.Valuer interfaces. + Raw Bytes Mapping []byte passed as arguments to Query, QueryRow, and Exec are passed unmodified diff --git a/query.go b/query.go index 8c0e9d07..8398562b 100644 --- a/query.go +++ b/query.go @@ -1,6 +1,7 @@ package pgx import ( + "database/sql" "errors" "fmt" "net" @@ -255,6 +256,40 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } + } else if s, ok := d.(sql.Scanner); ok { + var val interface{} + if 0 <= vr.Len() { + switch vr.Type().DataType { + case BoolOid: + val = decodeBool(vr) + case Int8Oid: + val = int64(decodeInt8(vr)) + case Int2Oid: + val = int64(decodeInt2(vr)) + case Int4Oid: + val = int64(decodeInt4(vr)) + case TextOid, VarcharOid: + val = decodeText(vr) + case OidOid: + val = int64(decodeOid(vr)) + case Float4Oid: + val = float64(decodeFloat4(vr)) + case Float8Oid: + val = decodeFloat8(vr) + case DateOid: + val = decodeDate(vr) + case TimestampOid: + val = decodeTimestamp(vr) + case TimestampTzOid: + val = decodeTimestampTz(vr) + default: + val = vr.ReadBytes(vr.Len()) + } + } + err = s.Scan(val) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } } else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid { decodeJson(vr, &d) } else { diff --git a/query_test.go b/query_test.go index 2f9e7ee9..06df6c95 100644 --- a/query_test.go +++ b/query_test.go @@ -2,10 +2,13 @@ package pgx_test import ( "bytes" + "database/sql" "github.com/jackc/pgx" "strings" "testing" "time" + + "github.com/shopspring/decimal" ) func TestConnQueryScan(t *testing.T) { @@ -904,3 +907,113 @@ func TestReadingNullByteArrays(t *testing.T) { t.Errorf("Expected to read 2 rows, read: ", count) } } + +// Use github.com/shopspring/decimal as real-world database/sql custom type +// to test against. +func TestConnQueryDatabaseSQLScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var num decimal.Decimal + + err := conn.QueryRow("select '1234.567'::decimal").Scan(&num) + if err != nil { + t.Fatalf("Scan failed: %v", err) + } + + expected, err := decimal.NewFromString("1234.567") + if err != nil { + t.Fatal(err) + } + + if !num.Equals(expected) { + t.Errorf("Expected num to be %v, but it was %v", expected, num) + } + + ensureConnValid(t, conn) +} + +// Use github.com/shopspring/decimal as real-world database/sql custom type +// to test against. +func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + expected, err := decimal.NewFromString("1234.567") + if err != nil { + t.Fatal(err) + } + var num decimal.Decimal + + err = conn.QueryRow("select $1::decimal", expected).Scan(&num) + if err != nil { + t.Fatalf("Scan failed: %v", err) + } + + if !num.Equals(expected) { + t.Errorf("Expected num to be %v, but it was %v", expected, num) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryDatabaseSQLNullX(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type row struct { + boolValid sql.NullBool + boolNull sql.NullBool + int64Valid sql.NullInt64 + int64Null sql.NullInt64 + float64Valid sql.NullFloat64 + float64Null sql.NullFloat64 + stringValid sql.NullString + stringNull sql.NullString + } + + expected := row{ + boolValid: sql.NullBool{Bool: true, Valid: true}, + int64Valid: sql.NullInt64{Int64: 123, Valid: true}, + float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true}, + stringValid: sql.NullString{String: "pgx", Valid: true}, + } + + var actual row + + err := conn.QueryRow( + "select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text", + expected.boolValid, + expected.boolNull, + expected.int64Valid, + expected.int64Null, + expected.float64Valid, + expected.float64Null, + expected.stringValid, + expected.stringNull, + ).Scan( + &actual.boolValid, + &actual.boolNull, + &actual.int64Valid, + &actual.int64Null, + &actual.float64Valid, + &actual.float64Null, + &actual.stringValid, + &actual.stringNull, + ) + if err != nil { + t.Fatalf("Scan failed: %v", err) + } + + if expected != actual { + t.Errorf("Expected %v, but got %v", expected, actual) + } + + ensureConnValid(t, conn) +}