mirror of https://github.com/jackc/pgx.git
Merge branch 'master' into v3-experimental
commit
93e5c68f69
12
.travis.yml
12
.travis.yml
|
@ -1,8 +1,8 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- 1.6.2
|
||||
- 1.5.2
|
||||
- 1.7.1
|
||||
- 1.6.3
|
||||
- tip
|
||||
|
||||
# Derived from https://github.com/lib/pq/blob/master/.travis.yml
|
||||
|
@ -29,6 +29,8 @@ env:
|
|||
- PGVERSION=9.2
|
||||
- PGVERSION=9.1
|
||||
|
||||
# The tricky test user, below, has to actually exist so that it can be used in a test
|
||||
# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
|
||||
before_script:
|
||||
- mv conn_config_test.go.travis conn_config_test.go
|
||||
- psql -U postgres -c 'create database pgx_test'
|
||||
|
@ -36,6 +38,12 @@ before_script:
|
|||
- psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
|
||||
- psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
|
||||
- psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
|
||||
- psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
|
||||
|
||||
install:
|
||||
- go get -u github.com/shopspring/decimal
|
||||
- go get -u gopkg.in/inconshreveable/log15.v2
|
||||
- go get -u github.com/jackc/fake
|
||||
|
||||
script:
|
||||
- go test -v -race -short ./...
|
||||
|
|
23
CHANGELOG.md
23
CHANGELOG.md
|
@ -2,6 +2,26 @@
|
|||
|
||||
## Fixes
|
||||
|
||||
* Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood)
|
||||
|
||||
## Features
|
||||
|
||||
* Add xid type support (Manni Wood)
|
||||
* Add cid type support (Manni Wood)
|
||||
* Add tid type support (Manni Wood)
|
||||
* Add "char" type support (Manni Wood)
|
||||
* Add NullOid type (Manni Wood)
|
||||
* Add json/jsonb binary support to allow use with CopyTo
|
||||
* Add named error ErrAcquireTimeout (Alexander Staubo)
|
||||
|
||||
## Compatibility
|
||||
|
||||
* jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work.
|
||||
|
||||
# 2.9.0 (August 26, 2016)
|
||||
|
||||
## Fixes
|
||||
|
||||
* Fix *ConnPool.Deallocate() not deleting prepared statement from map
|
||||
* Fix stdlib not logging unprepared query SQL (Krzysztof Dryś)
|
||||
* Fix Rows.Values() with varchar binary format
|
||||
|
@ -9,12 +29,13 @@
|
|||
|
||||
## Features
|
||||
|
||||
* Add CopyTo
|
||||
* Add PrepareEx
|
||||
* Add basic record to []interface{} decoding
|
||||
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
|
||||
* Decode inet/cidr to net.IP
|
||||
* Encode/decode [][]byte to/from bytea[]
|
||||
* Encode/decode named types whoses underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64
|
||||
* Encode/decode named types whose underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64
|
||||
|
||||
## Performance
|
||||
|
||||
|
|
21
README.md
21
README.md
|
@ -19,6 +19,7 @@ Pgx supports many additional features beyond what is available through database/
|
|||
* Transaction isolation level control
|
||||
* Full TLS connection control
|
||||
* Binary format support for custom types (can be much faster)
|
||||
* Copy protocol support for faster bulk data loads
|
||||
* Logging support
|
||||
* Configurable connection pool with after connect hooks to do arbitrary connection setup
|
||||
* PostgreSQL array to Go slice mapping for integers, floats, and strings
|
||||
|
@ -60,16 +61,23 @@ skip tests for connection types that are not configured.
|
|||
|
||||
### Normal Test Environment
|
||||
|
||||
To setup the normal test environment run the following SQL:
|
||||
To setup the normal test environment, first install these dependencies:
|
||||
|
||||
go get github.com/jackc/fake
|
||||
go get github.com/shopspring/decimal
|
||||
go get gopkg.in/inconshreveable/log15.v2
|
||||
|
||||
Then run the following SQL:
|
||||
|
||||
create user pgx_md5 password 'secret';
|
||||
create user " tricky, ' } "" \ test user " password 'secret';
|
||||
create database pgx_test;
|
||||
|
||||
Connect to database pgx_test and run:
|
||||
|
||||
create extension hstore;
|
||||
|
||||
Next open connection_settings_test.go.example and make a copy without the
|
||||
Next open conn_config_test.go.example and make a copy without the
|
||||
.example. If your PostgreSQL server is accepting connections on 127.0.0.1,
|
||||
then you are done.
|
||||
|
||||
|
@ -98,7 +106,8 @@ If you are developing on Windows with TCP connections:
|
|||
|
||||
## Version Policy
|
||||
|
||||
pgx follows semantic versioning for the documented public API. ```master```
|
||||
branch tracks the latest stable branch (```v2```). Consider using ```import
|
||||
"gopkg.in/jackc/pgx.v2"``` to lock to the ```v2``` branch or use a vendoring
|
||||
tool such as [godep](https://github.com/tools/godep).
|
||||
pgx follows semantic versioning for the documented public API on stable releases. Branch ```v2``` is the latest stable release. ```master``` can contain new features or behavior that will change or be removed before being merged to the stable ```v2``` branch (in practice, this occurs very rarely).
|
||||
|
||||
Consider using a vendoring
|
||||
tool such as [godep](https://github.com/tools/godep) or importing pgx via ```import
|
||||
"gopkg.in/jackc/pgx.v2"``` to lock to the ```v2``` branch.
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEscapeAclItem(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"foo",
|
||||
"foo",
|
||||
},
|
||||
{
|
||||
`foo, "\}`,
|
||||
`foo\, \"\\\}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual, err := escapeAclItem(tt.input)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected error %v", i, err)
|
||||
}
|
||||
|
||||
if actual != tt.expected {
|
||||
t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAclItemArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []AclItem
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
"",
|
||||
[]AclItem{},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"one",
|
||||
[]AclItem{"one"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one"`,
|
||||
[]AclItem{"one"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"one,two,three",
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","two","three"`,
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one",two,"three"`,
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`one,two,"three"`,
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","two",three`,
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","t w o",three`,
|
||||
[]AclItem{"one", "t w o", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","t, w o\"\}\\",three`,
|
||||
[]AclItem{"one", `t, w o"}\`, "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","two",three"`,
|
||||
[]AclItem{"one", "two", `three"`},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","two,"three"`,
|
||||
nil,
|
||||
"unexpected rune after quoted value",
|
||||
},
|
||||
{
|
||||
`"one","two","three`,
|
||||
nil,
|
||||
"unexpected end of quoted value",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual, err := parseAclItemArray(tt.input)
|
||||
|
||||
if err != nil {
|
||||
if tt.errMsg == "" {
|
||||
t.Errorf("%d. Unexpected error %v", i, err)
|
||||
} else if err.Error() != tt.errMsg {
|
||||
t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error())
|
||||
}
|
||||
} else if tt.errMsg != "" {
|
||||
t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, tt.expected) {
|
||||
t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
331
bench_test.go
331
bench_test.go
|
@ -1,6 +1,9 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -397,3 +400,331 @@ func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const benchmarkWriteTableCreateSQL = `drop table if exists t;
|
||||
|
||||
create table t(
|
||||
varchar_1 varchar not null,
|
||||
varchar_2 varchar not null,
|
||||
varchar_null_1 varchar,
|
||||
date_1 date not null,
|
||||
date_null_1 date,
|
||||
int4_1 int4 not null,
|
||||
int4_2 int4 not null,
|
||||
int4_null_1 int4,
|
||||
tstz_1 timestamptz not null,
|
||||
tstz_2 timestamptz,
|
||||
bool_1 bool not null,
|
||||
bool_2 bool not null,
|
||||
bool_3 bool not null
|
||||
);
|
||||
`
|
||||
|
||||
const benchmarkWriteTableInsertSQL = `insert into t(
|
||||
varchar_1,
|
||||
varchar_2,
|
||||
varchar_null_1,
|
||||
date_1,
|
||||
date_null_1,
|
||||
int4_1,
|
||||
int4_2,
|
||||
int4_null_1,
|
||||
tstz_1,
|
||||
tstz_2,
|
||||
bool_1,
|
||||
bool_2,
|
||||
bool_3
|
||||
) values (
|
||||
$1::varchar,
|
||||
$2::varchar,
|
||||
$3::varchar,
|
||||
$4::date,
|
||||
$5::date,
|
||||
$6::int4,
|
||||
$7::int4,
|
||||
$8::int4,
|
||||
$9::timestamptz,
|
||||
$10::timestamptz,
|
||||
$11::bool,
|
||||
$12::bool,
|
||||
$13::bool
|
||||
)`
|
||||
|
||||
type benchmarkWriteTableCopyToSrc struct {
|
||||
count int
|
||||
idx int
|
||||
row []interface{}
|
||||
}
|
||||
|
||||
func (s *benchmarkWriteTableCopyToSrc) Next() bool {
|
||||
s.idx++
|
||||
return s.idx < s.count
|
||||
}
|
||||
|
||||
func (s *benchmarkWriteTableCopyToSrc) Values() ([]interface{}, error) {
|
||||
return s.row, nil
|
||||
}
|
||||
|
||||
func (s *benchmarkWriteTableCopyToSrc) Err() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newBenchmarkWriteTableCopyToSrc(count int) pgx.CopyToSource {
|
||||
return &benchmarkWriteTableCopyToSrc{
|
||||
count: count,
|
||||
row: []interface{}{
|
||||
"varchar_1",
|
||||
"varchar_2",
|
||||
pgx.NullString{},
|
||||
time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
pgx.NullTime{},
|
||||
1,
|
||||
2,
|
||||
pgx.NullInt32{},
|
||||
time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
mustExec(b, conn, benchmarkWriteTableCreateSQL)
|
||||
_, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
src := newBenchmarkWriteTableCopyToSrc(n)
|
||||
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for src.Next() {
|
||||
values, _ := src.Values()
|
||||
if _, err = tx.Exec("insert_t", values...); err != nil {
|
||||
b.Fatalf("Exec unexpectedly failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// note this function is only used for benchmarks -- it doesn't escape tableName
|
||||
// or columnNames
|
||||
func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyToSource) (int, error) {
|
||||
maxRowsPerInsert := 65535 / len(columnNames)
|
||||
rowsThisInsert := 0
|
||||
rowCount := 0
|
||||
|
||||
sqlBuf := &bytes.Buffer{}
|
||||
args := make(pgx.QueryArgs, 0)
|
||||
|
||||
resetQuery := func() {
|
||||
sqlBuf.Reset()
|
||||
fmt.Fprintf(sqlBuf, "insert into %s(%s) values", tableName, strings.Join(columnNames, ", "))
|
||||
|
||||
args = args[0:0]
|
||||
|
||||
rowsThisInsert = 0
|
||||
}
|
||||
resetQuery()
|
||||
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for rowSrc.Next() {
|
||||
if rowsThisInsert > 0 {
|
||||
sqlBuf.WriteByte(',')
|
||||
}
|
||||
|
||||
sqlBuf.WriteByte('(')
|
||||
|
||||
values, err := rowSrc.Values()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for i, val := range values {
|
||||
if i > 0 {
|
||||
sqlBuf.WriteByte(',')
|
||||
}
|
||||
sqlBuf.WriteString(args.Append(val))
|
||||
}
|
||||
|
||||
sqlBuf.WriteByte(')')
|
||||
|
||||
rowsThisInsert++
|
||||
|
||||
if rowsThisInsert == maxRowsPerInsert {
|
||||
_, err := tx.Exec(sqlBuf.String(), args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
rowCount += rowsThisInsert
|
||||
resetQuery()
|
||||
}
|
||||
}
|
||||
|
||||
if rowsThisInsert > 0 {
|
||||
_, err := tx.Exec(sqlBuf.String(), args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
rowCount += rowsThisInsert
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return rowCount, nil
|
||||
|
||||
}
|
||||
|
||||
func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
mustExec(b, conn, benchmarkWriteTableCreateSQL)
|
||||
_, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
src := newBenchmarkWriteTableCopyToSrc(n)
|
||||
|
||||
_, err := multiInsert(conn, "t",
|
||||
[]string{"varchar_1",
|
||||
"varchar_2",
|
||||
"varchar_null_1",
|
||||
"date_1",
|
||||
"date_null_1",
|
||||
"int4_1",
|
||||
"int4_2",
|
||||
"int4_null_1",
|
||||
"tstz_1",
|
||||
"tstz_2",
|
||||
"bool_1",
|
||||
"bool_2",
|
||||
"bool_3"},
|
||||
src)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
mustExec(b, conn, benchmarkWriteTableCreateSQL)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
src := newBenchmarkWriteTableCopyToSrc(n)
|
||||
|
||||
_, err := conn.CopyTo("t",
|
||||
[]string{"varchar_1",
|
||||
"varchar_2",
|
||||
"varchar_null_1",
|
||||
"date_1",
|
||||
"date_null_1",
|
||||
"int4_1",
|
||||
"int4_2",
|
||||
"int4_null_1",
|
||||
"tstz_1",
|
||||
"tstz_2",
|
||||
"bool_1",
|
||||
"bool_2",
|
||||
"bool_3"},
|
||||
src)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWrite5RowsViaInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaInsert(b, 5)
|
||||
}
|
||||
|
||||
func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 5)
|
||||
}
|
||||
|
||||
func BenchmarkWrite5RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 5)
|
||||
}
|
||||
|
||||
func BenchmarkWrite10RowsViaInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaInsert(b, 10)
|
||||
}
|
||||
|
||||
func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 10)
|
||||
}
|
||||
|
||||
func BenchmarkWrite10RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 10)
|
||||
}
|
||||
|
||||
func BenchmarkWrite100RowsViaInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaInsert(b, 100)
|
||||
}
|
||||
|
||||
func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 100)
|
||||
}
|
||||
|
||||
func BenchmarkWrite100RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 100)
|
||||
}
|
||||
|
||||
func BenchmarkWrite1000RowsViaInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaInsert(b, 1000)
|
||||
}
|
||||
|
||||
func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 1000)
|
||||
}
|
||||
|
||||
func BenchmarkWrite1000RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 1000)
|
||||
}
|
||||
|
||||
func BenchmarkWrite10000RowsViaInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaInsert(b, 10000)
|
||||
}
|
||||
|
||||
func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
|
||||
benchmarkWriteNRowsViaMultiInsert(b, 10000)
|
||||
}
|
||||
|
||||
func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 10000)
|
||||
}
|
||||
|
|
61
conn.go
61
conn.go
|
@ -63,8 +63,8 @@ type Conn struct {
|
|||
logLevel int
|
||||
mr msgReader
|
||||
fp *fastpath
|
||||
pgsql_af_inet *byte
|
||||
pgsql_af_inet6 *byte
|
||||
pgsqlAfInet *byte
|
||||
pgsqlAfInet6 *byte
|
||||
busy bool
|
||||
poolResetCount int
|
||||
preallocatedRows []Rows
|
||||
|
@ -145,7 +145,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
return connect(config, nil, nil, nil)
|
||||
}
|
||||
|
||||
func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgsql_af_inet6 *byte) (c *Conn, err error) {
|
||||
func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) {
|
||||
c = new(Conn)
|
||||
|
||||
c.config = config
|
||||
|
@ -157,13 +157,13 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgs
|
|||
}
|
||||
}
|
||||
|
||||
if pgsql_af_inet != nil {
|
||||
c.pgsql_af_inet = new(byte)
|
||||
*c.pgsql_af_inet = *pgsql_af_inet
|
||||
if pgsqlAfInet != nil {
|
||||
c.pgsqlAfInet = new(byte)
|
||||
*c.pgsqlAfInet = *pgsqlAfInet
|
||||
}
|
||||
if pgsql_af_inet6 != nil {
|
||||
c.pgsql_af_inet6 = new(byte)
|
||||
*c.pgsql_af_inet6 = *pgsql_af_inet6
|
||||
if pgsqlAfInet6 != nil {
|
||||
c.pgsqlAfInet6 = new(byte)
|
||||
*c.pgsqlAfInet6 = *pgsqlAfInet6
|
||||
}
|
||||
|
||||
if c.config.LogLevel != 0 {
|
||||
|
@ -209,12 +209,21 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgs
|
|||
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
||||
}
|
||||
err = c.connect(config, network, address, config.TLSConfig)
|
||||
if err != nil && config.UseFallbackTLS {
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, fmt.Sprintf("Connect with TLSConfig failed, trying FallbackTLSConfig: %v", err))
|
||||
}
|
||||
err = c.connect(config, network, address, config.FallbackTLSConfig)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, fmt.Sprintf("Connect failed: %v", err))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -222,23 +231,14 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgs
|
|||
}
|
||||
|
||||
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
||||
}
|
||||
c.conn, err = c.config.Dial(network, address)
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, fmt.Sprintf("Connection failed: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if c != nil && err != nil {
|
||||
c.conn.Close()
|
||||
c.alive = false
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -253,9 +253,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
c.log(LogLevelDebug, "Starting TLS handshake")
|
||||
}
|
||||
if err := c.startTLS(tlsConfig); err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, fmt.Sprintf("TLS failed: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -315,7 +312,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
}
|
||||
}
|
||||
|
||||
if c.pgsql_af_inet == nil || c.pgsql_af_inet6 == nil {
|
||||
if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil {
|
||||
err = c.loadInetConstants()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -372,8 +369,8 @@ func (c *Conn) loadInetConstants() error {
|
|||
return err
|
||||
}
|
||||
|
||||
c.pgsql_af_inet = &ipv4[0]
|
||||
c.pgsql_af_inet6 = &ipv6[0]
|
||||
c.pgsqlAfInet = &ipv4[0]
|
||||
c.pgsqlAfInet6 = &ipv6[0]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -430,7 +427,7 @@ func ParseURI(uri string) (ConnConfig, error) {
|
|||
}
|
||||
|
||||
ignoreKeys := map[string]struct{}{
|
||||
"sslmode": struct{}{},
|
||||
"sslmode": {},
|
||||
}
|
||||
|
||||
cp.RuntimeParams = make(map[string]string)
|
||||
|
@ -446,7 +443,7 @@ func ParseURI(uri string) (ConnConfig, error) {
|
|||
return cp, nil
|
||||
}
|
||||
|
||||
var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
||||
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
||||
|
||||
// ParseDSN parses a database DSN (data source name) into a ConnConfig
|
||||
//
|
||||
|
@ -462,7 +459,7 @@ var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
|||
func ParseDSN(s string) (ConnConfig, error) {
|
||||
var cp ConnConfig
|
||||
|
||||
m := dsn_regexp.FindAllStringSubmatch(s, -1)
|
||||
m := dsnRegexp.FindAllStringSubmatch(s, -1)
|
||||
|
||||
var sslmode string
|
||||
|
||||
|
@ -477,11 +474,11 @@ func ParseDSN(s string) (ConnConfig, error) {
|
|||
case "host":
|
||||
cp.Host = b[2]
|
||||
case "port":
|
||||
if p, err := strconv.ParseUint(b[2], 10, 16); err != nil {
|
||||
p, err := strconv.ParseUint(b[2], 10, 16)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
} else {
|
||||
cp.Port = uint16(p)
|
||||
}
|
||||
cp.Port = uint16(p)
|
||||
case "dbname":
|
||||
cp.Database = b[2]
|
||||
case "sslmode":
|
||||
|
@ -627,7 +624,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
|
||||
if opts != nil {
|
||||
if len(opts.ParameterOIDs) > 65535 {
|
||||
return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)))
|
||||
return nil, fmt.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs))
|
||||
}
|
||||
wbuf.WriteInt16(int16(len(opts.ParameterOIDs)))
|
||||
for _, oid := range opts.ParameterOIDs {
|
||||
|
@ -917,7 +914,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
switch oid {
|
||||
case BoolOID, ByteaOID, Int2OID, Int4OID, Int8OID, Float4OID, Float8OID, TimestampTzOID, TimestampTzArrayOID, TimestampOID, TimestampArrayOID, DateOID, BoolArrayOID, ByteaArrayOID, Int2ArrayOID, Int4ArrayOID, Int8ArrayOID, Float4ArrayOID, Float8ArrayOID, TextArrayOID, VarcharArrayOID, OIDOID, InetOID, CidrOID, InetArrayOID, CidrArrayOID, RecordOID:
|
||||
case BoolOID, ByteaOID, Int2OID, Int4OID, Int8OID, Float4OID, Float8OID, TimestampTzOID, TimestampTzArrayOID, TimestampOID, TimestampArrayOID, DateOID, BoolArrayOID, ByteaArrayOID, Int2ArrayOID, Int4ArrayOID, Int8ArrayOID, Float4ArrayOID, Float8ArrayOID, TextArrayOID, VarcharArrayOID, OIDOID, InetOID, CidrOID, InetArrayOID, CidrArrayOID, RecordOID, JSONOID, JSONBOID:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
default:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
|
|
|
@ -11,7 +11,6 @@ var tcpConnConfig *pgx.ConnConfig = nil
|
|||
var unixSocketConnConfig *pgx.ConnConfig = nil
|
||||
var md5ConnConfig *pgx.ConnConfig = nil
|
||||
var plainPasswordConnConfig *pgx.ConnConfig = nil
|
||||
var noPasswordConnConfig *pgx.ConnConfig = nil
|
||||
var invalidUserConnConfig *pgx.ConnConfig = nil
|
||||
var tlsConnConfig *pgx.ConnConfig = nil
|
||||
var customDialerConnConfig *pgx.ConnConfig = nil
|
||||
|
@ -20,7 +19,6 @@ var customDialerConnConfig *pgx.ConnConfig = nil
|
|||
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
|
||||
// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
|
||||
// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"}
|
||||
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
||||
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
|
|
69
conn_pool.go
69
conn_pool.go
|
@ -28,8 +28,10 @@ type ConnPool struct {
|
|||
preparedStatements map[string]*PreparedStatement
|
||||
acquireTimeout time.Duration
|
||||
pgTypes map[OID]PgType
|
||||
pgsql_af_inet *byte
|
||||
pgsql_af_inet6 *byte
|
||||
pgsqlAfInet *byte
|
||||
pgsqlAfInet6 *byte
|
||||
txAfterClose func(tx *Tx)
|
||||
rowsAfterClose func(rows *Rows)
|
||||
}
|
||||
|
||||
type ConnPoolStat struct {
|
||||
|
@ -38,6 +40,9 @@ type ConnPoolStat struct {
|
|||
AvailableConnections int // unused live connections
|
||||
}
|
||||
|
||||
// ErrAcquireTimeout occurs when an attempt to acquire a connection times out.
|
||||
var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool")
|
||||
|
||||
// NewConnPool creates a new ConnPool. config.ConnConfig is passed through to
|
||||
// Connect directly.
|
||||
func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
|
||||
|
@ -68,6 +73,14 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
|
|||
p.logLevel = LogLevelNone
|
||||
}
|
||||
|
||||
p.txAfterClose = func(tx *Tx) {
|
||||
p.Release(tx.Conn())
|
||||
}
|
||||
|
||||
p.rowsAfterClose = func(rows *Rows) {
|
||||
p.Release(rows.Conn())
|
||||
}
|
||||
|
||||
p.allConnections = make([]*Conn, 0, p.maxConnections)
|
||||
p.availableConnections = make([]*Conn, 0, p.maxConnections)
|
||||
p.preparedStatements = make(map[string]*PreparedStatement)
|
||||
|
@ -121,7 +134,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
|
|||
|
||||
// Make sure the deadline (if it is) has not passed yet
|
||||
if p.deadlinePassed(deadline) {
|
||||
return nil, errors.New("Timeout: Acquire connection timeout")
|
||||
return nil, ErrAcquireTimeout
|
||||
}
|
||||
|
||||
// If there is a deadline then start a timeout timer
|
||||
|
@ -138,26 +151,25 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
|
|||
// Create a new connection.
|
||||
// Careful here: createConnectionUnlocked() removes the current lock,
|
||||
// creates a connection and then locks it back.
|
||||
if c, err := p.createConnectionUnlocked(); err == nil {
|
||||
c.poolResetCount = p.resetCount
|
||||
p.allConnections = append(p.allConnections, c)
|
||||
return c, nil
|
||||
} else {
|
||||
c, err := p.createConnectionUnlocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// All connections are in use and we cannot create more
|
||||
if p.logLevel >= LogLevelWarn {
|
||||
p.logger.Log(LogLevelWarn, "All connections in pool are busy - waiting...")
|
||||
}
|
||||
c.poolResetCount = p.resetCount
|
||||
p.allConnections = append(p.allConnections, c)
|
||||
return c, nil
|
||||
}
|
||||
// All connections are in use and we cannot create more
|
||||
if p.logLevel >= LogLevelWarn {
|
||||
p.logger.Log(LogLevelWarn, "All connections in pool are busy - waiting...")
|
||||
}
|
||||
|
||||
// Wait until there is an available connection OR room to create a new connection
|
||||
for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections {
|
||||
if p.deadlinePassed(deadline) {
|
||||
return nil, errors.New("Timeout: All connections in pool are busy")
|
||||
}
|
||||
p.cond.Wait()
|
||||
// Wait until there is an available connection OR room to create a new connection
|
||||
for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections {
|
||||
if p.deadlinePassed(deadline) {
|
||||
return nil, ErrAcquireTimeout
|
||||
}
|
||||
p.cond.Wait()
|
||||
}
|
||||
|
||||
// Stop the timer so that we do not spawn it on every acquire call.
|
||||
|
@ -272,7 +284,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) {
|
|||
}
|
||||
|
||||
func (p *ConnPool) createConnection() (*Conn, error) {
|
||||
c, err := connect(p.config, p.pgTypes, p.pgsql_af_inet, p.pgsql_af_inet6)
|
||||
c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -308,8 +320,8 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
|
|||
// all the known statements for the new connection.
|
||||
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
||||
p.pgTypes = c.PgTypes
|
||||
p.pgsql_af_inet = c.pgsql_af_inet
|
||||
p.pgsql_af_inet6 = c.pgsql_af_inet6
|
||||
p.pgsqlAfInet = c.pgsqlAfInet
|
||||
p.pgsqlAfInet6 = c.pgsqlAfInet6
|
||||
|
||||
if p.afterConnect != nil {
|
||||
err := p.afterConnect(c)
|
||||
|
@ -487,10 +499,13 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *ConnPool) txAfterClose(tx *Tx) {
|
||||
p.Release(tx.Conn())
|
||||
}
|
||||
// CopyTo acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
func (p *ConnPool) rowsAfterClose(rows *Rows) {
|
||||
p.Release(rows.Conn())
|
||||
return c.CopyTo(tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) {
|
|||
func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) {
|
||||
startTime := time.Now()
|
||||
c, err := pool.Acquire()
|
||||
return c, time.Now().Sub(startTime), err
|
||||
return c, time.Since(startTime), err
|
||||
}
|
||||
|
||||
func TestNewConnPool(t *testing.T) {
|
||||
|
@ -215,7 +215,7 @@ func TestPoolNonBlockingConnections(t *testing.T) {
|
|||
// Prior to createConnectionUnlocked() use the test took
|
||||
// maxConnections * openTimeout seconds to complete.
|
||||
// With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds.
|
||||
timeTaken := time.Now().Sub(startedAt)
|
||||
timeTaken := time.Since(startedAt)
|
||||
if timeTaken > openTimeout+1*time.Second {
|
||||
t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken)
|
||||
}
|
||||
|
@ -276,8 +276,8 @@ func TestPoolWithAcquireTimeoutSet(t *testing.T) {
|
|||
// ... then try to consume 1 more. It should fail after a short timeout.
|
||||
_, timeTaken, err := acquireWithTimeTaken(pool)
|
||||
|
||||
if err == nil || err.Error() != "Timeout: All connections in pool are busy" {
|
||||
t.Fatalf("Expected error to be 'Timeout: All connections in pool are busy', instead it was '%v'", err)
|
||||
if err == nil || err != pgx.ErrAcquireTimeout {
|
||||
t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err)
|
||||
}
|
||||
if timeTaken < connAllocTimeout {
|
||||
t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken)
|
||||
|
@ -366,12 +366,12 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Unable to Acquire: %v", err)
|
||||
}
|
||||
rows, _ := c.Query("select 1")
|
||||
rows, _ := c.Query("select 1, pg_sleep(0.02)")
|
||||
rows.Close()
|
||||
pool.Release(c)
|
||||
}
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
for i := 0; i < 10; i++ {
|
||||
doSomething()
|
||||
}
|
||||
|
||||
|
@ -381,7 +381,7 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) {
|
|||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1000; i++ {
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
|
|
@ -914,7 +914,7 @@ func TestPrepareQueryManyParameters(t *testing.T) {
|
|||
args := make([]interface{}, 0, tt.count)
|
||||
for j := 0; j < tt.count; j++ {
|
||||
params = append(params, fmt.Sprintf("($%d::text)", j+1))
|
||||
args = append(args, strconv.FormatInt(int64(j), 10))
|
||||
args = append(args, strconv.Itoa(j))
|
||||
}
|
||||
|
||||
sql := "values" + strings.Join(params, ", ")
|
||||
|
|
|
@ -0,0 +1,241 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// CopyToRows returns a CopyToSource interface over the provided rows slice
|
||||
// making it usable by *Conn.CopyTo.
|
||||
func CopyToRows(rows [][]interface{}) CopyToSource {
|
||||
return ©ToRows{rows: rows, idx: -1}
|
||||
}
|
||||
|
||||
type copyToRows struct {
|
||||
rows [][]interface{}
|
||||
idx int
|
||||
}
|
||||
|
||||
func (ctr *copyToRows) Next() bool {
|
||||
ctr.idx++
|
||||
return ctr.idx < len(ctr.rows)
|
||||
}
|
||||
|
||||
func (ctr *copyToRows) Values() ([]interface{}, error) {
|
||||
return ctr.rows[ctr.idx], nil
|
||||
}
|
||||
|
||||
func (ctr *copyToRows) Err() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CopyToSource is the interface used by *Conn.CopyTo as the source for copy data.
|
||||
type CopyToSource interface {
|
||||
// Next returns true if there is another row and makes the next row data
|
||||
// available to Values(). When there are no more rows available or an error
|
||||
// has occurred it returns false.
|
||||
Next() bool
|
||||
|
||||
// Values returns the values for the current row.
|
||||
Values() ([]interface{}, error)
|
||||
|
||||
// Err returns any error that has been encountered by the CopyToSource. If
|
||||
// this is not nil *Conn.CopyTo will abort the copy.
|
||||
Err() error
|
||||
}
|
||||
|
||||
type copyTo struct {
|
||||
conn *Conn
|
||||
tableName string
|
||||
columnNames []string
|
||||
rowSrc CopyToSource
|
||||
readerErrChan chan error
|
||||
}
|
||||
|
||||
func (ct *copyTo) readUntilReadyForQuery() {
|
||||
for {
|
||||
t, r, err := ct.conn.rxMsg()
|
||||
if err != nil {
|
||||
ct.readerErrChan <- err
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
}
|
||||
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
ct.conn.rxReadyForQuery(r)
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
case commandComplete:
|
||||
case errorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
|
||||
default:
|
||||
err = ct.conn.processContextFreeMsg(t, r)
|
||||
if err != nil {
|
||||
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *copyTo) waitForReaderDone() error {
|
||||
var err error
|
||||
for err = range ct.readerErrChan {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ct *copyTo) run() (int, error) {
|
||||
quotedTableName := quoteIdentifier(ct.tableName)
|
||||
buf := &bytes.Buffer{}
|
||||
for i, cn := range ct.columnNames {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
buf.WriteString(quoteIdentifier(cn))
|
||||
}
|
||||
quotedColumnNames := buf.String()
|
||||
|
||||
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.conn.readUntilCopyInResponse()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
go ct.readUntilReadyForQuery()
|
||||
defer ct.waitForReaderDone()
|
||||
|
||||
wbuf := newWriteBuf(ct.conn, copyData)
|
||||
|
||||
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
|
||||
wbuf.WriteInt32(0)
|
||||
wbuf.WriteInt32(0)
|
||||
|
||||
var sentCount int
|
||||
|
||||
for ct.rowSrc.Next() {
|
||||
select {
|
||||
case err = <-ct.readerErrChan:
|
||||
return 0, err
|
||||
default:
|
||||
}
|
||||
|
||||
if len(wbuf.buf) > 65536 {
|
||||
wbuf.closeMsg()
|
||||
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Directly manipulate wbuf to reset to reuse the same buffer
|
||||
wbuf.buf = wbuf.buf[0:5]
|
||||
wbuf.buf[0] = copyData
|
||||
wbuf.sizeIdx = 1
|
||||
}
|
||||
|
||||
sentCount++
|
||||
|
||||
values, err := ct.rowSrc.Values()
|
||||
if err != nil {
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
}
|
||||
if len(values) != len(ct.columnNames) {
|
||||
ct.cancelCopyIn()
|
||||
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
|
||||
if err != nil {
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if ct.rowSrc.Err() != nil {
|
||||
ct.cancelCopyIn()
|
||||
return 0, ct.rowSrc.Err()
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(-1) // terminate the copy stream
|
||||
|
||||
wbuf.startMsg(copyDone)
|
||||
wbuf.closeMsg()
|
||||
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.waitForReaderDone()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sentCount, nil
|
||||
}
|
||||
|
||||
func (c *Conn) readUntilCopyInResponse() error {
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case copyInResponse:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(t, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *copyTo) cancelCopyIn() error {
|
||||
wbuf := newWriteBuf(ct.conn, copyFail)
|
||||
wbuf.WriteCString("client error: abort")
|
||||
wbuf.closeMsg()
|
||||
_, err := ct.conn.conn.Write(wbuf.buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CopyTo uses the PostgreSQL copy protocol to perform bulk data insertion.
|
||||
// It returns the number of rows copied and an error.
|
||||
//
|
||||
// CopyTo requires all values use the binary format. Almost all types
|
||||
// implemented by pgx use the binary format by default. Types implementing
|
||||
// Encoder can only be used if they encode to the binary format.
|
||||
func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
|
||||
ct := ©To{
|
||||
conn: c,
|
||||
tableName: tableName,
|
||||
columnNames: columnNames,
|
||||
rowSrc: rowSrc,
|
||||
readerErrChan: make(chan error),
|
||||
}
|
||||
|
||||
return ct.run()
|
||||
}
|
|
@ -0,0 +1,428 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
func TestConnCopyToSmall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g timestamptz
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyTo: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g timestamptz,
|
||||
h bytea
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{}
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyTo: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
for _, oid := range []pgx.OID{pgx.JSONOID, pgx.JSONBOID} {
|
||||
if _, ok := conn.PgTypes[oid]; !ok {
|
||||
return // No JSON/JSONB type -- must be running against old PostgreSQL
|
||||
}
|
||||
}
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a json,
|
||||
b jsonb
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
|
||||
{nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyTo: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToFailServerSideMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int4,
|
||||
b varchar not null
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{int32(1), "abc"},
|
||||
{int32(2), nil}, // this row should trigger a failure
|
||||
{int32(3), "def"},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyTo return error, but it did not")
|
||||
}
|
||||
if _, ok := err.(pgx.PgError); !ok {
|
||||
t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type failSource struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (fs *failSource) Next() bool {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
fs.count++
|
||||
return fs.count < 100
|
||||
}
|
||||
|
||||
func (fs *failSource) Values() ([]interface{}, error) {
|
||||
if fs.count == 3 {
|
||||
return []interface{}{nil}, nil
|
||||
}
|
||||
return []interface{}{make([]byte, 100000)}, nil
|
||||
}
|
||||
|
||||
func (fs *failSource) Err() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a bytea not null
|
||||
)`)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyTo return error, but it did not")
|
||||
}
|
||||
if _, ok := err.(pgx.PgError); !ok {
|
||||
t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
copyTime := endTime.Sub(startTime)
|
||||
if copyTime > time.Second {
|
||||
t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type clientFailSource struct {
|
||||
count int
|
||||
err error
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Next() bool {
|
||||
cfs.count++
|
||||
return cfs.count < 100
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Values() ([]interface{}, error) {
|
||||
if cfs.count == 3 {
|
||||
cfs.err = fmt.Errorf("client error")
|
||||
return nil, cfs.err
|
||||
}
|
||||
return []interface{}{make([]byte, 100000)}, nil
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Err() error {
|
||||
return cfs.err
|
||||
}
|
||||
|
||||
func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a bytea not null
|
||||
)`)
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyTo return error, but it did not")
|
||||
}
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type clientFinalErrSource struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (cfs *clientFinalErrSource) Next() bool {
|
||||
cfs.count++
|
||||
return cfs.count < 5
|
||||
}
|
||||
|
||||
func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
|
||||
return []interface{}{make([]byte, 100000)}, nil
|
||||
}
|
||||
|
||||
func (cfs *clientFinalErrSource) Err() error {
|
||||
return fmt.Errorf("final error")
|
||||
}
|
||||
|
||||
func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a bytea not null
|
||||
)`)
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyTo return error, but it did not")
|
||||
}
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
118
doc.go
118
doc.go
|
@ -50,7 +50,7 @@ pgx also implements QueryRow in the same style as database/sql.
|
|||
return err
|
||||
}
|
||||
|
||||
Use exec to execute a query that does not return a result set.
|
||||
Use Exec to execute a query that does not return a result set.
|
||||
|
||||
commandTag, err := conn.Exec("delete from widgets where id=$1", 42)
|
||||
if err != nil {
|
||||
|
@ -81,43 +81,39 @@ releasing connections when you do not need that level of control.
|
|||
return err
|
||||
}
|
||||
|
||||
Transactions
|
||||
Base Type Mapping
|
||||
|
||||
Transactions are started by calling Begin or BeginIso. The BeginIso variant
|
||||
creates a transaction with a specified isolation level.
|
||||
pgx maps between all common base types directly between Go and PostgreSQL. In
|
||||
particular:
|
||||
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Rollback is safe to call even if the tx is already closed, so if
|
||||
// the tx commits successfully, this is a no-op
|
||||
defer tx.Rollback()
|
||||
Go PostgreSQL
|
||||
-----------------------
|
||||
string varchar
|
||||
text
|
||||
|
||||
_, err = tx.Exec("insert into foo(id) values (1)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Integers are automatically be converted to any other integer type if
|
||||
// it can be done without overflow or underflow.
|
||||
int8
|
||||
int16 smallint
|
||||
int32 int
|
||||
int64 bigint
|
||||
int
|
||||
uint8
|
||||
uint16
|
||||
uint32
|
||||
uint64
|
||||
uint
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Floats are strict and do not automatically convert like integers.
|
||||
float32 float4
|
||||
float64 float8
|
||||
|
||||
Listen and Notify
|
||||
time.Time date
|
||||
timestamp
|
||||
timestamptz
|
||||
|
||||
pgx can listen to the PostgreSQL notification system with the
|
||||
WaitForNotification function. It takes a maximum time to wait for a
|
||||
notification.
|
||||
[]byte bytea
|
||||
|
||||
err := conn.Listen("channelname")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if notification, err := conn.WaitForNotification(time.Second); err != nil {
|
||||
// do something with notification
|
||||
}
|
||||
|
||||
Null Mapping
|
||||
|
||||
|
@ -136,7 +132,7 @@ Array Mapping
|
|||
|
||||
pgx maps between int16, int32, int64, float32, float64, and string Go slices
|
||||
and the equivalent PostgreSQL array type. Go slices of native types do not
|
||||
support nulls, so if a PostgreSQL array that contains a slice is read into a
|
||||
support nulls, so if a PostgreSQL array that contains a null is read into a
|
||||
native Go slice an error will occur.
|
||||
|
||||
Hstore Mapping
|
||||
|
@ -192,6 +188,64 @@ the raw bytes returned by PostgreSQL. This can be especially useful for reading
|
|||
varchar, text, json, and jsonb values directly into a []byte and avoiding the
|
||||
type conversion from string.
|
||||
|
||||
Transactions
|
||||
|
||||
Transactions are started by calling Begin or BeginIso. The BeginIso variant
|
||||
creates a transaction with a specified isolation level.
|
||||
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Rollback is safe to call even if the tx is already closed, so if
|
||||
// the tx commits successfully, this is a no-op
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec("insert into foo(id) values (1)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Copy Protocol
|
||||
|
||||
Use CopyTo to efficiently insert multiple rows at a time using the PostgreSQL
|
||||
copy protocol. CopyTo accepts a CopyToSource interface. If the data is already
|
||||
in a [][]interface{} use CopyToRows to wrap it in a CopyToSource interface. Or
|
||||
implement CopyToSource to avoid buffering the entire data set in memory.
|
||||
|
||||
rows := [][]interface{}{
|
||||
{"John", "Smith", int32(36)},
|
||||
{"Jane", "Doe", int32(29)},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo(
|
||||
"people",
|
||||
[]string{"first_name", "last_name", "age"},
|
||||
pgx.CopyToRows(rows),
|
||||
)
|
||||
|
||||
CopyTo can be faster than an insert with as few as 5 rows.
|
||||
|
||||
Listen and Notify
|
||||
|
||||
pgx can listen to the PostgreSQL notification system with the
|
||||
WaitForNotification function. It takes a maximum time to wait for a
|
||||
notification.
|
||||
|
||||
err := conn.Listen("channelname")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if notification, err := conn.WaitForNotification(time.Second); err != nil {
|
||||
// do something with notification
|
||||
}
|
||||
|
||||
TLS
|
||||
|
||||
The pgx ConnConfig struct has a TLSConfig field. If this field is
|
||||
|
|
|
@ -4,8 +4,6 @@ import (
|
|||
"encoding/binary"
|
||||
)
|
||||
|
||||
type fastpathArg []byte
|
||||
|
||||
func newFastpath(cn *Conn) *fastpath {
|
||||
return &fastpath{cn: cn, fns: make(map[string]OID)}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@ const (
|
|||
hsVal
|
||||
hsNul
|
||||
hsNext
|
||||
hsEnd
|
||||
)
|
||||
|
||||
type hstoreParser struct {
|
||||
|
|
|
@ -85,23 +85,23 @@ func TestNullHstoreTranscode(t *testing.T) {
|
|||
{pgx.NullHstore{}, "null"},
|
||||
{pgx.NullHstore{Valid: true}, "empty"},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"single key/value"},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar", Valid: true}, "baz": pgx.NullString{String: "quz", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}},
|
||||
Valid: true},
|
||||
"multiple key/values"},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"NULL": pgx.NullString{String: "bar", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
`string "NULL" key`},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "NULL", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}},
|
||||
Valid: true},
|
||||
`string "NULL" value`},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "", Valid: false}},
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}},
|
||||
Valid: true},
|
||||
`NULL value`},
|
||||
}
|
||||
|
@ -120,36 +120,36 @@ func TestNullHstoreTranscode(t *testing.T) {
|
|||
}
|
||||
for _, sst := range specialStringTests {
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{sst.input + "foo": pgx.NullString{String: "bar", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"key with " + sst.description + " at beginning"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": pgx.NullString{String: "bar", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"key with " + sst.description + " in middle"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo" + sst.input: pgx.NullString{String: "bar", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"key with " + sst.description + " at end"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{sst.input: pgx.NullString{String: "bar", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"key is " + sst.description})
|
||||
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: sst.input + "bar", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"value with " + sst.description + " at beginning"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar" + sst.input + "bar", Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"value with " + sst.description + " in middle"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar" + sst.input, Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}},
|
||||
Valid: true},
|
||||
"value with " + sst.description + " at end"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: sst.input, Valid: true}},
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}},
|
||||
Valid: true},
|
||||
"value is " + sst.description})
|
||||
}
|
||||
|
|
28
messages.go
28
messages.go
|
@ -25,6 +25,10 @@ const (
|
|||
noData = 'n'
|
||||
closeComplete = '3'
|
||||
flush = 'H'
|
||||
copyInResponse = 'G'
|
||||
copyData = 'd'
|
||||
copyFail = 'f'
|
||||
copyDone = 'c'
|
||||
)
|
||||
|
||||
type startupMessage struct {
|
||||
|
@ -35,10 +39,10 @@ func newStartupMessage() *startupMessage {
|
|||
return &startupMessage{map[string]string{}}
|
||||
}
|
||||
|
||||
func (self *startupMessage) Bytes() (buf []byte) {
|
||||
func (s *startupMessage) Bytes() (buf []byte) {
|
||||
buf = make([]byte, 8, 128)
|
||||
binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber))
|
||||
for key, value := range self.options {
|
||||
for key, value := range s.options {
|
||||
buf = append(buf, key...)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, value...)
|
||||
|
@ -49,8 +53,6 @@ func (self *startupMessage) Bytes() (buf []byte) {
|
|||
return buf
|
||||
}
|
||||
|
||||
type OID int32
|
||||
|
||||
type FieldDescription struct {
|
||||
Name string
|
||||
Table OID
|
||||
|
@ -85,8 +87,8 @@ type PgError struct {
|
|||
Routine string
|
||||
}
|
||||
|
||||
func (self PgError) Error() string {
|
||||
return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")"
|
||||
func (pe PgError) Error() string {
|
||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
func newWriteBuf(c *Conn, t byte) *WriteBuf {
|
||||
|
@ -95,7 +97,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf {
|
|||
return &c.writeBuf
|
||||
}
|
||||
|
||||
// WrifeBuf is used build messages to send to the PostgreSQL server. It is used
|
||||
// WriteBuf is used build messages to send to the PostgreSQL server. It is used
|
||||
// by the Encoder interface when implementing custom encoders.
|
||||
type WriteBuf struct {
|
||||
buf []byte
|
||||
|
@ -128,12 +130,24 @@ func (wb *WriteBuf) WriteInt16(n int16) {
|
|||
wb.buf = append(wb.buf, b...)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteUint16(n uint16) {
|
||||
b := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(b, n)
|
||||
wb.buf = append(wb.buf, b...)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteInt32(n int32) {
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(b, uint32(n))
|
||||
wb.buf = append(wb.buf, b...)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteUint32(n uint32) {
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(b, n)
|
||||
wb.buf = append(wb.buf, b...)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteInt64(n int64) {
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, uint64(n))
|
||||
|
|
115
msg_reader.go
115
msg_reader.go
|
@ -10,7 +10,6 @@ import (
|
|||
// msgReader is a helper that reads values from a PostgreSQL message.
|
||||
type msgReader struct {
|
||||
reader *bufio.Reader
|
||||
buf [128]byte
|
||||
msgBytesRemaining int32
|
||||
err error
|
||||
log func(lvl int, msg string, ctx ...interface{})
|
||||
|
@ -47,10 +46,15 @@ func (r *msgReader) rxMsg() (byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
b := r.buf[0:5]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
b, err := r.reader.Peek(5)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0, err
|
||||
}
|
||||
msgType := b[0]
|
||||
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4
|
||||
return b[0], err
|
||||
r.reader.Discard(5)
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
func (r *msgReader) readByte() byte {
|
||||
|
@ -58,7 +62,7 @@ func (r *msgReader) readByte() byte {
|
|||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 1
|
||||
r.msgBytesRemaining--
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
|
@ -88,8 +92,7 @@ func (r *msgReader) readInt16() int16 {
|
|||
return 0
|
||||
}
|
||||
|
||||
b := r.buf[0:2]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
b, err := r.reader.Peek(2)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
|
@ -97,6 +100,8 @@ func (r *msgReader) readInt16() int16 {
|
|||
|
||||
n := int16(binary.BigEndian.Uint16(b))
|
||||
|
||||
r.reader.Discard(2)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
@ -115,8 +120,7 @@ func (r *msgReader) readInt32() int32 {
|
|||
return 0
|
||||
}
|
||||
|
||||
b := r.buf[0:4]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
b, err := r.reader.Peek(4)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
|
@ -124,6 +128,8 @@ func (r *msgReader) readInt32() int32 {
|
|||
|
||||
n := int32(binary.BigEndian.Uint32(b))
|
||||
|
||||
r.reader.Discard(4)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
@ -131,6 +137,62 @@ func (r *msgReader) readInt32() int32 {
|
|||
return n
|
||||
}
|
||||
|
||||
func (r *msgReader) readUint16() uint16 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 2
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(2)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := uint16(binary.BigEndian.Uint16(b))
|
||||
|
||||
r.reader.Discard(2)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *msgReader) readUint32() uint32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 4
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(4)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := uint32(binary.BigEndian.Uint32(b))
|
||||
|
||||
r.reader.Discard(4)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *msgReader) readInt64() int64 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
|
@ -142,8 +204,7 @@ func (r *msgReader) readInt64() int64 {
|
|||
return 0
|
||||
}
|
||||
|
||||
b := r.buf[0:8]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
b, err := r.reader.Peek(8)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
|
@ -151,6 +212,8 @@ func (r *msgReader) readInt64() int64 {
|
|||
|
||||
n := int64(binary.BigEndian.Uint64(b))
|
||||
|
||||
r.reader.Discard(8)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
@ -190,32 +253,34 @@ func (r *msgReader) readCString() string {
|
|||
}
|
||||
|
||||
// readString reads count bytes and returns as string
|
||||
func (r *msgReader) readString(count int32) string {
|
||||
func (r *msgReader) readString(countI32 int32) string {
|
||||
if r.err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= count
|
||||
r.msgBytesRemaining -= countI32
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
var b []byte
|
||||
if count <= int32(len(r.buf)) {
|
||||
b = r.buf[0:int(count)]
|
||||
count := int(countI32)
|
||||
var s string
|
||||
|
||||
if r.reader.Buffered() >= count {
|
||||
buf, _ := r.reader.Peek(count)
|
||||
s = string(buf)
|
||||
r.reader.Discard(count)
|
||||
} else {
|
||||
b = make([]byte, int(count))
|
||||
buf := make([]byte, count)
|
||||
_, err := io.ReadFull(r.reader, buf)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return ""
|
||||
}
|
||||
s = string(buf)
|
||||
}
|
||||
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return ""
|
||||
}
|
||||
|
||||
s := string(b)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
|
8
query.go
8
query.go
|
@ -298,13 +298,17 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
if err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if vr.Type().DataType == JSONOID || vr.Type().DataType == JSONBOID {
|
||||
} else if vr.Type().DataType == JSONOID {
|
||||
// Because the argument passed to decodeJSON will escape the heap.
|
||||
// This allows d to be stack allocated and only copied to the heap when
|
||||
// we actually are decoding JSON. This saves one memory allocation per
|
||||
// row.
|
||||
d2 := d
|
||||
decodeJSON(vr, &d2)
|
||||
} else if vr.Type().DataType == JSONBOID {
|
||||
// Same trick as above for getting stack allocation
|
||||
d2 := d
|
||||
decodeJSONB(vr, &d2)
|
||||
} else {
|
||||
if err := Decode(vr, d); err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
|
@ -393,7 +397,7 @@ func (rows *Rows) Values() ([]interface{}, error) {
|
|||
values = append(values, d)
|
||||
case JSONBOID:
|
||||
var d interface{}
|
||||
decodeJSON(vr, &d)
|
||||
decodeJSONB(vr, &d)
|
||||
values = append(values, d)
|
||||
default:
|
||||
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
|
||||
|
|
|
@ -3,11 +3,12 @@ package pgx_test
|
|||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"github.com/jackc/pgx"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
|
@ -784,7 +785,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) {
|
|||
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
|
||||
}
|
||||
|
||||
if bytes.Compare(actual, tt.expected) != 0 {
|
||||
if !bytes.Equal(actual, tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
|
||||
}
|
||||
|
||||
|
@ -1281,7 +1282,7 @@ func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
|
|||
}
|
||||
var num decimal.Decimal
|
||||
|
||||
err = conn.QueryRow("select $1::decimal", expected).Scan(&num)
|
||||
err = conn.QueryRow("select $1::decimal", &expected).Scan(&num)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
|
|
4
sql.go
4
sql.go
|
@ -14,7 +14,7 @@ func init() {
|
|||
placeholders = make([]string, 64)
|
||||
|
||||
for i := 1; i < 64; i++ {
|
||||
placeholders[i] = "$" + strconv.FormatInt(int64(i), 10)
|
||||
placeholders[i] = "$" + strconv.Itoa(i)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,5 +25,5 @@ func (qa *QueryArgs) Append(v interface{}) string {
|
|||
if len(*qa) < len(placeholders) {
|
||||
return placeholders[len(*qa)]
|
||||
}
|
||||
return "$" + strconv.FormatInt(int64(len(*qa)), 10)
|
||||
return "$" + strconv.Itoa(len(*qa))
|
||||
}
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
func TestQueryArgs(t *testing.T) {
|
||||
var qa pgx.QueryArgs
|
||||
|
||||
for i := 1; i < 512; i++ {
|
||||
expectedPlaceholder := "$" + strconv.FormatInt(int64(i), 10)
|
||||
expectedPlaceholder := "$" + strconv.Itoa(i)
|
||||
placeholder := qa.Append(i)
|
||||
if placeholder != expectedPlaceholder {
|
||||
t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder)
|
||||
|
|
9
tx.go
9
tx.go
|
@ -158,6 +158,15 @@ func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row {
|
|||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
// CopyTo delegates to the underlying *Conn
|
||||
func (tx *Tx) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return 0, ErrTxClosed
|
||||
}
|
||||
|
||||
return tx.conn.CopyTo(tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
||||
// Conn returns the *Conn this transaction is using.
|
||||
func (tx *Tx) Conn() *Conn {
|
||||
return tx.conn
|
||||
|
|
|
@ -60,6 +60,20 @@ func (r *ValueReader) ReadInt16() int16 {
|
|||
return r.mr.readInt16()
|
||||
}
|
||||
|
||||
func (r *ValueReader) ReadUint16() uint16 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.valueBytesRemaining -= 2
|
||||
if r.valueBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of value"))
|
||||
return 0
|
||||
}
|
||||
|
||||
return r.mr.readUint16()
|
||||
}
|
||||
|
||||
func (r *ValueReader) ReadInt32() int32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
|
@ -74,6 +88,20 @@ func (r *ValueReader) ReadInt32() int32 {
|
|||
return r.mr.readInt32()
|
||||
}
|
||||
|
||||
func (r *ValueReader) ReadUint32() uint32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.valueBytesRemaining -= 4
|
||||
if r.valueBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of value"))
|
||||
return 0
|
||||
}
|
||||
|
||||
return r.mr.readUint32()
|
||||
}
|
||||
|
||||
func (r *ValueReader) ReadInt64() int64 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
|
@ -89,7 +117,7 @@ func (r *ValueReader) ReadInt64() int64 {
|
|||
}
|
||||
|
||||
func (r *ValueReader) ReadOID() OID {
|
||||
return OID(r.ReadInt32())
|
||||
return OID(r.ReadUint32())
|
||||
}
|
||||
|
||||
// ReadString reads count bytes and returns as string
|
||||
|
|
171
values_test.go
171
values_test.go
|
@ -88,67 +88,74 @@ func TestJSONAndJSONBTranscode(t *testing.T) {
|
|||
if _, ok := conn.PgTypes[oid]; !ok {
|
||||
return // No JSON/JSONB type -- must be running against old PostgreSQL
|
||||
}
|
||||
typename := conn.PgTypes[oid].Name
|
||||
|
||||
testJSONString(t, conn, typename)
|
||||
testJSONStringPointer(t, conn, typename)
|
||||
testJSONSingleLevelStringMap(t, conn, typename)
|
||||
testJSONNestedMap(t, conn, typename)
|
||||
testJSONStringArray(t, conn, typename)
|
||||
testJSONInt64Array(t, conn, typename)
|
||||
testJSONInt16ArrayFailureDueToOverflow(t, conn, typename)
|
||||
testJSONStruct(t, conn, typename)
|
||||
for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} {
|
||||
pgtype := conn.PgTypes[oid]
|
||||
pgtype.DefaultFormat = format
|
||||
conn.PgTypes[oid] = pgtype
|
||||
|
||||
typename := conn.PgTypes[oid].Name
|
||||
|
||||
testJSONString(t, conn, typename, format)
|
||||
testJSONStringPointer(t, conn, typename, format)
|
||||
testJSONSingleLevelStringMap(t, conn, typename, format)
|
||||
testJSONNestedMap(t, conn, typename, format)
|
||||
testJSONStringArray(t, conn, typename, format)
|
||||
testJSONInt64Array(t, conn, typename, format)
|
||||
testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format)
|
||||
testJSONStruct(t, conn, typename, format)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testJSONString(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) {
|
||||
input := `{"key": "value"}`
|
||||
expectedOutput := map[string]string{"key": "value"}
|
||||
var output map[string]string
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expectedOutput, output) {
|
||||
t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output)
|
||||
t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) {
|
||||
input := `{"key": "value"}`
|
||||
expectedOutput := map[string]string{"key": "value"}
|
||||
var output map[string]string
|
||||
err := conn.QueryRow("select $1::"+typename, &input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expectedOutput, output) {
|
||||
t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output)
|
||||
t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) {
|
||||
input := map[string]string{"key": "value"}
|
||||
var output map[string]string
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output)
|
||||
t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) {
|
||||
input := map[string]interface{}{
|
||||
"name": "Uncanny",
|
||||
"stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)},
|
||||
|
@ -157,52 +164,52 @@ func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) {
|
|||
var output map[string]interface{}
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output)
|
||||
t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) {
|
||||
input := []string{"foo", "bar", "baz"}
|
||||
var output []string
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output)
|
||||
t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output)
|
||||
}
|
||||
}
|
||||
|
||||
func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) {
|
||||
input := []int64{1, 2, 234432}
|
||||
var output []int64
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output)
|
||||
t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output)
|
||||
}
|
||||
}
|
||||
|
||||
func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) {
|
||||
input := []int{1, 2, 234432}
|
||||
var output []int16
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" {
|
||||
t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err)
|
||||
t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err)
|
||||
}
|
||||
}
|
||||
|
||||
func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) {
|
||||
type person struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
|
@ -217,11 +224,11 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) {
|
|||
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output)
|
||||
t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -561,6 +568,13 @@ func TestNullX(t *testing.T) {
|
|||
s pgx.NullString
|
||||
i16 pgx.NullInt16
|
||||
i32 pgx.NullInt32
|
||||
c pgx.NullChar
|
||||
a pgx.NullAclItem
|
||||
n pgx.NullName
|
||||
oid pgx.NullOID
|
||||
xid pgx.NullXid
|
||||
cid pgx.NullCid
|
||||
tid pgx.NullTid
|
||||
i64 pgx.NullInt64
|
||||
f32 pgx.NullFloat32
|
||||
f64 pgx.NullFloat64
|
||||
|
@ -582,6 +596,27 @@ func TestNullX(t *testing.T) {
|
|||
{"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}},
|
||||
{"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}},
|
||||
{"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}},
|
||||
{"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 1, Valid: true}}},
|
||||
{"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 0, Valid: false}}},
|
||||
{"select $1::oid", []interface{}{pgx.NullOID{OID: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 4294967295, Valid: true}}},
|
||||
{"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}},
|
||||
{"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}},
|
||||
{"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}},
|
||||
{"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}},
|
||||
{"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}},
|
||||
{"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}},
|
||||
{"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}},
|
||||
{"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}},
|
||||
{"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}},
|
||||
{"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}},
|
||||
// A tricky (and valid) aclitem can still be used, especially with Go's useful backticks
|
||||
{"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}},
|
||||
{"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}},
|
||||
{"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}},
|
||||
{"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}},
|
||||
{"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}},
|
||||
{"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}},
|
||||
{"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}},
|
||||
{"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}},
|
||||
{"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}},
|
||||
{"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}},
|
||||
|
@ -615,6 +650,52 @@ func TestNullX(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) {
|
||||
if !reflect.DeepEqual(query, scan) {
|
||||
t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAclArrayDecoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := "select $1::aclitem[]"
|
||||
var scan []pgx.AclItem
|
||||
|
||||
tests := []struct {
|
||||
query []pgx.AclItem
|
||||
}{
|
||||
{
|
||||
[]pgx.AclItem{},
|
||||
},
|
||||
{
|
||||
[]pgx.AclItem{"=r/postgres"},
|
||||
},
|
||||
{
|
||||
[]pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"},
|
||||
},
|
||||
{
|
||||
[]pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(sql, tt.query).Scan(&scan)
|
||||
if err != nil {
|
||||
// t.Errorf(`%d. error reading array: %v`, i, err)
|
||||
t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query)
|
||||
if pgerr, ok := err.(pgx.PgError); ok {
|
||||
t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail)
|
||||
}
|
||||
continue
|
||||
}
|
||||
assertAclItemSlicesEqual(t, tt.query, scan)
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArrayDecoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -630,7 +711,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::bool[]", []bool{true, false, true}, &[]bool{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]bool))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]bool))) {
|
||||
t.Errorf("failed to encode bool[]")
|
||||
}
|
||||
},
|
||||
|
@ -638,7 +719,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]int16))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]int16))) {
|
||||
t.Errorf("failed to encode smallint[]")
|
||||
}
|
||||
},
|
||||
|
@ -646,7 +727,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]uint16))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]uint16))) {
|
||||
t.Errorf("failed to encode smallint[]")
|
||||
}
|
||||
},
|
||||
|
@ -654,7 +735,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::int[]", []int32{2, 4, 484}, &[]int32{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]int32))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]int32))) {
|
||||
t.Errorf("failed to encode int[]")
|
||||
}
|
||||
},
|
||||
|
@ -662,7 +743,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]uint32))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]uint32))) {
|
||||
t.Errorf("failed to encode int[]")
|
||||
}
|
||||
},
|
||||
|
@ -670,7 +751,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]int64))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]int64))) {
|
||||
t.Errorf("failed to encode bigint[]")
|
||||
}
|
||||
},
|
||||
|
@ -678,7 +759,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]uint64))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]uint64))) {
|
||||
t.Errorf("failed to encode bigint[]")
|
||||
}
|
||||
},
|
||||
|
@ -686,7 +767,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]string))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]string))) {
|
||||
t.Errorf("failed to encode text[]")
|
||||
}
|
||||
},
|
||||
|
@ -694,7 +775,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) {
|
||||
t.Errorf("failed to encode time.Time[] to timestamp[]")
|
||||
}
|
||||
},
|
||||
|
@ -702,7 +783,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
{
|
||||
"select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
|
||||
if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) {
|
||||
t.Errorf("failed to encode time.Time[] to timestamptz[]")
|
||||
}
|
||||
},
|
||||
|
@ -718,7 +799,7 @@ func TestArrayDecoding(t *testing.T) {
|
|||
for i := range queryBytesSliceSlice {
|
||||
qb := queryBytesSliceSlice[i]
|
||||
sb := scanBytesSliceSlice[i]
|
||||
if bytes.Compare(qb, sb) != 0 {
|
||||
if !bytes.Equal(qb, sb) {
|
||||
t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb)
|
||||
}
|
||||
}
|
||||
|
@ -1075,11 +1156,11 @@ func TestRowDecode(t *testing.T) {
|
|||
expected []interface{}
|
||||
}{
|
||||
{
|
||||
"select row(1, 'cat', '2015-01-01 08:12:42'::timestamptz)",
|
||||
"select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)",
|
||||
[]interface{}{
|
||||
int32(1),
|
||||
"cat",
|
||||
time.Date(2015, 1, 1, 8, 12, 42, 0, time.Local),
|
||||
time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue