mirror of https://github.com/jackc/pgx.git
Add compatibility with database/sql custom types
Support database/sql.Scanner Support database/sql/driver.Valuerredshift-ssl-tests
parent
029bd49065
commit
9f9a9779ac
|
@ -1,5 +1,6 @@
|
|||
# Tip
|
||||
|
||||
* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces
|
||||
* Go float64 can no longer be encoded to a PostgreSQL float4
|
||||
* Add ConnPool.Reset method
|
||||
* []byte skips encoding/decoding
|
||||
|
|
|
@ -23,6 +23,7 @@ Pgx supports many additional features beyond what is available through database/
|
|||
* Maps inet and cidr PostgreSQL types to net.IPNet
|
||||
* Large object support
|
||||
* Null mapping to Null* struct or pointer to pointer.
|
||||
* Supports database/sql.Scanner and database/sql/driver/Valuer interfaces for custom types
|
||||
|
||||
## Performance
|
||||
|
||||
|
|
8
conn.go
8
conn.go
|
@ -4,6 +4,7 @@ import (
|
|||
"bufio"
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
|
@ -851,15 +852,20 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
|
||||
wbuf.WriteInt16(int16(len(arguments)))
|
||||
for i, oid := range ps.ParameterOids {
|
||||
encode:
|
||||
if arguments[i] == nil {
|
||||
wbuf.WriteInt32(-1)
|
||||
continue
|
||||
}
|
||||
|
||||
encode:
|
||||
switch arg := arguments[i].(type) {
|
||||
case Encoder:
|
||||
err = arg.Encode(wbuf, oid)
|
||||
case driver.Valuer:
|
||||
arguments[i], err = arg.Value()
|
||||
if err == nil {
|
||||
goto encode
|
||||
}
|
||||
case string:
|
||||
err = encodeText(wbuf, arguments[i])
|
||||
case []byte:
|
||||
|
|
3
doc.go
3
doc.go
|
@ -181,6 +181,9 @@ Conn.PgTypes.
|
|||
See example_custom_type_test.go for an example of a custom type for the
|
||||
PostgreSQL point type.
|
||||
|
||||
pgx also includes support for custom types implementing the database/sql.Scanner
|
||||
and database/sql/driver.Valuer interfaces.
|
||||
|
||||
Raw Bytes Mapping
|
||||
|
||||
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified
|
||||
|
|
35
query.go
35
query.go
|
@ -1,6 +1,7 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -255,6 +256,40 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
if err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if s, ok := d.(sql.Scanner); ok {
|
||||
var val interface{}
|
||||
if 0 <= vr.Len() {
|
||||
switch vr.Type().DataType {
|
||||
case BoolOid:
|
||||
val = decodeBool(vr)
|
||||
case Int8Oid:
|
||||
val = int64(decodeInt8(vr))
|
||||
case Int2Oid:
|
||||
val = int64(decodeInt2(vr))
|
||||
case Int4Oid:
|
||||
val = int64(decodeInt4(vr))
|
||||
case TextOid, VarcharOid:
|
||||
val = decodeText(vr)
|
||||
case OidOid:
|
||||
val = int64(decodeOid(vr))
|
||||
case Float4Oid:
|
||||
val = float64(decodeFloat4(vr))
|
||||
case Float8Oid:
|
||||
val = decodeFloat8(vr)
|
||||
case DateOid:
|
||||
val = decodeDate(vr)
|
||||
case TimestampOid:
|
||||
val = decodeTimestamp(vr)
|
||||
case TimestampTzOid:
|
||||
val = decodeTimestampTz(vr)
|
||||
default:
|
||||
val = vr.ReadBytes(vr.Len())
|
||||
}
|
||||
}
|
||||
err = s.Scan(val)
|
||||
if err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
|
||||
decodeJson(vr, &d)
|
||||
} else {
|
||||
|
|
113
query_test.go
113
query_test.go
|
@ -2,10 +2,13 @@ package pgx_test
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"github.com/jackc/pgx"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func TestConnQueryScan(t *testing.T) {
|
||||
|
@ -904,3 +907,113 @@ func TestReadingNullByteArrays(t *testing.T) {
|
|||
t.Errorf("Expected to read 2 rows, read: ", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Use github.com/shopspring/decimal as real-world database/sql custom type
|
||||
// to test against.
|
||||
func TestConnQueryDatabaseSQLScanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var num decimal.Decimal
|
||||
|
||||
err := conn.QueryRow("select '1234.567'::decimal").Scan(&num)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
|
||||
expected, err := decimal.NewFromString("1234.567")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !num.Equals(expected) {
|
||||
t.Errorf("Expected num to be %v, but it was %v", expected, num)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Use github.com/shopspring/decimal as real-world database/sql custom type
|
||||
// to test against.
|
||||
func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
expected, err := decimal.NewFromString("1234.567")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var num decimal.Decimal
|
||||
|
||||
err = conn.QueryRow("select $1::decimal", expected).Scan(&num)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
|
||||
if !num.Equals(expected) {
|
||||
t.Errorf("Expected num to be %v, but it was %v", expected, num)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryDatabaseSQLNullX(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type row struct {
|
||||
boolValid sql.NullBool
|
||||
boolNull sql.NullBool
|
||||
int64Valid sql.NullInt64
|
||||
int64Null sql.NullInt64
|
||||
float64Valid sql.NullFloat64
|
||||
float64Null sql.NullFloat64
|
||||
stringValid sql.NullString
|
||||
stringNull sql.NullString
|
||||
}
|
||||
|
||||
expected := row{
|
||||
boolValid: sql.NullBool{Bool: true, Valid: true},
|
||||
int64Valid: sql.NullInt64{Int64: 123, Valid: true},
|
||||
float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true},
|
||||
stringValid: sql.NullString{String: "pgx", Valid: true},
|
||||
}
|
||||
|
||||
var actual row
|
||||
|
||||
err := conn.QueryRow(
|
||||
"select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text",
|
||||
expected.boolValid,
|
||||
expected.boolNull,
|
||||
expected.int64Valid,
|
||||
expected.int64Null,
|
||||
expected.float64Valid,
|
||||
expected.float64Null,
|
||||
expected.stringValid,
|
||||
expected.stringNull,
|
||||
).Scan(
|
||||
&actual.boolValid,
|
||||
&actual.boolNull,
|
||||
&actual.int64Valid,
|
||||
&actual.int64Null,
|
||||
&actual.float64Valid,
|
||||
&actual.float64Null,
|
||||
&actual.stringValid,
|
||||
&actual.stringNull,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
|
||||
if expected != actual {
|
||||
t.Errorf("Expected %v, but got %v", expected, actual)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue