Merge branch 'master' into v3-experimental

v3-experimental-wait-ping-context
Jack Christensen 2016-12-10 12:21:08 -06:00
commit 93e5c68f69
26 changed files with 2500 additions and 239 deletions

View File

@ -1,8 +1,8 @@
language: go
go:
- 1.6.2
- 1.5.2
- 1.7.1
- 1.6.3
- tip
# Derived from https://github.com/lib/pq/blob/master/.travis.yml
@ -29,6 +29,8 @@ env:
- PGVERSION=9.2
- PGVERSION=9.1
# The tricky test user, below, has to actually exist so that it can be used in a test
# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
before_script:
- mv conn_config_test.go.travis conn_config_test.go
- psql -U postgres -c 'create database pgx_test'
@ -36,6 +38,12 @@ before_script:
- psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
- psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
- psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
- psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
install:
- go get -u github.com/shopspring/decimal
- go get -u gopkg.in/inconshreveable/log15.v2
- go get -u github.com/jackc/fake
script:
- go test -v -race -short ./...

View File

@ -2,6 +2,26 @@
## Fixes
* Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood)
## Features
* Add xid type support (Manni Wood)
* Add cid type support (Manni Wood)
* Add tid type support (Manni Wood)
* Add "char" type support (Manni Wood)
* Add NullOid type (Manni Wood)
* Add json/jsonb binary support to allow use with CopyTo
* Add named error ErrAcquireTimeout (Alexander Staubo)
## Compatibility
* jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work.
# 2.9.0 (August 26, 2016)
## Fixes
* Fix *ConnPool.Deallocate() not deleting prepared statement from map
* Fix stdlib not logging unprepared query SQL (Krzysztof Dryś)
* Fix Rows.Values() with varchar binary format
@ -9,12 +29,13 @@
## Features
* Add CopyTo
* Add PrepareEx
* Add basic record to []interface{} decoding
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
* Decode inet/cidr to net.IP
* Encode/decode [][]byte to/from bytea[]
* Encode/decode named types whoses underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64
* Encode/decode named types whose underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64
## Performance

View File

@ -19,6 +19,7 @@ Pgx supports many additional features beyond what is available through database/
* Transaction isolation level control
* Full TLS connection control
* Binary format support for custom types (can be much faster)
* Copy protocol support for faster bulk data loads
* Logging support
* Configurable connection pool with after connect hooks to do arbitrary connection setup
* PostgreSQL array to Go slice mapping for integers, floats, and strings
@ -60,16 +61,23 @@ skip tests for connection types that are not configured.
### Normal Test Environment
To setup the normal test environment run the following SQL:
To setup the normal test environment, first install these dependencies:
go get github.com/jackc/fake
go get github.com/shopspring/decimal
go get gopkg.in/inconshreveable/log15.v2
Then run the following SQL:
create user pgx_md5 password 'secret';
create user " tricky, ' } "" \ test user " password 'secret';
create database pgx_test;
Connect to database pgx_test and run:
create extension hstore;
Next open connection_settings_test.go.example and make a copy without the
Next open conn_config_test.go.example and make a copy without the
.example. If your PostgreSQL server is accepting connections on 127.0.0.1,
then you are done.
@ -98,7 +106,8 @@ If you are developing on Windows with TCP connections:
## Version Policy
pgx follows semantic versioning for the documented public API. ```master```
branch tracks the latest stable branch (```v2```). Consider using ```import
"gopkg.in/jackc/pgx.v2"``` to lock to the ```v2``` branch or use a vendoring
tool such as [godep](https://github.com/tools/godep).
pgx follows semantic versioning for the documented public API on stable releases. Branch ```v2``` is the latest stable release. ```master``` can contain new features or behavior that will change or be removed before being merged to the stable ```v2``` branch (in practice, this occurs very rarely).
Consider using a vendoring
tool such as [godep](https://github.com/tools/godep) or importing pgx via ```import
"gopkg.in/jackc/pgx.v2"``` to lock to the ```v2``` branch.

126
aclitem_parse_test.go Normal file
View File

@ -0,0 +1,126 @@
package pgx
import (
"reflect"
"testing"
)
func TestEscapeAclItem(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
"foo",
"foo",
},
{
`foo, "\}`,
`foo\, \"\\\}`,
},
}
for i, tt := range tests {
actual, err := escapeAclItem(tt.input)
if err != nil {
t.Errorf("%d. Unexpected error %v", i, err)
}
if actual != tt.expected {
t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual)
}
}
}
func TestParseAclItemArray(t *testing.T) {
tests := []struct {
input string
expected []AclItem
errMsg string
}{
{
"",
[]AclItem{},
"",
},
{
"one",
[]AclItem{"one"},
"",
},
{
`"one"`,
[]AclItem{"one"},
"",
},
{
"one,two,three",
[]AclItem{"one", "two", "three"},
"",
},
{
`"one","two","three"`,
[]AclItem{"one", "two", "three"},
"",
},
{
`"one",two,"three"`,
[]AclItem{"one", "two", "three"},
"",
},
{
`one,two,"three"`,
[]AclItem{"one", "two", "three"},
"",
},
{
`"one","two",three`,
[]AclItem{"one", "two", "three"},
"",
},
{
`"one","t w o",three`,
[]AclItem{"one", "t w o", "three"},
"",
},
{
`"one","t, w o\"\}\\",three`,
[]AclItem{"one", `t, w o"}\`, "three"},
"",
},
{
`"one","two",three"`,
[]AclItem{"one", "two", `three"`},
"",
},
{
`"one","two,"three"`,
nil,
"unexpected rune after quoted value",
},
{
`"one","two","three`,
nil,
"unexpected end of quoted value",
},
}
for i, tt := range tests {
actual, err := parseAclItemArray(tt.input)
if err != nil {
if tt.errMsg == "" {
t.Errorf("%d. Unexpected error %v", i, err)
} else if err.Error() != tt.errMsg {
t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error())
}
} else if tt.errMsg != "" {
t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg)
}
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual)
}
}
}

View File

@ -1,6 +1,9 @@
package pgx_test
import (
"bytes"
"fmt"
"strings"
"testing"
"time"
@ -397,3 +400,331 @@ func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
}
}
}
const benchmarkWriteTableCreateSQL = `drop table if exists t;
create table t(
varchar_1 varchar not null,
varchar_2 varchar not null,
varchar_null_1 varchar,
date_1 date not null,
date_null_1 date,
int4_1 int4 not null,
int4_2 int4 not null,
int4_null_1 int4,
tstz_1 timestamptz not null,
tstz_2 timestamptz,
bool_1 bool not null,
bool_2 bool not null,
bool_3 bool not null
);
`
const benchmarkWriteTableInsertSQL = `insert into t(
varchar_1,
varchar_2,
varchar_null_1,
date_1,
date_null_1,
int4_1,
int4_2,
int4_null_1,
tstz_1,
tstz_2,
bool_1,
bool_2,
bool_3
) values (
$1::varchar,
$2::varchar,
$3::varchar,
$4::date,
$5::date,
$6::int4,
$7::int4,
$8::int4,
$9::timestamptz,
$10::timestamptz,
$11::bool,
$12::bool,
$13::bool
)`
type benchmarkWriteTableCopyToSrc struct {
count int
idx int
row []interface{}
}
func (s *benchmarkWriteTableCopyToSrc) Next() bool {
s.idx++
return s.idx < s.count
}
func (s *benchmarkWriteTableCopyToSrc) Values() ([]interface{}, error) {
return s.row, nil
}
func (s *benchmarkWriteTableCopyToSrc) Err() error {
return nil
}
func newBenchmarkWriteTableCopyToSrc(count int) pgx.CopyToSource {
return &benchmarkWriteTableCopyToSrc{
count: count,
row: []interface{}{
"varchar_1",
"varchar_2",
pgx.NullString{},
time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
pgx.NullTime{},
1,
2,
pgx.NullInt32{},
time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
true,
false,
true,
},
}
}
func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
mustExec(b, conn, benchmarkWriteTableCreateSQL)
_, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
src := newBenchmarkWriteTableCopyToSrc(n)
tx, err := conn.Begin()
if err != nil {
b.Fatal(err)
}
for src.Next() {
values, _ := src.Values()
if _, err = tx.Exec("insert_t", values...); err != nil {
b.Fatalf("Exec unexpectedly failed with: %v", err)
}
}
err = tx.Commit()
if err != nil {
b.Fatal(err)
}
}
}
// note this function is only used for benchmarks -- it doesn't escape tableName
// or columnNames
func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyToSource) (int, error) {
maxRowsPerInsert := 65535 / len(columnNames)
rowsThisInsert := 0
rowCount := 0
sqlBuf := &bytes.Buffer{}
args := make(pgx.QueryArgs, 0)
resetQuery := func() {
sqlBuf.Reset()
fmt.Fprintf(sqlBuf, "insert into %s(%s) values", tableName, strings.Join(columnNames, ", "))
args = args[0:0]
rowsThisInsert = 0
}
resetQuery()
tx, err := conn.Begin()
if err != nil {
return 0, err
}
defer tx.Rollback()
for rowSrc.Next() {
if rowsThisInsert > 0 {
sqlBuf.WriteByte(',')
}
sqlBuf.WriteByte('(')
values, err := rowSrc.Values()
if err != nil {
return 0, err
}
for i, val := range values {
if i > 0 {
sqlBuf.WriteByte(',')
}
sqlBuf.WriteString(args.Append(val))
}
sqlBuf.WriteByte(')')
rowsThisInsert++
if rowsThisInsert == maxRowsPerInsert {
_, err := tx.Exec(sqlBuf.String(), args...)
if err != nil {
return 0, err
}
rowCount += rowsThisInsert
resetQuery()
}
}
if rowsThisInsert > 0 {
_, err := tx.Exec(sqlBuf.String(), args...)
if err != nil {
return 0, err
}
rowCount += rowsThisInsert
}
if err := tx.Commit(); err != nil {
return 0, nil
}
return rowCount, nil
}
func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
mustExec(b, conn, benchmarkWriteTableCreateSQL)
_, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
src := newBenchmarkWriteTableCopyToSrc(n)
_, err := multiInsert(conn, "t",
[]string{"varchar_1",
"varchar_2",
"varchar_null_1",
"date_1",
"date_null_1",
"int4_1",
"int4_2",
"int4_null_1",
"tstz_1",
"tstz_2",
"bool_1",
"bool_2",
"bool_3"},
src)
if err != nil {
b.Fatal(err)
}
}
}
func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
mustExec(b, conn, benchmarkWriteTableCreateSQL)
b.ResetTimer()
for i := 0; i < b.N; i++ {
src := newBenchmarkWriteTableCopyToSrc(n)
_, err := conn.CopyTo("t",
[]string{"varchar_1",
"varchar_2",
"varchar_null_1",
"date_1",
"date_null_1",
"int4_1",
"int4_2",
"int4_null_1",
"tstz_1",
"tstz_2",
"bool_1",
"bool_2",
"bool_3"},
src)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkWrite5RowsViaInsert(b *testing.B) {
benchmarkWriteNRowsViaInsert(b, 5)
}
func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 5)
}
func BenchmarkWrite5RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 5)
}
func BenchmarkWrite10RowsViaInsert(b *testing.B) {
benchmarkWriteNRowsViaInsert(b, 10)
}
func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 10)
}
func BenchmarkWrite10RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 10)
}
func BenchmarkWrite100RowsViaInsert(b *testing.B) {
benchmarkWriteNRowsViaInsert(b, 100)
}
func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 100)
}
func BenchmarkWrite100RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 100)
}
func BenchmarkWrite1000RowsViaInsert(b *testing.B) {
benchmarkWriteNRowsViaInsert(b, 1000)
}
func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 1000)
}
func BenchmarkWrite1000RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 1000)
}
func BenchmarkWrite10000RowsViaInsert(b *testing.B) {
benchmarkWriteNRowsViaInsert(b, 10000)
}
func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 10000)
}
func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 10000)
}

61
conn.go
View File

@ -63,8 +63,8 @@ type Conn struct {
logLevel int
mr msgReader
fp *fastpath
pgsql_af_inet *byte
pgsql_af_inet6 *byte
pgsqlAfInet *byte
pgsqlAfInet6 *byte
busy bool
poolResetCount int
preallocatedRows []Rows
@ -145,7 +145,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
return connect(config, nil, nil, nil)
}
func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgsql_af_inet6 *byte) (c *Conn, err error) {
func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) {
c = new(Conn)
c.config = config
@ -157,13 +157,13 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgs
}
}
if pgsql_af_inet != nil {
c.pgsql_af_inet = new(byte)
*c.pgsql_af_inet = *pgsql_af_inet
if pgsqlAfInet != nil {
c.pgsqlAfInet = new(byte)
*c.pgsqlAfInet = *pgsqlAfInet
}
if pgsql_af_inet6 != nil {
c.pgsql_af_inet6 = new(byte)
*c.pgsql_af_inet6 = *pgsql_af_inet6
if pgsqlAfInet6 != nil {
c.pgsqlAfInet6 = new(byte)
*c.pgsqlAfInet6 = *pgsqlAfInet6
}
if c.config.LogLevel != 0 {
@ -209,12 +209,21 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgs
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
}
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
}
err = c.connect(config, network, address, config.TLSConfig)
if err != nil && config.UseFallbackTLS {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, fmt.Sprintf("Connect with TLSConfig failed, trying FallbackTLSConfig: %v", err))
}
err = c.connect(config, network, address, config.FallbackTLSConfig)
}
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, fmt.Sprintf("Connect failed: %v", err))
}
return nil, err
}
@ -222,23 +231,14 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgs
}
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
}
c.conn, err = c.config.Dial(network, address)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, fmt.Sprintf("Connection failed: %v", err))
}
return err
}
defer func() {
if c != nil && err != nil {
c.conn.Close()
c.alive = false
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, err.Error())
}
}
}()
@ -253,9 +253,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.log(LogLevelDebug, "Starting TLS handshake")
}
if err := c.startTLS(tlsConfig); err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, fmt.Sprintf("TLS failed: %v", err))
}
return err
}
}
@ -315,7 +312,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
}
}
if c.pgsql_af_inet == nil || c.pgsql_af_inet6 == nil {
if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil {
err = c.loadInetConstants()
if err != nil {
return err
@ -372,8 +369,8 @@ func (c *Conn) loadInetConstants() error {
return err
}
c.pgsql_af_inet = &ipv4[0]
c.pgsql_af_inet6 = &ipv6[0]
c.pgsqlAfInet = &ipv4[0]
c.pgsqlAfInet6 = &ipv6[0]
return nil
}
@ -430,7 +427,7 @@ func ParseURI(uri string) (ConnConfig, error) {
}
ignoreKeys := map[string]struct{}{
"sslmode": struct{}{},
"sslmode": {},
}
cp.RuntimeParams = make(map[string]string)
@ -446,7 +443,7 @@ func ParseURI(uri string) (ConnConfig, error) {
return cp, nil
}
var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
// ParseDSN parses a database DSN (data source name) into a ConnConfig
//
@ -462,7 +459,7 @@ var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
func ParseDSN(s string) (ConnConfig, error) {
var cp ConnConfig
m := dsn_regexp.FindAllStringSubmatch(s, -1)
m := dsnRegexp.FindAllStringSubmatch(s, -1)
var sslmode string
@ -477,11 +474,11 @@ func ParseDSN(s string) (ConnConfig, error) {
case "host":
cp.Host = b[2]
case "port":
if p, err := strconv.ParseUint(b[2], 10, 16); err != nil {
p, err := strconv.ParseUint(b[2], 10, 16)
if err != nil {
return cp, err
} else {
cp.Port = uint16(p)
}
cp.Port = uint16(p)
case "dbname":
cp.Database = b[2]
case "sslmode":
@ -627,7 +624,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
if opts != nil {
if len(opts.ParameterOIDs) > 65535 {
return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)))
return nil, fmt.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs))
}
wbuf.WriteInt16(int16(len(opts.ParameterOIDs)))
for _, oid := range opts.ParameterOIDs {
@ -917,7 +914,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
wbuf.WriteInt16(TextFormatCode)
default:
switch oid {
case BoolOID, ByteaOID, Int2OID, Int4OID, Int8OID, Float4OID, Float8OID, TimestampTzOID, TimestampTzArrayOID, TimestampOID, TimestampArrayOID, DateOID, BoolArrayOID, ByteaArrayOID, Int2ArrayOID, Int4ArrayOID, Int8ArrayOID, Float4ArrayOID, Float8ArrayOID, TextArrayOID, VarcharArrayOID, OIDOID, InetOID, CidrOID, InetArrayOID, CidrArrayOID, RecordOID:
case BoolOID, ByteaOID, Int2OID, Int4OID, Int8OID, Float4OID, Float8OID, TimestampTzOID, TimestampTzArrayOID, TimestampOID, TimestampArrayOID, DateOID, BoolArrayOID, ByteaArrayOID, Int2ArrayOID, Int4ArrayOID, Int8ArrayOID, Float4ArrayOID, Float8ArrayOID, TextArrayOID, VarcharArrayOID, OIDOID, InetOID, CidrOID, InetArrayOID, CidrArrayOID, RecordOID, JSONOID, JSONBOID:
wbuf.WriteInt16(BinaryFormatCode)
default:
wbuf.WriteInt16(TextFormatCode)

View File

@ -11,7 +11,6 @@ var tcpConnConfig *pgx.ConnConfig = nil
var unixSocketConnConfig *pgx.ConnConfig = nil
var md5ConnConfig *pgx.ConnConfig = nil
var plainPasswordConnConfig *pgx.ConnConfig = nil
var noPasswordConnConfig *pgx.ConnConfig = nil
var invalidUserConnConfig *pgx.ConnConfig = nil
var tlsConnConfig *pgx.ConnConfig = nil
var customDialerConnConfig *pgx.ConnConfig = nil
@ -20,7 +19,6 @@ var customDialerConnConfig *pgx.ConnConfig = nil
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"}
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}

View File

@ -28,8 +28,10 @@ type ConnPool struct {
preparedStatements map[string]*PreparedStatement
acquireTimeout time.Duration
pgTypes map[OID]PgType
pgsql_af_inet *byte
pgsql_af_inet6 *byte
pgsqlAfInet *byte
pgsqlAfInet6 *byte
txAfterClose func(tx *Tx)
rowsAfterClose func(rows *Rows)
}
type ConnPoolStat struct {
@ -38,6 +40,9 @@ type ConnPoolStat struct {
AvailableConnections int // unused live connections
}
// ErrAcquireTimeout occurs when an attempt to acquire a connection times out.
var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool")
// NewConnPool creates a new ConnPool. config.ConnConfig is passed through to
// Connect directly.
func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
@ -68,6 +73,14 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
p.logLevel = LogLevelNone
}
p.txAfterClose = func(tx *Tx) {
p.Release(tx.Conn())
}
p.rowsAfterClose = func(rows *Rows) {
p.Release(rows.Conn())
}
p.allConnections = make([]*Conn, 0, p.maxConnections)
p.availableConnections = make([]*Conn, 0, p.maxConnections)
p.preparedStatements = make(map[string]*PreparedStatement)
@ -121,7 +134,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
// Make sure the deadline (if it is) has not passed yet
if p.deadlinePassed(deadline) {
return nil, errors.New("Timeout: Acquire connection timeout")
return nil, ErrAcquireTimeout
}
// If there is a deadline then start a timeout timer
@ -138,26 +151,25 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
// Create a new connection.
// Careful here: createConnectionUnlocked() removes the current lock,
// creates a connection and then locks it back.
if c, err := p.createConnectionUnlocked(); err == nil {
c.poolResetCount = p.resetCount
p.allConnections = append(p.allConnections, c)
return c, nil
} else {
c, err := p.createConnectionUnlocked()
if err != nil {
return nil, err
}
} else {
// All connections are in use and we cannot create more
if p.logLevel >= LogLevelWarn {
p.logger.Log(LogLevelWarn, "All connections in pool are busy - waiting...")
}
c.poolResetCount = p.resetCount
p.allConnections = append(p.allConnections, c)
return c, nil
}
// All connections are in use and we cannot create more
if p.logLevel >= LogLevelWarn {
p.logger.Log(LogLevelWarn, "All connections in pool are busy - waiting...")
}
// Wait until there is an available connection OR room to create a new connection
for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections {
if p.deadlinePassed(deadline) {
return nil, errors.New("Timeout: All connections in pool are busy")
}
p.cond.Wait()
// Wait until there is an available connection OR room to create a new connection
for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections {
if p.deadlinePassed(deadline) {
return nil, ErrAcquireTimeout
}
p.cond.Wait()
}
// Stop the timer so that we do not spawn it on every acquire call.
@ -272,7 +284,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) {
}
func (p *ConnPool) createConnection() (*Conn, error) {
c, err := connect(p.config, p.pgTypes, p.pgsql_af_inet, p.pgsql_af_inet6)
c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6)
if err != nil {
return nil, err
}
@ -308,8 +320,8 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
// all the known statements for the new connection.
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
p.pgTypes = c.PgTypes
p.pgsql_af_inet = c.pgsql_af_inet
p.pgsql_af_inet6 = c.pgsql_af_inet6
p.pgsqlAfInet = c.pgsqlAfInet
p.pgsqlAfInet6 = c.pgsqlAfInet6
if p.afterConnect != nil {
err := p.afterConnect(c)
@ -487,10 +499,13 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) {
}
}
func (p *ConnPool) txAfterClose(tx *Tx) {
p.Release(tx.Conn())
}
// CopyTo acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
c, err := p.Acquire()
if err != nil {
return 0, err
}
defer p.Release(c)
func (p *ConnPool) rowsAfterClose(rows *Rows) {
p.Release(rows.Conn())
return c.CopyTo(tableName, columnNames, rowSrc)
}

View File

@ -40,7 +40,7 @@ func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) {
func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) {
startTime := time.Now()
c, err := pool.Acquire()
return c, time.Now().Sub(startTime), err
return c, time.Since(startTime), err
}
func TestNewConnPool(t *testing.T) {
@ -215,7 +215,7 @@ func TestPoolNonBlockingConnections(t *testing.T) {
// Prior to createConnectionUnlocked() use the test took
// maxConnections * openTimeout seconds to complete.
// With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds.
timeTaken := time.Now().Sub(startedAt)
timeTaken := time.Since(startedAt)
if timeTaken > openTimeout+1*time.Second {
t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken)
}
@ -276,8 +276,8 @@ func TestPoolWithAcquireTimeoutSet(t *testing.T) {
// ... then try to consume 1 more. It should fail after a short timeout.
_, timeTaken, err := acquireWithTimeTaken(pool)
if err == nil || err.Error() != "Timeout: All connections in pool are busy" {
t.Fatalf("Expected error to be 'Timeout: All connections in pool are busy', instead it was '%v'", err)
if err == nil || err != pgx.ErrAcquireTimeout {
t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err)
}
if timeTaken < connAllocTimeout {
t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken)
@ -366,12 +366,12 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) {
if err != nil {
t.Fatalf("Unable to Acquire: %v", err)
}
rows, _ := c.Query("select 1")
rows, _ := c.Query("select 1, pg_sleep(0.02)")
rows.Close()
pool.Release(c)
}
for i := 0; i < 1000; i++ {
for i := 0; i < 10; i++ {
doSomething()
}
@ -381,7 +381,7 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) {
}
var wg sync.WaitGroup
for i := 0; i < 1000; i++ {
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()

View File

@ -914,7 +914,7 @@ func TestPrepareQueryManyParameters(t *testing.T) {
args := make([]interface{}, 0, tt.count)
for j := 0; j < tt.count; j++ {
params = append(params, fmt.Sprintf("($%d::text)", j+1))
args = append(args, strconv.FormatInt(int64(j), 10))
args = append(args, strconv.Itoa(j))
}
sql := "values" + strings.Join(params, ", ")

241
copy_to.go Normal file
View File

@ -0,0 +1,241 @@
package pgx
import (
"bytes"
"fmt"
)
// CopyToRows returns a CopyToSource interface over the provided rows slice
// making it usable by *Conn.CopyTo.
func CopyToRows(rows [][]interface{}) CopyToSource {
return &copyToRows{rows: rows, idx: -1}
}
type copyToRows struct {
rows [][]interface{}
idx int
}
func (ctr *copyToRows) Next() bool {
ctr.idx++
return ctr.idx < len(ctr.rows)
}
func (ctr *copyToRows) Values() ([]interface{}, error) {
return ctr.rows[ctr.idx], nil
}
func (ctr *copyToRows) Err() error {
return nil
}
// CopyToSource is the interface used by *Conn.CopyTo as the source for copy data.
type CopyToSource interface {
// Next returns true if there is another row and makes the next row data
// available to Values(). When there are no more rows available or an error
// has occurred it returns false.
Next() bool
// Values returns the values for the current row.
Values() ([]interface{}, error)
// Err returns any error that has been encountered by the CopyToSource. If
// this is not nil *Conn.CopyTo will abort the copy.
Err() error
}
type copyTo struct {
conn *Conn
tableName string
columnNames []string
rowSrc CopyToSource
readerErrChan chan error
}
func (ct *copyTo) readUntilReadyForQuery() {
for {
t, r, err := ct.conn.rxMsg()
if err != nil {
ct.readerErrChan <- err
close(ct.readerErrChan)
return
}
switch t {
case readyForQuery:
ct.conn.rxReadyForQuery(r)
close(ct.readerErrChan)
return
case commandComplete:
case errorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
default:
err = ct.conn.processContextFreeMsg(t, r)
if err != nil {
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
}
}
}
}
func (ct *copyTo) waitForReaderDone() error {
var err error
for err = range ct.readerErrChan {
}
return err
}
func (ct *copyTo) run() (int, error) {
quotedTableName := quoteIdentifier(ct.tableName)
buf := &bytes.Buffer{}
for i, cn := range ct.columnNames {
if i != 0 {
buf.WriteString(", ")
}
buf.WriteString(quoteIdentifier(cn))
}
quotedColumnNames := buf.String()
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
if err != nil {
return 0, err
}
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
if err != nil {
return 0, err
}
err = ct.conn.readUntilCopyInResponse()
if err != nil {
return 0, err
}
go ct.readUntilReadyForQuery()
defer ct.waitForReaderDone()
wbuf := newWriteBuf(ct.conn, copyData)
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
wbuf.WriteInt32(0)
wbuf.WriteInt32(0)
var sentCount int
for ct.rowSrc.Next() {
select {
case err = <-ct.readerErrChan:
return 0, err
default:
}
if len(wbuf.buf) > 65536 {
wbuf.closeMsg()
_, err = ct.conn.conn.Write(wbuf.buf)
if err != nil {
ct.conn.die(err)
return 0, err
}
// Directly manipulate wbuf to reset to reuse the same buffer
wbuf.buf = wbuf.buf[0:5]
wbuf.buf[0] = copyData
wbuf.sizeIdx = 1
}
sentCount++
values, err := ct.rowSrc.Values()
if err != nil {
ct.cancelCopyIn()
return 0, err
}
if len(values) != len(ct.columnNames) {
ct.cancelCopyIn()
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}
wbuf.WriteInt16(int16(len(ct.columnNames)))
for i, val := range values {
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
ct.cancelCopyIn()
return 0, err
}
}
}
if ct.rowSrc.Err() != nil {
ct.cancelCopyIn()
return 0, ct.rowSrc.Err()
}
wbuf.WriteInt16(-1) // terminate the copy stream
wbuf.startMsg(copyDone)
wbuf.closeMsg()
_, err = ct.conn.conn.Write(wbuf.buf)
if err != nil {
ct.conn.die(err)
return 0, err
}
err = ct.waitForReaderDone()
if err != nil {
return 0, err
}
return sentCount, nil
}
func (c *Conn) readUntilCopyInResponse() error {
for {
var t byte
var r *msgReader
t, r, err := c.rxMsg()
if err != nil {
return err
}
switch t {
case copyInResponse:
return nil
default:
err = c.processContextFreeMsg(t, r)
if err != nil {
return err
}
}
}
}
func (ct *copyTo) cancelCopyIn() error {
wbuf := newWriteBuf(ct.conn, copyFail)
wbuf.WriteCString("client error: abort")
wbuf.closeMsg()
_, err := ct.conn.conn.Write(wbuf.buf)
if err != nil {
ct.conn.die(err)
return err
}
return nil
}
// CopyTo uses the PostgreSQL copy protocol to perform bulk data insertion.
// It returns the number of rows copied and an error.
//
// CopyTo requires all values use the binary format. Almost all types
// implemented by pgx use the binary format by default. Types implementing
// Encoder can only be used if they encode to the binary format.
func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
ct := &copyTo{
conn: c,
tableName: tableName,
columnNames: columnNames,
rowSrc: rowSrc,
readerErrChan: make(chan error),
}
return ct.run()
}

428
copy_to_test.go Normal file
View File

@ -0,0 +1,428 @@
package pgx_test
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/jackc/pgx"
)
func TestConnCopyToSmall(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g timestamptz
)`)
inputRows := [][]interface{}{
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
{nil, nil, nil, nil, nil, nil, nil},
}
copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyTo: %v", err)
}
if copyCount != len(inputRows) {
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
ensureConnValid(t, conn)
}
func TestConnCopyToLarge(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g timestamptz,
h bytea
)`)
inputRows := [][]interface{}{}
for i := 0; i < 10000; i++ {
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
}
copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyTo: %v", err)
}
if copyCount != len(inputRows) {
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal")
}
ensureConnValid(t, conn)
}
func TestConnCopyToJSON(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
for _, oid := range []pgx.OID{pgx.JSONOID, pgx.JSONBOID} {
if _, ok := conn.PgTypes[oid]; !ok {
return // No JSON/JSONB type -- must be running against old PostgreSQL
}
}
mustExec(t, conn, `create temporary table foo(
a json,
b jsonb
)`)
inputRows := [][]interface{}{
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
{nil, nil},
}
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyTo: %v", err)
}
if copyCount != len(inputRows) {
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
ensureConnValid(t, conn)
}
func TestConnCopyToFailServerSideMidway(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int4,
b varchar not null
)`)
inputRows := [][]interface{}{
{int32(1), "abc"},
{int32(2), nil}, // this row should trigger a failure
{int32(3), "def"},
}
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
if err == nil {
t.Errorf("Expected CopyTo return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
}
if copyCount != 0 {
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
}
ensureConnValid(t, conn)
}
type failSource struct {
count int
}
func (fs *failSource) Next() bool {
time.Sleep(time.Millisecond * 100)
fs.count++
return fs.count < 100
}
func (fs *failSource) Values() ([]interface{}, error) {
if fs.count == 3 {
return []interface{}{nil}, nil
}
return []interface{}{make([]byte, 100000)}, nil
}
func (fs *failSource) Err() error {
return nil
}
func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a bytea not null
)`)
startTime := time.Now()
copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{})
if err == nil {
t.Errorf("Expected CopyTo return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
}
if copyCount != 0 {
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
}
endTime := time.Now()
copyTime := endTime.Sub(startTime)
if copyTime > time.Second {
t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
}
ensureConnValid(t, conn)
}
type clientFailSource struct {
count int
err error
}
func (cfs *clientFailSource) Next() bool {
cfs.count++
return cfs.count < 100
}
func (cfs *clientFailSource) Values() ([]interface{}, error) {
if cfs.count == 3 {
cfs.err = fmt.Errorf("client error")
return nil, cfs.err
}
return []interface{}{make([]byte, 100000)}, nil
}
func (cfs *clientFailSource) Err() error {
return cfs.err
}
func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a bytea not null
)`)
copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{})
if err == nil {
t.Errorf("Expected CopyTo return error, but it did not")
}
if copyCount != 0 {
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
}
ensureConnValid(t, conn)
}
type clientFinalErrSource struct {
count int
}
func (cfs *clientFinalErrSource) Next() bool {
cfs.count++
return cfs.count < 5
}
func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
return []interface{}{make([]byte, 100000)}, nil
}
func (cfs *clientFinalErrSource) Err() error {
return fmt.Errorf("final error")
}
func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a bytea not null
)`)
copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{})
if err == nil {
t.Errorf("Expected CopyTo return error, but it did not")
}
if copyCount != 0 {
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
}
ensureConnValid(t, conn)
}

118
doc.go
View File

@ -50,7 +50,7 @@ pgx also implements QueryRow in the same style as database/sql.
return err
}
Use exec to execute a query that does not return a result set.
Use Exec to execute a query that does not return a result set.
commandTag, err := conn.Exec("delete from widgets where id=$1", 42)
if err != nil {
@ -81,43 +81,39 @@ releasing connections when you do not need that level of control.
return err
}
Transactions
Base Type Mapping
Transactions are started by calling Begin or BeginIso. The BeginIso variant
creates a transaction with a specified isolation level.
pgx maps between all common base types directly between Go and PostgreSQL. In
particular:
tx, err := conn.Begin()
if err != nil {
return err
}
// Rollback is safe to call even if the tx is already closed, so if
// the tx commits successfully, this is a no-op
defer tx.Rollback()
Go PostgreSQL
-----------------------
string varchar
text
_, err = tx.Exec("insert into foo(id) values (1)")
if err != nil {
return err
}
// Integers are automatically be converted to any other integer type if
// it can be done without overflow or underflow.
int8
int16 smallint
int32 int
int64 bigint
int
uint8
uint16
uint32
uint64
uint
err = tx.Commit()
if err != nil {
return err
}
// Floats are strict and do not automatically convert like integers.
float32 float4
float64 float8
Listen and Notify
time.Time date
timestamp
timestamptz
pgx can listen to the PostgreSQL notification system with the
WaitForNotification function. It takes a maximum time to wait for a
notification.
[]byte bytea
err := conn.Listen("channelname")
if err != nil {
return nil
}
if notification, err := conn.WaitForNotification(time.Second); err != nil {
// do something with notification
}
Null Mapping
@ -136,7 +132,7 @@ Array Mapping
pgx maps between int16, int32, int64, float32, float64, and string Go slices
and the equivalent PostgreSQL array type. Go slices of native types do not
support nulls, so if a PostgreSQL array that contains a slice is read into a
support nulls, so if a PostgreSQL array that contains a null is read into a
native Go slice an error will occur.
Hstore Mapping
@ -192,6 +188,64 @@ the raw bytes returned by PostgreSQL. This can be especially useful for reading
varchar, text, json, and jsonb values directly into a []byte and avoiding the
type conversion from string.
Transactions
Transactions are started by calling Begin or BeginIso. The BeginIso variant
creates a transaction with a specified isolation level.
tx, err := conn.Begin()
if err != nil {
return err
}
// Rollback is safe to call even if the tx is already closed, so if
// the tx commits successfully, this is a no-op
defer tx.Rollback()
_, err = tx.Exec("insert into foo(id) values (1)")
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
Copy Protocol
Use CopyTo to efficiently insert multiple rows at a time using the PostgreSQL
copy protocol. CopyTo accepts a CopyToSource interface. If the data is already
in a [][]interface{} use CopyToRows to wrap it in a CopyToSource interface. Or
implement CopyToSource to avoid buffering the entire data set in memory.
rows := [][]interface{}{
{"John", "Smith", int32(36)},
{"Jane", "Doe", int32(29)},
}
copyCount, err := conn.CopyTo(
"people",
[]string{"first_name", "last_name", "age"},
pgx.CopyToRows(rows),
)
CopyTo can be faster than an insert with as few as 5 rows.
Listen and Notify
pgx can listen to the PostgreSQL notification system with the
WaitForNotification function. It takes a maximum time to wait for a
notification.
err := conn.Listen("channelname")
if err != nil {
return nil
}
if notification, err := conn.WaitForNotification(time.Second); err != nil {
// do something with notification
}
TLS
The pgx ConnConfig struct has a TLSConfig field. If this field is

View File

@ -4,8 +4,6 @@ import (
"encoding/binary"
)
type fastpathArg []byte
func newFastpath(cn *Conn) *fastpath {
return &fastpath{cn: cn, fns: make(map[string]OID)}
}

View File

@ -15,7 +15,6 @@ const (
hsVal
hsNul
hsNext
hsEnd
)
type hstoreParser struct {

View File

@ -85,23 +85,23 @@ func TestNullHstoreTranscode(t *testing.T) {
{pgx.NullHstore{}, "null"},
{pgx.NullHstore{Valid: true}, "empty"},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar", Valid: true}},
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}},
Valid: true},
"single key/value"},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar", Valid: true}, "baz": pgx.NullString{String: "quz", Valid: true}},
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}},
Valid: true},
"multiple key/values"},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"NULL": pgx.NullString{String: "bar", Valid: true}},
Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}},
Valid: true},
`string "NULL" key`},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "NULL", Valid: true}},
Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}},
Valid: true},
`string "NULL" value`},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "", Valid: false}},
Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}},
Valid: true},
`NULL value`},
}
@ -120,36 +120,36 @@ func TestNullHstoreTranscode(t *testing.T) {
}
for _, sst := range specialStringTests {
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{sst.input + "foo": pgx.NullString{String: "bar", Valid: true}},
Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}},
Valid: true},
"key with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": pgx.NullString{String: "bar", Valid: true}},
Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}},
Valid: true},
"key with " + sst.description + " in middle"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo" + sst.input: pgx.NullString{String: "bar", Valid: true}},
Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}},
Valid: true},
"key with " + sst.description + " at end"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{sst.input: pgx.NullString{String: "bar", Valid: true}},
Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}},
Valid: true},
"key is " + sst.description})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: sst.input + "bar", Valid: true}},
Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}},
Valid: true},
"value with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar" + sst.input + "bar", Valid: true}},
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}},
Valid: true},
"value with " + sst.description + " in middle"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar" + sst.input, Valid: true}},
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}},
Valid: true},
"value with " + sst.description + " at end"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: sst.input, Valid: true}},
Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}},
Valid: true},
"value is " + sst.description})
}

View File

@ -25,6 +25,10 @@ const (
noData = 'n'
closeComplete = '3'
flush = 'H'
copyInResponse = 'G'
copyData = 'd'
copyFail = 'f'
copyDone = 'c'
)
type startupMessage struct {
@ -35,10 +39,10 @@ func newStartupMessage() *startupMessage {
return &startupMessage{map[string]string{}}
}
func (self *startupMessage) Bytes() (buf []byte) {
func (s *startupMessage) Bytes() (buf []byte) {
buf = make([]byte, 8, 128)
binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber))
for key, value := range self.options {
for key, value := range s.options {
buf = append(buf, key...)
buf = append(buf, 0)
buf = append(buf, value...)
@ -49,8 +53,6 @@ func (self *startupMessage) Bytes() (buf []byte) {
return buf
}
type OID int32
type FieldDescription struct {
Name string
Table OID
@ -85,8 +87,8 @@ type PgError struct {
Routine string
}
func (self PgError) Error() string {
return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")"
func (pe PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
func newWriteBuf(c *Conn, t byte) *WriteBuf {
@ -95,7 +97,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf {
return &c.writeBuf
}
// WrifeBuf is used build messages to send to the PostgreSQL server. It is used
// WriteBuf is used build messages to send to the PostgreSQL server. It is used
// by the Encoder interface when implementing custom encoders.
type WriteBuf struct {
buf []byte
@ -128,12 +130,24 @@ func (wb *WriteBuf) WriteInt16(n int16) {
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) WriteUint16(n uint16) {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, n)
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) WriteInt32(n int32) {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(n))
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) WriteUint32(n uint32) {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, n)
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) WriteInt64(n int64) {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(n))

View File

@ -10,7 +10,6 @@ import (
// msgReader is a helper that reads values from a PostgreSQL message.
type msgReader struct {
reader *bufio.Reader
buf [128]byte
msgBytesRemaining int32
err error
log func(lvl int, msg string, ctx ...interface{})
@ -47,10 +46,15 @@ func (r *msgReader) rxMsg() (byte, error) {
}
}
b := r.buf[0:5]
_, err := io.ReadFull(r.reader, b)
b, err := r.reader.Peek(5)
if err != nil {
r.fatal(err)
return 0, err
}
msgType := b[0]
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4
return b[0], err
r.reader.Discard(5)
return msgType, nil
}
func (r *msgReader) readByte() byte {
@ -58,7 +62,7 @@ func (r *msgReader) readByte() byte {
return 0
}
r.msgBytesRemaining -= 1
r.msgBytesRemaining--
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return 0
@ -88,8 +92,7 @@ func (r *msgReader) readInt16() int16 {
return 0
}
b := r.buf[0:2]
_, err := io.ReadFull(r.reader, b)
b, err := r.reader.Peek(2)
if err != nil {
r.fatal(err)
return 0
@ -97,6 +100,8 @@ func (r *msgReader) readInt16() int16 {
n := int16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
@ -115,8 +120,7 @@ func (r *msgReader) readInt32() int32 {
return 0
}
b := r.buf[0:4]
_, err := io.ReadFull(r.reader, b)
b, err := r.reader.Peek(4)
if err != nil {
r.fatal(err)
return 0
@ -124,6 +128,8 @@ func (r *msgReader) readInt32() int32 {
n := int32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
@ -131,6 +137,62 @@ func (r *msgReader) readInt32() int32 {
return n
}
func (r *msgReader) readUint16() uint16 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(2)
if err != nil {
r.fatal(err)
return 0
}
n := uint16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
return n
}
func (r *msgReader) readUint32() uint32 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(4)
if err != nil {
r.fatal(err)
return 0
}
n := uint32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
return n
}
func (r *msgReader) readInt64() int64 {
if r.err != nil {
return 0
@ -142,8 +204,7 @@ func (r *msgReader) readInt64() int64 {
return 0
}
b := r.buf[0:8]
_, err := io.ReadFull(r.reader, b)
b, err := r.reader.Peek(8)
if err != nil {
r.fatal(err)
return 0
@ -151,6 +212,8 @@ func (r *msgReader) readInt64() int64 {
n := int64(binary.BigEndian.Uint64(b))
r.reader.Discard(8)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
@ -190,32 +253,34 @@ func (r *msgReader) readCString() string {
}
// readString reads count bytes and returns as string
func (r *msgReader) readString(count int32) string {
func (r *msgReader) readString(countI32 int32) string {
if r.err != nil {
return ""
}
r.msgBytesRemaining -= count
r.msgBytesRemaining -= countI32
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return ""
}
var b []byte
if count <= int32(len(r.buf)) {
b = r.buf[0:int(count)]
count := int(countI32)
var s string
if r.reader.Buffered() >= count {
buf, _ := r.reader.Peek(count)
s = string(buf)
r.reader.Discard(count)
} else {
b = make([]byte, int(count))
buf := make([]byte, count)
_, err := io.ReadFull(r.reader, buf)
if err != nil {
r.fatal(err)
return ""
}
s = string(buf)
}
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.fatal(err)
return ""
}
s := string(b)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
}

View File

@ -298,13 +298,17 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if vr.Type().DataType == JSONOID || vr.Type().DataType == JSONBOID {
} else if vr.Type().DataType == JSONOID {
// Because the argument passed to decodeJSON will escape the heap.
// This allows d to be stack allocated and only copied to the heap when
// we actually are decoding JSON. This saves one memory allocation per
// row.
d2 := d
decodeJSON(vr, &d2)
} else if vr.Type().DataType == JSONBOID {
// Same trick as above for getting stack allocation
d2 := d
decodeJSONB(vr, &d2)
} else {
if err := Decode(vr, d); err != nil {
rows.Fatal(scanArgError{col: i, err: err})
@ -393,7 +397,7 @@ func (rows *Rows) Values() ([]interface{}, error) {
values = append(values, d)
case JSONBOID:
var d interface{}
decodeJSON(vr, &d)
decodeJSONB(vr, &d)
values = append(values, d)
default:
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))

View File

@ -3,11 +3,12 @@ package pgx_test
import (
"bytes"
"database/sql"
"github.com/jackc/pgx"
"strings"
"testing"
"time"
"github.com/jackc/pgx"
"github.com/shopspring/decimal"
)
@ -784,7 +785,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) {
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
}
if bytes.Compare(actual, tt.expected) != 0 {
if !bytes.Equal(actual, tt.expected) {
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
}
@ -1281,7 +1282,7 @@ func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
}
var num decimal.Decimal
err = conn.QueryRow("select $1::decimal", expected).Scan(&num)
err = conn.QueryRow("select $1::decimal", &expected).Scan(&num)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}

4
sql.go
View File

@ -14,7 +14,7 @@ func init() {
placeholders = make([]string, 64)
for i := 1; i < 64; i++ {
placeholders[i] = "$" + strconv.FormatInt(int64(i), 10)
placeholders[i] = "$" + strconv.Itoa(i)
}
}
@ -25,5 +25,5 @@ func (qa *QueryArgs) Append(v interface{}) string {
if len(*qa) < len(placeholders) {
return placeholders[len(*qa)]
}
return "$" + strconv.FormatInt(int64(len(*qa)), 10)
return "$" + strconv.Itoa(len(*qa))
}

View File

@ -1,16 +1,17 @@
package pgx_test
import (
"github.com/jackc/pgx"
"strconv"
"testing"
"github.com/jackc/pgx"
)
func TestQueryArgs(t *testing.T) {
var qa pgx.QueryArgs
for i := 1; i < 512; i++ {
expectedPlaceholder := "$" + strconv.FormatInt(int64(i), 10)
expectedPlaceholder := "$" + strconv.Itoa(i)
placeholder := qa.Append(i)
if placeholder != expectedPlaceholder {
t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder)

9
tx.go
View File

@ -158,6 +158,15 @@ func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row {
return (*Row)(rows)
}
// CopyTo delegates to the underlying *Conn
func (tx *Tx) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
if tx.status != TxStatusInProgress {
return 0, ErrTxClosed
}
return tx.conn.CopyTo(tableName, columnNames, rowSrc)
}
// Conn returns the *Conn this transaction is using.
func (tx *Tx) Conn() *Conn {
return tx.conn

View File

@ -60,6 +60,20 @@ func (r *ValueReader) ReadInt16() int16 {
return r.mr.readInt16()
}
func (r *ValueReader) ReadUint16() uint16 {
if r.err != nil {
return 0
}
r.valueBytesRemaining -= 2
if r.valueBytesRemaining < 0 {
r.Fatal(errors.New("read past end of value"))
return 0
}
return r.mr.readUint16()
}
func (r *ValueReader) ReadInt32() int32 {
if r.err != nil {
return 0
@ -74,6 +88,20 @@ func (r *ValueReader) ReadInt32() int32 {
return r.mr.readInt32()
}
func (r *ValueReader) ReadUint32() uint32 {
if r.err != nil {
return 0
}
r.valueBytesRemaining -= 4
if r.valueBytesRemaining < 0 {
r.Fatal(errors.New("read past end of value"))
return 0
}
return r.mr.readUint32()
}
func (r *ValueReader) ReadInt64() int64 {
if r.err != nil {
return 0
@ -89,7 +117,7 @@ func (r *ValueReader) ReadInt64() int64 {
}
func (r *ValueReader) ReadOID() OID {
return OID(r.ReadInt32())
return OID(r.ReadUint32())
}
// ReadString reads count bytes and returns as string

885
values.go

File diff suppressed because it is too large Load Diff

View File

@ -88,67 +88,74 @@ func TestJSONAndJSONBTranscode(t *testing.T) {
if _, ok := conn.PgTypes[oid]; !ok {
return // No JSON/JSONB type -- must be running against old PostgreSQL
}
typename := conn.PgTypes[oid].Name
testJSONString(t, conn, typename)
testJSONStringPointer(t, conn, typename)
testJSONSingleLevelStringMap(t, conn, typename)
testJSONNestedMap(t, conn, typename)
testJSONStringArray(t, conn, typename)
testJSONInt64Array(t, conn, typename)
testJSONInt16ArrayFailureDueToOverflow(t, conn, typename)
testJSONStruct(t, conn, typename)
for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} {
pgtype := conn.PgTypes[oid]
pgtype.DefaultFormat = format
conn.PgTypes[oid] = pgtype
typename := conn.PgTypes[oid].Name
testJSONString(t, conn, typename, format)
testJSONStringPointer(t, conn, typename, format)
testJSONSingleLevelStringMap(t, conn, typename, format)
testJSONNestedMap(t, conn, typename, format)
testJSONStringArray(t, conn, typename, format)
testJSONInt64Array(t, conn, typename, format)
testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format)
testJSONStruct(t, conn, typename, format)
}
}
}
func testJSONString(t *testing.T, conn *pgx.Conn, typename string) {
func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) {
input := `{"key": "value"}`
expectedOutput := map[string]string{"key": "value"}
var output map[string]string
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
return
}
if !reflect.DeepEqual(expectedOutput, output) {
t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output)
t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output)
return
}
}
func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) {
func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) {
input := `{"key": "value"}`
expectedOutput := map[string]string{"key": "value"}
var output map[string]string
err := conn.QueryRow("select $1::"+typename, &input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
return
}
if !reflect.DeepEqual(expectedOutput, output) {
t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output)
t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output)
return
}
}
func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) {
func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) {
input := map[string]string{"key": "value"}
var output map[string]string
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
return
}
if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output)
t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output)
return
}
}
func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) {
func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) {
input := map[string]interface{}{
"name": "Uncanny",
"stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)},
@ -157,52 +164,52 @@ func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) {
var output map[string]interface{}
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
return
}
if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output)
t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output)
return
}
}
func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) {
func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) {
input := []string{"foo", "bar", "baz"}
var output []string
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
}
if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output)
t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output)
}
}
func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) {
func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) {
input := []int64{1, 2, 234432}
var output []int64
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
}
if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output)
t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output)
}
}
func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) {
func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) {
input := []int{1, 2, 234432}
var output []int16
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" {
t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err)
t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err)
}
}
func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) {
func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) {
type person struct {
Name string `json:"name"`
Age int `json:"age"`
@ -217,11 +224,11 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) {
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
}
if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output)
t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output)
}
}
@ -561,6 +568,13 @@ func TestNullX(t *testing.T) {
s pgx.NullString
i16 pgx.NullInt16
i32 pgx.NullInt32
c pgx.NullChar
a pgx.NullAclItem
n pgx.NullName
oid pgx.NullOID
xid pgx.NullXid
cid pgx.NullCid
tid pgx.NullTid
i64 pgx.NullInt64
f32 pgx.NullFloat32
f64 pgx.NullFloat64
@ -582,6 +596,27 @@ func TestNullX(t *testing.T) {
{"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}},
{"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}},
{"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}},
{"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 1, Valid: true}}},
{"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 0, Valid: false}}},
{"select $1::oid", []interface{}{pgx.NullOID{OID: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 4294967295, Valid: true}}},
{"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}},
{"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}},
{"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}},
{"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}},
{"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}},
{"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}},
{"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}},
{"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}},
{"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}},
{"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}},
// A tricky (and valid) aclitem can still be used, especially with Go's useful backticks
{"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}},
{"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}},
{"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}},
{"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}},
{"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}},
{"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}},
{"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}},
{"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}},
{"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}},
{"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}},
@ -615,6 +650,52 @@ func TestNullX(t *testing.T) {
}
}
func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) {
if !reflect.DeepEqual(query, scan) {
t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan)
}
}
func TestAclArrayDecoding(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := "select $1::aclitem[]"
var scan []pgx.AclItem
tests := []struct {
query []pgx.AclItem
}{
{
[]pgx.AclItem{},
},
{
[]pgx.AclItem{"=r/postgres"},
},
{
[]pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"},
},
{
[]pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`},
},
}
for i, tt := range tests {
err := conn.QueryRow(sql, tt.query).Scan(&scan)
if err != nil {
// t.Errorf(`%d. error reading array: %v`, i, err)
t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query)
if pgerr, ok := err.(pgx.PgError); ok {
t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail)
}
continue
}
assertAclItemSlicesEqual(t, tt.query, scan)
ensureConnValid(t, conn)
}
}
func TestArrayDecoding(t *testing.T) {
t.Parallel()
@ -630,7 +711,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::bool[]", []bool{true, false, true}, &[]bool{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]bool))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]bool))) {
t.Errorf("failed to encode bool[]")
}
},
@ -638,7 +719,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]int16))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]int16))) {
t.Errorf("failed to encode smallint[]")
}
},
@ -646,7 +727,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]uint16))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]uint16))) {
t.Errorf("failed to encode smallint[]")
}
},
@ -654,7 +735,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::int[]", []int32{2, 4, 484}, &[]int32{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]int32))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]int32))) {
t.Errorf("failed to encode int[]")
}
},
@ -662,7 +743,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]uint32))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]uint32))) {
t.Errorf("failed to encode int[]")
}
},
@ -670,7 +751,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]int64))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]int64))) {
t.Errorf("failed to encode bigint[]")
}
},
@ -678,7 +759,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]uint64))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]uint64))) {
t.Errorf("failed to encode bigint[]")
}
},
@ -686,7 +767,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]string))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]string))) {
t.Errorf("failed to encode text[]")
}
},
@ -694,7 +775,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) {
t.Errorf("failed to encode time.Time[] to timestamp[]")
}
},
@ -702,7 +783,7 @@ func TestArrayDecoding(t *testing.T) {
{
"select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) {
t.Errorf("failed to encode time.Time[] to timestamptz[]")
}
},
@ -718,7 +799,7 @@ func TestArrayDecoding(t *testing.T) {
for i := range queryBytesSliceSlice {
qb := queryBytesSliceSlice[i]
sb := scanBytesSliceSlice[i]
if bytes.Compare(qb, sb) != 0 {
if !bytes.Equal(qb, sb) {
t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb)
}
}
@ -1075,11 +1156,11 @@ func TestRowDecode(t *testing.T) {
expected []interface{}
}{
{
"select row(1, 'cat', '2015-01-01 08:12:42'::timestamptz)",
"select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)",
[]interface{}{
int32(1),
"cat",
time.Date(2015, 1, 1, 8, 12, 42, 0, time.Local),
time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(),
},
},
}