diff --git a/.travis.yml b/.travis.yml index 314e26e8..b120b33a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: - - 1.6.2 - - 1.5.2 + - 1.7.1 + - 1.6.3 - tip # Derived from https://github.com/lib/pq/blob/master/.travis.yml @@ -29,6 +29,8 @@ env: - PGVERSION=9.2 - PGVERSION=9.1 +# The tricky test user, below, has to actually exist so that it can be used in a test +# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. before_script: - mv conn_config_test.go.travis conn_config_test.go - psql -U postgres -c 'create database pgx_test' @@ -36,6 +38,12 @@ before_script: - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + - psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" + +install: + - go get -u github.com/shopspring/decimal + - go get -u gopkg.in/inconshreveable/log15.v2 + - go get -u github.com/jackc/fake script: - go test -v -race -short ./... diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d185f2b..bedf106b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,26 @@ ## Fixes +* Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood) + +## Features + +* Add xid type support (Manni Wood) +* Add cid type support (Manni Wood) +* Add tid type support (Manni Wood) +* Add "char" type support (Manni Wood) +* Add NullOid type (Manni Wood) +* Add json/jsonb binary support to allow use with CopyTo +* Add named error ErrAcquireTimeout (Alexander Staubo) + +## Compatibility + +* jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work. + +# 2.9.0 (August 26, 2016) + +## Fixes + * Fix *ConnPool.Deallocate() not deleting prepared statement from map * Fix stdlib not logging unprepared query SQL (Krzysztof Dryƛ) * Fix Rows.Values() with varchar binary format @@ -9,12 +29,13 @@ ## Features +* Add CopyTo * Add PrepareEx * Add basic record to []interface{} decoding * Encode and decode between all Go and PostgreSQL integer types with bounds checking * Decode inet/cidr to net.IP * Encode/decode [][]byte to/from bytea[] -* Encode/decode named types whoses underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64 +* Encode/decode named types whose underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64 ## Performance diff --git a/README.md b/README.md index b2795fca..87b10797 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Pgx supports many additional features beyond what is available through database/ * Transaction isolation level control * Full TLS connection control * Binary format support for custom types (can be much faster) +* Copy protocol support for faster bulk data loads * Logging support * Configurable connection pool with after connect hooks to do arbitrary connection setup * PostgreSQL array to Go slice mapping for integers, floats, and strings @@ -60,16 +61,23 @@ skip tests for connection types that are not configured. ### Normal Test Environment -To setup the normal test environment run the following SQL: +To setup the normal test environment, first install these dependencies: + + go get github.com/jackc/fake + go get github.com/shopspring/decimal + go get gopkg.in/inconshreveable/log15.v2 + +Then run the following SQL: create user pgx_md5 password 'secret'; + create user " tricky, ' } "" \ test user " password 'secret'; create database pgx_test; Connect to database pgx_test and run: create extension hstore; -Next open connection_settings_test.go.example and make a copy without the +Next open conn_config_test.go.example and make a copy without the .example. If your PostgreSQL server is accepting connections on 127.0.0.1, then you are done. @@ -98,7 +106,8 @@ If you are developing on Windows with TCP connections: ## Version Policy -pgx follows semantic versioning for the documented public API. ```master``` -branch tracks the latest stable branch (```v2```). Consider using ```import -"gopkg.in/jackc/pgx.v2"``` to lock to the ```v2``` branch or use a vendoring -tool such as [godep](https://github.com/tools/godep). +pgx follows semantic versioning for the documented public API on stable releases. Branch ```v2``` is the latest stable release. ```master``` can contain new features or behavior that will change or be removed before being merged to the stable ```v2``` branch (in practice, this occurs very rarely). + +Consider using a vendoring +tool such as [godep](https://github.com/tools/godep) or importing pgx via ```import +"gopkg.in/jackc/pgx.v2"``` to lock to the ```v2``` branch. diff --git a/aclitem_parse_test.go b/aclitem_parse_test.go new file mode 100644 index 00000000..5c7c748f --- /dev/null +++ b/aclitem_parse_test.go @@ -0,0 +1,126 @@ +package pgx + +import ( + "reflect" + "testing" +) + +func TestEscapeAclItem(t *testing.T) { + tests := []struct { + input string + expected string + }{ + { + "foo", + "foo", + }, + { + `foo, "\}`, + `foo\, \"\\\}`, + }, + } + + for i, tt := range tests { + actual, err := escapeAclItem(tt.input) + + if err != nil { + t.Errorf("%d. Unexpected error %v", i, err) + } + + if actual != tt.expected { + t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual) + } + } +} + +func TestParseAclItemArray(t *testing.T) { + tests := []struct { + input string + expected []AclItem + errMsg string + }{ + { + "", + []AclItem{}, + "", + }, + { + "one", + []AclItem{"one"}, + "", + }, + { + `"one"`, + []AclItem{"one"}, + "", + }, + { + "one,two,three", + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one","two","three"`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one",two,"three"`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `one,two,"three"`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one","two",three`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one","t w o",three`, + []AclItem{"one", "t w o", "three"}, + "", + }, + { + `"one","t, w o\"\}\\",three`, + []AclItem{"one", `t, w o"}\`, "three"}, + "", + }, + { + `"one","two",three"`, + []AclItem{"one", "two", `three"`}, + "", + }, + { + `"one","two,"three"`, + nil, + "unexpected rune after quoted value", + }, + { + `"one","two","three`, + nil, + "unexpected end of quoted value", + }, + } + + for i, tt := range tests { + actual, err := parseAclItemArray(tt.input) + + if err != nil { + if tt.errMsg == "" { + t.Errorf("%d. Unexpected error %v", i, err) + } else if err.Error() != tt.errMsg { + t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error()) + } + } else if tt.errMsg != "" { + t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg) + } + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual) + } + } +} diff --git a/bench_test.go b/bench_test.go index 99f65c2b..b08c2b4e 100644 --- a/bench_test.go +++ b/bench_test.go @@ -1,6 +1,9 @@ package pgx_test import ( + "bytes" + "fmt" + "strings" "testing" "time" @@ -397,3 +400,331 @@ func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { } } } + +const benchmarkWriteTableCreateSQL = `drop table if exists t; + +create table t( + varchar_1 varchar not null, + varchar_2 varchar not null, + varchar_null_1 varchar, + date_1 date not null, + date_null_1 date, + int4_1 int4 not null, + int4_2 int4 not null, + int4_null_1 int4, + tstz_1 timestamptz not null, + tstz_2 timestamptz, + bool_1 bool not null, + bool_2 bool not null, + bool_3 bool not null +); +` + +const benchmarkWriteTableInsertSQL = `insert into t( + varchar_1, + varchar_2, + varchar_null_1, + date_1, + date_null_1, + int4_1, + int4_2, + int4_null_1, + tstz_1, + tstz_2, + bool_1, + bool_2, + bool_3 +) values ( + $1::varchar, + $2::varchar, + $3::varchar, + $4::date, + $5::date, + $6::int4, + $7::int4, + $8::int4, + $9::timestamptz, + $10::timestamptz, + $11::bool, + $12::bool, + $13::bool +)` + +type benchmarkWriteTableCopyToSrc struct { + count int + idx int + row []interface{} +} + +func (s *benchmarkWriteTableCopyToSrc) Next() bool { + s.idx++ + return s.idx < s.count +} + +func (s *benchmarkWriteTableCopyToSrc) Values() ([]interface{}, error) { + return s.row, nil +} + +func (s *benchmarkWriteTableCopyToSrc) Err() error { + return nil +} + +func newBenchmarkWriteTableCopyToSrc(count int) pgx.CopyToSource { + return &benchmarkWriteTableCopyToSrc{ + count: count, + row: []interface{}{ + "varchar_1", + "varchar_2", + pgx.NullString{}, + time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), + pgx.NullTime{}, + 1, + 2, + pgx.NullInt32{}, + time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), + true, + false, + true, + }, + } +} + +func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyToSrc(n) + + tx, err := conn.Begin() + if err != nil { + b.Fatal(err) + } + + for src.Next() { + values, _ := src.Values() + if _, err = tx.Exec("insert_t", values...); err != nil { + b.Fatalf("Exec unexpectedly failed with: %v", err) + } + } + + err = tx.Commit() + if err != nil { + b.Fatal(err) + } + } +} + +// note this function is only used for benchmarks -- it doesn't escape tableName +// or columnNames +func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyToSource) (int, error) { + maxRowsPerInsert := 65535 / len(columnNames) + rowsThisInsert := 0 + rowCount := 0 + + sqlBuf := &bytes.Buffer{} + args := make(pgx.QueryArgs, 0) + + resetQuery := func() { + sqlBuf.Reset() + fmt.Fprintf(sqlBuf, "insert into %s(%s) values", tableName, strings.Join(columnNames, ", ")) + + args = args[0:0] + + rowsThisInsert = 0 + } + resetQuery() + + tx, err := conn.Begin() + if err != nil { + return 0, err + } + defer tx.Rollback() + + for rowSrc.Next() { + if rowsThisInsert > 0 { + sqlBuf.WriteByte(',') + } + + sqlBuf.WriteByte('(') + + values, err := rowSrc.Values() + if err != nil { + return 0, err + } + + for i, val := range values { + if i > 0 { + sqlBuf.WriteByte(',') + } + sqlBuf.WriteString(args.Append(val)) + } + + sqlBuf.WriteByte(')') + + rowsThisInsert++ + + if rowsThisInsert == maxRowsPerInsert { + _, err := tx.Exec(sqlBuf.String(), args...) + if err != nil { + return 0, err + } + + rowCount += rowsThisInsert + resetQuery() + } + } + + if rowsThisInsert > 0 { + _, err := tx.Exec(sqlBuf.String(), args...) + if err != nil { + return 0, err + } + + rowCount += rowsThisInsert + } + + if err := tx.Commit(); err != nil { + return 0, nil + } + + return rowCount, nil + +} + +func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyToSrc(n) + + _, err := multiInsert(conn, "t", + []string{"varchar_1", + "varchar_2", + "varchar_null_1", + "date_1", + "date_null_1", + "int4_1", + "int4_2", + "int4_null_1", + "tstz_1", + "tstz_2", + "bool_1", + "bool_2", + "bool_3"}, + src) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyToSrc(n) + + _, err := conn.CopyTo("t", + []string{"varchar_1", + "varchar_2", + "varchar_null_1", + "date_1", + "date_null_1", + "int4_1", + "int4_2", + "int4_null_1", + "tstz_1", + "tstz_2", + "bool_1", + "bool_2", + "bool_3"}, + src) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkWrite5RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 5) +} + +func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 5) +} + +func BenchmarkWrite5RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 5) +} + +func BenchmarkWrite10RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 10) +} + +func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 10) +} + +func BenchmarkWrite10RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 10) +} + +func BenchmarkWrite100RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 100) +} + +func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 100) +} + +func BenchmarkWrite100RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 100) +} + +func BenchmarkWrite1000RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 1000) +} + +func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 1000) +} + +func BenchmarkWrite1000RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 1000) +} + +func BenchmarkWrite10000RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 10000) +} + +func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 10000) +} + +func BenchmarkWrite10000RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 10000) +} diff --git a/conn.go b/conn.go index 752c3ddd..8fa4d400 100644 --- a/conn.go +++ b/conn.go @@ -63,8 +63,8 @@ type Conn struct { logLevel int mr msgReader fp *fastpath - pgsql_af_inet *byte - pgsql_af_inet6 *byte + pgsqlAfInet *byte + pgsqlAfInet6 *byte busy bool poolResetCount int preallocatedRows []Rows @@ -145,7 +145,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { return connect(config, nil, nil, nil) } -func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgsql_af_inet6 *byte) (c *Conn, err error) { +func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) { c = new(Conn) c.config = config @@ -157,13 +157,13 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgs } } - if pgsql_af_inet != nil { - c.pgsql_af_inet = new(byte) - *c.pgsql_af_inet = *pgsql_af_inet + if pgsqlAfInet != nil { + c.pgsqlAfInet = new(byte) + *c.pgsqlAfInet = *pgsqlAfInet } - if pgsql_af_inet6 != nil { - c.pgsql_af_inet6 = new(byte) - *c.pgsql_af_inet6 = *pgsql_af_inet6 + if pgsqlAfInet6 != nil { + c.pgsqlAfInet6 = new(byte) + *c.pgsqlAfInet6 = *pgsqlAfInet6 } if c.config.LogLevel != 0 { @@ -209,12 +209,21 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgs c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial } + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) + } err = c.connect(config, network, address, config.TLSConfig) if err != nil && config.UseFallbackTLS { + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, fmt.Sprintf("Connect with TLSConfig failed, trying FallbackTLSConfig: %v", err)) + } err = c.connect(config, network, address, config.FallbackTLSConfig) } if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, fmt.Sprintf("Connect failed: %v", err)) + } return nil, err } @@ -222,23 +231,14 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgs } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) - } c.conn, err = c.config.Dial(network, address) if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, fmt.Sprintf("Connection failed: %v", err)) - } return err } defer func() { if c != nil && err != nil { c.conn.Close() c.alive = false - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, err.Error()) - } } }() @@ -253,9 +253,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.log(LogLevelDebug, "Starting TLS handshake") } if err := c.startTLS(tlsConfig); err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, fmt.Sprintf("TLS failed: %v", err)) - } return err } } @@ -315,7 +312,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - if c.pgsql_af_inet == nil || c.pgsql_af_inet6 == nil { + if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil { err = c.loadInetConstants() if err != nil { return err @@ -372,8 +369,8 @@ func (c *Conn) loadInetConstants() error { return err } - c.pgsql_af_inet = &ipv4[0] - c.pgsql_af_inet6 = &ipv6[0] + c.pgsqlAfInet = &ipv4[0] + c.pgsqlAfInet6 = &ipv6[0] return nil } @@ -430,7 +427,7 @@ func ParseURI(uri string) (ConnConfig, error) { } ignoreKeys := map[string]struct{}{ - "sslmode": struct{}{}, + "sslmode": {}, } cp.RuntimeParams = make(map[string]string) @@ -446,7 +443,7 @@ func ParseURI(uri string) (ConnConfig, error) { return cp, nil } -var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) +var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) // ParseDSN parses a database DSN (data source name) into a ConnConfig // @@ -462,7 +459,7 @@ var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) func ParseDSN(s string) (ConnConfig, error) { var cp ConnConfig - m := dsn_regexp.FindAllStringSubmatch(s, -1) + m := dsnRegexp.FindAllStringSubmatch(s, -1) var sslmode string @@ -477,11 +474,11 @@ func ParseDSN(s string) (ConnConfig, error) { case "host": cp.Host = b[2] case "port": - if p, err := strconv.ParseUint(b[2], 10, 16); err != nil { + p, err := strconv.ParseUint(b[2], 10, 16) + if err != nil { return cp, err - } else { - cp.Port = uint16(p) } + cp.Port = uint16(p) case "dbname": cp.Database = b[2] case "sslmode": @@ -627,7 +624,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared if opts != nil { if len(opts.ParameterOIDs) > 65535 { - return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs))) + return nil, fmt.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)) } wbuf.WriteInt16(int16(len(opts.ParameterOIDs))) for _, oid := range opts.ParameterOIDs { @@ -917,7 +914,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOID, ByteaOID, Int2OID, Int4OID, Int8OID, Float4OID, Float8OID, TimestampTzOID, TimestampTzArrayOID, TimestampOID, TimestampArrayOID, DateOID, BoolArrayOID, ByteaArrayOID, Int2ArrayOID, Int4ArrayOID, Int8ArrayOID, Float4ArrayOID, Float8ArrayOID, TextArrayOID, VarcharArrayOID, OIDOID, InetOID, CidrOID, InetArrayOID, CidrArrayOID, RecordOID: + case BoolOID, ByteaOID, Int2OID, Int4OID, Int8OID, Float4OID, Float8OID, TimestampTzOID, TimestampTzArrayOID, TimestampOID, TimestampArrayOID, DateOID, BoolArrayOID, ByteaArrayOID, Int2ArrayOID, Int4ArrayOID, Int8ArrayOID, Float4ArrayOID, Float8ArrayOID, TextArrayOID, VarcharArrayOID, OIDOID, InetOID, CidrOID, InetArrayOID, CidrArrayOID, RecordOID, JSONOID, JSONBOID: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 358e0247..0b80d490 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -11,7 +11,6 @@ var tcpConnConfig *pgx.ConnConfig = nil var unixSocketConnConfig *pgx.ConnConfig = nil var md5ConnConfig *pgx.ConnConfig = nil var plainPasswordConnConfig *pgx.ConnConfig = nil -var noPasswordConnConfig *pgx.ConnConfig = nil var invalidUserConnConfig *pgx.ConnConfig = nil var tlsConnConfig *pgx.ConnConfig = nil var customDialerConnConfig *pgx.ConnConfig = nil @@ -20,7 +19,6 @@ var customDialerConnConfig *pgx.ConnConfig = nil // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} -// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"} // var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} // var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} // var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_pool.go b/conn_pool.go index 1627af10..67868769 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -28,8 +28,10 @@ type ConnPool struct { preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration pgTypes map[OID]PgType - pgsql_af_inet *byte - pgsql_af_inet6 *byte + pgsqlAfInet *byte + pgsqlAfInet6 *byte + txAfterClose func(tx *Tx) + rowsAfterClose func(rows *Rows) } type ConnPoolStat struct { @@ -38,6 +40,9 @@ type ConnPoolStat struct { AvailableConnections int // unused live connections } +// ErrAcquireTimeout occurs when an attempt to acquire a connection times out. +var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool") + // NewConnPool creates a new ConnPool. config.ConnConfig is passed through to // Connect directly. func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { @@ -68,6 +73,14 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p.logLevel = LogLevelNone } + p.txAfterClose = func(tx *Tx) { + p.Release(tx.Conn()) + } + + p.rowsAfterClose = func(rows *Rows) { + p.Release(rows.Conn()) + } + p.allConnections = make([]*Conn, 0, p.maxConnections) p.availableConnections = make([]*Conn, 0, p.maxConnections) p.preparedStatements = make(map[string]*PreparedStatement) @@ -121,7 +134,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Make sure the deadline (if it is) has not passed yet if p.deadlinePassed(deadline) { - return nil, errors.New("Timeout: Acquire connection timeout") + return nil, ErrAcquireTimeout } // If there is a deadline then start a timeout timer @@ -138,26 +151,25 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Create a new connection. // Careful here: createConnectionUnlocked() removes the current lock, // creates a connection and then locks it back. - if c, err := p.createConnectionUnlocked(); err == nil { - c.poolResetCount = p.resetCount - p.allConnections = append(p.allConnections, c) - return c, nil - } else { + c, err := p.createConnectionUnlocked() + if err != nil { return nil, err } - } else { - // All connections are in use and we cannot create more - if p.logLevel >= LogLevelWarn { - p.logger.Log(LogLevelWarn, "All connections in pool are busy - waiting...") - } + c.poolResetCount = p.resetCount + p.allConnections = append(p.allConnections, c) + return c, nil + } + // All connections are in use and we cannot create more + if p.logLevel >= LogLevelWarn { + p.logger.Log(LogLevelWarn, "All connections in pool are busy - waiting...") + } - // Wait until there is an available connection OR room to create a new connection - for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections { - if p.deadlinePassed(deadline) { - return nil, errors.New("Timeout: All connections in pool are busy") - } - p.cond.Wait() + // Wait until there is an available connection OR room to create a new connection + for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections { + if p.deadlinePassed(deadline) { + return nil, ErrAcquireTimeout } + p.cond.Wait() } // Stop the timer so that we do not spawn it on every acquire call. @@ -272,7 +284,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.pgTypes, p.pgsql_af_inet, p.pgsql_af_inet6) + c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6) if err != nil { return nil, err } @@ -308,8 +320,8 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { // all the known statements for the new connection. func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { p.pgTypes = c.PgTypes - p.pgsql_af_inet = c.pgsql_af_inet - p.pgsql_af_inet6 = c.pgsql_af_inet6 + p.pgsqlAfInet = c.pgsqlAfInet + p.pgsqlAfInet6 = c.pgsqlAfInet6 if p.afterConnect != nil { err := p.afterConnect(c) @@ -487,10 +499,13 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) { } } -func (p *ConnPool) txAfterClose(tx *Tx) { - p.Release(tx.Conn()) -} +// CopyTo acquires a connection, delegates the call to that connection, and releases the connection +func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + c, err := p.Acquire() + if err != nil { + return 0, err + } + defer p.Release(c) -func (p *ConnPool) rowsAfterClose(rows *Rows) { - p.Release(rows.Conn()) + return c.CopyTo(tableName, columnNames, rowSrc) } diff --git a/conn_pool_test.go b/conn_pool_test.go index 773a0272..db8702fb 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -40,7 +40,7 @@ func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) { func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) { startTime := time.Now() c, err := pool.Acquire() - return c, time.Now().Sub(startTime), err + return c, time.Since(startTime), err } func TestNewConnPool(t *testing.T) { @@ -215,7 +215,7 @@ func TestPoolNonBlockingConnections(t *testing.T) { // Prior to createConnectionUnlocked() use the test took // maxConnections * openTimeout seconds to complete. // With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds. - timeTaken := time.Now().Sub(startedAt) + timeTaken := time.Since(startedAt) if timeTaken > openTimeout+1*time.Second { t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken) } @@ -276,8 +276,8 @@ func TestPoolWithAcquireTimeoutSet(t *testing.T) { // ... then try to consume 1 more. It should fail after a short timeout. _, timeTaken, err := acquireWithTimeTaken(pool) - if err == nil || err.Error() != "Timeout: All connections in pool are busy" { - t.Fatalf("Expected error to be 'Timeout: All connections in pool are busy', instead it was '%v'", err) + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) } if timeTaken < connAllocTimeout { t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) @@ -366,12 +366,12 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) { if err != nil { t.Fatalf("Unable to Acquire: %v", err) } - rows, _ := c.Query("select 1") + rows, _ := c.Query("select 1, pg_sleep(0.02)") rows.Close() pool.Release(c) } - for i := 0; i < 1000; i++ { + for i := 0; i < 10; i++ { doSomething() } @@ -381,7 +381,7 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) { } var wg sync.WaitGroup - for i := 0; i < 1000; i++ { + for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() diff --git a/conn_test.go b/conn_test.go index 60f762ed..10f40552 100644 --- a/conn_test.go +++ b/conn_test.go @@ -914,7 +914,7 @@ func TestPrepareQueryManyParameters(t *testing.T) { args := make([]interface{}, 0, tt.count) for j := 0; j < tt.count; j++ { params = append(params, fmt.Sprintf("($%d::text)", j+1)) - args = append(args, strconv.FormatInt(int64(j), 10)) + args = append(args, strconv.Itoa(j)) } sql := "values" + strings.Join(params, ", ") diff --git a/copy_to.go b/copy_to.go new file mode 100644 index 00000000..91292bb0 --- /dev/null +++ b/copy_to.go @@ -0,0 +1,241 @@ +package pgx + +import ( + "bytes" + "fmt" +) + +// CopyToRows returns a CopyToSource interface over the provided rows slice +// making it usable by *Conn.CopyTo. +func CopyToRows(rows [][]interface{}) CopyToSource { + return ©ToRows{rows: rows, idx: -1} +} + +type copyToRows struct { + rows [][]interface{} + idx int +} + +func (ctr *copyToRows) Next() bool { + ctr.idx++ + return ctr.idx < len(ctr.rows) +} + +func (ctr *copyToRows) Values() ([]interface{}, error) { + return ctr.rows[ctr.idx], nil +} + +func (ctr *copyToRows) Err() error { + return nil +} + +// CopyToSource is the interface used by *Conn.CopyTo as the source for copy data. +type CopyToSource interface { + // Next returns true if there is another row and makes the next row data + // available to Values(). When there are no more rows available or an error + // has occurred it returns false. + Next() bool + + // Values returns the values for the current row. + Values() ([]interface{}, error) + + // Err returns any error that has been encountered by the CopyToSource. If + // this is not nil *Conn.CopyTo will abort the copy. + Err() error +} + +type copyTo struct { + conn *Conn + tableName string + columnNames []string + rowSrc CopyToSource + readerErrChan chan error +} + +func (ct *copyTo) readUntilReadyForQuery() { + for { + t, r, err := ct.conn.rxMsg() + if err != nil { + ct.readerErrChan <- err + close(ct.readerErrChan) + return + } + + switch t { + case readyForQuery: + ct.conn.rxReadyForQuery(r) + close(ct.readerErrChan) + return + case commandComplete: + case errorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(r) + default: + err = ct.conn.processContextFreeMsg(t, r) + if err != nil { + ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + } + } + } +} + +func (ct *copyTo) waitForReaderDone() error { + var err error + for err = range ct.readerErrChan { + } + return err +} + +func (ct *copyTo) run() (int, error) { + quotedTableName := quoteIdentifier(ct.tableName) + buf := &bytes.Buffer{} + for i, cn := range ct.columnNames { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(quoteIdentifier(cn)) + } + quotedColumnNames := buf.String() + + ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) + if err != nil { + return 0, err + } + + err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + if err != nil { + return 0, err + } + + err = ct.conn.readUntilCopyInResponse() + if err != nil { + return 0, err + } + + go ct.readUntilReadyForQuery() + defer ct.waitForReaderDone() + + wbuf := newWriteBuf(ct.conn, copyData) + + wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) + wbuf.WriteInt32(0) + wbuf.WriteInt32(0) + + var sentCount int + + for ct.rowSrc.Next() { + select { + case err = <-ct.readerErrChan: + return 0, err + default: + } + + if len(wbuf.buf) > 65536 { + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + // Directly manipulate wbuf to reset to reuse the same buffer + wbuf.buf = wbuf.buf[0:5] + wbuf.buf[0] = copyData + wbuf.sizeIdx = 1 + } + + sentCount++ + + values, err := ct.rowSrc.Values() + if err != nil { + ct.cancelCopyIn() + return 0, err + } + if len(values) != len(ct.columnNames) { + ct.cancelCopyIn() + return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + wbuf.WriteInt16(int16(len(ct.columnNames))) + for i, val := range values { + err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + if err != nil { + ct.cancelCopyIn() + return 0, err + } + + } + } + + if ct.rowSrc.Err() != nil { + ct.cancelCopyIn() + return 0, ct.rowSrc.Err() + } + + wbuf.WriteInt16(-1) // terminate the copy stream + + wbuf.startMsg(copyDone) + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + err = ct.waitForReaderDone() + if err != nil { + return 0, err + } + return sentCount, nil +} + +func (c *Conn) readUntilCopyInResponse() error { + for { + var t byte + var r *msgReader + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case copyInResponse: + return nil + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } +} + +func (ct *copyTo) cancelCopyIn() error { + wbuf := newWriteBuf(ct.conn, copyFail) + wbuf.WriteCString("client error: abort") + wbuf.closeMsg() + _, err := ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return err + } + + return nil +} + +// CopyTo uses the PostgreSQL copy protocol to perform bulk data insertion. +// It returns the number of rows copied and an error. +// +// CopyTo requires all values use the binary format. Almost all types +// implemented by pgx use the binary format by default. Types implementing +// Encoder can only be used if they encode to the binary format. +func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + ct := ©To{ + conn: c, + tableName: tableName, + columnNames: columnNames, + rowSrc: rowSrc, + readerErrChan: make(chan error), + } + + return ct.run() +} diff --git a/copy_to_test.go b/copy_to_test.go new file mode 100644 index 00000000..43cb5acc --- /dev/null +++ b/copy_to_test.go @@ -0,0 +1,428 @@ +package pgx_test + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func TestConnCopyToSmall(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToLarge(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz, + h bytea + )`) + + inputRows := [][]interface{}{} + + for i := 0; i < 10000; i++ { + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToJSON(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for _, oid := range []pgx.OID{pgx.JSONOID, pgx.JSONBOID} { + if _, ok := conn.PgTypes[oid]; !ok { + return // No JSON/JSONB type -- must be running against old PostgreSQL + } + } + + mustExec(t, conn, `create temporary table foo( + a json, + b jsonb + )`) + + inputRows := [][]interface{}{ + {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + {nil, nil}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToFailServerSideMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b varchar not null + )`) + + inputRows := [][]interface{}{ + {int32(1), "abc"}, + {int32(2), nil}, // this row should trigger a failure + {int32(3), "def"}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type failSource struct { + count int +} + +func (fs *failSource) Next() bool { + time.Sleep(time.Millisecond * 100) + fs.count++ + return fs.count < 100 +} + +func (fs *failSource) Values() ([]interface{}, error) { + if fs.count == 3 { + return []interface{}{nil}, nil + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (fs *failSource) Err() error { + return nil +} + +func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + startTime := time.Now() + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + endTime := time.Now() + copyTime := endTime.Sub(startTime) + if copyTime > time.Second { + t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFailSource struct { + count int + err error +} + +func (cfs *clientFailSource) Next() bool { + cfs.count++ + return cfs.count < 100 +} + +func (cfs *clientFailSource) Values() ([]interface{}, error) { + if cfs.count == 3 { + cfs.err = fmt.Errorf("client error") + return nil, cfs.err + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFailSource) Err() error { + return cfs.err +} + +func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFinalErrSource struct { + count int +} + +func (cfs *clientFinalErrSource) Next() bool { + cfs.count++ + return cfs.count < 5 +} + +func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFinalErrSource) Err() error { + return fmt.Errorf("final error") +} + +func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} diff --git a/doc.go b/doc.go index 7964aa82..248c7e26 100644 --- a/doc.go +++ b/doc.go @@ -50,7 +50,7 @@ pgx also implements QueryRow in the same style as database/sql. return err } -Use exec to execute a query that does not return a result set. +Use Exec to execute a query that does not return a result set. commandTag, err := conn.Exec("delete from widgets where id=$1", 42) if err != nil { @@ -81,43 +81,39 @@ releasing connections when you do not need that level of control. return err } -Transactions +Base Type Mapping -Transactions are started by calling Begin or BeginIso. The BeginIso variant -creates a transaction with a specified isolation level. +pgx maps between all common base types directly between Go and PostgreSQL. In +particular: - tx, err := conn.Begin() - if err != nil { - return err - } - // Rollback is safe to call even if the tx is already closed, so if - // the tx commits successfully, this is a no-op - defer tx.Rollback() + Go PostgreSQL + ----------------------- + string varchar + text - _, err = tx.Exec("insert into foo(id) values (1)") - if err != nil { - return err - } + // Integers are automatically be converted to any other integer type if + // it can be done without overflow or underflow. + int8 + int16 smallint + int32 int + int64 bigint + int + uint8 + uint16 + uint32 + uint64 + uint - err = tx.Commit() - if err != nil { - return err - } + // Floats are strict and do not automatically convert like integers. + float32 float4 + float64 float8 -Listen and Notify + time.Time date + timestamp + timestamptz -pgx can listen to the PostgreSQL notification system with the -WaitForNotification function. It takes a maximum time to wait for a -notification. + []byte bytea - err := conn.Listen("channelname") - if err != nil { - return nil - } - - if notification, err := conn.WaitForNotification(time.Second); err != nil { - // do something with notification - } Null Mapping @@ -136,7 +132,7 @@ Array Mapping pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type. Go slices of native types do not -support nulls, so if a PostgreSQL array that contains a slice is read into a +support nulls, so if a PostgreSQL array that contains a null is read into a native Go slice an error will occur. Hstore Mapping @@ -192,6 +188,64 @@ the raw bytes returned by PostgreSQL. This can be especially useful for reading varchar, text, json, and jsonb values directly into a []byte and avoiding the type conversion from string. +Transactions + +Transactions are started by calling Begin or BeginIso. The BeginIso variant +creates a transaction with a specified isolation level. + + tx, err := conn.Begin() + if err != nil { + return err + } + // Rollback is safe to call even if the tx is already closed, so if + // the tx commits successfully, this is a no-op + defer tx.Rollback() + + _, err = tx.Exec("insert into foo(id) values (1)") + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + +Copy Protocol + +Use CopyTo to efficiently insert multiple rows at a time using the PostgreSQL +copy protocol. CopyTo accepts a CopyToSource interface. If the data is already +in a [][]interface{} use CopyToRows to wrap it in a CopyToSource interface. Or +implement CopyToSource to avoid buffering the entire data set in memory. + + rows := [][]interface{}{ + {"John", "Smith", int32(36)}, + {"Jane", "Doe", int32(29)}, + } + + copyCount, err := conn.CopyTo( + "people", + []string{"first_name", "last_name", "age"}, + pgx.CopyToRows(rows), + ) + +CopyTo can be faster than an insert with as few as 5 rows. + +Listen and Notify + +pgx can listen to the PostgreSQL notification system with the +WaitForNotification function. It takes a maximum time to wait for a +notification. + + err := conn.Listen("channelname") + if err != nil { + return nil + } + + if notification, err := conn.WaitForNotification(time.Second); err != nil { + // do something with notification + } + TLS The pgx ConnConfig struct has a TLSConfig field. If this field is diff --git a/fastpath.go b/fastpath.go index a1212e1b..28f88d5e 100644 --- a/fastpath.go +++ b/fastpath.go @@ -4,8 +4,6 @@ import ( "encoding/binary" ) -type fastpathArg []byte - func newFastpath(cn *Conn) *fastpath { return &fastpath{cn: cn, fns: make(map[string]OID)} } diff --git a/hstore.go b/hstore.go index a5d40cce..0ab9f779 100644 --- a/hstore.go +++ b/hstore.go @@ -15,7 +15,6 @@ const ( hsVal hsNul hsNext - hsEnd ) type hstoreParser struct { diff --git a/hstore_test.go b/hstore_test.go index dba5206b..c948f0cd 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -85,23 +85,23 @@ func TestNullHstoreTranscode(t *testing.T) { {pgx.NullHstore{}, "null"}, {pgx.NullHstore{Valid: true}, "empty"}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}}, Valid: true}, "single key/value"}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar", Valid: true}, "baz": pgx.NullString{String: "quz", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}}, Valid: true}, "multiple key/values"}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"NULL": pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}}, Valid: true}, `string "NULL" key`}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "NULL", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}}, Valid: true}, `string "NULL" value`}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "", Valid: false}}, + Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}}, Valid: true}, `NULL value`}, } @@ -120,36 +120,36 @@ func TestNullHstoreTranscode(t *testing.T) { } for _, sst := range specialStringTests { tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input + "foo": pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}}, Valid: true}, "key with " + sst.description + " at beginning"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}}, Valid: true}, "key with " + sst.description + " in middle"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input: pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}}, Valid: true}, "key with " + sst.description + " at end"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input: pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}}, Valid: true}, "key is " + sst.description}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: sst.input + "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}}, Valid: true}, "value with " + sst.description + " at beginning"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar" + sst.input + "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}}, Valid: true}, "value with " + sst.description + " in middle"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar" + sst.input, Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}}, Valid: true}, "value with " + sst.description + " at end"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: sst.input, Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}}, Valid: true}, "value is " + sst.description}) } diff --git a/messages.go b/messages.go index e4bdfb2c..c2964b82 100644 --- a/messages.go +++ b/messages.go @@ -25,6 +25,10 @@ const ( noData = 'n' closeComplete = '3' flush = 'H' + copyInResponse = 'G' + copyData = 'd' + copyFail = 'f' + copyDone = 'c' ) type startupMessage struct { @@ -35,10 +39,10 @@ func newStartupMessage() *startupMessage { return &startupMessage{map[string]string{}} } -func (self *startupMessage) Bytes() (buf []byte) { +func (s *startupMessage) Bytes() (buf []byte) { buf = make([]byte, 8, 128) binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber)) - for key, value := range self.options { + for key, value := range s.options { buf = append(buf, key...) buf = append(buf, 0) buf = append(buf, value...) @@ -49,8 +53,6 @@ func (self *startupMessage) Bytes() (buf []byte) { return buf } -type OID int32 - type FieldDescription struct { Name string Table OID @@ -85,8 +87,8 @@ type PgError struct { Routine string } -func (self PgError) Error() string { - return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")" +func (pe PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } func newWriteBuf(c *Conn, t byte) *WriteBuf { @@ -95,7 +97,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf { return &c.writeBuf } -// WrifeBuf is used build messages to send to the PostgreSQL server. It is used +// WriteBuf is used build messages to send to the PostgreSQL server. It is used // by the Encoder interface when implementing custom encoders. type WriteBuf struct { buf []byte @@ -128,12 +130,24 @@ func (wb *WriteBuf) WriteInt16(n int16) { wb.buf = append(wb.buf, b...) } +func (wb *WriteBuf) WriteUint16(n uint16) { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, n) + wb.buf = append(wb.buf, b...) +} + func (wb *WriteBuf) WriteInt32(n int32) { b := make([]byte, 4) binary.BigEndian.PutUint32(b, uint32(n)) wb.buf = append(wb.buf, b...) } +func (wb *WriteBuf) WriteUint32(n uint32) { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, n) + wb.buf = append(wb.buf, b...) +} + func (wb *WriteBuf) WriteInt64(n int64) { b := make([]byte, 8) binary.BigEndian.PutUint64(b, uint64(n)) diff --git a/msg_reader.go b/msg_reader.go index b5848946..43e80d98 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -10,7 +10,6 @@ import ( // msgReader is a helper that reads values from a PostgreSQL message. type msgReader struct { reader *bufio.Reader - buf [128]byte msgBytesRemaining int32 err error log func(lvl int, msg string, ctx ...interface{}) @@ -47,10 +46,15 @@ func (r *msgReader) rxMsg() (byte, error) { } } - b := r.buf[0:5] - _, err := io.ReadFull(r.reader, b) + b, err := r.reader.Peek(5) + if err != nil { + r.fatal(err) + return 0, err + } + msgType := b[0] r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4 - return b[0], err + r.reader.Discard(5) + return msgType, nil } func (r *msgReader) readByte() byte { @@ -58,7 +62,7 @@ func (r *msgReader) readByte() byte { return 0 } - r.msgBytesRemaining -= 1 + r.msgBytesRemaining-- if r.msgBytesRemaining < 0 { r.fatal(errors.New("read past end of message")) return 0 @@ -88,8 +92,7 @@ func (r *msgReader) readInt16() int16 { return 0 } - b := r.buf[0:2] - _, err := io.ReadFull(r.reader, b) + b, err := r.reader.Peek(2) if err != nil { r.fatal(err) return 0 @@ -97,6 +100,8 @@ func (r *msgReader) readInt16() int16 { n := int16(binary.BigEndian.Uint16(b)) + r.reader.Discard(2) + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } @@ -115,8 +120,7 @@ func (r *msgReader) readInt32() int32 { return 0 } - b := r.buf[0:4] - _, err := io.ReadFull(r.reader, b) + b, err := r.reader.Peek(4) if err != nil { r.fatal(err) return 0 @@ -124,6 +128,8 @@ func (r *msgReader) readInt32() int32 { n := int32(binary.BigEndian.Uint32(b)) + r.reader.Discard(4) + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } @@ -131,6 +137,62 @@ func (r *msgReader) readInt32() int32 { return n } +func (r *msgReader) readUint16() uint16 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 2 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(2) + if err != nil { + r.fatal(err) + return 0 + } + + n := uint16(binary.BigEndian.Uint16(b)) + + r.reader.Discard(2) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + +func (r *msgReader) readUint32() uint32 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 4 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(4) + if err != nil { + r.fatal(err) + return 0 + } + + n := uint32(binary.BigEndian.Uint32(b)) + + r.reader.Discard(4) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + func (r *msgReader) readInt64() int64 { if r.err != nil { return 0 @@ -142,8 +204,7 @@ func (r *msgReader) readInt64() int64 { return 0 } - b := r.buf[0:8] - _, err := io.ReadFull(r.reader, b) + b, err := r.reader.Peek(8) if err != nil { r.fatal(err) return 0 @@ -151,6 +212,8 @@ func (r *msgReader) readInt64() int64 { n := int64(binary.BigEndian.Uint64(b)) + r.reader.Discard(8) + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } @@ -190,32 +253,34 @@ func (r *msgReader) readCString() string { } // readString reads count bytes and returns as string -func (r *msgReader) readString(count int32) string { +func (r *msgReader) readString(countI32 int32) string { if r.err != nil { return "" } - r.msgBytesRemaining -= count + r.msgBytesRemaining -= countI32 if r.msgBytesRemaining < 0 { r.fatal(errors.New("read past end of message")) return "" } - var b []byte - if count <= int32(len(r.buf)) { - b = r.buf[0:int(count)] + count := int(countI32) + var s string + + if r.reader.Buffered() >= count { + buf, _ := r.reader.Peek(count) + s = string(buf) + r.reader.Discard(count) } else { - b = make([]byte, int(count)) + buf := make([]byte, count) + _, err := io.ReadFull(r.reader, buf) + if err != nil { + r.fatal(err) + return "" + } + s = string(buf) } - _, err := io.ReadFull(r.reader, b) - if err != nil { - r.fatal(err) - return "" - } - - s := string(b) - if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) } diff --git a/query.go b/query.go index 34035794..30e0476e 100644 --- a/query.go +++ b/query.go @@ -298,13 +298,17 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if vr.Type().DataType == JSONOID || vr.Type().DataType == JSONBOID { + } else if vr.Type().DataType == JSONOID { // Because the argument passed to decodeJSON will escape the heap. // This allows d to be stack allocated and only copied to the heap when // we actually are decoding JSON. This saves one memory allocation per // row. d2 := d decodeJSON(vr, &d2) + } else if vr.Type().DataType == JSONBOID { + // Same trick as above for getting stack allocation + d2 := d + decodeJSONB(vr, &d2) } else { if err := Decode(vr, d); err != nil { rows.Fatal(scanArgError{col: i, err: err}) @@ -393,7 +397,7 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, d) case JSONBOID: var d interface{} - decodeJSON(vr, &d) + decodeJSONB(vr, &d) values = append(values, d) default: rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) diff --git a/query_test.go b/query_test.go index 21496c19..791b65cc 100644 --- a/query_test.go +++ b/query_test.go @@ -3,11 +3,12 @@ package pgx_test import ( "bytes" "database/sql" - "github.com/jackc/pgx" "strings" "testing" "time" + "github.com/jackc/pgx" + "github.com/shopspring/decimal" ) @@ -784,7 +785,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) { t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) } - if bytes.Compare(actual, tt.expected) != 0 { + if !bytes.Equal(actual, tt.expected) { t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) } @@ -1281,7 +1282,7 @@ func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { } var num decimal.Decimal - err = conn.QueryRow("select $1::decimal", expected).Scan(&num) + err = conn.QueryRow("select $1::decimal", &expected).Scan(&num) if err != nil { t.Fatalf("Scan failed: %v", err) } diff --git a/sql.go b/sql.go index 9445c263..7ee0f2a0 100644 --- a/sql.go +++ b/sql.go @@ -14,7 +14,7 @@ func init() { placeholders = make([]string, 64) for i := 1; i < 64; i++ { - placeholders[i] = "$" + strconv.FormatInt(int64(i), 10) + placeholders[i] = "$" + strconv.Itoa(i) } } @@ -25,5 +25,5 @@ func (qa *QueryArgs) Append(v interface{}) string { if len(*qa) < len(placeholders) { return placeholders[len(*qa)] } - return "$" + strconv.FormatInt(int64(len(*qa)), 10) + return "$" + strconv.Itoa(len(*qa)) } diff --git a/sql_test.go b/sql_test.go index eafd92fa..dd036035 100644 --- a/sql_test.go +++ b/sql_test.go @@ -1,16 +1,17 @@ package pgx_test import ( - "github.com/jackc/pgx" "strconv" "testing" + + "github.com/jackc/pgx" ) func TestQueryArgs(t *testing.T) { var qa pgx.QueryArgs for i := 1; i < 512; i++ { - expectedPlaceholder := "$" + strconv.FormatInt(int64(i), 10) + expectedPlaceholder := "$" + strconv.Itoa(i) placeholder := qa.Append(i) if placeholder != expectedPlaceholder { t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder) diff --git a/tx.go b/tx.go index e5c90c23..36f99c28 100644 --- a/tx.go +++ b/tx.go @@ -158,6 +158,15 @@ func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +// CopyTo delegates to the underlying *Conn +func (tx *Tx) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + if tx.status != TxStatusInProgress { + return 0, ErrTxClosed + } + + return tx.conn.CopyTo(tableName, columnNames, rowSrc) +} + // Conn returns the *Conn this transaction is using. func (tx *Tx) Conn() *Conn { return tx.conn diff --git a/value_reader.go b/value_reader.go index a47a1d17..249b8ba3 100644 --- a/value_reader.go +++ b/value_reader.go @@ -60,6 +60,20 @@ func (r *ValueReader) ReadInt16() int16 { return r.mr.readInt16() } +func (r *ValueReader) ReadUint16() uint16 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 2 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readUint16() +} + func (r *ValueReader) ReadInt32() int32 { if r.err != nil { return 0 @@ -74,6 +88,20 @@ func (r *ValueReader) ReadInt32() int32 { return r.mr.readInt32() } +func (r *ValueReader) ReadUint32() uint32 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 4 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readUint32() +} + func (r *ValueReader) ReadInt64() int64 { if r.err != nil { return 0 @@ -89,7 +117,7 @@ func (r *ValueReader) ReadInt64() int64 { } func (r *ValueReader) ReadOID() OID { - return OID(r.ReadInt32()) + return OID(r.ReadUint32()) } // ReadString reads count bytes and returns as string diff --git a/values.go b/values.go index e8e5a6d5..231a37f7 100644 --- a/values.go +++ b/values.go @@ -5,9 +5,11 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "io" "math" "net" "reflect" + "regexp" "strconv" "strings" "time" @@ -17,11 +19,16 @@ import ( const ( BoolOID = 16 ByteaOID = 17 + CharOID = 18 + NameOID = 19 Int8OID = 20 Int2OID = 21 Int4OID = 23 TextOID = 25 OIDOID = 26 + TidOID = 27 + XidOID = 28 + CidOID = 29 JSONOID = 114 CidrOID = 650 CidrArrayOID = 651 @@ -38,6 +45,8 @@ const ( Int8ArrayOID = 1016 Float4ArrayOID = 1021 Float8ArrayOID = 1022 + AclItemOID = 1033 + AclItemArrayOID = 1034 InetArrayOID = 1041 VarcharOID = 1043 DateOID = 1082 @@ -64,11 +73,13 @@ const minInt = -maxInt - 1 // or binary). In theory the Scanner interface should be the one to determine // the format of the returned values. However, the query has already been // executed by the time Scan is called so it has no chance to set the format. -// So for types that should be returned in binary th +// So for types that should always be returned in binary the format should be +// set here. var DefaultTypeFormats map[string]int16 func init() { DefaultTypeFormats = map[string]int16{ + "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) "_bool": BinaryFormatCode, "_bytea": BinaryFormatCode, "_cidr": BinaryFormatCode, @@ -82,22 +93,30 @@ func init() { "_timestamp": BinaryFormatCode, "_timestamptz": BinaryFormatCode, "_varchar": BinaryFormatCode, + "aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) "bool": BinaryFormatCode, "bytea": BinaryFormatCode, + "char": BinaryFormatCode, + "cid": BinaryFormatCode, "cidr": BinaryFormatCode, "date": BinaryFormatCode, "float4": BinaryFormatCode, "float8": BinaryFormatCode, + "json": BinaryFormatCode, + "jsonb": BinaryFormatCode, "inet": BinaryFormatCode, "int2": BinaryFormatCode, "int4": BinaryFormatCode, "int8": BinaryFormatCode, + "name": BinaryFormatCode, "oid": BinaryFormatCode, "record": BinaryFormatCode, "text": BinaryFormatCode, + "tid": BinaryFormatCode, "timestamp": BinaryFormatCode, "timestamptz": BinaryFormatCode, "varchar": BinaryFormatCode, + "xid": BinaryFormatCode, } } @@ -225,16 +244,16 @@ type NullString struct { Valid bool // Valid is true if String is not NULL } -func (s *NullString) Scan(vr *ValueReader) error { +func (n *NullString) Scan(vr *ValueReader) error { // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later if vr.Len() == -1 { - s.String, s.Valid = "", false + n.String, n.Valid = "", false return nil } - s.Valid = true - s.String = decodeText(vr) + n.Valid = true + n.String = decodeText(vr) return vr.Err() } @@ -249,7 +268,156 @@ func (s NullString) Encode(w *WriteBuf, oid OID) error { return encodeString(w, oid, s.String) } -// NullInt16 represents an smallint that may be null. NullInt16 implements the +// AclItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +type AclItem string + +// NullAclItem represents a pgx.AclItem that may be null. NullAclItem implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullAclItem struct { + AclItem AclItem + Valid bool // Valid is true if AclItem is not NULL +} + +func (n *NullAclItem) Scan(vr *ValueReader) error { + if vr.Type().DataType != AclItemOID { + return SerializationError(fmt.Sprintf("NullAclItem.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.AclItem, n.Valid = "", false + return nil + } + + n.Valid = true + n.AclItem = AclItem(decodeText(vr)) + return vr.Err() +} + +// Particularly important to return TextFormatCode, seeing as Postgres +// only ever sends aclitem as text, not binary. +func (n NullAclItem) FormatCode() int16 { return TextFormatCode } + +func (n NullAclItem) Encode(w *WriteBuf, oid OID) error { + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeString(w, oid, string(n.AclItem)) +} + +// Name is a type used for PostgreSQL's special 63-byte +// name data type, used for identifiers like table names. +// The pg_class.relname column is a good example of where the +// name data type is used. +// +// Note that the underlying Go data type of pgx.Name is string, +// so there is no way to enforce the 63-byte length. Inputting +// a longer name into PostgreSQL will result in silent truncation +// to 63 bytes. +// +// Also, if you have custom-compiled PostgreSQL and set +// NAMEDATALEN to a different value, obviously that number of +// bytes applies, rather than the default 63. +type Name string + +// NullName represents a pgx.Name that may be null. NullName implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullName struct { + Name Name + Valid bool // Valid is true if Name is not NULL +} + +func (n *NullName) Scan(vr *ValueReader) error { + if vr.Type().DataType != NameOID { + return SerializationError(fmt.Sprintf("NullName.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Name, n.Valid = "", false + return nil + } + + n.Valid = true + n.Name = Name(decodeText(vr)) + return vr.Err() +} + +func (n NullName) FormatCode() int16 { return TextFormatCode } + +func (n NullName) Encode(w *WriteBuf, oid OID) error { + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeString(w, oid, string(n.Name)) +} + +// The pgx.Char type is for PostgreSQL's special 8-bit-only +// "char" type more akin to the C language's char type, or Go's byte type. +// (Note that the name in PostgreSQL itself is "char", in double-quotes, +// and not char.) It gets used a lot in PostgreSQL's system tables to hold +// a single ASCII character value (eg pg_class.relkind). +type Char byte + +// NullChar represents a pgx.Char that may be null. NullChar implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullChar struct { + Char Char + Valid bool // Valid is true if Char is not NULL +} + +func (n *NullChar) Scan(vr *ValueReader) error { + if vr.Type().DataType != CharOID { + return SerializationError(fmt.Sprintf("NullChar.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Char, n.Valid = 0, false + return nil + } + n.Valid = true + n.Char = decodeChar(vr) + return vr.Err() +} + +func (n NullChar) FormatCode() int16 { return BinaryFormatCode } + +func (n NullChar) Encode(w *WriteBuf, oid OID) error { + if oid != CharOID { + return SerializationError(fmt.Sprintf("NullChar.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeChar(w, oid, n.Char) +} + +// NullInt16 represents a smallint that may be null. NullInt16 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan for prepared and unprepared queries. // @@ -327,6 +495,213 @@ func (n NullInt32) Encode(w *WriteBuf, oid OID) error { return encodeInt32(w, oid, n.Int32) } +// OID (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, +// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented +// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h +// in the PostgreSQL sources. +type OID uint32 + +// NullOID represents a Command Identifier (OID) that may be null. NullOID implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullOID struct { + OID OID + Valid bool // Valid is true if OID is not NULL +} + +func (n *NullOID) Scan(vr *ValueReader) error { + if vr.Type().DataType != OIDOID { + return SerializationError(fmt.Sprintf("NullOID.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.OID, n.Valid = 0, false + return nil + } + n.Valid = true + n.OID = decodeOID(vr) + return vr.Err() +} + +func (n NullOID) FormatCode() int16 { return BinaryFormatCode } + +func (n NullOID) Encode(w *WriteBuf, oid OID) error { + if oid != OIDOID { + return SerializationError(fmt.Sprintf("NullOID.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeOID(w, oid, n.OID) +} + +// Xid is PostgreSQL's Transaction ID type. +// +// In later versions of PostgreSQL, it is the type used for the backend_xid +// and backend_xmin columns of the pg_stat_activity system view. +// +// Also, when one does +// +// select xmin, xmax, * from some_table; +// +// it is the data type of the xmin and xmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/postgres_ext.h as TransactionId +// in the PostgreSQL sources. +type Xid uint32 + +// NullXid represents a Transaction ID (Xid) that may be null. NullXid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullXid struct { + Xid Xid + Valid bool // Valid is true if Xid is not NULL +} + +func (n *NullXid) Scan(vr *ValueReader) error { + if vr.Type().DataType != XidOID { + return SerializationError(fmt.Sprintf("NullXid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Xid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Xid = decodeXid(vr) + return vr.Err() +} + +func (n NullXid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullXid) Encode(w *WriteBuf, oid OID) error { + if oid != XidOID { + return SerializationError(fmt.Sprintf("NullXid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeXid(w, oid, n.Xid) +} + +// Cid is PostgreSQL's Command Identifier type. +// +// When one does +// +// select cmin, cmax, * from some_table; +// +// it is the data type of the cmin and cmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/c.h as CommandId +// in the PostgreSQL sources. +type Cid uint32 + +// NullCid represents a Command Identifier (Cid) that may be null. NullCid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullCid struct { + Cid Cid + Valid bool // Valid is true if Cid is not NULL +} + +func (n *NullCid) Scan(vr *ValueReader) error { + if vr.Type().DataType != CidOID { + return SerializationError(fmt.Sprintf("NullCid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Cid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Cid = decodeCid(vr) + return vr.Err() +} + +func (n NullCid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullCid) Encode(w *WriteBuf, oid OID) error { + if oid != CidOID { + return SerializationError(fmt.Sprintf("NullCid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeCid(w, oid, n.Cid) +} + +// Tid is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type Tid struct { + BlockNumber uint32 + OffsetNumber uint16 +} + +// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullTid struct { + Tid Tid + Valid bool // Valid is true if Tid is not NULL +} + +func (n *NullTid) Scan(vr *ValueReader) error { + if vr.Type().DataType != TidOID { + return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Tid, n.Valid = Tid{BlockNumber: 0, OffsetNumber: 0}, false + return nil + } + n.Valid = true + n.Tid = decodeTid(vr) + return vr.Err() +} + +func (n NullTid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullTid) Encode(w *WriteBuf, oid OID) error { + if oid != TidOID { + return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeTid(w, oid, n.Tid) +} + // NullInt64 represents an bigint that may be null. NullInt64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -609,6 +984,8 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return Encode(wbuf, oid, v) case string: return encodeString(wbuf, oid, arg) + case []AclItem: + return encodeAclItemSlice(wbuf, oid, arg) case []byte: return encodeByteSlice(wbuf, oid, arg) case [][]byte: @@ -621,15 +998,17 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { if refVal.IsNil() { wbuf.WriteInt32(-1) return nil - } else { - arg = refVal.Elem().Interface() - return Encode(wbuf, oid, arg) } + arg = refVal.Elem().Interface() + return Encode(wbuf, oid, arg) } - if oid == JSONOID || oid == JSONBOID { + if oid == JSONOID { return encodeJSON(wbuf, oid, arg) } + if oid == JSONBOID { + return encodeJSONB(wbuf, oid, arg) + } switch arg := arg.(type) { case []string: @@ -642,6 +1021,16 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return encodeInt(wbuf, oid, arg) case uint: return encodeUInt(wbuf, oid, arg) + case Char: + return encodeChar(wbuf, oid, arg) + case AclItem: + // The aclitem data type goes over the wire using the same format as string, + // so just cast to string and use encodeString + return encodeString(wbuf, oid, string(arg)) + case Name: + // The name data type goes over the wire using the same format as string, + // so just cast to string and use encodeString + return encodeString(wbuf, oid, string(arg)) case int8: return encodeInt8(wbuf, oid, arg) case uint8: @@ -692,6 +1081,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return encodeIPNetSlice(wbuf, oid, arg) case OID: return encodeOID(wbuf, oid, arg) + case Xid: + return encodeXid(wbuf, oid, arg) + case Cid: + return encodeCid(wbuf, oid, arg) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -814,14 +1207,30 @@ func Decode(vr *ValueReader, d interface{}) error { return fmt.Errorf("%d is less than zero for uint64", n) } *v = uint64(n) + case *Char: + *v = decodeChar(vr) + case *AclItem: + // aclitem goes over the wire just like text + *v = AclItem(decodeText(vr)) + case *Name: + // name goes over the wire just like text + *v = Name(decodeText(vr)) case *OID: *v = decodeOID(vr) + case *Xid: + *v = decodeXid(vr) + case *Tid: + *v = decodeTid(vr) + case *Cid: + *v = decodeCid(vr) case *string: *v = decodeText(vr) case *float32: *v = decodeFloat4(vr) case *float64: *v = decodeFloat8(vr) + case *[]AclItem: + *v = decodeAclItemArray(vr) case *[]bool: *v = decodeBoolArray(vr) case *[]int16: @@ -892,14 +1301,13 @@ func Decode(vr *ValueReader, d interface{}) error { el.Set(reflect.Zero(el.Type())) } return nil - } else { - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - d = el.Interface() - return Decode(vr, d) } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + d = el.Interface() + return Decode(vr, d) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n := decodeInt(vr) if el.OverflowInt(n) { @@ -1008,6 +1416,30 @@ func decodeInt8(vr *ValueReader) int64 { return vr.ReadInt64() } +func decodeChar(vr *ValueReader) Char { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into char")) + return Char(0) + } + + if vr.Type().DataType != CharOID { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into char", vr.Type().DataType))) + return Char(0) + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Char(0) + } + + if vr.Len() != 1 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a char: %d", vr.Len()))) + return Char(0) + } + + return Char(vr.ReadByte()) +} + func decodeInt2(vr *ValueReader) int16 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int16")) @@ -1093,6 +1525,12 @@ func encodeUInt(w *WriteBuf, oid OID, value uint) error { return nil } +func encodeChar(w *WriteBuf, oid OID, value Char) error { + w.WriteInt32(1) + w.WriteByte(byte(value)) + return nil +} + func encodeInt8(w *WriteBuf, oid OID, value int8) error { switch oid { case Int2OID: @@ -1301,19 +1739,19 @@ func decodeInt4(vr *ValueReader) int32 { func decodeOID(vr *ValueReader) OID { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into OID")) - return 0 + return OID(0) } if vr.Type().DataType != OIDOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.OID", vr.Type().DataType))) - return 0 + return OID(0) } // OID needs to decode text format because it is used in loadPgTypes switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) - n, err := strconv.ParseInt(s, 10, 32) + n, err := strconv.ParseUint(s, 10, 32) if err != nil { vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) } @@ -1321,7 +1759,7 @@ func decodeOID(vr *ValueReader) OID { case BinaryFormatCode: if vr.Len() != 4 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) - return 0 + return OID(0) } return OID(vr.ReadInt32()) default: @@ -1336,7 +1774,153 @@ func encodeOID(w *WriteBuf, oid OID, value OID) error { } w.WriteInt32(4) - w.WriteInt32(int32(value)) + w.WriteUint32(uint32(value)) + + return nil +} + +func decodeXid(vr *ValueReader) Xid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Xid")) + return Xid(0) + } + + if vr.Type().DataType != XidOID { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Xid", vr.Type().DataType))) + return Xid(0) + } + + // Unlikely Xid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) + } + return Xid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) + return Xid(0) + } + return Xid(vr.ReadUint32()) + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Xid(0) + } +} + +func encodeXid(w *WriteBuf, oid OID, value Xid) error { + if oid != XidOID { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Xid", oid) + } + + w.WriteInt32(4) + w.WriteUint32(uint32(value)) + + return nil +} + +func decodeCid(vr *ValueReader) Cid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Cid")) + return Cid(0) + } + + if vr.Type().DataType != CidOID { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Cid", vr.Type().DataType))) + return Cid(0) + } + + // Unlikely Cid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) + } + return Cid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) + return Cid(0) + } + return Cid(vr.ReadUint32()) + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Cid(0) + } +} + +func encodeCid(w *WriteBuf, oid OID, value Cid) error { + if oid != CidOID { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Cid", oid) + } + + w.WriteInt32(4) + w.WriteUint32(uint32(value)) + + return nil +} + +// Note that we do not match negative numbers, because neither the +// BlockNumber nor OffsetNumber of a Tid can be negative. +var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) + +func decodeTid(vr *ValueReader) Tid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Tid")) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + if vr.Type().DataType != TidOID { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + // Unlikely Tid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + + match := tidRegexp.FindStringSubmatch(s) + if match == nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + blockNumber, err := strconv.ParseUint(s, 10, 16) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid BlockNumber part of a Tid: %v", s))) + } + + offsetNumber, err := strconv.ParseUint(s, 10, 16) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s))) + } + return Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber)} + case BinaryFormatCode: + if vr.Len() != 6 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + return Tid{BlockNumber: vr.ReadUint32(), OffsetNumber: vr.ReadUint16()} + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } +} + +func encodeTid(w *WriteBuf, oid OID, value Tid) error { + if oid != TidOID { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid) + } + + w.WriteInt32(6) + w.WriteUint32(value.BlockNumber) + w.WriteUint16(value.OffsetNumber) return nil } @@ -1463,7 +2047,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { return nil } - if vr.Type().DataType != JSONOID && vr.Type().DataType != JSONBOID { + if vr.Type().DataType != JSONOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) } @@ -1476,7 +2060,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { } func encodeJSON(w *WriteBuf, oid OID, value interface{}) error { - if oid != JSONOID && oid != JSONBOID { + if oid != JSONOID { return fmt.Errorf("cannot encode JSON into oid %v", oid) } @@ -1491,6 +2075,51 @@ func encodeJSON(w *WriteBuf, oid OID, value interface{}) error { return nil } +func decodeJSONB(vr *ValueReader, d interface{}) error { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != JSONBOID { + err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType)) + vr.Fatal(err) + return err + } + + bytes := vr.ReadBytes(vr.Len()) + if vr.Type().FormatCode == BinaryFormatCode { + if bytes[0] != 1 { + err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0])) + vr.Fatal(err) + return err + } + bytes = bytes[1:] + } + + err := json.Unmarshal(bytes, d) + if err != nil { + vr.Fatal(err) + } + return err +} + +func encodeJSONB(w *WriteBuf, oid OID, value interface{}) error { + if oid != JSONBOID { + return fmt.Errorf("cannot encode JSON into oid %v", oid) + } + + s, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("Failed to encode json from type: %T", value) + } + + w.WriteInt32(int32(len(s) + 1)) + w.WriteByte(1) // JSONB format header + w.WriteBytes(s) + + return nil +} + func decodeDate(vr *ValueReader) time.Time { var zeroTime time.Time @@ -1645,10 +2274,10 @@ func encodeIPNet(w *WriteBuf, oid OID, value net.IPNet) error { switch len(value.IP) { case net.IPv4len: size = 8 - family = *w.conn.pgsql_af_inet + family = *w.conn.pgsqlAfInet case net.IPv6len: size = 20 - family = *w.conn.pgsql_af_inet6 + family = *w.conn.pgsqlAfInet6 default: return fmt.Errorf("Unexpected IP length: %v", len(value.IP)) } @@ -2378,6 +3007,210 @@ func decodeTextArray(vr *ValueReader) []string { return a } +// escapeAclItem escapes an AclItem before it is added to +// its aclitem[] string representation. The PostgreSQL aclitem +// datatype itself can need escapes because it follows the +// formatting rules of SQL identifiers. Think of this function +// as escaping the escapes, so that PostgreSQL's array parser +// will do the right thing. +func escapeAclItem(acl string) (string, error) { + var escapedAclItem bytes.Buffer + reader := strings.NewReader(acl) + for { + rn, _, err := reader.ReadRune() + if err != nil { + if err == io.EOF { + // Here, EOF is an expected end state, not an error. + return escapedAclItem.String(), nil + } + // This error was not expected + return "", err + } + if needsEscape(rn) { + escapedAclItem.WriteRune('\\') + } + escapedAclItem.WriteRune(rn) + } +} + +// needsEscape determines whether or not a rune needs escaping +// before being placed in the textual representation of an +// aclitem[] array. +func needsEscape(rn rune) bool { + return rn == '\\' || rn == ',' || rn == '"' || rn == '}' +} + +// encodeAclItemSlice encodes a slice of AclItems in +// their textual represention for PostgreSQL. +func encodeAclItemSlice(w *WriteBuf, oid OID, aclitems []AclItem) error { + strs := make([]string, len(aclitems)) + var escapedAclItem string + var err error + for i := range strs { + escapedAclItem, err = escapeAclItem(string(aclitems[i])) + if err != nil { + return err + } + strs[i] = string(escapedAclItem) + } + + var buf bytes.Buffer + buf.WriteRune('{') + buf.WriteString(strings.Join(strs, ",")) + buf.WriteRune('}') + str := buf.String() + w.WriteInt32(int32(len(str))) + w.WriteBytes([]byte(str)) + return nil +} + +// parseAclItemArray parses the textual representation +// of the aclitem[] type. The textual representation is chosen because +// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin). +// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +// for formatting notes. +func parseAclItemArray(arr string) ([]AclItem, error) { + reader := strings.NewReader(arr) + // Difficult to guess a performant initial capacity for a slice of + // aclitems, but let's go with 5. + aclItems := make([]AclItem, 0, 5) + // A single value + aclItem := AclItem("") + for { + // Grab the first/next/last rune to see if we are dealing with a + // quoted value, an unquoted value, or the end of the string. + rn, _, err := reader.ReadRune() + if err != nil { + if err == io.EOF { + // Here, EOF is an expected end state, not an error. + return aclItems, nil + } + // This error was not expected + return nil, err + } + + if rn == '"' { + // Discard the opening quote of the quoted value. + aclItem, err = parseQuotedAclItem(reader) + } else { + // We have just read the first rune of an unquoted (bare) value; + // put it back so that ParseBareValue can read it. + err := reader.UnreadRune() + if err != nil { + return nil, err + } + aclItem, err = parseBareAclItem(reader) + } + + if err != nil { + if err == io.EOF { + // Here, EOF is an expected end state, not an error.. + aclItems = append(aclItems, aclItem) + return aclItems, nil + } + // This error was not expected. + return nil, err + } + aclItems = append(aclItems, aclItem) + } +} + +// parseBareAclItem parses a bare (unquoted) aclitem from reader +func parseBareAclItem(reader *strings.Reader) (AclItem, error) { + var aclItem bytes.Buffer + for { + rn, _, err := reader.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF. + // (io.EOF marks the end of a bare aclitem at the end of a string) + return AclItem(aclItem.String()), err + } + if rn == ',' { + // A comma marks the end of a bare aclitem. + return AclItem(aclItem.String()), nil + } else { + aclItem.WriteRune(rn) + } + } +} + +// parseQuotedAclItem parses an aclitem which is in double quotes from reader +func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) { + var aclItem bytes.Buffer + for { + rn, escaped, err := readPossiblyEscapedRune(reader) + if err != nil { + if err == io.EOF { + // Even when it is the last value, the final rune of + // a quoted aclitem should be the final closing quote, not io.EOF. + return AclItem(""), fmt.Errorf("unexpected end of quoted value") + } + // Return the read aclitem in case the error is a harmless io.EOF, + // which will be determined by the caller. + return AclItem(aclItem.String()), err + } + if !escaped && rn == '"' { + // An unescaped double quote marks the end of a quoted value. + // The next rune should either be a comma or the end of the string. + rn, _, err := reader.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF, + // which will be determined by the caller. + return AclItem(aclItem.String()), err + } + if rn != ',' { + return AclItem(""), fmt.Errorf("unexpected rune after quoted value") + } + return AclItem(aclItem.String()), nil + } + aclItem.WriteRune(rn) + } +} + +// Returns the next rune from r, unless it is a backslash; +// in that case, it returns the rune after the backslash. The second +// return value tells us whether or not the rune was +// preceeded by a backslash (escaped). +func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) { + rn, _, err := reader.ReadRune() + if err != nil { + return 0, false, err + } + if rn == '\\' { + // Discard the backslash and read the next rune. + rn, _, err = reader.ReadRune() + if err != nil { + return 0, false, err + } + return rn, true, nil + } + return rn, false, nil +} + +func decodeAclItemArray(vr *ValueReader) []AclItem { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into []AclItem")) + return nil + } + + str := vr.ReadString(vr.Len()) + + // Short-circuit empty array. + if str == "{}" { + return []AclItem{} + } + + // Remove the '{' at the front and the '}' at the end, + // so that parseAclItemArray doesn't have to deal with them. + str = str[1 : len(str)-1] + aclItems, err := parseAclItemArray(str) + if err != nil { + vr.Fatal(ProtocolError(err.Error())) + return nil + } + return aclItems +} + func encodeStringSlice(w *WriteBuf, oid OID, slice []string) error { var elOID OID switch oid { diff --git a/values_test.go b/values_test.go index 2ef9c774..6ab221f7 100644 --- a/values_test.go +++ b/values_test.go @@ -88,67 +88,74 @@ func TestJSONAndJSONBTranscode(t *testing.T) { if _, ok := conn.PgTypes[oid]; !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } - typename := conn.PgTypes[oid].Name - testJSONString(t, conn, typename) - testJSONStringPointer(t, conn, typename) - testJSONSingleLevelStringMap(t, conn, typename) - testJSONNestedMap(t, conn, typename) - testJSONStringArray(t, conn, typename) - testJSONInt64Array(t, conn, typename) - testJSONInt16ArrayFailureDueToOverflow(t, conn, typename) - testJSONStruct(t, conn, typename) + for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { + pgtype := conn.PgTypes[oid] + pgtype.DefaultFormat = format + conn.PgTypes[oid] = pgtype + + typename := conn.PgTypes[oid].Name + + testJSONString(t, conn, typename, format) + testJSONStringPointer(t, conn, typename, format) + testJSONSingleLevelStringMap(t, conn, typename, format) + testJSONNestedMap(t, conn, typename, format) + testJSONStringArray(t, conn, typename, format) + testJSONInt64Array(t, conn, typename, format) + testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format) + testJSONStruct(t, conn, typename, format) + } } } -func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) return } if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) + t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) return } } -func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) return } if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) + t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) return } } -func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) return } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) return } } -func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := map[string]interface{}{ "name": "Uncanny", "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, @@ -157,52 +164,52 @@ func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { var output map[string]interface{} err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) return } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) return } } -func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := []string{"foo", "bar", "baz"} var output []string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) } } -func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := []int64{1, 2, 234432} var output []int64 err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) } } -func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := []int{1, 2, 234432} var output []int16 err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { - t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err) + t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) } } -func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { type person struct { Name string `json:"name"` Age int `json:"age"` @@ -217,11 +224,11 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) } } @@ -561,6 +568,13 @@ func TestNullX(t *testing.T) { s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 + c pgx.NullChar + a pgx.NullAclItem + n pgx.NullName + oid pgx.NullOID + xid pgx.NullXid + cid pgx.NullCid + tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 f64 pgx.NullFloat64 @@ -582,6 +596,27 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, + {"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 1, Valid: true}}}, + {"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 0, Valid: false}}}, + {"select $1::oid", []interface{}{pgx.NullOID{OID: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 4294967295, Valid: true}}}, + {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}}, + {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}}, + {"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}}, + {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, + {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, + {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, + {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}}, + {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}}, + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, + // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, + {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, + {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, + {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}}, @@ -615,6 +650,52 @@ func TestNullX(t *testing.T) { } } +func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) { + if !reflect.DeepEqual(query, scan) { + t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan) + } +} + +func TestAclArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := "select $1::aclitem[]" + var scan []pgx.AclItem + + tests := []struct { + query []pgx.AclItem + }{ + { + []pgx.AclItem{}, + }, + { + []pgx.AclItem{"=r/postgres"}, + }, + { + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, + }, + { + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`}, + }, + } + for i, tt := range tests { + err := conn.QueryRow(sql, tt.query).Scan(&scan) + if err != nil { + // t.Errorf(`%d. error reading array: %v`, i, err) + t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query) + if pgerr, ok := err.(pgx.PgError); ok { + t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) + } + continue + } + assertAclItemSlicesEqual(t, tt.query, scan) + ensureConnValid(t, conn) + } +} + func TestArrayDecoding(t *testing.T) { t.Parallel() @@ -630,7 +711,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::bool[]", []bool{true, false, true}, &[]bool{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]bool))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]bool))) { t.Errorf("failed to encode bool[]") } }, @@ -638,7 +719,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]int16))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]int16))) { t.Errorf("failed to encode smallint[]") } }, @@ -646,7 +727,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]uint16))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { t.Errorf("failed to encode smallint[]") } }, @@ -654,7 +735,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]int32))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]int32))) { t.Errorf("failed to encode int[]") } }, @@ -662,7 +743,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]uint32))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { t.Errorf("failed to encode int[]") } }, @@ -670,7 +751,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]int64))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]int64))) { t.Errorf("failed to encode bigint[]") } }, @@ -678,7 +759,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]uint64))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { t.Errorf("failed to encode bigint[]") } }, @@ -686,7 +767,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]string))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]string))) { t.Errorf("failed to encode text[]") } }, @@ -694,7 +775,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { t.Errorf("failed to encode time.Time[] to timestamp[]") } }, @@ -702,7 +783,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { t.Errorf("failed to encode time.Time[] to timestamptz[]") } }, @@ -718,7 +799,7 @@ func TestArrayDecoding(t *testing.T) { for i := range queryBytesSliceSlice { qb := queryBytesSliceSlice[i] sb := scanBytesSliceSlice[i] - if bytes.Compare(qb, sb) != 0 { + if !bytes.Equal(qb, sb) { t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) } } @@ -1075,11 +1156,11 @@ func TestRowDecode(t *testing.T) { expected []interface{} }{ { - "select row(1, 'cat', '2015-01-01 08:12:42'::timestamptz)", + "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", []interface{}{ int32(1), "cat", - time.Date(2015, 1, 1, 8, 12, 42, 0, time.Local), + time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), }, }, }