Add compatibility with database/sql custom types

Support database/sql.Scanner
Support database/sql/driver.Valuer
redshift-ssl-tests
Jack Christensen 2015-12-31 14:46:43 -06:00
parent 029bd49065
commit 9f9a9779ac
6 changed files with 160 additions and 1 deletions

View File

@ -1,5 +1,6 @@
# Tip
* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces
* Go float64 can no longer be encoded to a PostgreSQL float4
* Add ConnPool.Reset method
* []byte skips encoding/decoding

View File

@ -23,6 +23,7 @@ Pgx supports many additional features beyond what is available through database/
* Maps inet and cidr PostgreSQL types to net.IPNet
* Large object support
* Null mapping to Null* struct or pointer to pointer.
* Supports database/sql.Scanner and database/sql/driver/Valuer interfaces for custom types
## Performance

View File

@ -4,6 +4,7 @@ import (
"bufio"
"crypto/md5"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"encoding/hex"
"errors"
@ -851,15 +852,20 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
wbuf.WriteInt16(int16(len(arguments)))
for i, oid := range ps.ParameterOids {
encode:
if arguments[i] == nil {
wbuf.WriteInt32(-1)
continue
}
encode:
switch arg := arguments[i].(type) {
case Encoder:
err = arg.Encode(wbuf, oid)
case driver.Valuer:
arguments[i], err = arg.Value()
if err == nil {
goto encode
}
case string:
err = encodeText(wbuf, arguments[i])
case []byte:

3
doc.go
View File

@ -181,6 +181,9 @@ Conn.PgTypes.
See example_custom_type_test.go for an example of a custom type for the
PostgreSQL point type.
pgx also includes support for custom types implementing the database/sql.Scanner
and database/sql/driver.Valuer interfaces.
Raw Bytes Mapping
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified

View File

@ -1,6 +1,7 @@
package pgx
import (
"database/sql"
"errors"
"fmt"
"net"
@ -255,6 +256,40 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if s, ok := d.(sql.Scanner); ok {
var val interface{}
if 0 <= vr.Len() {
switch vr.Type().DataType {
case BoolOid:
val = decodeBool(vr)
case Int8Oid:
val = int64(decodeInt8(vr))
case Int2Oid:
val = int64(decodeInt2(vr))
case Int4Oid:
val = int64(decodeInt4(vr))
case TextOid, VarcharOid:
val = decodeText(vr)
case OidOid:
val = int64(decodeOid(vr))
case Float4Oid:
val = float64(decodeFloat4(vr))
case Float8Oid:
val = decodeFloat8(vr)
case DateOid:
val = decodeDate(vr)
case TimestampOid:
val = decodeTimestamp(vr)
case TimestampTzOid:
val = decodeTimestampTz(vr)
default:
val = vr.ReadBytes(vr.Len())
}
}
err = s.Scan(val)
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
decodeJson(vr, &d)
} else {

View File

@ -2,10 +2,13 @@ package pgx_test
import (
"bytes"
"database/sql"
"github.com/jackc/pgx"
"strings"
"testing"
"time"
"github.com/shopspring/decimal"
)
func TestConnQueryScan(t *testing.T) {
@ -904,3 +907,113 @@ func TestReadingNullByteArrays(t *testing.T) {
t.Errorf("Expected to read 2 rows, read: ", count)
}
}
// Use github.com/shopspring/decimal as real-world database/sql custom type
// to test against.
func TestConnQueryDatabaseSQLScanner(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var num decimal.Decimal
err := conn.QueryRow("select '1234.567'::decimal").Scan(&num)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}
expected, err := decimal.NewFromString("1234.567")
if err != nil {
t.Fatal(err)
}
if !num.Equals(expected) {
t.Errorf("Expected num to be %v, but it was %v", expected, num)
}
ensureConnValid(t, conn)
}
// Use github.com/shopspring/decimal as real-world database/sql custom type
// to test against.
func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
expected, err := decimal.NewFromString("1234.567")
if err != nil {
t.Fatal(err)
}
var num decimal.Decimal
err = conn.QueryRow("select $1::decimal", expected).Scan(&num)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}
if !num.Equals(expected) {
t.Errorf("Expected num to be %v, but it was %v", expected, num)
}
ensureConnValid(t, conn)
}
func TestConnQueryDatabaseSQLNullX(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
type row struct {
boolValid sql.NullBool
boolNull sql.NullBool
int64Valid sql.NullInt64
int64Null sql.NullInt64
float64Valid sql.NullFloat64
float64Null sql.NullFloat64
stringValid sql.NullString
stringNull sql.NullString
}
expected := row{
boolValid: sql.NullBool{Bool: true, Valid: true},
int64Valid: sql.NullInt64{Int64: 123, Valid: true},
float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true},
stringValid: sql.NullString{String: "pgx", Valid: true},
}
var actual row
err := conn.QueryRow(
"select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text",
expected.boolValid,
expected.boolNull,
expected.int64Valid,
expected.int64Null,
expected.float64Valid,
expected.float64Null,
expected.stringValid,
expected.stringNull,
).Scan(
&actual.boolValid,
&actual.boolNull,
&actual.int64Valid,
&actual.int64Null,
&actual.float64Valid,
&actual.float64Null,
&actual.stringValid,
&actual.stringNull,
)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}
if expected != actual {
t.Errorf("Expected %v, but got %v", expected, actual)
}
ensureConnValid(t, conn)
}