mirror of https://github.com/jackc/pgx.git
Make v3 main release
commit
f79e52f1ee
|
@ -22,3 +22,4 @@ _testmain.go
|
|||
*.exe
|
||||
|
||||
conn_config_test.go
|
||||
.envrc
|
||||
|
|
14
.travis.yml
14
.travis.yml
|
@ -1,8 +1,7 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- 1.7.4
|
||||
- 1.6.4
|
||||
- 1.8
|
||||
- tip
|
||||
|
||||
# Derived from https://github.com/lib/pq/blob/master/.travis.yml
|
||||
|
@ -28,6 +27,8 @@ before_install:
|
|||
- sudo /etc/init.d/postgresql restart
|
||||
|
||||
env:
|
||||
global:
|
||||
- PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
matrix:
|
||||
- PGVERSION=9.6
|
||||
- PGVERSION=9.5
|
||||
|
@ -40,7 +41,7 @@ env:
|
|||
before_script:
|
||||
- mv conn_config_test.go.travis conn_config_test.go
|
||||
- psql -U postgres -c 'create database pgx_test'
|
||||
- "[[ \"${PGVERSION}\" = '9.0' ]] && psql -U postgres -f /usr/share/postgresql/9.0/contrib/hstore.sql pgx_test || psql -U postgres pgx_test -c 'create extension hstore'"
|
||||
- psql -U postgres pgx_test -c 'create extension hstore'
|
||||
- psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
|
||||
- psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
|
||||
- psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
|
||||
|
@ -51,9 +52,14 @@ install:
|
|||
- go get -u github.com/shopspring/decimal
|
||||
- go get -u gopkg.in/inconshreveable/log15.v2
|
||||
- go get -u github.com/jackc/fake
|
||||
- go get -u github.com/lib/pq
|
||||
- go get -u github.com/hashicorp/go-version
|
||||
- go get -u github.com/satori/go.uuid
|
||||
- go get -u github.com/Sirupsen/logrus
|
||||
- go get -u github.com/pkg/errors
|
||||
|
||||
script:
|
||||
- go test -v -race -short ./...
|
||||
- go test -v -race ./...
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
|
|
60
CHANGELOG.md
60
CHANGELOG.md
|
@ -1,4 +1,62 @@
|
|||
# Unreleased
|
||||
# Unreleased V3
|
||||
|
||||
## Changes
|
||||
|
||||
* Pid to PID in accordance with Go naming conventions.
|
||||
* Conn.Pid changed to accessor method Conn.PID()
|
||||
* Conn.SecretKey removed
|
||||
* Remove Conn.TxStatus
|
||||
* Logger interface reduced to single Log method.
|
||||
* Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode.
|
||||
* Transaction isolation level constants are now typed strings instead of bare strings.
|
||||
* Conn.WaitForNotification now takes context.Context instead of time.Duration for cancellation support.
|
||||
* Conn.WaitForNotification no longer automatically pings internally every 15 seconds.
|
||||
* ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support.
|
||||
* Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228
|
||||
* No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed.
|
||||
* Remove CopyTo (functionality is now in CopyFrom)
|
||||
* OID constants moved from pgx to pgtype package
|
||||
* Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system
|
||||
* Removed ValueReader
|
||||
* ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset.
|
||||
* Removed Rows.Fatal(error)
|
||||
* Removed Rows.AfterClose()
|
||||
* Removed Rows.Conn()
|
||||
* Removed Tx.AfterClose()
|
||||
* Removed Tx.Conn()
|
||||
* Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR
|
||||
* Replaced stdlib.OpenFromConnPool with DriverConfig system
|
||||
|
||||
## Features
|
||||
|
||||
* Entirely revamped pluggable type system that supports approximately 60 PostgreSQL types.
|
||||
* Types support database/sql interfaces and therefore can be used with other drivers
|
||||
* Added context methods supporting cancellation where appropriate
|
||||
* Added simple query protocol support
|
||||
* Added single round-trip query mode
|
||||
* Added batch query operations
|
||||
* Added OnNotice
|
||||
* github.com/pkg/errors used where possible for errors
|
||||
* Added stdlib.DriverConfig which directly allows full configuration of underlying pgx connections without needing to use a pgx.ConnPool
|
||||
* Added AcquireConn and ReleaseConn to stdlib to allow acquiring a connection from a database/sql connection.
|
||||
|
||||
# 2.11.0 (June 5, 2017)
|
||||
|
||||
## Fixes
|
||||
|
||||
* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock)
|
||||
|
||||
## Features
|
||||
|
||||
* .pgpass support (j7b)
|
||||
* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen)
|
||||
* Add ParseConnectionString (James Lawrence)
|
||||
|
||||
## Performance
|
||||
|
||||
* Optimize HStore encoding (René Kroon)
|
||||
|
||||
# 2.10.0 (March 17, 2017)
|
||||
|
||||
## Fixes
|
||||
|
||||
|
|
74
README.md
74
README.md
|
@ -1,63 +1,57 @@
|
|||
[](https://godoc.org/github.com/jackc/pgx)
|
||||
|
||||
# Pgx
|
||||
# pgx - PostgreSQL Driver and Toolkit
|
||||
|
||||
## Master Branch
|
||||
|
||||
This is the `master` branch which tracks the stable release of the current
|
||||
version. At the moment this is `v2`. The `v3` branch which is currently in beta.
|
||||
General release is planned for July. `v3` is considered to be stable in the
|
||||
sense of lack of known bugs, but the API is not considered stable until general
|
||||
release. No further changes are planned, but the beta process may surface
|
||||
desirable changes. If possible API changes are acceptable, then `v3` is the
|
||||
recommended branch for new development. Regardless, please lock to the `v2` or
|
||||
`v3` branch as when `v3` is released breaking changes will be applied to the
|
||||
master branch.
|
||||
|
||||
Pgx is a pure Go database connection library designed specifically for
|
||||
PostgreSQL. Pgx is different from other drivers such as
|
||||
[pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a
|
||||
database/sql compatible driver, pgx is primarily intended to be used directly.
|
||||
It offers a native interface similar to database/sql that offers better
|
||||
performance and more features.
|
||||
pgx is a pure Go driver and toolkit for PostgreSQL. pgx is different from other drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a database/sql compatible driver, pgx is also usable directly. It offers a native interface similar to database/sql that offers better performance and more features.
|
||||
|
||||
## Features
|
||||
|
||||
Pgx supports many additional features beyond what is available through database/sql.
|
||||
pgx supports many additional features beyond what is available through database/sql.
|
||||
|
||||
* Listen / notify
|
||||
* Transaction isolation level control
|
||||
* Support for approximately 60 different PostgreSQL types
|
||||
* Batch queries
|
||||
* Single-round trip query mode
|
||||
* Full TLS connection control
|
||||
* Binary format support for custom types (can be much faster)
|
||||
* Copy from protocol support for faster bulk data loads
|
||||
* Logging support
|
||||
* Configurable connection pool with after connect hooks to do arbitrary connection setup
|
||||
* Copy protocol support for faster bulk data loads
|
||||
* Extendable logging support including built-in support for log15 and logrus
|
||||
* Connection pool with after connect hook to do arbitrary connection setup
|
||||
* Listen / notify
|
||||
* PostgreSQL array to Go slice mapping for integers, floats, and strings
|
||||
* Hstore support
|
||||
* JSON and JSONB support
|
||||
* Maps inet and cidr PostgreSQL types to net.IPNet and net.IP
|
||||
* Large object support
|
||||
* Null mapping to Null* struct or pointer to pointer.
|
||||
* NULL mapping to Null* struct or pointer to pointer.
|
||||
* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types
|
||||
* Logical replication connections, including receiving WAL and sending standby status updates
|
||||
* Notice response handling (this is different than listen / notify)
|
||||
|
||||
## Performance
|
||||
|
||||
Pgx performs roughly equivalent to [pq](http://godoc.org/github.com/lib/pq) and
|
||||
[go-pg](https://github.com/go-pg/pg) for selecting a single column from a single
|
||||
row, but it is substantially faster when selecting multiple entire rows (6893
|
||||
queries/sec for pgx vs. 3968 queries/sec for pq -- 73% faster).
|
||||
pgx performs roughly equivalent to [go-pg](https://github.com/go-pg/pg) and is almost always faster than [pq](http://godoc.org/github.com/lib/pq). When parsing large result sets the percentage difference can be significant (16483 queries/sec for pgx vs. 10106 queries/sec for pq -- 63% faster).
|
||||
|
||||
See this [gist](https://gist.github.com/jackc/d282f39e088b495fba3e) for the
|
||||
underlying benchmark results or checkout
|
||||
[go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself.
|
||||
In many use cases a significant cause of latency is network round trips between the application and the server. pgx supports query batching to bundle multiple queries into a single round trip. Even in the case of a connection with the lowest possible latency, a local Unix domain socket, batching as few as three queries together can yield an improvement of 57%. With a typical network connection the results can be even more substantial.
|
||||
|
||||
## database/sql
|
||||
See this [gist](https://gist.github.com/jackc/4996e8648a0c59839bff644f49d6e434) for the underlying benchmark results or checkout [go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself.
|
||||
|
||||
Import the ```github.com/jackc/pgx/stdlib``` package to use pgx as a driver for
|
||||
database/sql. It is possible to retrieve a pgx connection from database/sql on
|
||||
demand. This allows using the database/sql interface in most places, but using
|
||||
pgx directly when more performance or PostgreSQL specific features are needed.
|
||||
In addition to the native driver, pgx also includes a number of packages that provide additional functionality.
|
||||
|
||||
## github.com/jackc/pgxstdlib
|
||||
|
||||
database/sql compatibility layer for pgx. pgx can be used as a normal database/sql driver, but at any time the native interface may be acquired for more performance or PostgreSQL specific functionality.
|
||||
|
||||
## github.com/jackc/pgx/pgtype
|
||||
|
||||
Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver.
|
||||
|
||||
## github.com/jackc/pgx/pgproto3
|
||||
|
||||
pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling.
|
||||
|
||||
## github.com/jackc/pgx/pgmock
|
||||
|
||||
pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler).
|
||||
|
||||
## Documentation
|
||||
|
||||
|
@ -132,6 +126,8 @@ Change the following settings in your postgresql.conf:
|
|||
max_wal_senders=5
|
||||
max_replication_slots=5
|
||||
|
||||
Set `replicationConnConfig` appropriately in `conn_config_test.go`.
|
||||
|
||||
## Version Policy
|
||||
|
||||
pgx follows semantic versioning for the documented public API on stable releases. Branch `v2` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v2` branch (in practice, this occurs very rarely).
|
||||
pgx follows semantic versioning for the documented public API on stable releases. Branch `v3` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v3` branch (in practice, this occurs very rarely). `v2` is the previous stable release.
|
||||
|
|
|
@ -1,126 +0,0 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEscapeAclItem(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"foo",
|
||||
"foo",
|
||||
},
|
||||
{
|
||||
`foo, "\}`,
|
||||
`foo\, \"\\\}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual, err := escapeAclItem(tt.input)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected error %v", i, err)
|
||||
}
|
||||
|
||||
if actual != tt.expected {
|
||||
t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAclItemArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []AclItem
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
"",
|
||||
[]AclItem{},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"one",
|
||||
[]AclItem{"one"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one"`,
|
||||
[]AclItem{"one"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"one,two,three",
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","two","three"`,
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one",two,"three"`,
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`one,two,"three"`,
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","two",three`,
|
||||
[]AclItem{"one", "two", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","t w o",three`,
|
||||
[]AclItem{"one", "t w o", "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","t, w o\"\}\\",three`,
|
||||
[]AclItem{"one", `t, w o"}\`, "three"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","two",three"`,
|
||||
[]AclItem{"one", "two", `three"`},
|
||||
"",
|
||||
},
|
||||
{
|
||||
`"one","two,"three"`,
|
||||
nil,
|
||||
"unexpected rune after quoted value",
|
||||
},
|
||||
{
|
||||
`"one","two","three`,
|
||||
nil,
|
||||
"unexpected end of quoted value",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual, err := parseAclItemArray(tt.input)
|
||||
|
||||
if err != nil {
|
||||
if tt.errMsg == "" {
|
||||
t.Errorf("%d. Unexpected error %v", i, err)
|
||||
} else if err.Error() != tt.errMsg {
|
||||
t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error())
|
||||
}
|
||||
} else if tt.errMsg != "" {
|
||||
t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, tt.expected) {
|
||||
t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,246 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
type batchItem struct {
|
||||
query string
|
||||
arguments []interface{}
|
||||
parameterOIDs []pgtype.OID
|
||||
resultFormatCodes []int16
|
||||
}
|
||||
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips.
|
||||
type Batch struct {
|
||||
conn *Conn
|
||||
connPool *ConnPool
|
||||
items []*batchItem
|
||||
resultsRead int
|
||||
sent bool
|
||||
ctx context.Context
|
||||
err error
|
||||
}
|
||||
|
||||
// BeginBatch returns a *Batch query for c.
|
||||
func (c *Conn) BeginBatch() *Batch {
|
||||
return &Batch{conn: c}
|
||||
}
|
||||
|
||||
// Conn returns the underlying connection that b will or was performed on.
|
||||
func (b *Batch) Conn() *Conn {
|
||||
return b.conn
|
||||
}
|
||||
|
||||
// Queue queues a query to batch b. parameterOIDs are required if there are
|
||||
// parameters and query is not the name of a prepared statement.
|
||||
// resultFormatCodes are required if there is a result.
|
||||
func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgtype.OID, resultFormatCodes []int16) {
|
||||
b.items = append(b.items, &batchItem{
|
||||
query: query,
|
||||
arguments: arguments,
|
||||
parameterOIDs: parameterOIDs,
|
||||
resultFormatCodes: resultFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// Send sends all queued queries to the server at once. All queries are wrapped
|
||||
// in a transaction. The transaction can optionally be configured with
|
||||
// txOptions. The context is in effect until the Batch is closed.
|
||||
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
|
||||
b.ctx = ctx
|
||||
|
||||
err := b.conn.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := b.conn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = b.conn.initContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := appendQuery(b.conn.wbuf, txOptions.beginSQL())
|
||||
|
||||
for _, bi := range b.items {
|
||||
var psName string
|
||||
var psParameterOIDs []pgtype.OID
|
||||
|
||||
if ps, ok := b.conn.preparedStatements[bi.query]; ok {
|
||||
psName = ps.Name
|
||||
psParameterOIDs = ps.ParameterOIDs
|
||||
} else {
|
||||
psParameterOIDs = bi.parameterOIDs
|
||||
buf = appendParse(buf, "", bi.query, psParameterOIDs)
|
||||
}
|
||||
|
||||
var err error
|
||||
buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf = appendDescribe(buf, 'P', "")
|
||||
buf = appendExecute(buf, "", 0)
|
||||
}
|
||||
|
||||
buf = appendSync(buf)
|
||||
buf = appendQuery(buf, "commit")
|
||||
|
||||
n, err := b.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
if fatalWriteErr(n, err) {
|
||||
b.conn.die(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// expect ReadyForQuery from sync and from commit
|
||||
b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2
|
||||
|
||||
b.sent = true
|
||||
|
||||
for {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
return nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Exec.
|
||||
func (b *Batch) ExecResults() (CommandTag, error) {
|
||||
if b.err != nil {
|
||||
return "", b.err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.die(b.ctx.Err())
|
||||
return "", b.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
for {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// QueryResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Query.
|
||||
func (b *Batch) QueryResults() (*Rows, error) {
|
||||
rows := b.conn.getRows("batch query", nil)
|
||||
|
||||
if b.err != nil {
|
||||
rows.fatal(b.err)
|
||||
return rows, b.err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.die(b.ctx.Err())
|
||||
rows.fatal(b.err)
|
||||
return rows, b.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
fieldDescriptions, err := b.conn.readUntilRowDescription()
|
||||
if err != nil {
|
||||
b.die(err)
|
||||
rows.fatal(b.err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
rows.batch = b
|
||||
rows.fields = fieldDescriptions
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// QueryRowResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with QueryRow.
|
||||
func (b *Batch) QueryRowResults() *Row {
|
||||
rows, _ := b.QueryResults()
|
||||
return (*Row)(rows)
|
||||
|
||||
}
|
||||
|
||||
// Close closes the batch operation. Any error that occured during a batch
|
||||
// operation may have made it impossible to resyncronize the connection with the
|
||||
// server. In this case the underlying connection will have been closed.
|
||||
func (b *Batch) Close() (err error) {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = b.conn.termContext(err)
|
||||
if b.conn != nil && b.connPool != nil {
|
||||
b.connPool.Release(b.conn)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := b.resultsRead; i < len(b.items); i++ {
|
||||
if _, err = b.ExecResults(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = b.conn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Batch) die(err error) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.err = err
|
||||
b.conn.die(err)
|
||||
|
||||
if b.conn != nil && b.connPool != nil {
|
||||
b.connPool.Release(b.conn)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,478 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func TestConnBeginBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)",
|
||||
[]interface{}{"q1", 1},
|
||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)",
|
||||
[]interface{}{"q2", 2},
|
||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)",
|
||||
[]interface{}{"q3", 3},
|
||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
batch.Queue("select id, description, amount from ledger order by id",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("select sum(amount) from ledger",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ct, err := batch.ExecResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if ct.RowsAffected() != 1 {
|
||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||
}
|
||||
|
||||
ct, err = batch.ExecResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if ct.RowsAffected() != 1 {
|
||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||
}
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var id int32
|
||||
var description string
|
||||
var amount int32
|
||||
if !rows.Next() {
|
||||
t.Fatal("expected a row to be available")
|
||||
}
|
||||
if err := rows.Scan(&id, &description, &amount); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if id != 1 {
|
||||
t.Errorf("id => %v, want %v", id, 1)
|
||||
}
|
||||
if description != "q1" {
|
||||
t.Errorf("description => %v, want %v", description, "q1")
|
||||
}
|
||||
if amount != 1 {
|
||||
t.Errorf("amount => %v, want %v", amount, 1)
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
t.Fatal("expected a row to be available")
|
||||
}
|
||||
if err := rows.Scan(&id, &description, &amount); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if id != 2 {
|
||||
t.Errorf("id => %v, want %v", id, 2)
|
||||
}
|
||||
if description != "q2" {
|
||||
t.Errorf("description => %v, want %v", description, "q2")
|
||||
}
|
||||
if amount != 2 {
|
||||
t.Errorf("amount => %v, want %v", amount, 2)
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
t.Fatal("expected a row to be available")
|
||||
}
|
||||
if err := rows.Scan(&id, &description, &amount); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if id != 3 {
|
||||
t.Errorf("id => %v, want %v", id, 3)
|
||||
}
|
||||
if description != "q3" {
|
||||
t.Errorf("description => %v, want %v", description, "q3")
|
||||
}
|
||||
if amount != 3 {
|
||||
t.Errorf("amount => %v, want %v", amount, 3)
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
t.Fatal("did not expect a row to be available")
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Fatal(rows.Err())
|
||||
}
|
||||
|
||||
err = batch.QueryRowResults().Scan(&amount)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if amount != 6 {
|
||||
t.Errorf("amount => %v, want %v", amount, 6)
|
||||
}
|
||||
|
||||
err = batch.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
_, err := conn.Prepare("ps1", "select n from generate_series(0,$1::int) n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
|
||||
queryCount := 3
|
||||
for i := 0; i < queryCount; i++ {
|
||||
batch.Queue("ps1",
|
||||
[]interface{}{5},
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
}
|
||||
|
||||
err = batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < queryCount; i++ {
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for k := 0; rows.Next(); k++ {
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != k {
|
||||
t.Fatalf("n => %v, want %v", n, k)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Fatal(rows.Err())
|
||||
}
|
||||
}
|
||||
|
||||
err = batch.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)",
|
||||
[]interface{}{"q1", 1},
|
||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
batch.Queue("select pg_sleep(2)",
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
err := batch.Send(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cancelFn()
|
||||
|
||||
_, err = batch.ExecResults()
|
||||
if err != context.Canceled {
|
||||
t.Errorf("err => %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select pg_sleep(2)",
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
batch.Queue("select pg_sleep(2)",
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
err := batch.Send(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cancelFn()
|
||||
|
||||
_, err = batch.QueryResults()
|
||||
if err != context.Canceled {
|
||||
t.Errorf("err => %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select pg_sleep(2)",
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
batch.Queue("select pg_sleep(2)",
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
||||
err := batch.Send(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cancelFn()
|
||||
|
||||
err = batch.Close()
|
||||
if err != context.Canceled {
|
||||
t.Errorf("err => %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select n from generate_series(0,5) n",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("select n from generate_series(0,5) n",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
if !rows.Next() {
|
||||
t.Error("expected a row to be available")
|
||||
}
|
||||
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n != i {
|
||||
t.Errorf("n => %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
|
||||
rows.Close()
|
||||
|
||||
rows, err = batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n != i {
|
||||
t.Errorf("n => %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Error(rows.Err())
|
||||
}
|
||||
|
||||
err = batch.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("select n from generate_series(0,5) n",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n != i {
|
||||
t.Errorf("n => %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
|
||||
if pgErr, ok := rows.Err().(pgx.PgError); !(ok && pgErr.Code == "22012") {
|
||||
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
|
||||
}
|
||||
|
||||
err = batch.Close()
|
||||
if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "22012") {
|
||||
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select 1 1",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var n int32
|
||||
err = batch.QueryRowResults().Scan(&n)
|
||||
if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "42601") {
|
||||
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
|
||||
}
|
||||
|
||||
err = batch.Close()
|
||||
if err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkPgtypeInt4ParseBinary(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err := conn.Prepare("selectBinary", "select n::int4 from generate_series(1, 100) n")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n int32
|
||||
|
||||
rows, err := conn.Query("selectBinary")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&n)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
b.Fatal(rows.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPgtypeInt4EncodeBinary(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err := conn.Prepare("encodeBinary", "select $1::int4, $2::int4, $3::int4, $4::int4, $5::int4, $6::int4, $7::int4")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := conn.Query("encodeBinary", int32(i), int32(i), int32(i), int32(i), int32(i), int32(i), int32(i))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
}
|
306
bench_test.go
306
bench_test.go
|
@ -2,13 +2,14 @@ package pgx_test
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
log "gopkg.in/inconshreveable/log15.v2"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func BenchmarkConnPool(b *testing.B) {
|
||||
|
@ -50,126 +51,6 @@ func BenchmarkConnPoolQueryRow(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullXWithNullValues(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var record struct {
|
||||
id int32
|
||||
userName string
|
||||
email pgx.NullString
|
||||
name pgx.NullString
|
||||
sex pgx.NullString
|
||||
birthDate pgx.NullTime
|
||||
lastLoginTime pgx.NullTime
|
||||
}
|
||||
|
||||
err = conn.QueryRow("selectNulls").Scan(
|
||||
&record.id,
|
||||
&record.userName,
|
||||
&record.email,
|
||||
&record.name,
|
||||
&record.sex,
|
||||
&record.birthDate,
|
||||
&record.lastLoginTime,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// These checks both ensure that the correct data was returned
|
||||
// and provide a benchmark of accessing the returned values.
|
||||
if record.id != 1 {
|
||||
b.Fatalf("bad value for id: %v", record.id)
|
||||
}
|
||||
if record.userName != "johnsmith" {
|
||||
b.Fatalf("bad value for userName: %v", record.userName)
|
||||
}
|
||||
if record.email.Valid {
|
||||
b.Fatalf("bad value for email: %v", record.email)
|
||||
}
|
||||
if record.name.Valid {
|
||||
b.Fatalf("bad value for name: %v", record.name)
|
||||
}
|
||||
if record.sex.Valid {
|
||||
b.Fatalf("bad value for sex: %v", record.sex)
|
||||
}
|
||||
if record.birthDate.Valid {
|
||||
b.Fatalf("bad value for birthDate: %v", record.birthDate)
|
||||
}
|
||||
if record.lastLoginTime.Valid {
|
||||
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullXWithPresentValues(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var record struct {
|
||||
id int32
|
||||
userName string
|
||||
email pgx.NullString
|
||||
name pgx.NullString
|
||||
sex pgx.NullString
|
||||
birthDate pgx.NullTime
|
||||
lastLoginTime pgx.NullTime
|
||||
}
|
||||
|
||||
err = conn.QueryRow("selectNulls").Scan(
|
||||
&record.id,
|
||||
&record.userName,
|
||||
&record.email,
|
||||
&record.name,
|
||||
&record.sex,
|
||||
&record.birthDate,
|
||||
&record.lastLoginTime,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// These checks both ensure that the correct data was returned
|
||||
// and provide a benchmark of accessing the returned values.
|
||||
if record.id != 1 {
|
||||
b.Fatalf("bad value for id: %v", record.id)
|
||||
}
|
||||
if record.userName != "johnsmith" {
|
||||
b.Fatalf("bad value for userName: %v", record.userName)
|
||||
}
|
||||
if !record.email.Valid || record.email.String != "johnsmith@example.com" {
|
||||
b.Fatalf("bad value for email: %v", record.email)
|
||||
}
|
||||
if !record.name.Valid || record.name.String != "John Smith" {
|
||||
b.Fatalf("bad value for name: %v", record.name)
|
||||
}
|
||||
if !record.sex.Valid || record.sex.String != "male" {
|
||||
b.Fatalf("bad value for sex: %v", record.sex)
|
||||
}
|
||||
if !record.birthDate.Valid || record.birthDate.Time != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) {
|
||||
b.Fatalf("bad value for birthDate: %v", record.birthDate)
|
||||
}
|
||||
if !record.lastLoginTime.Valid || record.lastLoginTime.Time != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
|
||||
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPointerPointerWithNullValues(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
@ -297,71 +178,51 @@ func BenchmarkSelectWithoutLogging(b *testing.B) {
|
|||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingTraceWithLog15(b *testing.B) {
|
||||
connConfig := *defaultConnConfig
|
||||
type discardLogger struct{}
|
||||
|
||||
logger := log.New()
|
||||
lvl, err := log.LvlFromString("debug")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
|
||||
connConfig.Logger = logger
|
||||
connConfig.LogLevel = pgx.LogLevelTrace
|
||||
conn := mustConnect(b, connConfig)
|
||||
func (dl discardLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {}
|
||||
|
||||
func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
var logger discardLogger
|
||||
conn.SetLogger(logger)
|
||||
conn.SetLogLevel(pgx.LogLevelTrace)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingDebugWithLog15(b *testing.B) {
|
||||
connConfig := *defaultConnConfig
|
||||
|
||||
logger := log.New()
|
||||
lvl, err := log.LvlFromString("debug")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
|
||||
connConfig.Logger = logger
|
||||
connConfig.LogLevel = pgx.LogLevelDebug
|
||||
conn := mustConnect(b, connConfig)
|
||||
func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
var logger discardLogger
|
||||
conn.SetLogger(logger)
|
||||
conn.SetLogLevel(pgx.LogLevelDebug)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingInfoWithLog15(b *testing.B) {
|
||||
connConfig := *defaultConnConfig
|
||||
|
||||
logger := log.New()
|
||||
lvl, err := log.LvlFromString("info")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
|
||||
connConfig.Logger = logger
|
||||
connConfig.LogLevel = pgx.LogLevelInfo
|
||||
conn := mustConnect(b, connConfig)
|
||||
func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
var logger discardLogger
|
||||
conn.SetLogger(logger)
|
||||
conn.SetLogLevel(pgx.LogLevelInfo)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingErrorWithLog15(b *testing.B) {
|
||||
connConfig := *defaultConnConfig
|
||||
|
||||
logger := log.New()
|
||||
lvl, err := log.LvlFromString("error")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
|
||||
connConfig.Logger = logger
|
||||
connConfig.LogLevel = pgx.LogLevelError
|
||||
conn := mustConnect(b, connConfig)
|
||||
func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
var logger discardLogger
|
||||
conn.SetLogger(logger)
|
||||
conn.SetLogLevel(pgx.LogLevelError)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
|
@ -422,20 +283,6 @@ func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
|
|||
}
|
||||
}
|
||||
|
||||
func BenchmarkLog15Discard(b *testing.B) {
|
||||
logger := log.New()
|
||||
lvl, err := log.LvlFromString("error")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Debug("benchmark", "i", i, "b.N", b.N)
|
||||
}
|
||||
}
|
||||
|
||||
const benchmarkWriteTableCreateSQL = `drop table if exists t;
|
||||
|
||||
create table t(
|
||||
|
@ -510,12 +357,12 @@ func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource {
|
|||
row: []interface{}{
|
||||
"varchar_1",
|
||||
"varchar_2",
|
||||
pgx.NullString{},
|
||||
pgtype.Text{},
|
||||
time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
pgx.NullTime{},
|
||||
pgtype.Date{},
|
||||
1,
|
||||
2,
|
||||
pgx.NullInt32{},
|
||||
pgtype.Int4{},
|
||||
time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
true,
|
||||
|
@ -763,3 +610,92 @@ func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
|
|||
func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
|
||||
benchmarkWriteNRowsViaCopy(b, 10000)
|
||||
}
|
||||
|
||||
func BenchmarkMultipleQueriesNonBatch(b *testing.B) {
|
||||
config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5}
|
||||
pool, err := pgx.NewConnPool(config)
|
||||
if err != nil {
|
||||
b.Fatalf("Unable to create connection pool: %v", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
queryCount := 3
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for j := 0; j < queryCount; j++ {
|
||||
rows, err := pool.Query("select n from generate_series(0, 5) n")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for k := 0; rows.Next(); k++ {
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if n != k {
|
||||
b.Fatalf("n => %v, want %v", n, k)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
b.Fatal(rows.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMultipleQueriesBatch(b *testing.B) {
|
||||
config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5}
|
||||
pool, err := pgx.NewConnPool(config)
|
||||
if err != nil {
|
||||
b.Fatalf("Unable to create connection pool: %v", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
queryCount := 3
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
batch := pool.BeginBatch()
|
||||
for j := 0; j < queryCount; j++ {
|
||||
batch.Queue("select n from generate_series(0,5) n",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
}
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for j := 0; j < queryCount; j++ {
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for k := 0; rows.Next(); k++ {
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if n != k {
|
||||
b.Fatalf("n => %v, want %v", n, k)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
b.Fatal(rows.Err())
|
||||
}
|
||||
}
|
||||
|
||||
err = batch.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
package chunkreader
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type ChunkReader struct {
|
||||
r io.Reader
|
||||
|
||||
buf []byte
|
||||
rp, wp int // buf read position and write position
|
||||
|
||||
options Options
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
MinBufLen int // Minimum buffer length
|
||||
}
|
||||
|
||||
func NewChunkReader(r io.Reader) *ChunkReader {
|
||||
cr, err := NewChunkReaderEx(r, Options{})
|
||||
if err != nil {
|
||||
panic("default options can't be bad")
|
||||
}
|
||||
|
||||
return cr
|
||||
}
|
||||
|
||||
func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) {
|
||||
if options.MinBufLen == 0 {
|
||||
options.MinBufLen = 4096
|
||||
}
|
||||
|
||||
return &ChunkReader{
|
||||
r: r,
|
||||
buf: make([]byte, options.MinBufLen),
|
||||
options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Next returns buf filled with the next n bytes. If an error occurs, buf will
|
||||
// be nil.
|
||||
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
|
||||
// n bytes already in buf
|
||||
if (r.wp - r.rp) >= n {
|
||||
buf = r.buf[r.rp : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, err
|
||||
}
|
||||
|
||||
// available space in buf is less than n
|
||||
if len(r.buf) < n {
|
||||
r.copyBufContents(r.newBuf(n))
|
||||
}
|
||||
|
||||
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
||||
minReadCount := n - (r.wp - r.rp)
|
||||
if (len(r.buf) - r.wp) < minReadCount {
|
||||
newBuf := r.newBuf(n)
|
||||
r.copyBufContents(newBuf)
|
||||
}
|
||||
|
||||
if err := r.appendAtLeast(minReadCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = r.buf[r.rp : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (r *ChunkReader) appendAtLeast(fillLen int) error {
|
||||
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
|
||||
r.wp += n
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ChunkReader) newBuf(size int) []byte {
|
||||
if size < r.options.MinBufLen {
|
||||
size = r.options.MinBufLen
|
||||
}
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
||||
func (r *ChunkReader) copyBufContents(dest []byte) {
|
||||
r.wp = copy(dest, r.buf[r.rp:r.wp])
|
||||
r.rp = 0
|
||||
r.buf = dest
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package chunkreader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:2]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
|
||||
}
|
||||
|
||||
n2, err := r.Next(2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n2, src[2:4]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(r.buf, src) != 0 {
|
||||
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
|
||||
}
|
||||
if r.rp != 4 {
|
||||
t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp)
|
||||
}
|
||||
if r.wp != 4 {
|
||||
t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:5]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1)
|
||||
}
|
||||
if len(r.buf) != 5 {
|
||||
t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkReaderDoesNotReuseBuf(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1)
|
||||
}
|
||||
|
||||
n2, err := r.Next(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n2, src[4:8]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
||||
t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1)
|
||||
}
|
||||
}
|
|
@ -23,3 +23,5 @@ var replicationConnConfig *pgx.ConnConfig = nil
|
|||
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
||||
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
// var replicationConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"}
|
||||
|
||||
|
|
133
conn_pool.go
133
conn_pool.go
|
@ -1,9 +1,13 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
type ConnPoolConfig struct {
|
||||
|
@ -27,11 +31,7 @@ type ConnPool struct {
|
|||
closed bool
|
||||
preparedStatements map[string]*PreparedStatement
|
||||
acquireTimeout time.Duration
|
||||
pgTypes map[Oid]PgType
|
||||
pgsqlAfInet *byte
|
||||
pgsqlAfInet6 *byte
|
||||
txAfterClose func(tx *Tx)
|
||||
rowsAfterClose func(rows *Rows)
|
||||
connInfo *pgtype.ConnInfo
|
||||
}
|
||||
|
||||
type ConnPoolStat struct {
|
||||
|
@ -48,6 +48,7 @@ var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool")
|
|||
func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
|
||||
p = new(ConnPool)
|
||||
p.config = config.ConnConfig
|
||||
p.connInfo = minimalConnInfo
|
||||
p.maxConnections = config.MaxConnections
|
||||
if p.maxConnections == 0 {
|
||||
p.maxConnections = 5
|
||||
|
@ -73,14 +74,6 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
|
|||
p.logLevel = LogLevelNone
|
||||
}
|
||||
|
||||
p.txAfterClose = func(tx *Tx) {
|
||||
p.Release(tx.Conn())
|
||||
}
|
||||
|
||||
p.rowsAfterClose = func(rows *Rows) {
|
||||
p.Release(rows.Conn())
|
||||
}
|
||||
|
||||
p.allConnections = make([]*Conn, 0, p.maxConnections)
|
||||
p.availableConnections = make([]*Conn, 0, p.maxConnections)
|
||||
p.preparedStatements = make(map[string]*PreparedStatement)
|
||||
|
@ -94,6 +87,7 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
|
|||
}
|
||||
p.allConnections = append(p.allConnections, c)
|
||||
p.availableConnections = append(p.availableConnections, c)
|
||||
p.connInfo = c.ConnInfo.DeepCopy()
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -161,7 +155,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
|
|||
}
|
||||
// All connections are in use and we cannot create more
|
||||
if p.logLevel >= LogLevelWarn {
|
||||
p.logger.Warn("All connections in pool are busy - waiting...")
|
||||
p.logger.Log(LogLevelWarn, "waiting for available connection", nil)
|
||||
}
|
||||
|
||||
// Wait until there is an available connection OR room to create a new connection
|
||||
|
@ -181,7 +175,11 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
|
|||
|
||||
// Release gives up use of a connection.
|
||||
func (p *ConnPool) Release(conn *Conn) {
|
||||
if conn.TxStatus != 'I' {
|
||||
if conn.ctxInProgress {
|
||||
panic("should never release when context is in progress")
|
||||
}
|
||||
|
||||
if conn.txStatus != 'I' {
|
||||
conn.Exec("rollback")
|
||||
}
|
||||
|
||||
|
@ -223,25 +221,21 @@ func (p *ConnPool) removeFromAllConnections(conn *Conn) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// Close ends the use of a connection pool. It prevents any new connections
|
||||
// from being acquired, waits until all acquired connections are released,
|
||||
// then closes all underlying connections.
|
||||
// Close ends the use of a connection pool. It prevents any new connections from
|
||||
// being acquired and closes available underlying connections. Any acquired
|
||||
// connections will be closed when they are released.
|
||||
func (p *ConnPool) Close() {
|
||||
p.cond.L.Lock()
|
||||
defer p.cond.L.Unlock()
|
||||
|
||||
p.closed = true
|
||||
|
||||
// Wait until all connections are released
|
||||
if len(p.availableConnections) != len(p.allConnections) {
|
||||
for len(p.availableConnections) != len(p.allConnections) {
|
||||
p.cond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
for _, c := range p.allConnections {
|
||||
for _, c := range p.availableConnections {
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
// This will cause any checked out connections to be closed on release
|
||||
p.resetCount++
|
||||
}
|
||||
|
||||
// Reset closes all open connections, but leaves the pool open. It is intended
|
||||
|
@ -289,7 +283,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) {
|
|||
}
|
||||
|
||||
func (p *ConnPool) createConnection() (*Conn, error) {
|
||||
c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6)
|
||||
c, err := connect(p.config, p.connInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -324,10 +318,6 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
|
|||
// afterConnectionCreated executes (if it is) afterConnect() callback and prepares
|
||||
// all the known statements for the new connection.
|
||||
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
||||
p.pgTypes = c.PgTypes
|
||||
p.pgsqlAfInet = c.pgsqlAfInet
|
||||
p.pgsqlAfInet6 = c.pgsqlAfInet6
|
||||
|
||||
if p.afterConnect != nil {
|
||||
err := p.afterConnect(c)
|
||||
if err != nil {
|
||||
|
@ -357,6 +347,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
|
|||
return c.Exec(sql, arguments...)
|
||||
}
|
||||
|
||||
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
var c *Conn
|
||||
if c, err = p.Acquire(); err != nil {
|
||||
return
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.ExecEx(ctx, sql, options, arguments...)
|
||||
}
|
||||
|
||||
// Query acquires a connection and delegates the call to that connection. When
|
||||
// *Rows are closed, the connection is released automatically.
|
||||
func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||
|
@ -372,7 +372,25 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
|
|||
return rows, err
|
||||
}
|
||||
|
||||
rows.AfterClose(p.rowsAfterClose)
|
||||
rows.connPool = p
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||
return &Rows{closed: true, err: err}, err
|
||||
}
|
||||
|
||||
rows, err := c.QueryEx(ctx, sql, options, args...)
|
||||
if err != nil {
|
||||
p.Release(c)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
rows.connPool = p
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
@ -385,10 +403,15 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
|
|||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
|
||||
rows, _ := p.QueryEx(ctx, sql, options, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
// Begin acquires a connection and begins a transaction on it. When the
|
||||
// transaction is closed the connection will be automatically released.
|
||||
func (p *ConnPool) Begin() (*Tx, error) {
|
||||
return p.BeginIso("")
|
||||
return p.BeginEx(context.Background(), nil)
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement on a connection in the pool to test the
|
||||
|
@ -403,7 +426,7 @@ func (p *ConnPool) Begin() (*Tx, error) {
|
|||
// the same name and sql arguments. This allows a code path to Prepare and
|
||||
// Query/Exec/PrepareEx without concern for if the statement has already been prepared.
|
||||
func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) {
|
||||
return p.PrepareEx(name, sql, nil)
|
||||
return p.PrepareEx(context.Background(), name, sql, nil)
|
||||
}
|
||||
|
||||
// PrepareEx creates a prepared statement on a connection in the pool to test the
|
||||
|
@ -417,7 +440,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) {
|
|||
// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
|
||||
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without
|
||||
// concern for if the statement has already been prepared.
|
||||
func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
|
||||
func (p *ConnPool) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
|
||||
p.cond.L.Lock()
|
||||
defer p.cond.L.Unlock()
|
||||
|
||||
|
@ -439,13 +462,13 @@ func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*Prepare
|
|||
return ps, nil
|
||||
}
|
||||
|
||||
ps, err := c.PrepareEx(name, sql, opts)
|
||||
ps, err := c.PrepareEx(ctx, name, sql, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, c := range p.availableConnections {
|
||||
_, err := c.PrepareEx(name, sql, opts)
|
||||
_, err := c.PrepareEx(ctx, name, sql, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -474,17 +497,17 @@ func (p *ConnPool) Deallocate(name string) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// BeginIso acquires a connection and begins a transaction in isolation mode iso
|
||||
// on it. When the transaction is closed the connection will be automatically
|
||||
// released.
|
||||
func (p *ConnPool) BeginIso(iso string) (*Tx, error) {
|
||||
// BeginEx acquires a connection and starts a transaction with txOptions
|
||||
// determining the transaction mode. When the transaction is closed the
|
||||
// connection will be automatically released.
|
||||
func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
|
||||
for {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx, err := c.BeginIso(iso)
|
||||
tx, err := c.BeginEx(ctx, txOptions)
|
||||
if err != nil {
|
||||
alive := c.IsAlive()
|
||||
p.Release(c)
|
||||
|
@ -499,25 +522,12 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) {
|
|||
continue
|
||||
}
|
||||
|
||||
tx.AfterClose(p.txAfterClose)
|
||||
tx.connPool = p
|
||||
return tx, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated. Use CopyFrom instead. CopyTo acquires a connection, delegates the
|
||||
// call to that connection, and releases the connection.
|
||||
func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.CopyTo(tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
||||
// CopyFrom acquires a connection, delegates the call to that connection, and
|
||||
// releases the connection.
|
||||
// CopyFrom acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
|
@ -527,3 +537,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
|
|||
|
||||
return c.CopyFrom(tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
||||
// BeginBatch acquires a connection and begins a batch on that connection. When
|
||||
// *Batch is finished, the connection is released automatically.
|
||||
func (p *ConnPool) BeginBatch() *Batch {
|
||||
c, err := p.Acquire()
|
||||
return &Batch{conn: c, connPool: p, err: err}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
|
@ -297,7 +299,7 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) {
|
|||
// ... then try to consume 1 more. It should hang forever.
|
||||
// To unblock it we release the previously taken connection in a goroutine.
|
||||
stopDeadWaitTimeout := 5 * time.Second
|
||||
timer := time.AfterFunc(stopDeadWaitTimeout, func() {
|
||||
timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() {
|
||||
releaseAllConnections(pool, allConnections)
|
||||
})
|
||||
defer timer.Stop()
|
||||
|
@ -328,14 +330,14 @@ func TestPoolReleaseWithTransactions(t *testing.T) {
|
|||
t.Fatal("Did not receive expected error")
|
||||
}
|
||||
|
||||
if conn.TxStatus != 'E' {
|
||||
t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus)
|
||||
if conn.TxStatus() != 'E' {
|
||||
t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus())
|
||||
}
|
||||
|
||||
pool.Release(conn)
|
||||
|
||||
if conn.TxStatus != 'I' {
|
||||
t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus)
|
||||
if conn.TxStatus() != 'I' {
|
||||
t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus())
|
||||
}
|
||||
|
||||
conn, err = pool.Acquire()
|
||||
|
@ -343,14 +345,14 @@ func TestPoolReleaseWithTransactions(t *testing.T) {
|
|||
t.Fatalf("Unable to acquire connection: %v", err)
|
||||
}
|
||||
mustExec(t, conn, "begin")
|
||||
if conn.TxStatus != 'T' {
|
||||
t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus)
|
||||
if conn.TxStatus() != 'T' {
|
||||
t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus())
|
||||
}
|
||||
|
||||
pool.Release(conn)
|
||||
|
||||
if conn.TxStatus != 'I' {
|
||||
t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus)
|
||||
if conn.TxStatus() != 'I' {
|
||||
t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -428,7 +430,7 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
if _, err = c2.Exec("select pg_terminate_backend($1)", c1.Pid); err != nil {
|
||||
if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID()); err != nil {
|
||||
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
|
||||
}
|
||||
|
||||
|
@ -635,9 +637,9 @@ func TestConnPoolTransactionIso(t *testing.T) {
|
|||
pool := createConnPool(t, 2)
|
||||
defer pool.Close()
|
||||
|
||||
tx, err := pool.BeginIso(pgx.Serializable)
|
||||
tx, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
|
||||
if err != nil {
|
||||
t.Fatalf("pool.Begin failed: %v", err)
|
||||
t.Fatalf("pool.BeginEx failed: %v", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
|
@ -674,7 +676,7 @@ func TestConnPoolBeginRetry(t *testing.T) {
|
|||
pool.Release(victimConn)
|
||||
|
||||
// Terminate connection that was released to pool
|
||||
if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.Pid); err != nil {
|
||||
if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID()); err != nil {
|
||||
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
|
||||
}
|
||||
|
||||
|
@ -686,13 +688,13 @@ func TestConnPoolBeginRetry(t *testing.T) {
|
|||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
var txPid int32
|
||||
err = tx.QueryRow("select pg_backend_pid()").Scan(&txPid)
|
||||
var txPID uint32
|
||||
err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID)
|
||||
if err != nil {
|
||||
t.Fatalf("tx.QueryRow Scan failed: %v", err)
|
||||
}
|
||||
if txPid == victimConn.Pid {
|
||||
t.Error("Expected txPid to defer from killed conn pid, but it didn't")
|
||||
if txPID == victimConn.PID() {
|
||||
t.Error("Expected txPID to defer from killed conn pid, but it didn't")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@ -980,3 +982,70 @@ func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) {
|
|||
t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnPoolBeginBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pool := createConnPool(t, 2)
|
||||
defer pool.Close()
|
||||
|
||||
batch := pool.BeginBatch()
|
||||
batch.Queue("select n from generate_series(0,5) n",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("select n from generate_series(0,5) n",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n != i {
|
||||
t.Errorf("n => %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Error(rows.Err())
|
||||
}
|
||||
|
||||
rows, err = batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
var n int
|
||||
if err := rows.Scan(&n); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n != i {
|
||||
t.Errorf("n => %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Error(rows.Err())
|
||||
}
|
||||
|
||||
err = batch.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
255
conn_test.go
255
conn_test.go
|
@ -1,6 +1,7 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
|
@ -27,14 +29,10 @@ func TestConnect(t *testing.T) {
|
|||
t.Error("Runtime parameters not stored")
|
||||
}
|
||||
|
||||
if conn.Pid == 0 {
|
||||
if conn.PID() == 0 {
|
||||
t.Error("Backend PID not stored")
|
||||
}
|
||||
|
||||
if conn.SecretKey == 0 {
|
||||
t.Error("Backend secret key not stored")
|
||||
}
|
||||
|
||||
var currentDB string
|
||||
err = conn.QueryRow("select current_database()").Scan(¤tDB)
|
||||
if err != nil {
|
||||
|
@ -752,7 +750,7 @@ func TestParseConnectionString(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestParseEnvLibpq(t *testing.T) {
|
||||
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME"}
|
||||
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"}
|
||||
|
||||
savedEnv := make(map[string]string)
|
||||
for _, n := range pgEnvvars {
|
||||
|
@ -1035,6 +1033,169 @@ func TestExecFailure(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecExContextWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
|
||||
rows, _ := conn.Query("select 1")
|
||||
rows.Close()
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecExContextCancelationCancelsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
_, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil)
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("Expected context.Canceled err, got %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExecExExtendedProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
|
||||
commandTag, err = conn.ExecEx(
|
||||
ctx,
|
||||
"insert into foo(name) values($1);",
|
||||
nil,
|
||||
"bar",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExecExSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
|
||||
commandTag, err = conn.ExecEx(
|
||||
ctx,
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
||||
|
||||
commandTag, err := conn.ExecEx(
|
||||
context.Background(),
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.VarcharOID}},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
||||
|
||||
_, err := conn.ExecEx(
|
||||
context.Background(),
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -1206,7 +1367,7 @@ func TestPrepareEx(t *testing.T) {
|
|||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
_, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}})
|
||||
_, err := conn.PrepareEx(context.Background(), "test", "select $1", &pgx.PrepareExOptions{ParameterOIDs: []pgtype.OID{pgtype.TextOID}})
|
||||
if err != nil {
|
||||
t.Errorf("Unable to prepare statement: %v", err)
|
||||
return
|
||||
|
@ -1244,7 +1405,7 @@ func TestListenNotify(t *testing.T) {
|
|||
mustExec(t, notifier, "notify chat")
|
||||
|
||||
// when notification is waiting on the socket to be read
|
||||
notification, err := listener.WaitForNotification(time.Second)
|
||||
notification, err := listener.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1259,7 +1420,10 @@ func TestListenNotify(t *testing.T) {
|
|||
if rows.Err() != nil {
|
||||
t.Fatalf("Unexpected error on Query: %v", rows.Err())
|
||||
}
|
||||
notification, err = listener.WaitForNotification(0)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
cancelFn()
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1268,8 +1432,9 @@ func TestListenNotify(t *testing.T) {
|
|||
}
|
||||
|
||||
// when timeout occurs
|
||||
notification, err = listener.WaitForNotification(time.Millisecond)
|
||||
if err != pgx.ErrNotificationTimeout {
|
||||
ctx, _ = context.WithTimeout(context.Background(), time.Millisecond)
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
|
||||
}
|
||||
if notification != nil {
|
||||
|
@ -1278,7 +1443,7 @@ func TestListenNotify(t *testing.T) {
|
|||
|
||||
// listener can listen again after a timeout
|
||||
mustExec(t, notifier, "notify chat")
|
||||
notification, err = listener.WaitForNotification(time.Second)
|
||||
notification, err = listener.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1303,7 +1468,7 @@ func TestUnlistenSpecificChannel(t *testing.T) {
|
|||
mustExec(t, notifier, "notify unlisten_test")
|
||||
|
||||
// when notification is waiting on the socket to be read
|
||||
notification, err := listener.WaitForNotification(time.Second)
|
||||
notification, err := listener.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1323,8 +1488,10 @@ func TestUnlistenSpecificChannel(t *testing.T) {
|
|||
if rows.Err() != nil {
|
||||
t.Fatalf("Unexpected error on Query: %v", rows.Err())
|
||||
}
|
||||
notification, err = listener.WaitForNotification(100 * time.Millisecond)
|
||||
if err != pgx.ErrNotificationTimeout {
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -1376,13 +1543,9 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
notifierDone := make(chan bool)
|
||||
go func() {
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
defer func() {
|
||||
notifierDone <- true
|
||||
}()
|
||||
|
||||
for i := 0; i < 100000; i++ {
|
||||
mustExec(t, conn, "notify busysafe, 'hello'")
|
||||
|
@ -1406,7 +1569,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
|
|||
// Notify self and WaitForNotification immediately
|
||||
mustExec(t, conn, "notify self")
|
||||
|
||||
notification, err := conn.WaitForNotification(time.Second)
|
||||
ctx, _ := context.WithTimeout(context.Background(), time.Second)
|
||||
notification, err := conn.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1423,7 +1587,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
|
|||
t.Fatalf("Unexpected error on Query: %v", rows.Err())
|
||||
}
|
||||
|
||||
notification, err = conn.WaitForNotification(time.Second)
|
||||
ctx, _ = context.WithTimeout(context.Background(), time.Second)
|
||||
notification, err = conn.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1474,7 +1639,7 @@ func TestFatalRxError(t *testing.T) {
|
|||
}
|
||||
defer otherConn.Close()
|
||||
|
||||
if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.Pid); err != nil {
|
||||
if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID()); err != nil {
|
||||
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
|
||||
}
|
||||
|
||||
|
@ -1500,7 +1665,7 @@ func TestFatalTxError(t *testing.T) {
|
|||
}
|
||||
defer otherConn.Close()
|
||||
|
||||
_, err = otherConn.Exec("select pg_terminate_backend($1)", conn.Pid)
|
||||
_, err = otherConn.Exec("select pg_terminate_backend($1)", conn.PID())
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
|
||||
}
|
||||
|
@ -1611,26 +1776,17 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
|
|||
}
|
||||
|
||||
type testLog struct {
|
||||
lvl int
|
||||
lvl pgx.LogLevel
|
||||
msg string
|
||||
ctx []interface{}
|
||||
data map[string]interface{}
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
logs []testLog
|
||||
}
|
||||
|
||||
func (l *testLogger) Debug(msg string, ctx ...interface{}) {
|
||||
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx})
|
||||
}
|
||||
func (l *testLogger) Info(msg string, ctx ...interface{}) {
|
||||
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx})
|
||||
}
|
||||
func (l *testLogger) Warn(msg string, ctx ...interface{}) {
|
||||
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx})
|
||||
}
|
||||
func (l *testLogger) Error(msg string, ctx ...interface{}) {
|
||||
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx})
|
||||
func (l *testLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
|
||||
}
|
||||
|
||||
func TestSetLogger(t *testing.T) {
|
||||
|
@ -1742,3 +1898,30 @@ func TestIdentifierSanitize(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnOnNotice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var msg string
|
||||
|
||||
connConfig := *defaultConnConfig
|
||||
connConfig.OnNotice = func(c *pgx.Conn, notice *pgx.Notice) {
|
||||
msg = notice.Message
|
||||
}
|
||||
conn := mustConnect(t, connConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
_, err := conn.Exec(`do $$
|
||||
begin
|
||||
raise notice 'hello, world';
|
||||
end$$;`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if msg != "hello, world" {
|
||||
t.Errorf("msg => %v, want %v", msg, "hello, world")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
|
90
copy_from.go
90
copy_from.go
|
@ -3,6 +3,10 @@ package pgx
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
||||
|
@ -54,25 +58,25 @@ type copyFrom struct {
|
|||
|
||||
func (ct *copyFrom) readUntilReadyForQuery() {
|
||||
for {
|
||||
t, r, err := ct.conn.rxMsg()
|
||||
msg, err := ct.conn.rxMsg()
|
||||
if err != nil {
|
||||
ct.readerErrChan <- err
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
}
|
||||
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
ct.conn.rxReadyForQuery(r)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
ct.conn.rxReadyForQuery(msg)
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
case commandComplete:
|
||||
case errorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
|
||||
case *pgproto3.CommandComplete:
|
||||
case *pgproto3.ErrorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
|
||||
default:
|
||||
err = ct.conn.processContextFreeMsg(t, r)
|
||||
err = ct.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
|
||||
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -87,14 +91,14 @@ func (ct *copyFrom) waitForReaderDone() error {
|
|||
|
||||
func (ct *copyFrom) run() (int, error) {
|
||||
quotedTableName := ct.tableName.Sanitize()
|
||||
buf := &bytes.Buffer{}
|
||||
cbuf := &bytes.Buffer{}
|
||||
for i, cn := range ct.columnNames {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
cbuf.WriteString(", ")
|
||||
}
|
||||
buf.WriteString(quoteIdentifier(cn))
|
||||
cbuf.WriteString(quoteIdentifier(cn))
|
||||
}
|
||||
quotedColumnNames := buf.String()
|
||||
quotedColumnNames := cbuf.String()
|
||||
|
||||
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
||||
if err != nil {
|
||||
|
@ -114,11 +118,14 @@ func (ct *copyFrom) run() (int, error) {
|
|||
go ct.readUntilReadyForQuery()
|
||||
defer ct.waitForReaderDone()
|
||||
|
||||
wbuf := newWriteBuf(ct.conn, copyData)
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
|
||||
wbuf.WriteInt32(0)
|
||||
wbuf.WriteInt32(0)
|
||||
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
var sentCount int
|
||||
|
||||
|
@ -129,18 +136,16 @@ func (ct *copyFrom) run() (int, error) {
|
|||
default:
|
||||
}
|
||||
|
||||
if len(wbuf.buf) > 65536 {
|
||||
wbuf.closeMsg()
|
||||
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||
if len(buf) > 65536 {
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
_, err = ct.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Directly manipulate wbuf to reset to reuse the same buffer
|
||||
wbuf.buf = wbuf.buf[0:5]
|
||||
wbuf.buf[0] = copyData
|
||||
wbuf.sizeIdx = 1
|
||||
buf = buf[0:5]
|
||||
}
|
||||
|
||||
sentCount++
|
||||
|
@ -152,12 +157,12 @@ func (ct *copyFrom) run() (int, error) {
|
|||
}
|
||||
if len(values) != len(ct.columnNames) {
|
||||
ct.cancelCopyIn()
|
||||
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(int16(len(ct.columnNames)))
|
||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
|
||||
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
|
||||
if err != nil {
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
|
@ -171,11 +176,13 @@ func (ct *copyFrom) run() (int, error) {
|
|||
return 0, ct.rowSrc.Err()
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(-1) // terminate the copy stream
|
||||
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
wbuf.startMsg(copyDone)
|
||||
wbuf.closeMsg()
|
||||
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||
buf = append(buf, copyDone)
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
_, err = ct.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
|
@ -190,18 +197,16 @@ func (ct *copyFrom) run() (int, error) {
|
|||
|
||||
func (c *Conn) readUntilCopyInResponse() error {
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err := c.rxMsg()
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case copyInResponse:
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyInResponse:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(t, r)
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -210,10 +215,15 @@ func (c *Conn) readUntilCopyInResponse() error {
|
|||
}
|
||||
|
||||
func (ct *copyFrom) cancelCopyIn() error {
|
||||
wbuf := newWriteBuf(ct.conn, copyFail)
|
||||
wbuf.WriteCString("client error: abort")
|
||||
wbuf.closeMsg()
|
||||
_, err := ct.conn.conn.Write(wbuf.buf)
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyFail)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, "client error: abort"...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err := ct.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return err
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestConnCopyFromSmall(t *testing.T) {
|
||||
|
@ -26,7 +26,7 @@ func TestConnCopyFromSmall(t *testing.T) {
|
|||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
|
@ -83,7 +83,7 @@ func TestConnCopyFromLarge(t *testing.T) {
|
|||
inputRows := [][]interface{}{}
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
|
||||
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
|
||||
|
@ -125,8 +125,8 @@ func TestConnCopyFromJSON(t *testing.T) {
|
|||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
|
||||
if _, ok := conn.PgTypes[oid]; !ok {
|
||||
for _, typeName := range []string{"json", "jsonb"} {
|
||||
if _, ok := conn.ConnInfo.DataTypeForName(typeName); !ok {
|
||||
return // No JSON/JSONB type -- must be running against old PostgreSQL
|
||||
}
|
||||
}
|
||||
|
@ -174,6 +174,28 @@ func TestConnCopyFromJSON(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type clientFailSource struct {
|
||||
count int
|
||||
err error
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Next() bool {
|
||||
cfs.count++
|
||||
return cfs.count < 100
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Values() ([]interface{}, error) {
|
||||
if cfs.count == 3 {
|
||||
cfs.err = errors.Errorf("client error")
|
||||
return nil, cfs.err
|
||||
}
|
||||
return []interface{}{make([]byte, 100000)}, nil
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Err() error {
|
||||
return cfs.err
|
||||
}
|
||||
|
||||
func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -302,28 +324,6 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type clientFailSource struct {
|
||||
count int
|
||||
err error
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Next() bool {
|
||||
cfs.count++
|
||||
return cfs.count < 100
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Values() ([]interface{}, error) {
|
||||
if cfs.count == 3 {
|
||||
cfs.err = fmt.Errorf("client error")
|
||||
return nil, cfs.err
|
||||
}
|
||||
return []interface{}{make([]byte, 100000)}, nil
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Err() error {
|
||||
return cfs.err
|
||||
}
|
||||
|
||||
func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -381,7 +381,7 @@ func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
|
|||
}
|
||||
|
||||
func (cfs *clientFinalErrSource) Err() error {
|
||||
return fmt.Errorf("final error")
|
||||
return errors.Errorf("final error")
|
||||
}
|
||||
|
||||
func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
||||
|
|
222
copy_to.go
222
copy_to.go
|
@ -1,222 +0,0 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Deprecated. Use CopyFromRows instead. CopyToRows returns a CopyToSource
|
||||
// interface over the provided rows slice making it usable by *Conn.CopyTo.
|
||||
func CopyToRows(rows [][]interface{}) CopyToSource {
|
||||
return ©ToRows{rows: rows, idx: -1}
|
||||
}
|
||||
|
||||
type copyToRows struct {
|
||||
rows [][]interface{}
|
||||
idx int
|
||||
}
|
||||
|
||||
func (ctr *copyToRows) Next() bool {
|
||||
ctr.idx++
|
||||
return ctr.idx < len(ctr.rows)
|
||||
}
|
||||
|
||||
func (ctr *copyToRows) Values() ([]interface{}, error) {
|
||||
return ctr.rows[ctr.idx], nil
|
||||
}
|
||||
|
||||
func (ctr *copyToRows) Err() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deprecated. Use CopyFromSource instead. CopyToSource is the interface used by
|
||||
// *Conn.CopyTo as the source for copy data.
|
||||
type CopyToSource interface {
|
||||
// Next returns true if there is another row and makes the next row data
|
||||
// available to Values(). When there are no more rows available or an error
|
||||
// has occurred it returns false.
|
||||
Next() bool
|
||||
|
||||
// Values returns the values for the current row.
|
||||
Values() ([]interface{}, error)
|
||||
|
||||
// Err returns any error that has been encountered by the CopyToSource. If
|
||||
// this is not nil *Conn.CopyTo will abort the copy.
|
||||
Err() error
|
||||
}
|
||||
|
||||
type copyTo struct {
|
||||
conn *Conn
|
||||
tableName string
|
||||
columnNames []string
|
||||
rowSrc CopyToSource
|
||||
readerErrChan chan error
|
||||
}
|
||||
|
||||
func (ct *copyTo) readUntilReadyForQuery() {
|
||||
for {
|
||||
t, r, err := ct.conn.rxMsg()
|
||||
if err != nil {
|
||||
ct.readerErrChan <- err
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
}
|
||||
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
ct.conn.rxReadyForQuery(r)
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
case commandComplete:
|
||||
case errorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
|
||||
default:
|
||||
err = ct.conn.processContextFreeMsg(t, r)
|
||||
if err != nil {
|
||||
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *copyTo) waitForReaderDone() error {
|
||||
var err error
|
||||
for err = range ct.readerErrChan {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ct *copyTo) run() (int, error) {
|
||||
quotedTableName := quoteIdentifier(ct.tableName)
|
||||
buf := &bytes.Buffer{}
|
||||
for i, cn := range ct.columnNames {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
buf.WriteString(quoteIdentifier(cn))
|
||||
}
|
||||
quotedColumnNames := buf.String()
|
||||
|
||||
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.conn.readUntilCopyInResponse()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
go ct.readUntilReadyForQuery()
|
||||
defer ct.waitForReaderDone()
|
||||
|
||||
wbuf := newWriteBuf(ct.conn, copyData)
|
||||
|
||||
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
|
||||
wbuf.WriteInt32(0)
|
||||
wbuf.WriteInt32(0)
|
||||
|
||||
var sentCount int
|
||||
|
||||
for ct.rowSrc.Next() {
|
||||
select {
|
||||
case err = <-ct.readerErrChan:
|
||||
return 0, err
|
||||
default:
|
||||
}
|
||||
|
||||
if len(wbuf.buf) > 65536 {
|
||||
wbuf.closeMsg()
|
||||
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Directly manipulate wbuf to reset to reuse the same buffer
|
||||
wbuf.buf = wbuf.buf[0:5]
|
||||
wbuf.buf[0] = copyData
|
||||
wbuf.sizeIdx = 1
|
||||
}
|
||||
|
||||
sentCount++
|
||||
|
||||
values, err := ct.rowSrc.Values()
|
||||
if err != nil {
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
}
|
||||
if len(values) != len(ct.columnNames) {
|
||||
ct.cancelCopyIn()
|
||||
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
|
||||
if err != nil {
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if ct.rowSrc.Err() != nil {
|
||||
ct.cancelCopyIn()
|
||||
return 0, ct.rowSrc.Err()
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(-1) // terminate the copy stream
|
||||
|
||||
wbuf.startMsg(copyDone)
|
||||
wbuf.closeMsg()
|
||||
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.waitForReaderDone()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sentCount, nil
|
||||
}
|
||||
|
||||
func (ct *copyTo) cancelCopyIn() error {
|
||||
wbuf := newWriteBuf(ct.conn, copyFail)
|
||||
wbuf.WriteCString("client error: abort")
|
||||
wbuf.closeMsg()
|
||||
_, err := ct.conn.conn.Write(wbuf.buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deprecated. Use CopyFrom instead. CopyTo uses the PostgreSQL copy protocol to
|
||||
// perform bulk data insertion. It returns the number of rows copied and an
|
||||
// error.
|
||||
//
|
||||
// CopyTo requires all values use the binary format. Almost all types
|
||||
// implemented by pgx use the binary format by default. Types implementing
|
||||
// Encoder can only be used if they encode to the binary format.
|
||||
func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
|
||||
ct := ©To{
|
||||
conn: c,
|
||||
tableName: tableName,
|
||||
columnNames: columnNames,
|
||||
rowSrc: rowSrc,
|
||||
readerErrChan: make(chan error),
|
||||
}
|
||||
|
||||
return ct.run()
|
||||
}
|
367
copy_to_test.go
367
copy_to_test.go
|
@ -1,367 +0,0 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
func TestConnCopyToSmall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g timestamptz
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyTo: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g timestamptz,
|
||||
h bytea
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{}
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyTo: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
|
||||
if _, ok := conn.PgTypes[oid]; !ok {
|
||||
return // No JSON/JSONB type -- must be running against old PostgreSQL
|
||||
}
|
||||
}
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a json,
|
||||
b jsonb
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
|
||||
{nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyTo: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToFailServerSideMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int4,
|
||||
b varchar not null
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{int32(1), "abc"},
|
||||
{int32(2), nil}, // this row should trigger a failure
|
||||
{int32(3), "def"},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyTo return error, but it did not")
|
||||
}
|
||||
if _, ok := err.(pgx.PgError); !ok {
|
||||
t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a bytea not null
|
||||
)`)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyTo return error, but it did not")
|
||||
}
|
||||
if _, ok := err.(pgx.PgError); !ok {
|
||||
t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
copyTime := endTime.Sub(startTime)
|
||||
if copyTime > time.Second {
|
||||
t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a bytea not null
|
||||
)`)
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyTo return error, but it did not")
|
||||
}
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a bytea not null
|
||||
)`)
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyTo return error, but it did not")
|
||||
}
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if len(outputRows) != 0 {
|
||||
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
76
doc.go
76
doc.go
|
@ -62,17 +62,15 @@ Use Exec to execute a query that does not return a result set.
|
|||
|
||||
Connection Pool
|
||||
|
||||
Connection pool usage is explicit and configurable. In pgx, a connection can
|
||||
be created and managed directly, or a connection pool with a configurable
|
||||
maximum connections can be used. Also, the connection pool offers an after
|
||||
connect hook that allows every connection to be automatically setup before
|
||||
being made available in the connection pool. This is especially useful to
|
||||
ensure all connections have the same prepared statements available or to
|
||||
change any other connection settings.
|
||||
Connection pool usage is explicit and configurable. In pgx, a connection can be
|
||||
created and managed directly, or a connection pool with a configurable maximum
|
||||
connections can be used. The connection pool offers an after connect hook that
|
||||
allows every connection to be automatically setup before being made available in
|
||||
the connection pool.
|
||||
|
||||
It delegates Query, QueryRow, Exec, and Begin functions to an automatically
|
||||
checked out and released connection so you can avoid manually acquiring and
|
||||
releasing connections when you do not need that level of control.
|
||||
It delegates methods such as QueryRow to an automatically checked out and
|
||||
released connection so you can avoid manually acquiring and releasing
|
||||
connections when you do not need that level of control.
|
||||
|
||||
var name string
|
||||
var weight int64
|
||||
|
@ -117,11 +115,11 @@ particular:
|
|||
|
||||
Null Mapping
|
||||
|
||||
pgx can map nulls in two ways. The first is Null* types that have a data field
|
||||
and a valid field. They work in a similar fashion to database/sql. The second
|
||||
is to use a pointer to a pointer.
|
||||
pgx can map nulls in two ways. The first is package pgtype provides types that
|
||||
have a data field and a status field. They work in a similar fashion to
|
||||
database/sql. The second is to use a pointer to a pointer.
|
||||
|
||||
var foo pgx.NullString
|
||||
var foo pgtype.Varchar
|
||||
var bar *string
|
||||
err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&a, &b)
|
||||
if err != nil {
|
||||
|
@ -133,20 +131,15 @@ Array Mapping
|
|||
pgx maps between int16, int32, int64, float32, float64, and string Go slices
|
||||
and the equivalent PostgreSQL array type. Go slices of native types do not
|
||||
support nulls, so if a PostgreSQL array that contains a null is read into a
|
||||
native Go slice an error will occur.
|
||||
|
||||
Hstore Mapping
|
||||
|
||||
pgx includes an Hstore type and a NullHstore type. Hstore is simply a
|
||||
map[string]string and is preferred when the hstore contains no nulls. NullHstore
|
||||
follows the Null* pattern and supports null values.
|
||||
native Go slice an error will occur. The pgtype package includes many more
|
||||
array types for PostgreSQL types that do not directly map to native Go types.
|
||||
|
||||
JSON and JSONB Mapping
|
||||
|
||||
pgx includes built-in support to marshal and unmarshal between Go types and
|
||||
the PostgreSQL JSON and JSONB.
|
||||
|
||||
Inet and Cidr Mapping
|
||||
Inet and CIDR Mapping
|
||||
|
||||
pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In
|
||||
addition, as a convenience pgx will encode from a net.IP; it will assume a /32
|
||||
|
@ -155,25 +148,10 @@ netmask for IPv4 and a /128 for IPv6.
|
|||
Custom Type Support
|
||||
|
||||
pgx includes support for the common data types like integers, floats, strings,
|
||||
dates, and times that have direct mappings between Go and SQL. Support can be
|
||||
added for additional types like point, hstore, numeric, etc. that do not have
|
||||
direct mappings in Go by the types implementing ScannerPgx and Encoder.
|
||||
|
||||
Custom types can support text or binary formats. Binary format can provide a
|
||||
large performance increase. The natural place for deciding the format for a
|
||||
value would be in ScannerPgx as it is responsible for decoding the returned
|
||||
data. However, that is impossible as the query has already been sent by the time
|
||||
the ScannerPgx is invoked. The solution to this is the global
|
||||
DefaultTypeFormats. If a custom type prefers binary format it should register it
|
||||
there.
|
||||
|
||||
pgx.DefaultTypeFormats["point"] = pgx.BinaryFormatCode
|
||||
|
||||
Note that the type is referred to by name, not by OID. This is because custom
|
||||
PostgreSQL types like hstore will have different OIDs on different servers. When
|
||||
pgx establishes a connection it queries the pg_type table for all types. It then
|
||||
matches the names in DefaultTypeFormats with the returned OIDs and stores it in
|
||||
Conn.PgTypes.
|
||||
dates, and times that have direct mappings between Go and SQL. In addition,
|
||||
pgx uses the github.com/jackc/pgx/pgtype library to support more types. See
|
||||
documention for that library for instructions on how to implement custom
|
||||
types.
|
||||
|
||||
See example_custom_type_test.go for an example of a custom type for the
|
||||
PostgreSQL point type.
|
||||
|
@ -184,15 +162,12 @@ and database/sql/driver.Valuer interfaces.
|
|||
Raw Bytes Mapping
|
||||
|
||||
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified
|
||||
to PostgreSQL. In like manner, a *[]byte passed to Scan will be filled with
|
||||
the raw bytes returned by PostgreSQL. This can be especially useful for reading
|
||||
varchar, text, json, and jsonb values directly into a []byte and avoiding the
|
||||
type conversion from string.
|
||||
to PostgreSQL.
|
||||
|
||||
Transactions
|
||||
|
||||
Transactions are started by calling Begin or BeginIso. The BeginIso variant
|
||||
creates a transaction with a specified isolation level.
|
||||
Transactions are started by calling Begin or BeginEx. The BeginEx variant
|
||||
can create a transaction with a specified isolation level.
|
||||
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
|
@ -257,9 +232,8 @@ connection.
|
|||
Logging
|
||||
|
||||
pgx defines a simple logger interface. Connections optionally accept a logger
|
||||
that satisfies this interface. The log15 package
|
||||
(http://gopkg.in/inconshreveable/log15.v2) satisfies this interface and it is
|
||||
simple to define adapters for other loggers. Set LogLevel to control logging
|
||||
verbosity.
|
||||
that satisfies this interface. Set LogLevel to control logging verbosity.
|
||||
Adapters for github.com/inconshreveable/log15, github.com/Sirupsen/logrus, and
|
||||
the testing log are provided in the log directory.
|
||||
*/
|
||||
package pgx
|
||||
|
|
|
@ -1,83 +1,76 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`)
|
||||
|
||||
// NullPoint represents a point that may be null.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullPoint struct {
|
||||
// Point represents a point that may be null.
|
||||
type Point struct {
|
||||
X, Y float64 // Coordinates of point
|
||||
Valid bool // Valid is true if not NULL
|
||||
Status pgtype.Status
|
||||
}
|
||||
|
||||
func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error {
|
||||
if vr.Type().DataTypeName != "point" {
|
||||
return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (OID %d)", vr.Type().DataTypeName, vr.Type().DataType))
|
||||
func (dst *Point) Set(src interface{}) error {
|
||||
return errors.Errorf("cannot convert %v to Point", src)
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
p.X, p.Y, p.Valid = 0, 0, false
|
||||
func (dst *Point) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case pgtype.Present:
|
||||
return dst
|
||||
case pgtype.Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Point) AssignTo(dst interface{}) error {
|
||||
return errors.Errorf("cannot assign %v to %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Point{Status: pgtype.Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch vr.Type().FormatCode {
|
||||
case pgx.TextFormatCode:
|
||||
s := vr.ReadString(vr.Len())
|
||||
s := string(src)
|
||||
match := pointRegexp.FindStringSubmatch(s)
|
||||
if match == nil {
|
||||
return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
|
||||
return errors.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
|
||||
var err error
|
||||
p.X, err = strconv.ParseFloat(match[1], 64)
|
||||
x, err := strconv.ParseFloat(match[1], 64)
|
||||
if err != nil {
|
||||
return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
|
||||
return errors.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
p.Y, err = strconv.ParseFloat(match[2], 64)
|
||||
y, err := strconv.ParseFloat(match[2], 64)
|
||||
if err != nil {
|
||||
return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
|
||||
}
|
||||
case pgx.BinaryFormatCode:
|
||||
return errors.New("binary format not implemented")
|
||||
default:
|
||||
return fmt.Errorf("unknown format %v", vr.Type().FormatCode)
|
||||
return errors.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
|
||||
p.Valid = true
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (p NullPoint) FormatCode() int16 { return pgx.TextFormatCode }
|
||||
|
||||
func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error {
|
||||
if !p.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
s := fmt.Sprintf("(%v,%v)", p.X, p.Y)
|
||||
w.WriteInt32(int32(len(s)))
|
||||
w.WriteBytes([]byte(s))
|
||||
*dst = Point{X: x, Y: y, Status: pgtype.Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p NullPoint) String() string {
|
||||
if p.Valid {
|
||||
return fmt.Sprintf("%v, %v", p.X, p.Y)
|
||||
}
|
||||
func (src *Point) String() string {
|
||||
if src.Status == pgtype.Null {
|
||||
return "null point"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%.1f, %.1f", src.X, src.Y)
|
||||
}
|
||||
|
||||
func Example_CustomType() {
|
||||
conn, err := pgx.Connect(*defaultConnConfig)
|
||||
if err != nil {
|
||||
|
@ -85,22 +78,22 @@ func Example_CustomType() {
|
|||
return
|
||||
}
|
||||
|
||||
var p NullPoint
|
||||
err = conn.QueryRow("select null::point").Scan(&p)
|
||||
// Override registered handler for point
|
||||
conn.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||
Value: &Point{},
|
||||
Name: "point",
|
||||
OID: 600,
|
||||
})
|
||||
|
||||
p := &Point{}
|
||||
err = conn.QueryRow("select null::point").Scan(p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(p)
|
||||
|
||||
err = conn.QueryRow("select point(1.5,2.5)").Scan(&p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(p)
|
||||
|
||||
err = conn.QueryRow("select $1::point", &NullPoint{X: 0.5, Y: 0.75, Valid: true}).Scan(&p)
|
||||
err = conn.QueryRow("select point(1.5,2.5)").Scan(p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
|
@ -109,5 +102,4 @@ func Example_CustomType() {
|
|||
// Output:
|
||||
// null point
|
||||
// 1.5, 2.5
|
||||
// 0.5, 0.75
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package pgx_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
|
@ -12,13 +13,6 @@ func Example_JSON() {
|
|||
return
|
||||
}
|
||||
|
||||
if _, ok := conn.PgTypes[pgx.JsonOid]; !ok {
|
||||
// No JSON type -- must be running against very old PostgreSQL
|
||||
// Pretend it works
|
||||
fmt.Println("John", 42)
|
||||
return
|
||||
}
|
||||
|
||||
type person struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
|
|
|
@ -2,10 +2,11 @@ package main
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/jackc/pgx"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
var pool *pgx.ConnPool
|
||||
|
@ -58,16 +59,13 @@ func listen() {
|
|||
conn.Listen("chat")
|
||||
|
||||
for {
|
||||
notification, err := conn.WaitForNotification(time.Second)
|
||||
if err == pgx.ErrNotificationTimeout {
|
||||
continue
|
||||
}
|
||||
notification, err := conn.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error waiting for notification:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Println("PID:", notification.Pid, "Channel:", notification.Channel, "Payload:", notification.Payload)
|
||||
fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
log "gopkg.in/inconshreveable/log15.v2"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/log/log15adapter"
|
||||
log "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
var pool *pgx.ConnPool
|
||||
|
@ -89,6 +91,8 @@ func urlHandler(w http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
|
||||
func main() {
|
||||
logger := log15adapter.NewLogger(log.New("module", "pgx"))
|
||||
|
||||
var err error
|
||||
connPoolConfig := pgx.ConnPoolConfig{
|
||||
ConnConfig: pgx.ConnConfig{
|
||||
|
@ -96,7 +100,7 @@ func main() {
|
|||
User: "jack",
|
||||
Password: "jack",
|
||||
Database: "url_shortener",
|
||||
Logger: log.New("module", "pgx"),
|
||||
Logger: logger,
|
||||
},
|
||||
MaxConnections: 5,
|
||||
AfterConnect: afterConnect,
|
||||
|
|
67
fastpath.go
67
fastpath.go
|
@ -2,29 +2,33 @@ package pgx
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func newFastpath(cn *Conn) *fastpath {
|
||||
return &fastpath{cn: cn, fns: make(map[string]Oid)}
|
||||
return &fastpath{cn: cn, fns: make(map[string]pgtype.OID)}
|
||||
}
|
||||
|
||||
type fastpath struct {
|
||||
cn *Conn
|
||||
fns map[string]Oid
|
||||
fns map[string]pgtype.OID
|
||||
}
|
||||
|
||||
func (f *fastpath) functionOID(name string) Oid {
|
||||
func (f *fastpath) functionOID(name string) pgtype.OID {
|
||||
return f.fns[name]
|
||||
}
|
||||
|
||||
func (f *fastpath) addFunction(name string, oid Oid) {
|
||||
func (f *fastpath) addFunction(name string, oid pgtype.OID) {
|
||||
f.fns[name] = oid
|
||||
}
|
||||
|
||||
func (f *fastpath) addFunctions(rows *Rows) error {
|
||||
for rows.Next() {
|
||||
var name string
|
||||
var oid Oid
|
||||
var oid pgtype.OID
|
||||
if err := rows.Scan(&name, &oid); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -47,41 +51,46 @@ func fpInt64Arg(n int64) fpArg {
|
|||
return res
|
||||
}
|
||||
|
||||
func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) {
|
||||
wbuf := newWriteBuf(f.cn, 'F') // function call
|
||||
wbuf.WriteInt32(int32(oid)) // function object id
|
||||
wbuf.WriteInt16(1) // # of argument format codes
|
||||
wbuf.WriteInt16(1) // format code: binary
|
||||
wbuf.WriteInt16(int16(len(args))) // # of arguments
|
||||
for _, arg := range args {
|
||||
wbuf.WriteInt32(int32(len(arg))) // length of argument
|
||||
wbuf.WriteBytes(arg) // argument value
|
||||
func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) {
|
||||
if err := f.cn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wbuf.WriteInt16(1) // response format code (binary)
|
||||
wbuf.closeMsg()
|
||||
|
||||
if _, err := f.cn.conn.Write(wbuf.buf); err != nil {
|
||||
buf := f.cn.wbuf
|
||||
buf = append(buf, 'F') // function call
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
buf = pgio.AppendInt32(buf, int32(oid)) // function object id
|
||||
buf = pgio.AppendInt16(buf, 1) // # of argument format codes
|
||||
buf = pgio.AppendInt16(buf, 1) // format code: binary
|
||||
buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments
|
||||
for _, arg := range args {
|
||||
buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument
|
||||
buf = append(buf, arg...) // argument value
|
||||
}
|
||||
buf = pgio.AppendInt16(buf, 1) // response format code (binary)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
if _, err := f.cn.conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err = f.cn.rxMsg()
|
||||
msg, err := f.cn.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch t {
|
||||
case 'V': // FunctionCallResponse
|
||||
data := r.readBytes(r.readInt32())
|
||||
res = make([]byte, len(data))
|
||||
copy(res, data)
|
||||
case 'Z': // Ready for query
|
||||
f.cn.rxReadyForQuery(r)
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.FunctionCallResponse:
|
||||
res = make([]byte, len(msg.Result))
|
||||
copy(res, msg.Result)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
f.cn.rxReadyForQuery(msg)
|
||||
// done
|
||||
return
|
||||
return res, err
|
||||
default:
|
||||
if err := f.cn.processContextFreeMsg(t, r); err != nil {
|
||||
if err := f.cn.processContextFreeMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
|
||||
|
@ -21,7 +22,6 @@ func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.Replicatio
|
|||
return conn
|
||||
}
|
||||
|
||||
|
||||
func closeConn(t testing.TB, conn *pgx.Conn) {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
|
|
222
hstore.go
222
hstore.go
|
@ -1,222 +0,0 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
hsPre = iota
|
||||
hsKey
|
||||
hsSep
|
||||
hsVal
|
||||
hsNul
|
||||
hsNext
|
||||
)
|
||||
|
||||
type hstoreParser struct {
|
||||
str string
|
||||
pos int
|
||||
}
|
||||
|
||||
func newHSP(in string) *hstoreParser {
|
||||
return &hstoreParser{
|
||||
pos: 0,
|
||||
str: in,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *hstoreParser) Consume() (r rune, end bool) {
|
||||
if p.pos >= len(p.str) {
|
||||
end = true
|
||||
return
|
||||
}
|
||||
r, w := utf8.DecodeRuneInString(p.str[p.pos:])
|
||||
p.pos += w
|
||||
return
|
||||
}
|
||||
|
||||
func (p *hstoreParser) Peek() (r rune, end bool) {
|
||||
if p.pos >= len(p.str) {
|
||||
end = true
|
||||
return
|
||||
}
|
||||
r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
|
||||
return
|
||||
}
|
||||
|
||||
func parseHstoreToMap(s string) (m map[string]string, err error) {
|
||||
keys, values, err := ParseHstore(s)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
m = make(map[string]string, len(keys))
|
||||
for i, key := range keys {
|
||||
if !values[i].Valid {
|
||||
err = fmt.Errorf("key '%s' has NULL value", key)
|
||||
m = nil
|
||||
return
|
||||
}
|
||||
m[key] = values[i].String
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) {
|
||||
keys, values, err := ParseHstore(s)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
store = make(map[string]NullString, len(keys))
|
||||
|
||||
for i, key := range keys {
|
||||
store[key] = values[i]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ParseHstore parses the string representation of an hstore column (the same
|
||||
// you would get from an ordinary SELECT) into two slices of keys and values. it
|
||||
// is used internally in the default parsing of hstores, but is exported for use
|
||||
// in handling custom data structures backed by an hstore column without the
|
||||
// overhead of creating a map[string]string
|
||||
func ParseHstore(s string) (k []string, v []NullString, err error) {
|
||||
if s == "" {
|
||||
return
|
||||
}
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
keys := []string{}
|
||||
values := []NullString{}
|
||||
p := newHSP(s)
|
||||
|
||||
r, end := p.Consume()
|
||||
state := hsPre
|
||||
|
||||
for !end {
|
||||
switch state {
|
||||
case hsPre:
|
||||
if r == '"' {
|
||||
state = hsKey
|
||||
} else {
|
||||
err = errors.New("String does not begin with \"")
|
||||
}
|
||||
case hsKey:
|
||||
switch r {
|
||||
case '"': //End of the key
|
||||
if buf.Len() == 0 {
|
||||
err = errors.New("Empty Key is invalid")
|
||||
} else {
|
||||
keys = append(keys, buf.String())
|
||||
buf = bytes.Buffer{}
|
||||
state = hsSep
|
||||
}
|
||||
case '\\': //Potential escaped character
|
||||
n, end := p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS in key, expecting character or \"")
|
||||
case n == '"', n == '\\':
|
||||
buf.WriteRune(n)
|
||||
default:
|
||||
buf.WriteRune(r)
|
||||
buf.WriteRune(n)
|
||||
}
|
||||
default: //Any other character
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case hsSep:
|
||||
if r == '=' {
|
||||
r, end = p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS after '=', expecting '>'")
|
||||
case r == '>':
|
||||
r, end = p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
|
||||
case r == '"':
|
||||
state = hsVal
|
||||
case r == 'N':
|
||||
state = hsNul
|
||||
default:
|
||||
err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r)
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("Invalid character after '=', expecting '>'")
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r)
|
||||
}
|
||||
case hsVal:
|
||||
switch r {
|
||||
case '"': //End of the value
|
||||
values = append(values, NullString{String: buf.String(), Valid: true})
|
||||
buf = bytes.Buffer{}
|
||||
state = hsNext
|
||||
case '\\': //Potential escaped character
|
||||
n, end := p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS in key, expecting character or \"")
|
||||
case n == '"', n == '\\':
|
||||
buf.WriteRune(n)
|
||||
default:
|
||||
buf.WriteRune(r)
|
||||
buf.WriteRune(n)
|
||||
}
|
||||
default: //Any other character
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case hsNul:
|
||||
nulBuf := make([]rune, 3)
|
||||
nulBuf[0] = r
|
||||
for i := 1; i < 3; i++ {
|
||||
r, end = p.Consume()
|
||||
if end {
|
||||
err = errors.New("Found EOS in NULL value")
|
||||
return
|
||||
}
|
||||
nulBuf[i] = r
|
||||
}
|
||||
if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
|
||||
values = append(values, NullString{String: "", Valid: false})
|
||||
state = hsNext
|
||||
} else {
|
||||
err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
|
||||
}
|
||||
case hsNext:
|
||||
if r == ',' {
|
||||
r, end = p.Consume()
|
||||
switch {
|
||||
case end:
|
||||
err = errors.New("Found EOS after ',', expcting space")
|
||||
case (unicode.IsSpace(r)):
|
||||
r, end = p.Consume()
|
||||
state = hsKey
|
||||
default:
|
||||
err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r)
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, end = p.Consume()
|
||||
}
|
||||
if state != hsNext {
|
||||
err = errors.New("Improperly formatted hstore")
|
||||
return
|
||||
}
|
||||
k = keys
|
||||
v = values
|
||||
return
|
||||
}
|
181
hstore_test.go
181
hstore_test.go
|
@ -1,181 +0,0 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHstoreTranscode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type test struct {
|
||||
hstore pgx.Hstore
|
||||
description string
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{pgx.Hstore{}, "empty"},
|
||||
{pgx.Hstore{"foo": "bar"}, "single key/value"},
|
||||
{pgx.Hstore{"foo": "bar", "baz": "quz"}, "multiple key/values"},
|
||||
{pgx.Hstore{"NULL": "bar"}, `string "NULL" key`},
|
||||
{pgx.Hstore{"foo": "NULL"}, `string "NULL" value`},
|
||||
}
|
||||
|
||||
specialStringTests := []struct {
|
||||
input string
|
||||
description string
|
||||
}{
|
||||
{`"`, `double quote (")`},
|
||||
{`'`, `single quote (')`},
|
||||
{`\`, `backslash (\)`},
|
||||
{`\\`, `multiple backslashes (\\)`},
|
||||
{`=>`, `separator (=>)`},
|
||||
{` `, `space`},
|
||||
{`\ / / \\ => " ' " '`, `multiple special characters`},
|
||||
}
|
||||
for _, sst := range specialStringTests {
|
||||
tests = append(tests, test{pgx.Hstore{sst.input + "foo": "bar"}, "key with " + sst.description + " at beginning"})
|
||||
tests = append(tests, test{pgx.Hstore{"foo" + sst.input + "foo": "bar"}, "key with " + sst.description + " in middle"})
|
||||
tests = append(tests, test{pgx.Hstore{"foo" + sst.input: "bar"}, "key with " + sst.description + " at end"})
|
||||
tests = append(tests, test{pgx.Hstore{sst.input: "bar"}, "key is " + sst.description})
|
||||
|
||||
tests = append(tests, test{pgx.Hstore{"foo": sst.input + "bar"}, "value with " + sst.description + " at beginning"})
|
||||
tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input + "bar"}, "value with " + sst.description + " in middle"})
|
||||
tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input}, "value with " + sst.description + " at end"})
|
||||
tests = append(tests, test{pgx.Hstore{"foo": sst.input}, "value is " + sst.description})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
var result pgx.Hstore
|
||||
err := conn.QueryRow("select $1::hstore", tt.hstore).Scan(&result)
|
||||
if err != nil {
|
||||
t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err)
|
||||
}
|
||||
|
||||
for key, inValue := range tt.hstore {
|
||||
outValue, ok := result[key]
|
||||
if ok {
|
||||
if inValue != outValue {
|
||||
t.Errorf(`%s: Key %s mismatch - expected %s, received %s`, tt.description, key, inValue, outValue)
|
||||
}
|
||||
} else {
|
||||
t.Errorf(`%s: Missing key %s`, tt.description, key)
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullHstoreTranscode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type test struct {
|
||||
nullHstore pgx.NullHstore
|
||||
description string
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{pgx.NullHstore{}, "null"},
|
||||
{pgx.NullHstore{Valid: true}, "empty"},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"single key/value"},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}},
|
||||
Valid: true},
|
||||
"multiple key/values"},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
`string "NULL" key`},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}},
|
||||
Valid: true},
|
||||
`string "NULL" value`},
|
||||
{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}},
|
||||
Valid: true},
|
||||
`NULL value`},
|
||||
}
|
||||
|
||||
specialStringTests := []struct {
|
||||
input string
|
||||
description string
|
||||
}{
|
||||
{`"`, `double quote (")`},
|
||||
{`'`, `single quote (')`},
|
||||
{`\`, `backslash (\)`},
|
||||
{`\\`, `multiple backslashes (\\)`},
|
||||
{`=>`, `separator (=>)`},
|
||||
{` `, `space`},
|
||||
{`\ / / \\ => " ' " '`, `multiple special characters`},
|
||||
}
|
||||
for _, sst := range specialStringTests {
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"key with " + sst.description + " at beginning"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"key with " + sst.description + " in middle"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"key with " + sst.description + " at end"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"key is " + sst.description})
|
||||
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"value with " + sst.description + " at beginning"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}},
|
||||
Valid: true},
|
||||
"value with " + sst.description + " in middle"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}},
|
||||
Valid: true},
|
||||
"value with " + sst.description + " at end"})
|
||||
tests = append(tests, test{pgx.NullHstore{
|
||||
Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}},
|
||||
Valid: true},
|
||||
"value is " + sst.description})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
var result pgx.NullHstore
|
||||
err := conn.QueryRow("select $1::hstore", tt.nullHstore).Scan(&result)
|
||||
if err != nil {
|
||||
t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err)
|
||||
}
|
||||
|
||||
if result.Valid != tt.nullHstore.Valid {
|
||||
t.Errorf(`%s: Valid mismatch - expected %v, received %v`, tt.description, tt.nullHstore.Valid, result.Valid)
|
||||
}
|
||||
|
||||
for key, inValue := range tt.nullHstore.Hstore {
|
||||
outValue, ok := result.Hstore[key]
|
||||
if ok {
|
||||
if inValue != outValue {
|
||||
t.Errorf(`%s: Key %s mismatch - expected %v, received %v`, tt.description, key, inValue, outValue)
|
||||
}
|
||||
} else {
|
||||
t.Errorf(`%s: Missing key %s`, tt.description, key)
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,237 @@
|
|||
package sanitize
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Part is either a string or an int. A string is raw SQL. An int is a
|
||||
// argument placeholder.
|
||||
type Part interface{}
|
||||
|
||||
type Query struct {
|
||||
Parts []Part
|
||||
}
|
||||
|
||||
func (q *Query) Sanitize(args ...interface{}) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
for _, part := range q.Parts {
|
||||
var str string
|
||||
switch part := part.(type) {
|
||||
case string:
|
||||
str = part
|
||||
case int:
|
||||
argIdx := part - 1
|
||||
if argIdx >= len(args) {
|
||||
return "", errors.Errorf("insufficient arguments")
|
||||
}
|
||||
arg := args[argIdx]
|
||||
switch arg := arg.(type) {
|
||||
case nil:
|
||||
str = "null"
|
||||
case int64:
|
||||
str = strconv.FormatInt(arg, 10)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case bool:
|
||||
str = strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
str = QuoteBytes(arg)
|
||||
case string:
|
||||
str = QuoteString(arg)
|
||||
case time.Time:
|
||||
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
default:
|
||||
return "", errors.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
default:
|
||||
return "", errors.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
for i, used := range argUse {
|
||||
if !used {
|
||||
return "", errors.Errorf("unused argument: %d", i)
|
||||
}
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
}
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func QuoteString(str string) string {
|
||||
return "'" + strings.Replace(str, "'", "''", -1) + "'"
|
||||
}
|
||||
|
||||
func QuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
stateFn stateFn
|
||||
parts []Part
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case 'e', 'E':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '\'' {
|
||||
l.pos += width
|
||||
return escapeStringState
|
||||
}
|
||||
case '\'':
|
||||
return singleQuoteState
|
||||
case '"':
|
||||
return doubleQuoteState
|
||||
case '$':
|
||||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if '0' <= nextRune && nextRune <= '9' {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||
}
|
||||
l.start = l.pos
|
||||
return placeholderState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func singleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '"' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// placeholderState consumes a placeholder value. The $ must have already has
|
||||
// already been consumed. The first rune must be a digit.
|
||||
func placeholderState(l *sqlLexer) stateFn {
|
||||
num := 0
|
||||
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
if '0' <= r && r <= '9' {
|
||||
num *= 10
|
||||
num += int(r - '0')
|
||||
} else {
|
||||
l.parts = append(l.parts, num)
|
||||
l.pos -= width
|
||||
l.start = l.pos
|
||||
return rawState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func escapeStringState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return query.Sanitize(args...)
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
package sanitize_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/internal/sanitize"
|
||||
)
|
||||
|
||||
func TestNewQuery(t *testing.T) {
|
||||
successTests := []struct {
|
||||
sql string
|
||||
expected sanitize.Query
|
||||
}{
|
||||
{
|
||||
sql: "select 42",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
},
|
||||
{
|
||||
sql: "select $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 'quoted $42', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select "doubled quoted $42", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 'foo''bar', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select "foo""bar", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select '''', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select """", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
|
||||
},
|
||||
{
|
||||
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
|
||||
},
|
||||
{
|
||||
sql: `select E'escape string\' $42', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
|
||||
},
|
||||
{
|
||||
sql: `select e'escape string\' $42', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successTests {
|
||||
query, err := sanitize.NewQuery(tt.sql)
|
||||
if err != nil {
|
||||
t.Errorf("%d. %v", i, err)
|
||||
}
|
||||
|
||||
if len(query.Parts) == len(tt.expected.Parts) {
|
||||
for j := range query.Parts {
|
||||
if query.Parts[j] != tt.expected.Parts[j] {
|
||||
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySanitize(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
args: []interface{}{},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{float64(1.23)},
|
||||
expected: `select 1.23`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{true},
|
||||
expected: `select true`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
|
||||
expected: `select '\x00010203ff'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{nil},
|
||||
expected: `select null`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foobar"},
|
||||
expected: `select 'foobar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foo'bar"},
|
||||
expected: `select 'foo''bar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{`foo\'bar`},
|
||||
expected: `select 'foo\''bar'`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
actual, err := tt.query.Sanitize(tt.args...)
|
||||
if err != nil {
|
||||
t.Errorf("%d. %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if tt.expected != actual {
|
||||
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `insufficient arguments`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `unused argument: 0`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{42},
|
||||
expected: `invalid arg type: int`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
_, err := tt.query.Sanitize(tt.args...)
|
||||
if err == nil || err.Error() != tt.expected {
|
||||
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,6 +2,8 @@ package pgx
|
|||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
// LargeObjects is a structure used to access the large objects API. It is only
|
||||
|
@ -60,19 +62,19 @@ const (
|
|||
|
||||
// Create creates a new large object. If id is zero, the server assigns an
|
||||
// unused OID.
|
||||
func (o *LargeObjects) Create(id Oid) (Oid, error) {
|
||||
newOid, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))}))
|
||||
return Oid(newOid), err
|
||||
func (o *LargeObjects) Create(id pgtype.OID) (pgtype.OID, error) {
|
||||
newOID, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))}))
|
||||
return pgtype.OID(newOID), err
|
||||
}
|
||||
|
||||
// Open opens an existing large object with the given mode.
|
||||
func (o *LargeObjects) Open(oid Oid, mode LargeObjectMode) (*LargeObject, error) {
|
||||
func (o *LargeObjects) Open(oid pgtype.OID, mode LargeObjectMode) (*LargeObject, error) {
|
||||
fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))}))
|
||||
return &LargeObject{fd: fd, lo: o}, err
|
||||
}
|
||||
|
||||
// Unlink removes a large object from the database.
|
||||
func (o *LargeObjects) Unlink(oid Oid) error {
|
||||
func (o *LargeObjects) Unlink(oid pgtype.OID) error {
|
||||
_, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))})
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger
|
||||
// log.
|
||||
package log15adapter
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
// Log15Logger interface defines the subset of
|
||||
// github.com/inconshreveable/log15.Logger that this adapter uses.
|
||||
type Log15Logger interface {
|
||||
Debug(msg string, ctx ...interface{})
|
||||
Info(msg string, ctx ...interface{})
|
||||
Warn(msg string, ctx ...interface{})
|
||||
Error(msg string, ctx ...interface{})
|
||||
Crit(msg string, ctx ...interface{})
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
l Log15Logger
|
||||
}
|
||||
|
||||
func NewLogger(l Log15Logger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logArgs := make([]interface{}, 0, len(data))
|
||||
for k, v := range data {
|
||||
logArgs = append(logArgs, k, v)
|
||||
}
|
||||
|
||||
switch level {
|
||||
case pgx.LogLevelTrace:
|
||||
l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...)
|
||||
case pgx.LogLevelDebug:
|
||||
l.l.Debug(msg, logArgs...)
|
||||
case pgx.LogLevelInfo:
|
||||
l.l.Info(msg, logArgs...)
|
||||
case pgx.LogLevelWarn:
|
||||
l.l.Warn(msg, logArgs...)
|
||||
case pgx.LogLevelError:
|
||||
l.l.Error(msg, logArgs...)
|
||||
default:
|
||||
l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
// Package logrusadapter provides a logger that writes to a github.com/Sirupsen/logrus.Logger
|
||||
// log.
|
||||
package logrusadapter
|
||||
|
||||
import (
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func NewLogger(l *logrus.Logger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
var logger logrus.FieldLogger
|
||||
if data != nil {
|
||||
logger = l.l.WithFields(data)
|
||||
} else {
|
||||
logger = l.l
|
||||
}
|
||||
|
||||
switch level {
|
||||
case pgx.LogLevelTrace:
|
||||
logger.WithField("PGX_LOG_LEVEL", level).Debug(msg)
|
||||
case pgx.LogLevelDebug:
|
||||
logger.Debug(msg)
|
||||
case pgx.LogLevelInfo:
|
||||
logger.Info(msg)
|
||||
case pgx.LogLevelWarn:
|
||||
logger.Warn(msg)
|
||||
case pgx.LogLevelError:
|
||||
logger.Error(msg)
|
||||
default:
|
||||
logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
// Package testingadapter provides a logger that writes to a test or benchmark
|
||||
// log.
|
||||
package testingadapter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
// TestingLogger interface defines the subset of testing.TB methods used by this
|
||||
// adapter.
|
||||
type TestingLogger interface {
|
||||
Log(args ...interface{})
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
l TestingLogger
|
||||
}
|
||||
|
||||
func NewLogger(l TestingLogger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logArgs := make([]interface{}, 0, 2+len(data))
|
||||
logArgs = append(logArgs, level, msg)
|
||||
for k, v := range data {
|
||||
logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v))
|
||||
}
|
||||
l.l.Log(logArgs...)
|
||||
}
|
41
logger.go
41
logger.go
|
@ -2,13 +2,13 @@ package pgx
|
|||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// The values for log levels are chosen such that the zero value means that no
|
||||
// log level was specified and we can default to LogLevelDebug to preserve
|
||||
// the behavior that existed prior to log level introduction.
|
||||
// log level was specified.
|
||||
const (
|
||||
LogLevelTrace = 6
|
||||
LogLevelDebug = 5
|
||||
|
@ -18,16 +18,33 @@ const (
|
|||
LogLevelNone = 1
|
||||
)
|
||||
|
||||
// LogLevel represents the pgx logging level. See LogLevel* constants for
|
||||
// possible values.
|
||||
type LogLevel int
|
||||
|
||||
func (ll LogLevel) String() string {
|
||||
switch ll {
|
||||
case LogLevelTrace:
|
||||
return "trace"
|
||||
case LogLevelDebug:
|
||||
return "debug"
|
||||
case LogLevelInfo:
|
||||
return "info"
|
||||
case LogLevelWarn:
|
||||
return "warn"
|
||||
case LogLevelError:
|
||||
return "error"
|
||||
case LogLevelNone:
|
||||
return "none"
|
||||
default:
|
||||
return fmt.Sprintf("invalid level %d", ll)
|
||||
}
|
||||
}
|
||||
|
||||
// Logger is the interface used to get logging from pgx internals.
|
||||
// https://github.com/inconshreveable/log15 is the recommended logging package.
|
||||
// This logging interface was extracted from there. However, it should be simple
|
||||
// to adapt any logger to this interface.
|
||||
type Logger interface {
|
||||
// Log a message at the given level with context key/value pairs
|
||||
Debug(msg string, ctx ...interface{})
|
||||
Info(msg string, ctx ...interface{})
|
||||
Warn(msg string, ctx ...interface{})
|
||||
Error(msg string, ctx ...interface{})
|
||||
// Log a message at the given level with data key/value pairs. data may be nil.
|
||||
Log(level LogLevel, msg string, data map[string]interface{})
|
||||
}
|
||||
|
||||
// LogLevelFromString converts log level string to constant
|
||||
|
@ -39,7 +56,7 @@ type Logger interface {
|
|||
// warn
|
||||
// error
|
||||
// none
|
||||
func LogLevelFromString(s string) (int, error) {
|
||||
func LogLevelFromString(s string) (LogLevel, error) {
|
||||
switch s {
|
||||
case "trace":
|
||||
return LogLevelTrace, nil
|
||||
|
|
199
messages.go
199
messages.go
|
@ -1,66 +1,24 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
const (
|
||||
protocolVersionNumber = 196608 // 3.0
|
||||
)
|
||||
|
||||
const (
|
||||
backendKeyData = 'K'
|
||||
authenticationX = 'R'
|
||||
readyForQuery = 'Z'
|
||||
rowDescription = 'T'
|
||||
dataRow = 'D'
|
||||
commandComplete = 'C'
|
||||
errorResponse = 'E'
|
||||
noticeResponse = 'N'
|
||||
parseComplete = '1'
|
||||
parameterDescription = 't'
|
||||
bindComplete = '2'
|
||||
notificationResponse = 'A'
|
||||
emptyQueryResponse = 'I'
|
||||
noData = 'n'
|
||||
closeComplete = '3'
|
||||
flush = 'H'
|
||||
copyInResponse = 'G'
|
||||
copyData = 'd'
|
||||
copyFail = 'f'
|
||||
copyDone = 'c'
|
||||
)
|
||||
|
||||
type startupMessage struct {
|
||||
options map[string]string
|
||||
}
|
||||
|
||||
func newStartupMessage() *startupMessage {
|
||||
return &startupMessage{map[string]string{}}
|
||||
}
|
||||
|
||||
func (s *startupMessage) Bytes() (buf []byte) {
|
||||
buf = make([]byte, 8, 128)
|
||||
binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber))
|
||||
for key, value := range s.options {
|
||||
buf = append(buf, key...)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, value...)
|
||||
buf = append(buf, 0)
|
||||
}
|
||||
buf = append(buf, ("\000")...)
|
||||
binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf)))
|
||||
return buf
|
||||
}
|
||||
|
||||
type FieldDescription struct {
|
||||
Name string
|
||||
Table Oid
|
||||
AttributeNumber int16
|
||||
DataType Oid
|
||||
Table pgtype.OID
|
||||
AttributeNumber uint16
|
||||
DataType pgtype.OID
|
||||
DataTypeSize int16
|
||||
DataTypeName string
|
||||
Modifier int32
|
||||
Modifier uint32
|
||||
FormatCode int16
|
||||
}
|
||||
|
||||
|
@ -91,69 +49,114 @@ func (pe PgError) Error() string {
|
|||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
func newWriteBuf(c *Conn, t byte) *WriteBuf {
|
||||
buf := append(c.wbuf[0:0], t, 0, 0, 0, 0)
|
||||
c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c}
|
||||
return &c.writeBuf
|
||||
// Notice represents a notice response message reported by the PostgreSQL
|
||||
// server. Be aware that this is distinct from LISTEN/NOTIFY notification.
|
||||
type Notice PgError
|
||||
|
||||
// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it.
|
||||
func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.OID) []byte {
|
||||
buf = append(buf, 'P')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, query...)
|
||||
buf = append(buf, 0)
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(parameterOIDs)))
|
||||
for _, oid := range parameterOIDs {
|
||||
buf = pgio.AppendUint32(buf, uint32(oid))
|
||||
}
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// WriteBuf is used build messages to send to the PostgreSQL server. It is used
|
||||
// by the Encoder interface when implementing custom encoders.
|
||||
type WriteBuf struct {
|
||||
buf []byte
|
||||
sizeIdx int
|
||||
conn *Conn
|
||||
// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it.
|
||||
func appendDescribe(buf []byte, objectType byte, name string) []byte {
|
||||
buf = append(buf, 'D')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, objectType)
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) startMsg(t byte) {
|
||||
wb.closeMsg()
|
||||
wb.buf = append(wb.buf, t, 0, 0, 0, 0)
|
||||
wb.sizeIdx = len(wb.buf) - 4
|
||||
// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it.
|
||||
func appendSync(buf []byte) []byte {
|
||||
buf = append(buf, 'S')
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) closeMsg() {
|
||||
binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx))
|
||||
// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it.
|
||||
func appendBind(
|
||||
buf []byte,
|
||||
destinationPortal,
|
||||
preparedStatement string,
|
||||
connInfo *pgtype.ConnInfo,
|
||||
parameterOIDs []pgtype.OID,
|
||||
arguments []interface{},
|
||||
resultFormatCodes []int16,
|
||||
) ([]byte, error) {
|
||||
buf = append(buf, 'B')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, destinationPortal...)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, preparedStatement...)
|
||||
buf = append(buf, 0)
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(parameterOIDs)))
|
||||
for i, oid := range parameterOIDs {
|
||||
buf = pgio.AppendInt16(buf, chooseParameterFormatCode(connInfo, oid, arguments[i]))
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteByte(b byte) {
|
||||
wb.buf = append(wb.buf, b)
|
||||
buf = pgio.AppendInt16(buf, int16(len(arguments)))
|
||||
for i, oid := range parameterOIDs {
|
||||
var err error
|
||||
buf, err = encodePreparedStatementArgument(connInfo, buf, oid, arguments[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteCString(s string) {
|
||||
wb.buf = append(wb.buf, []byte(s)...)
|
||||
wb.buf = append(wb.buf, 0)
|
||||
buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes)))
|
||||
for _, fc := range resultFormatCodes {
|
||||
buf = pgio.AppendInt16(buf, fc)
|
||||
}
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteInt16(n int16) {
|
||||
b := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(b, uint16(n))
|
||||
wb.buf = append(wb.buf, b...)
|
||||
// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it.
|
||||
func appendExecute(buf []byte, portal string, maxRows uint32) []byte {
|
||||
buf = append(buf, 'E')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
buf = append(buf, portal...)
|
||||
buf = append(buf, 0)
|
||||
buf = pgio.AppendUint32(buf, maxRows)
|
||||
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteUint16(n uint16) {
|
||||
b := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(b, n)
|
||||
wb.buf = append(wb.buf, b...)
|
||||
}
|
||||
// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it.
|
||||
func appendQuery(buf []byte, query string) []byte {
|
||||
buf = append(buf, 'Q')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, query...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
func (wb *WriteBuf) WriteInt32(n int32) {
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(b, uint32(n))
|
||||
wb.buf = append(wb.buf, b...)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteUint32(n uint32) {
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(b, n)
|
||||
wb.buf = append(wb.buf, b...)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteInt64(n int64) {
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, uint64(n))
|
||||
wb.buf = append(wb.buf, b...)
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) WriteBytes(b []byte) {
|
||||
wb.buf = append(wb.buf, b...)
|
||||
return buf
|
||||
}
|
||||
|
|
316
msg_reader.go
316
msg_reader.go
|
@ -1,316 +0,0 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// msgReader is a helper that reads values from a PostgreSQL message.
|
||||
type msgReader struct {
|
||||
reader *bufio.Reader
|
||||
msgBytesRemaining int32
|
||||
err error
|
||||
log func(lvl int, msg string, ctx ...interface{})
|
||||
shouldLog func(lvl int) bool
|
||||
}
|
||||
|
||||
// Err returns any error that the msgReader has experienced
|
||||
func (r *msgReader) Err() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
// fatal tells rc that a Fatal error has occurred
|
||||
func (r *msgReader) fatal(err error) {
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
r.err = err
|
||||
}
|
||||
|
||||
// rxMsg reads the type and size of the next message.
|
||||
func (r *msgReader) rxMsg() (byte, error) {
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
if r.msgBytesRemaining > 0 {
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
_, err := r.reader.Discard(int(r.msgBytesRemaining))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(5)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0, err
|
||||
}
|
||||
msgType := b[0]
|
||||
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4
|
||||
r.reader.Discard(5)
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
func (r *msgReader) readByte() byte {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining--
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadByte()
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *msgReader) readInt16() int16 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 2
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(2)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := int16(binary.BigEndian.Uint16(b))
|
||||
|
||||
r.reader.Discard(2)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *msgReader) readInt32() int32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 4
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(4)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := int32(binary.BigEndian.Uint32(b))
|
||||
|
||||
r.reader.Discard(4)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *msgReader) readUint16() uint16 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 2
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(2)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := uint16(binary.BigEndian.Uint16(b))
|
||||
|
||||
r.reader.Discard(2)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *msgReader) readUint32() uint32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 4
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(4)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := uint32(binary.BigEndian.Uint32(b))
|
||||
|
||||
r.reader.Discard(4)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *msgReader) readInt64() int64 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 8
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(8)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := int64(binary.BigEndian.Uint64(b))
|
||||
|
||||
r.reader.Discard(8)
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *msgReader) readOid() Oid {
|
||||
return Oid(r.readInt32())
|
||||
}
|
||||
|
||||
// readCString reads a null terminated string
|
||||
func (r *msgReader) readCString() string {
|
||||
if r.err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadBytes(0)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= int32(len(b))
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
s := string(b[0 : len(b)-1])
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// readString reads count bytes and returns as string
|
||||
func (r *msgReader) readString(countI32 int32) string {
|
||||
if r.err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= countI32
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
count := int(countI32)
|
||||
var s string
|
||||
|
||||
if r.reader.Buffered() >= count {
|
||||
buf, _ := r.reader.Peek(count)
|
||||
s = string(buf)
|
||||
r.reader.Discard(count)
|
||||
} else {
|
||||
buf := make([]byte, count)
|
||||
_, err := io.ReadFull(r.reader, buf)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return ""
|
||||
}
|
||||
s = string(buf)
|
||||
}
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// readBytes reads count bytes and returns as []byte
|
||||
func (r *msgReader) readBytes(count int32) []byte {
|
||||
if r.err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= count
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return nil
|
||||
}
|
||||
|
||||
b := make([]byte, int(count))
|
||||
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||
/*
|
||||
pgio provides functions for appending integers to a []byte while doing byte
|
||||
order conversion.
|
||||
*/
|
||||
package pgio
|
|
@ -0,0 +1,40 @@
|
|||
package pgio
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func AppendUint16(buf []byte, n uint16) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0)
|
||||
binary.BigEndian.PutUint16(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendUint32(buf []byte, n uint32) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0)
|
||||
binary.BigEndian.PutUint32(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendUint64(buf []byte, n uint64) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
binary.BigEndian.PutUint64(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendInt16(buf []byte, n int16) []byte {
|
||||
return AppendUint16(buf, uint16(n))
|
||||
}
|
||||
|
||||
func AppendInt32(buf []byte, n int32) []byte {
|
||||
return AppendUint32(buf, uint32(n))
|
||||
}
|
||||
|
||||
func AppendInt64(buf []byte, n int64) []byte {
|
||||
return AppendUint64(buf, uint64(n))
|
||||
}
|
||||
|
||||
func SetInt32(buf []byte, n int32) {
|
||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package pgio
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAppendUint16NilBuf(t *testing.T) {
|
||||
buf := AppendUint16(nil, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint16EmptyBuf(t *testing.T) {
|
||||
buf := []byte{}
|
||||
buf = AppendUint16(buf, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||
buf := make([]byte, 0, 4)
|
||||
AppendUint16(buf, 1)
|
||||
buf = buf[0:2]
|
||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint32NilBuf(t *testing.T) {
|
||||
buf := AppendUint32(nil, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint32EmptyBuf(t *testing.T) {
|
||||
buf := []byte{}
|
||||
buf = AppendUint32(buf, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||
buf := make([]byte, 0, 4)
|
||||
AppendUint32(buf, 1)
|
||||
buf = buf[0:4]
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint64NilBuf(t *testing.T) {
|
||||
buf := AppendUint64(nil, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint64EmptyBuf(t *testing.T) {
|
||||
buf := []byte{}
|
||||
buf = AppendUint64(buf, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||
buf := make([]byte, 0, 8)
|
||||
AppendUint64(buf, 1)
|
||||
buf = buf[0:8]
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,501 @@
|
|||
package pgmock
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
ln net.Listener
|
||||
controller Controller
|
||||
}
|
||||
|
||||
func NewServer(controller Controller) (*Server, error) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
ln: ln,
|
||||
controller: controller,
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *Server) Addr() net.Addr {
|
||||
return s.ln.Addr()
|
||||
}
|
||||
|
||||
func (s *Server) ServeOne() error {
|
||||
conn, err := s.ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
s.Close()
|
||||
|
||||
backend, err := pgproto3.NewBackend(conn, conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return s.controller.Serve(backend)
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
err := s.ln.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Controller interface {
|
||||
Serve(backend *pgproto3.Backend) error
|
||||
}
|
||||
|
||||
type Step interface {
|
||||
Step(*pgproto3.Backend) error
|
||||
}
|
||||
|
||||
type Script struct {
|
||||
Steps []Step
|
||||
}
|
||||
|
||||
func (s *Script) Run(backend *pgproto3.Backend) error {
|
||||
for _, step := range s.Steps {
|
||||
err := step.Step(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Script) Serve(backend *pgproto3.Backend) error {
|
||||
for _, step := range s.Steps {
|
||||
err := step.Step(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Script) Step(backend *pgproto3.Backend) error {
|
||||
return s.Serve(backend)
|
||||
}
|
||||
|
||||
type expectMessageStep struct {
|
||||
want pgproto3.FrontendMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.Receive()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type expectStartupMessageStep struct {
|
||||
want *pgproto3.StartupMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.ReceiveStartupMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExpectMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, false)
|
||||
}
|
||||
|
||||
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, true)
|
||||
}
|
||||
|
||||
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
|
||||
if want, ok := want.(*pgproto3.StartupMessage); ok {
|
||||
return &expectStartupMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
return &expectMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
type sendMessageStep struct {
|
||||
msg pgproto3.BackendMessage
|
||||
}
|
||||
|
||||
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
return backend.Send(e.msg)
|
||||
}
|
||||
|
||||
func SendMessage(msg pgproto3.BackendMessage) Step {
|
||||
return &sendMessageStep{msg: msg}
|
||||
}
|
||||
|
||||
type waitForCloseMessageStep struct{}
|
||||
|
||||
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
for {
|
||||
msg, err := backend.Receive()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := msg.(*pgproto3.Terminate); ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WaitForClose() Step {
|
||||
return &waitForCloseMessageStep{}
|
||||
}
|
||||
|
||||
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
||||
return []Step{
|
||||
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||
SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}),
|
||||
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
}
|
||||
}
|
||||
|
||||
func PgxInitSteps() []Step {
|
||||
steps := []Step{
|
||||
ExpectMessage(&pgproto3.Parse{
|
||||
Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r', 'e')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)",
|
||||
}),
|
||||
ExpectMessage(&pgproto3.Describe{
|
||||
ObjectType: 'S',
|
||||
}),
|
||||
ExpectMessage(&pgproto3.Sync{}),
|
||||
SendMessage(&pgproto3.ParseComplete{}),
|
||||
SendMessage(&pgproto3.ParameterDescription{}),
|
||||
SendMessage(&pgproto3.RowDescription{
|
||||
Fields: []pgproto3.FieldDescription{
|
||||
{Name: "oid",
|
||||
TableOID: 1247,
|
||||
TableAttributeNumber: 65534,
|
||||
DataTypeOID: 26,
|
||||
DataTypeSize: 4,
|
||||
TypeModifier: 4294967295,
|
||||
Format: 0,
|
||||
},
|
||||
{Name: "typname",
|
||||
TableOID: 1247,
|
||||
TableAttributeNumber: 1,
|
||||
DataTypeOID: 19,
|
||||
DataTypeSize: 64,
|
||||
TypeModifier: 4294967295,
|
||||
Format: 0,
|
||||
},
|
||||
},
|
||||
}),
|
||||
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
ExpectMessage(&pgproto3.Bind{
|
||||
ResultFormatCodes: []int16{1, 1},
|
||||
}),
|
||||
ExpectMessage(&pgproto3.Execute{}),
|
||||
ExpectMessage(&pgproto3.Sync{}),
|
||||
SendMessage(&pgproto3.BindComplete{}),
|
||||
}
|
||||
|
||||
rowVals := []struct {
|
||||
oid pgtype.OID
|
||||
name string
|
||||
}{
|
||||
{16, "bool"},
|
||||
{17, "bytea"},
|
||||
{18, "char"},
|
||||
{19, "name"},
|
||||
{20, "int8"},
|
||||
{21, "int2"},
|
||||
{22, "int2vector"},
|
||||
{23, "int4"},
|
||||
{24, "regproc"},
|
||||
{25, "text"},
|
||||
{26, "oid"},
|
||||
{27, "tid"},
|
||||
{28, "xid"},
|
||||
{29, "cid"},
|
||||
{30, "oidvector"},
|
||||
{114, "json"},
|
||||
{142, "xml"},
|
||||
{143, "_xml"},
|
||||
{199, "_json"},
|
||||
{194, "pg_node_tree"},
|
||||
{32, "pg_ddl_command"},
|
||||
{210, "smgr"},
|
||||
{600, "point"},
|
||||
{601, "lseg"},
|
||||
{602, "path"},
|
||||
{603, "box"},
|
||||
{604, "polygon"},
|
||||
{628, "line"},
|
||||
{629, "_line"},
|
||||
{700, "float4"},
|
||||
{701, "float8"},
|
||||
{702, "abstime"},
|
||||
{703, "reltime"},
|
||||
{704, "tinterval"},
|
||||
{705, "unknown"},
|
||||
{718, "circle"},
|
||||
{719, "_circle"},
|
||||
{790, "money"},
|
||||
{791, "_money"},
|
||||
{829, "macaddr"},
|
||||
{869, "inet"},
|
||||
{650, "cidr"},
|
||||
{1000, "_bool"},
|
||||
{1001, "_bytea"},
|
||||
{1002, "_char"},
|
||||
{1003, "_name"},
|
||||
{1005, "_int2"},
|
||||
{1006, "_int2vector"},
|
||||
{1007, "_int4"},
|
||||
{1008, "_regproc"},
|
||||
{1009, "_text"},
|
||||
{1028, "_oid"},
|
||||
{1010, "_tid"},
|
||||
{1011, "_xid"},
|
||||
{1012, "_cid"},
|
||||
{1013, "_oidvector"},
|
||||
{1014, "_bpchar"},
|
||||
{1015, "_varchar"},
|
||||
{1016, "_int8"},
|
||||
{1017, "_point"},
|
||||
{1018, "_lseg"},
|
||||
{1019, "_path"},
|
||||
{1020, "_box"},
|
||||
{1021, "_float4"},
|
||||
{1022, "_float8"},
|
||||
{1023, "_abstime"},
|
||||
{1024, "_reltime"},
|
||||
{1025, "_tinterval"},
|
||||
{1027, "_polygon"},
|
||||
{1033, "aclitem"},
|
||||
{1034, "_aclitem"},
|
||||
{1040, "_macaddr"},
|
||||
{1041, "_inet"},
|
||||
{651, "_cidr"},
|
||||
{1263, "_cstring"},
|
||||
{1042, "bpchar"},
|
||||
{1043, "varchar"},
|
||||
{1082, "date"},
|
||||
{1083, "time"},
|
||||
{1114, "timestamp"},
|
||||
{1115, "_timestamp"},
|
||||
{1182, "_date"},
|
||||
{1183, "_time"},
|
||||
{1184, "timestamptz"},
|
||||
{1185, "_timestamptz"},
|
||||
{1186, "interval"},
|
||||
{1187, "_interval"},
|
||||
{1231, "_numeric"},
|
||||
{1266, "timetz"},
|
||||
{1270, "_timetz"},
|
||||
{1560, "bit"},
|
||||
{1561, "_bit"},
|
||||
{1562, "varbit"},
|
||||
{1563, "_varbit"},
|
||||
{1700, "numeric"},
|
||||
{1790, "refcursor"},
|
||||
{2201, "_refcursor"},
|
||||
{2202, "regprocedure"},
|
||||
{2203, "regoper"},
|
||||
{2204, "regoperator"},
|
||||
{2205, "regclass"},
|
||||
{2206, "regtype"},
|
||||
{4096, "regrole"},
|
||||
{4089, "regnamespace"},
|
||||
{2207, "_regprocedure"},
|
||||
{2208, "_regoper"},
|
||||
{2209, "_regoperator"},
|
||||
{2210, "_regclass"},
|
||||
{2211, "_regtype"},
|
||||
{4097, "_regrole"},
|
||||
{4090, "_regnamespace"},
|
||||
{2950, "uuid"},
|
||||
{2951, "_uuid"},
|
||||
{3220, "pg_lsn"},
|
||||
{3221, "_pg_lsn"},
|
||||
{3614, "tsvector"},
|
||||
{3642, "gtsvector"},
|
||||
{3615, "tsquery"},
|
||||
{3734, "regconfig"},
|
||||
{3769, "regdictionary"},
|
||||
{3643, "_tsvector"},
|
||||
{3644, "_gtsvector"},
|
||||
{3645, "_tsquery"},
|
||||
{3735, "_regconfig"},
|
||||
{3770, "_regdictionary"},
|
||||
{3802, "jsonb"},
|
||||
{3807, "_jsonb"},
|
||||
{2970, "txid_snapshot"},
|
||||
{2949, "_txid_snapshot"},
|
||||
{3904, "int4range"},
|
||||
{3905, "_int4range"},
|
||||
{3906, "numrange"},
|
||||
{3907, "_numrange"},
|
||||
{3908, "tsrange"},
|
||||
{3909, "_tsrange"},
|
||||
{3910, "tstzrange"},
|
||||
{3911, "_tstzrange"},
|
||||
{3912, "daterange"},
|
||||
{3913, "_daterange"},
|
||||
{3926, "int8range"},
|
||||
{3927, "_int8range"},
|
||||
{2249, "record"},
|
||||
{2287, "_record"},
|
||||
{2275, "cstring"},
|
||||
{2276, "any"},
|
||||
{2277, "anyarray"},
|
||||
{2278, "void"},
|
||||
{2279, "trigger"},
|
||||
{3838, "event_trigger"},
|
||||
{2280, "language_handler"},
|
||||
{2281, "internal"},
|
||||
{2282, "opaque"},
|
||||
{2283, "anyelement"},
|
||||
{2776, "anynonarray"},
|
||||
{3500, "anyenum"},
|
||||
{3115, "fdw_handler"},
|
||||
{325, "index_am_handler"},
|
||||
{3310, "tsm_handler"},
|
||||
{3831, "anyrange"},
|
||||
{51367, "gbtreekey4"},
|
||||
{51370, "_gbtreekey4"},
|
||||
{51371, "gbtreekey8"},
|
||||
{51374, "_gbtreekey8"},
|
||||
{51375, "gbtreekey16"},
|
||||
{51378, "_gbtreekey16"},
|
||||
{51379, "gbtreekey32"},
|
||||
{51382, "_gbtreekey32"},
|
||||
{51383, "gbtreekey_var"},
|
||||
{51386, "_gbtreekey_var"},
|
||||
{51921, "hstore"},
|
||||
{51926, "_hstore"},
|
||||
{52005, "ghstore"},
|
||||
{52008, "_ghstore"},
|
||||
}
|
||||
|
||||
for _, rv := range rowVals {
|
||||
step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat}))
|
||||
steps = append(steps, step)
|
||||
}
|
||||
|
||||
steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"}))
|
||||
steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
||||
|
||||
return steps
|
||||
}
|
||||
|
||||
type dataRowValue struct {
|
||||
Value interface{}
|
||||
FormatCode int16
|
||||
}
|
||||
|
||||
func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow {
|
||||
dr, err := buildDataRow(values, formatCodes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return dr
|
||||
}
|
||||
|
||||
func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) {
|
||||
dr := &pgproto3.DataRow{
|
||||
Values: make([][]byte, len(values)),
|
||||
}
|
||||
|
||||
if len(formatCodes) == 1 {
|
||||
for i := 1; i < len(values); i++ {
|
||||
formatCodes = append(formatCodes, formatCodes[0])
|
||||
}
|
||||
}
|
||||
|
||||
for i := range values {
|
||||
switch v := values[i].(type) {
|
||||
case string:
|
||||
values[i] = &pgtype.Text{String: v, Status: pgtype.Present}
|
||||
case int16:
|
||||
values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present}
|
||||
case int32:
|
||||
values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present}
|
||||
case int64:
|
||||
values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present}
|
||||
}
|
||||
}
|
||||
|
||||
for i := range values {
|
||||
switch formatCodes[i] {
|
||||
case pgproto3.TextFormat:
|
||||
if e, ok := values[i].(pgtype.TextEncoder); ok {
|
||||
buf, err := e.EncodeText(nil, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("failed to encode values[%d]", i)
|
||||
}
|
||||
dr.Values[i] = buf
|
||||
} else {
|
||||
return nil, errors.Errorf("values[%d] does not implement TextExcoder", i)
|
||||
}
|
||||
|
||||
case pgproto3.BinaryFormat:
|
||||
if e, ok := values[i].(pgtype.BinaryEncoder); ok {
|
||||
buf, err := e.EncodeBinary(nil, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("failed to encode values[%d]", i)
|
||||
}
|
||||
dr.Values[i] = buf
|
||||
} else {
|
||||
return nil, errors.Errorf("values[%d] does not implement BinaryEncoder", i)
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unknown FormatCode")
|
||||
}
|
||||
}
|
||||
|
||||
return dr, nil
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthTypeOk = 0
|
||||
AuthTypeCleartextPassword = 3
|
||||
AuthTypeMD5Password = 5
|
||||
)
|
||||
|
||||
type Authentication struct {
|
||||
Type uint32
|
||||
|
||||
// MD5Password fields
|
||||
Salt [4]byte
|
||||
}
|
||||
|
||||
func (*Authentication) Backend() {}
|
||||
|
||||
func (dst *Authentication) Decode(src []byte) error {
|
||||
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
|
||||
|
||||
switch dst.Type {
|
||||
case AuthTypeOk:
|
||||
case AuthTypeCleartextPassword:
|
||||
case AuthTypeMD5Password:
|
||||
copy(dst.Salt[:], src[4:8])
|
||||
default:
|
||||
return errors.Errorf("unknown authentication type: %d", dst.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Authentication) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = pgio.AppendUint32(dst, src.Type)
|
||||
|
||||
switch src.Type {
|
||||
case AuthTypeMD5Password:
|
||||
dst = append(dst, src.Salt[:]...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/chunkreader"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Backend struct {
|
||||
cr *chunkreader.ChunkReader
|
||||
w io.Writer
|
||||
|
||||
// Frontend message flyweights
|
||||
bind Bind
|
||||
_close Close
|
||||
describe Describe
|
||||
execute Execute
|
||||
flush Flush
|
||||
parse Parse
|
||||
passwordMessage PasswordMessage
|
||||
query Query
|
||||
startupMessage StartupMessage
|
||||
sync Sync
|
||||
terminate Terminate
|
||||
}
|
||||
|
||||
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
|
||||
cr := chunkreader.NewChunkReader(r)
|
||||
return &Backend{cr: cr, w: w}, nil
|
||||
}
|
||||
|
||||
func (b *Backend) Send(msg BackendMessage) error {
|
||||
_, err := b.w.Write(msg.Encode(nil))
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
||||
buf, err := b.cr.Next(4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||
|
||||
buf, err = b.cr.Next(msgSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = b.startupMessage.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &b.startupMessage, nil
|
||||
}
|
||||
|
||||
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||
header, err := b.cr.Next(5)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgType := header[0]
|
||||
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
|
||||
var msg FrontendMessage
|
||||
switch msgType {
|
||||
case 'B':
|
||||
msg = &b.bind
|
||||
case 'C':
|
||||
msg = &b._close
|
||||
case 'D':
|
||||
msg = &b.describe
|
||||
case 'E':
|
||||
msg = &b.execute
|
||||
case 'H':
|
||||
msg = &b.flush
|
||||
case 'P':
|
||||
msg = &b.parse
|
||||
case 'p':
|
||||
msg = &b.passwordMessage
|
||||
case 'Q':
|
||||
msg = &b.query
|
||||
case 'S':
|
||||
msg = &b.sync
|
||||
case 'X':
|
||||
msg = &b.terminate
|
||||
default:
|
||||
return nil, errors.Errorf("unknown message type: %c", msgType)
|
||||
}
|
||||
|
||||
msgBody, err := b.cr.Next(bodyLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = msg.Decode(msgBody)
|
||||
return msg, err
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type BackendKeyData struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}
|
||||
|
||||
func (*BackendKeyData) Backend() {}
|
||||
|
||||
func (dst *BackendKeyData) Decode(src []byte) error {
|
||||
if len(src) != 8 {
|
||||
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, 12)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *BackendKeyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}{
|
||||
Type: "BackendKeyData",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type BigEndianBuf [8]byte
|
||||
|
||||
func (b BigEndianBuf) Int16(n int16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, uint16(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint16(n uint16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int32(n int32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint32(n uint32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int64(n int64) []byte {
|
||||
buf := b[0:8]
|
||||
binary.BigEndian.PutUint64(buf, uint64(n))
|
||||
return buf
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Bind struct {
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters [][]byte
|
||||
ResultFormatCodes []int16
|
||||
}
|
||||
|
||||
func (*Bind) Frontend() {}
|
||||
|
||||
func (dst *Bind) Decode(src []byte) error {
|
||||
*dst = Bind{}
|
||||
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
dst.DestinationPortal = string(src[:idx])
|
||||
rp := idx + 1
|
||||
|
||||
idx = bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
dst.PreparedStatement = string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
if parameterFormatCodeCount > 0 {
|
||||
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
|
||||
|
||||
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < parameterFormatCodeCount; i++ {
|
||||
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
}
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
if parameterCount > 0 {
|
||||
dst.Parameters = make([][]byte, parameterCount)
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(src[rp:]) < msgSize {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
|
||||
dst.Parameters[i] = src[rp : rp+msgSize]
|
||||
rp += msgSize
|
||||
}
|
||||
}
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
|
||||
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < resultFormatCodeCount; i++ {
|
||||
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Bind) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'B')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.DestinationPortal...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.PreparedStatement...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||
for _, fc := range src.ParameterFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||
for _, p := range src.Parameters {
|
||||
if p == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
dst = pgio.AppendInt32(dst, int32(len(p)))
|
||||
dst = append(dst, p...)
|
||||
}
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||
for _, fc := range src.ResultFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Bind) MarshalJSON() ([]byte, error) {
|
||||
formattedParameters := make([]map[string]string, len(src.Parameters))
|
||||
for i, p := range src.Parameters {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if src.ParameterFormatCodes[i] == 0 {
|
||||
formattedParameters[i] = map[string]string{"text": string(p)}
|
||||
} else {
|
||||
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters []map[string]string
|
||||
ResultFormatCodes []int16
|
||||
}{
|
||||
Type: "Bind",
|
||||
DestinationPortal: src.DestinationPortal,
|
||||
PreparedStatement: src.PreparedStatement,
|
||||
ParameterFormatCodes: src.ParameterFormatCodes,
|
||||
Parameters: formattedParameters,
|
||||
ResultFormatCodes: src.ResultFormatCodes,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type BindComplete struct{}
|
||||
|
||||
func (*BindComplete) Backend() {}
|
||||
|
||||
func (dst *BindComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '2', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *BindComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "BindComplete",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Close struct {
|
||||
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||
Name string
|
||||
}
|
||||
|
||||
func (*Close) Frontend() {}
|
||||
|
||||
func (dst *Close) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Close"}
|
||||
}
|
||||
|
||||
dst.ObjectType = src[0]
|
||||
rp := 1
|
||||
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx != len(src[rp:])-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Close"}
|
||||
}
|
||||
|
||||
dst.Name = string(src[rp : len(src)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Close) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Close) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ObjectType string
|
||||
Name string
|
||||
}{
|
||||
Type: "Close",
|
||||
ObjectType: string(src.ObjectType),
|
||||
Name: src.Name,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CloseComplete struct{}
|
||||
|
||||
func (*CloseComplete) Backend() {}
|
||||
|
||||
func (dst *CloseComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '3', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *CloseComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "CloseComplete",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CommandComplete struct {
|
||||
CommandTag string
|
||||
}
|
||||
|
||||
func (*CommandComplete) Backend() {}
|
||||
|
||||
func (dst *CommandComplete) Decode(src []byte) error {
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "CommandComplete"}
|
||||
}
|
||||
|
||||
dst.CommandTag = string(src[:idx])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.CommandTag...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CommandComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
CommandTag string
|
||||
}{
|
||||
Type: "CommandComplete",
|
||||
CommandTag: src.CommandTag,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CopyBothResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyBothResponse) Backend() {}
|
||||
|
||||
func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'W')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CopyBothResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyBothResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CopyData struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (*CopyData) Backend() {}
|
||||
func (*CopyData) Frontend() {}
|
||||
|
||||
func (dst *CopyData) Decode(src []byte) error {
|
||||
dst.Data = src
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'd')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
dst = append(dst, src.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CopyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "CopyData",
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CopyInResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyInResponse) Backend() {}
|
||||
|
||||
func (dst *CopyInResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'G')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CopyInResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyInResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CopyOutResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyOutResponse) Backend() {}
|
||||
|
||||
func (dst *CopyOutResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'H')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CopyOutResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyOutResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type DataRow struct {
|
||||
Values [][]byte
|
||||
}
|
||||
|
||||
func (*DataRow) Backend() {}
|
||||
|
||||
func (dst *DataRow) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
rp := 0
|
||||
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
// If the capacity of the values slice is too small OR substantially too
|
||||
// large reallocate. This is too avoid one row with many columns from
|
||||
// permanently allocating memory.
|
||||
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
|
||||
newCap := 32
|
||||
if newCap < fieldCount {
|
||||
newCap = fieldCount
|
||||
}
|
||||
dst.Values = make([][]byte, fieldCount, newCap)
|
||||
} else {
|
||||
dst.Values = dst.Values[:fieldCount]
|
||||
}
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
dst.Values[i] = nil
|
||||
} else {
|
||||
if len(src[rp:]) < msgSize {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
dst.Values[i] = src[rp : rp+msgSize]
|
||||
rp += msgSize
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *DataRow) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
||||
for _, v := range src.Values {
|
||||
if v == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
dst = pgio.AppendInt32(dst, int32(len(v)))
|
||||
dst = append(dst, v...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *DataRow) MarshalJSON() ([]byte, error) {
|
||||
formattedValues := make([]map[string]string, len(src.Values))
|
||||
for i, v := range src.Values {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var hasNonPrintable bool
|
||||
for _, b := range v {
|
||||
if b < 32 {
|
||||
hasNonPrintable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasNonPrintable {
|
||||
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
|
||||
} else {
|
||||
formattedValues[i] = map[string]string{"text": string(v)}
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Values []map[string]string
|
||||
}{
|
||||
Type: "DataRow",
|
||||
Values: formattedValues,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Describe struct {
|
||||
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||
Name string
|
||||
}
|
||||
|
||||
func (*Describe) Frontend() {}
|
||||
|
||||
func (dst *Describe) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||
}
|
||||
|
||||
dst.ObjectType = src[0]
|
||||
rp := 1
|
||||
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx != len(src[rp:])-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||
}
|
||||
|
||||
dst.Name = string(src[rp : len(src)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Describe) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Describe) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ObjectType string
|
||||
Name string
|
||||
}{
|
||||
Type: "Describe",
|
||||
ObjectType: string(src.ObjectType),
|
||||
Name: src.Name,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type EmptyQueryResponse struct{}
|
||||
|
||||
func (*EmptyQueryResponse) Backend() {}
|
||||
|
||||
func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, 'I', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "EmptyQueryResponse",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
Severity string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
|
||||
UnknownFields map[byte]string
|
||||
}
|
||||
|
||||
func (*ErrorResponse) Backend() {}
|
||||
|
||||
func (dst *ErrorResponse) Decode(src []byte) error {
|
||||
*dst = ErrorResponse{}
|
||||
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
for {
|
||||
k, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if k == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
vb, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := string(vb[:len(vb)-1])
|
||||
|
||||
switch k {
|
||||
case 'S':
|
||||
dst.Severity = v
|
||||
case 'C':
|
||||
dst.Code = v
|
||||
case 'M':
|
||||
dst.Message = v
|
||||
case 'D':
|
||||
dst.Detail = v
|
||||
case 'H':
|
||||
dst.Hint = v
|
||||
case 'P':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.Position = int32(n)
|
||||
case 'p':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.InternalPosition = int32(n)
|
||||
case 'q':
|
||||
dst.InternalQuery = v
|
||||
case 'W':
|
||||
dst.Where = v
|
||||
case 's':
|
||||
dst.SchemaName = v
|
||||
case 't':
|
||||
dst.TableName = v
|
||||
case 'c':
|
||||
dst.ColumnName = v
|
||||
case 'd':
|
||||
dst.DataTypeName = v
|
||||
case 'n':
|
||||
dst.ConstraintName = v
|
||||
case 'F':
|
||||
dst.File = v
|
||||
case 'L':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.Line = int32(n)
|
||||
case 'R':
|
||||
dst.Routine = v
|
||||
|
||||
default:
|
||||
if dst.UnknownFields == nil {
|
||||
dst.UnknownFields = make(map[byte]string)
|
||||
}
|
||||
dst.UnknownFields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, src.marshalBinary('E')...)
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte(typeByte)
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
if src.Severity != "" {
|
||||
buf.WriteString(src.Severity)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Code != "" {
|
||||
buf.WriteString(src.Code)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Message != "" {
|
||||
buf.WriteString(src.Message)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Detail != "" {
|
||||
buf.WriteString(src.Detail)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Hint != "" {
|
||||
buf.WriteString(src.Hint)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Position != 0 {
|
||||
buf.WriteString(strconv.Itoa(int(src.Position)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.InternalPosition != 0 {
|
||||
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.InternalQuery != "" {
|
||||
buf.WriteString(src.InternalQuery)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Where != "" {
|
||||
buf.WriteString(src.Where)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.SchemaName != "" {
|
||||
buf.WriteString(src.SchemaName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.TableName != "" {
|
||||
buf.WriteString(src.TableName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.ColumnName != "" {
|
||||
buf.WriteString(src.ColumnName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.DataTypeName != "" {
|
||||
buf.WriteString(src.DataTypeName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.ConstraintName != "" {
|
||||
buf.WriteString(src.ConstraintName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.File != "" {
|
||||
buf.WriteString(src.File)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Line != 0 {
|
||||
buf.WriteString(strconv.Itoa(int(src.Line)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Routine != "" {
|
||||
buf.WriteString(src.Routine)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
|
||||
for k, v := range src.UnknownFields {
|
||||
buf.WriteByte(k)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(v)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
buf.WriteByte(0)
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Execute struct {
|
||||
Portal string
|
||||
MaxRows uint32
|
||||
}
|
||||
|
||||
func (*Execute) Frontend() {}
|
||||
|
||||
func (dst *Execute) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Portal = string(b[:len(b)-1])
|
||||
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Execute"}
|
||||
}
|
||||
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Execute) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'E')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Portal...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint32(dst, src.MaxRows)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Execute) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Portal string
|
||||
MaxRows uint32
|
||||
}{
|
||||
Type: "Execute",
|
||||
Portal: src.Portal,
|
||||
MaxRows: src.MaxRows,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Flush struct{}
|
||||
|
||||
func (*Flush) Frontend() {}
|
||||
|
||||
func (dst *Flush) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Flush) Encode(dst []byte) []byte {
|
||||
return append(dst, 'H', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *Flush) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "Flush",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/chunkreader"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Frontend struct {
|
||||
cr *chunkreader.ChunkReader
|
||||
w io.Writer
|
||||
|
||||
// Backend message flyweights
|
||||
authentication Authentication
|
||||
backendKeyData BackendKeyData
|
||||
bindComplete BindComplete
|
||||
closeComplete CloseComplete
|
||||
commandComplete CommandComplete
|
||||
copyBothResponse CopyBothResponse
|
||||
copyData CopyData
|
||||
copyInResponse CopyInResponse
|
||||
copyOutResponse CopyOutResponse
|
||||
dataRow DataRow
|
||||
emptyQueryResponse EmptyQueryResponse
|
||||
errorResponse ErrorResponse
|
||||
functionCallResponse FunctionCallResponse
|
||||
noData NoData
|
||||
noticeResponse NoticeResponse
|
||||
notificationResponse NotificationResponse
|
||||
parameterDescription ParameterDescription
|
||||
parameterStatus ParameterStatus
|
||||
parseComplete ParseComplete
|
||||
readyForQuery ReadyForQuery
|
||||
rowDescription RowDescription
|
||||
}
|
||||
|
||||
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
|
||||
cr := chunkreader.NewChunkReader(r)
|
||||
return &Frontend{cr: cr, w: w}, nil
|
||||
}
|
||||
|
||||
func (b *Frontend) Send(msg FrontendMessage) error {
|
||||
_, err := b.w.Write(msg.Encode(nil))
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Frontend) Receive() (BackendMessage, error) {
|
||||
header, err := b.cr.Next(5)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgType := header[0]
|
||||
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
|
||||
var msg BackendMessage
|
||||
switch msgType {
|
||||
case '1':
|
||||
msg = &b.parseComplete
|
||||
case '2':
|
||||
msg = &b.bindComplete
|
||||
case '3':
|
||||
msg = &b.closeComplete
|
||||
case 'A':
|
||||
msg = &b.notificationResponse
|
||||
case 'C':
|
||||
msg = &b.commandComplete
|
||||
case 'd':
|
||||
msg = &b.copyData
|
||||
case 'D':
|
||||
msg = &b.dataRow
|
||||
case 'E':
|
||||
msg = &b.errorResponse
|
||||
case 'G':
|
||||
msg = &b.copyInResponse
|
||||
case 'H':
|
||||
msg = &b.copyOutResponse
|
||||
case 'I':
|
||||
msg = &b.emptyQueryResponse
|
||||
case 'K':
|
||||
msg = &b.backendKeyData
|
||||
case 'n':
|
||||
msg = &b.noData
|
||||
case 'N':
|
||||
msg = &b.noticeResponse
|
||||
case 'R':
|
||||
msg = &b.authentication
|
||||
case 'S':
|
||||
msg = &b.parameterStatus
|
||||
case 't':
|
||||
msg = &b.parameterDescription
|
||||
case 'T':
|
||||
msg = &b.rowDescription
|
||||
case 'V':
|
||||
msg = &b.functionCallResponse
|
||||
case 'W':
|
||||
msg = &b.copyBothResponse
|
||||
case 'Z':
|
||||
msg = &b.readyForQuery
|
||||
default:
|
||||
return nil, errors.Errorf("unknown message type: %c", msgType)
|
||||
}
|
||||
|
||||
msgBody, err := b.cr.Next(bodyLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = msg.Decode(msgBody)
|
||||
return msg, err
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type FunctionCallResponse struct {
|
||||
Result []byte
|
||||
}
|
||||
|
||||
func (*FunctionCallResponse) Backend() {}
|
||||
|
||||
func (dst *FunctionCallResponse) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
rp := 0
|
||||
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
if resultSize == -1 {
|
||||
dst.Result = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src[rp:]) != resultSize {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
|
||||
dst.Result = src[rp:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'V')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
if src.Result == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
} else {
|
||||
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
|
||||
dst = append(dst, src.Result...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) {
|
||||
var formattedValue map[string]string
|
||||
var hasNonPrintable bool
|
||||
for _, b := range src.Result {
|
||||
if b < 32 {
|
||||
hasNonPrintable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasNonPrintable {
|
||||
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
|
||||
} else {
|
||||
formattedValue = map[string]string{"text": string(src.Result)}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Result map[string]string
|
||||
}{
|
||||
Type: "FunctionCallResponse",
|
||||
Result: formattedValue,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type NoData struct{}
|
||||
|
||||
func (*NoData) Backend() {}
|
||||
|
||||
func (dst *NoData) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *NoData) Encode(dst []byte) []byte {
|
||||
return append(dst, 'n', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *NoData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "NoData",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package pgproto3
|
||||
|
||||
type NoticeResponse ErrorResponse
|
||||
|
||||
func (*NoticeResponse) Backend() {}
|
||||
|
||||
func (dst *NoticeResponse) Decode(src []byte) error {
|
||||
return (*ErrorResponse)(dst).Decode(src)
|
||||
}
|
||||
|
||||
func (src *NoticeResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type NotificationResponse struct {
|
||||
PID uint32
|
||||
Channel string
|
||||
Payload string
|
||||
}
|
||||
|
||||
func (*NotificationResponse) Backend() {}
|
||||
|
||||
func (dst *NotificationResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
pid := binary.BigEndian.Uint32(buf.Next(4))
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
channel := string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := string(b[:len(b)-1])
|
||||
|
||||
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *NotificationResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'A')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Channel...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Payload...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *NotificationResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
PID uint32
|
||||
Channel string
|
||||
Payload string
|
||||
}{
|
||||
Type: "NotificationResponse",
|
||||
PID: src.PID,
|
||||
Channel: src.Channel,
|
||||
Payload: src.Payload,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type ParameterDescription struct {
|
||||
ParameterOIDs []uint32
|
||||
}
|
||||
|
||||
func (*ParameterDescription) Backend() {}
|
||||
|
||||
func (dst *ParameterDescription) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
|
||||
}
|
||||
|
||||
// Reported parameter count will be incorrect when number of args is greater than uint16
|
||||
buf.Next(2)
|
||||
// Instead infer parameter count by remaining size of message
|
||||
parameterCount := buf.Len() / 4
|
||||
|
||||
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ParameterDescription) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 't')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *ParameterDescription) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ParameterOIDs []uint32
|
||||
}{
|
||||
Type: "ParameterDescription",
|
||||
ParameterOIDs: src.ParameterOIDs,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type ParameterStatus struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (*ParameterStatus) Backend() {}
|
||||
|
||||
func (dst *ParameterStatus) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value := string(b[:len(b)-1])
|
||||
|
||||
*dst = ParameterStatus{Name: name, Value: value}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ParameterStatus) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'S')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Value...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (ps *ParameterStatus) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Name string
|
||||
Value string
|
||||
}{
|
||||
Type: "ParameterStatus",
|
||||
Name: ps.Name,
|
||||
Value: ps.Value,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Parse struct {
|
||||
Name string
|
||||
Query string
|
||||
ParameterOIDs []uint32
|
||||
}
|
||||
|
||||
func (*Parse) Frontend() {}
|
||||
|
||||
func (dst *Parse) Decode(src []byte) error {
|
||||
*dst = Parse{}
|
||||
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Name = string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Query = string(b[:len(b)-1])
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||
}
|
||||
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
for i := 0; i < parameterOIDCount; i++ {
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||
}
|
||||
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Parse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'P')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Query...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Parse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Name string
|
||||
Query string
|
||||
ParameterOIDs []uint32
|
||||
}{
|
||||
Type: "Parse",
|
||||
Name: src.Name,
|
||||
Query: src.Query,
|
||||
ParameterOIDs: src.ParameterOIDs,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ParseComplete struct{}
|
||||
|
||||
func (*ParseComplete) Backend() {}
|
||||
|
||||
func (dst *ParseComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ParseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '1', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *ParseComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "ParseComplete",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type PasswordMessage struct {
|
||||
Password string
|
||||
}
|
||||
|
||||
func (*PasswordMessage) Frontend() {}
|
||||
|
||||
func (dst *PasswordMessage) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Password = string(b[:len(b)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *PasswordMessage) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
|
||||
|
||||
dst = append(dst, src.Password...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *PasswordMessage) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Password string
|
||||
}{
|
||||
Type: "PasswordMessage",
|
||||
Password: src.Password,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package pgproto3
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Message is the interface implemented by an object that can decode and encode
|
||||
// a particular PostgreSQL message.
|
||||
type Message interface {
|
||||
// Decode is allowed and expected to retain a reference to data after
|
||||
// returning (unlike encoding.BinaryUnmarshaler).
|
||||
Decode(data []byte) error
|
||||
|
||||
// Encode appends itself to dst and returns the new buffer.
|
||||
Encode(dst []byte) []byte
|
||||
}
|
||||
|
||||
type FrontendMessage interface {
|
||||
Message
|
||||
Frontend() // no-op method to distinguish frontend from backend methods
|
||||
}
|
||||
|
||||
type BackendMessage interface {
|
||||
Message
|
||||
Backend() // no-op method to distinguish frontend from backend methods
|
||||
}
|
||||
|
||||
type invalidMessageLenErr struct {
|
||||
messageType string
|
||||
expectedLen int
|
||||
actualLen int
|
||||
}
|
||||
|
||||
func (e *invalidMessageLenErr) Error() string {
|
||||
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
|
||||
}
|
||||
|
||||
type invalidMessageFormatErr struct {
|
||||
messageType string
|
||||
}
|
||||
|
||||
func (e *invalidMessageFormatErr) Error() string {
|
||||
return fmt.Sprintf("%s body is invalid", e.messageType)
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Query struct {
|
||||
String string
|
||||
}
|
||||
|
||||
func (*Query) Frontend() {}
|
||||
|
||||
func (dst *Query) Decode(src []byte) error {
|
||||
i := bytes.IndexByte(src, 0)
|
||||
if i != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Query"}
|
||||
}
|
||||
|
||||
dst.String = string(src[:i])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Query) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'Q')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
|
||||
|
||||
dst = append(dst, src.String...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Query) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
String string
|
||||
}{
|
||||
Type: "Query",
|
||||
String: src.String,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ReadyForQuery struct {
|
||||
TxStatus byte
|
||||
}
|
||||
|
||||
func (*ReadyForQuery) Backend() {}
|
||||
|
||||
func (dst *ReadyForQuery) Decode(src []byte) error {
|
||||
if len(src) != 1 {
|
||||
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.TxStatus = src[0]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ReadyForQuery) Encode(dst []byte) []byte {
|
||||
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
|
||||
}
|
||||
|
||||
func (src *ReadyForQuery) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
TxStatus string
|
||||
}{
|
||||
Type: "ReadyForQuery",
|
||||
TxStatus: string(src.TxStatus),
|
||||
})
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
const (
|
||||
TextFormat = 0
|
||||
BinaryFormat = 1
|
||||
)
|
||||
|
||||
type FieldDescription struct {
|
||||
Name string
|
||||
TableOID uint32
|
||||
TableAttributeNumber uint16
|
||||
DataTypeOID uint32
|
||||
DataTypeSize int16
|
||||
TypeModifier uint32
|
||||
Format int16
|
||||
}
|
||||
|
||||
type RowDescription struct {
|
||||
Fields []FieldDescription
|
||||
}
|
||||
|
||||
func (*RowDescription) Backend() {}
|
||||
|
||||
func (dst *RowDescription) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||
}
|
||||
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
*dst = RowDescription{Fields: make([]FieldDescription, fieldCount)}
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
var fd FieldDescription
|
||||
bName, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fd.Name = string(bName[:len(bName)-1])
|
||||
|
||||
// Since buf.Next() doesn't return an error if we hit the end of the buffer
|
||||
// check Len ahead of time
|
||||
if buf.Len() < 18 {
|
||||
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||
}
|
||||
|
||||
fd.TableOID = binary.BigEndian.Uint32(buf.Next(4))
|
||||
fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2))
|
||||
fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4))
|
||||
fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4))
|
||||
fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
dst.Fields[i] = fd
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *RowDescription) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'T')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
|
||||
for _, fd := range src.Fields {
|
||||
dst = append(dst, fd.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint32(dst, fd.TableOID)
|
||||
dst = pgio.AppendUint16(dst, fd.TableAttributeNumber)
|
||||
dst = pgio.AppendUint32(dst, fd.DataTypeOID)
|
||||
dst = pgio.AppendInt16(dst, fd.DataTypeSize)
|
||||
dst = pgio.AppendUint32(dst, fd.TypeModifier)
|
||||
dst = pgio.AppendInt16(dst, fd.Format)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *RowDescription) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Fields []FieldDescription
|
||||
}{
|
||||
Type: "RowDescription",
|
||||
Fields: src.Fields,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolVersionNumber = 196608 // 3.0
|
||||
sslRequestNumber = 80877103
|
||||
)
|
||||
|
||||
type StartupMessage struct {
|
||||
ProtocolVersion uint32
|
||||
Parameters map[string]string
|
||||
}
|
||||
|
||||
func (*StartupMessage) Frontend() {}
|
||||
|
||||
func (dst *StartupMessage) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.Errorf("startup message too short")
|
||||
}
|
||||
|
||||
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
||||
rp := 4
|
||||
|
||||
if dst.ProtocolVersion == sslRequestNumber {
|
||||
return errors.Errorf("can't handle ssl connection request")
|
||||
}
|
||||
|
||||
if dst.ProtocolVersion != ProtocolVersionNumber {
|
||||
return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
||||
}
|
||||
|
||||
dst.Parameters = make(map[string]string)
|
||||
for {
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||
}
|
||||
key := string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
idx = bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||
}
|
||||
value := string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
dst.Parameters[key] = value
|
||||
|
||||
if len(src[rp:]) == 1 {
|
||||
if src[rp] != 0 {
|
||||
return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *StartupMessage) Encode(dst []byte) []byte {
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint32(dst, src.ProtocolVersion)
|
||||
for k, v := range src.Parameters {
|
||||
dst = append(dst, k...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, v...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *StartupMessage) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProtocolVersion uint32
|
||||
Parameters map[string]string
|
||||
}{
|
||||
Type: "StartupMessage",
|
||||
ProtocolVersion: src.ProtocolVersion,
|
||||
Parameters: src.Parameters,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Sync struct{}
|
||||
|
||||
func (*Sync) Frontend() {}
|
||||
|
||||
func (dst *Sync) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Sync) Encode(dst []byte) []byte {
|
||||
return append(dst, 'S', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *Sync) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "Sync",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Terminate struct{}
|
||||
|
||||
func (*Terminate) Frontend() {}
|
||||
|
||||
func (dst *Terminate) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Terminate) Encode(dst []byte) []byte {
|
||||
return append(dst, 'X', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *Terminate) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "Terminate",
|
||||
})
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem
|
||||
// might look like this:
|
||||
//
|
||||
// postgres=arwdDxt/postgres
|
||||
//
|
||||
// Note, however, that because the user/role name part of an aclitem is
|
||||
// an identifier, it follows all the usual formatting rules for SQL
|
||||
// identifiers: if it contains spaces and other special characters,
|
||||
// it should appear in double-quotes:
|
||||
//
|
||||
// postgres=arwdDxt/"role with spaces"
|
||||
//
|
||||
type ACLItem struct {
|
||||
String string
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *ACLItem) Set(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
case string:
|
||||
*dst = ACLItem{String: value, Status: Present}
|
||||
case *string:
|
||||
if value == nil {
|
||||
*dst = ACLItem{Status: Null}
|
||||
} else {
|
||||
*dst = ACLItem{String: *value, Status: Present}
|
||||
}
|
||||
default:
|
||||
if originalSrc, ok := underlyingStringType(src); ok {
|
||||
return dst.Set(originalSrc)
|
||||
}
|
||||
return errors.Errorf("cannot convert %v to ACLItem", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *ACLItem) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst.String
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *ACLItem) AssignTo(dst interface{}) error {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
switch v := dst.(type) {
|
||||
case *string:
|
||||
*v = src.String
|
||||
return nil
|
||||
default:
|
||||
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
}
|
||||
case Null:
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot decode %v into %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = ACLItem{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
*dst = ACLItem{String: string(src), Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
return append(buf, src.String...), nil
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *ACLItem) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = ACLItem{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
srcCopy := make([]byte, len(src))
|
||||
copy(srcCopy, src)
|
||||
return dst.DecodeText(nil, srcCopy)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *ACLItem) Value() (driver.Value, error) {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
return src.String, nil
|
||||
case Null:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, errUndefined
|
||||
}
|
||||
}
|
|
@ -0,0 +1,206 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type ACLItemArray struct {
|
||||
Elements []ACLItem
|
||||
Dimensions []ArrayDimension
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *ACLItemArray) Set(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
|
||||
case []string:
|
||||
if value == nil {
|
||||
*dst = ACLItemArray{Status: Null}
|
||||
} else if len(value) == 0 {
|
||||
*dst = ACLItemArray{Status: Present}
|
||||
} else {
|
||||
elements := make([]ACLItem, len(value))
|
||||
for i := range value {
|
||||
if err := elements[i].Set(value[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*dst = ACLItemArray{
|
||||
Elements: elements,
|
||||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||
Status: Present,
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||
return dst.Set(originalSrc)
|
||||
}
|
||||
return errors.Errorf("cannot convert %v to ACLItem", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *ACLItemArray) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *ACLItemArray) AssignTo(dst interface{}) error {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
switch v := dst.(type) {
|
||||
|
||||
case *[]string:
|
||||
*v = make([]string, len(src.Elements))
|
||||
for i := range src.Elements {
|
||||
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
}
|
||||
case Null:
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot decode %v into %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = ACLItemArray{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
uta, err := ParseUntypedTextArray(string(src))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var elements []ACLItem
|
||||
|
||||
if len(uta.Elements) > 0 {
|
||||
elements = make([]ACLItem, len(uta.Elements))
|
||||
|
||||
for i, s := range uta.Elements {
|
||||
var elem ACLItem
|
||||
var elemSrc []byte
|
||||
if s != "NULL" {
|
||||
elemSrc = []byte(s)
|
||||
}
|
||||
err = elem.DecodeText(ci, elemSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elements[i] = elem
|
||||
}
|
||||
}
|
||||
|
||||
*dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
if len(src.Dimensions) == 0 {
|
||||
return append(buf, '{', '}'), nil
|
||||
}
|
||||
|
||||
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
|
||||
|
||||
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||
// or '}'.
|
||||
dimElemCounts := make([]int, len(src.Dimensions))
|
||||
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
|
||||
for i := len(src.Dimensions) - 2; i > -1; i-- {
|
||||
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
|
||||
}
|
||||
|
||||
inElemBuf := make([]byte, 0, 32)
|
||||
for i, elem := range src.Elements {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if i%dec == 0 {
|
||||
buf = append(buf, '{')
|
||||
}
|
||||
}
|
||||
|
||||
elemBuf, err := elem.EncodeText(ci, inElemBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if elemBuf == nil {
|
||||
buf = append(buf, `NULL`...)
|
||||
} else {
|
||||
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if (i+1)%dec == 0 {
|
||||
buf = append(buf, '}')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *ACLItemArray) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
return dst.DecodeText(nil, nil)
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
srcCopy := make([]byte, len(src))
|
||||
copy(srcCopy, src)
|
||||
return dst.DecodeText(nil, srcCopy)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *ACLItemArray) Value() (driver.Value, error) {
|
||||
buf, err := src.EncodeText(nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return string(buf), nil
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
)
|
||||
|
||||
func TestACLItemArrayTranscode(t *testing.T) {
|
||||
testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{
|
||||
&pgtype.ACLItemArray{
|
||||
Elements: nil,
|
||||
Dimensions: nil,
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.ACLItemArray{
|
||||
Elements: []pgtype.ACLItem{
|
||||
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
|
||||
pgtype.ACLItem{Status: pgtype.Null},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.ACLItemArray{Status: pgtype.Null},
|
||||
&pgtype.ACLItemArray{
|
||||
Elements: []pgtype.ACLItem{
|
||||
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
|
||||
pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present},
|
||||
pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present},
|
||||
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
|
||||
pgtype.ACLItem{Status: pgtype.Null},
|
||||
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.ACLItemArray{
|
||||
Elements: []pgtype.ACLItem{
|
||||
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
|
||||
pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present},
|
||||
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
|
||||
pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{
|
||||
{Length: 2, LowerBound: 4},
|
||||
{Length: 2, LowerBound: 2},
|
||||
},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestACLItemArraySet(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.ACLItemArray
|
||||
}{
|
||||
{
|
||||
source: []string{"=r/postgres"},
|
||||
result: pgtype.ACLItemArray{
|
||||
Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
source: (([]string)(nil)),
|
||||
result: pgtype.ACLItemArray{Status: pgtype.Null},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var r pgtype.ACLItemArray
|
||||
err := r.Set(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(r, tt.result) {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestACLItemArrayAssignTo(t *testing.T) {
|
||||
var stringSlice []string
|
||||
type _stringSlice []string
|
||||
var namedStringSlice _stringSlice
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.ACLItemArray
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
src: pgtype.ACLItemArray{
|
||||
Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &stringSlice,
|
||||
expected: []string{"=r/postgres"},
|
||||
},
|
||||
{
|
||||
src: pgtype.ACLItemArray{
|
||||
Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &namedStringSlice,
|
||||
expected: _stringSlice{"=r/postgres"},
|
||||
},
|
||||
{
|
||||
src: pgtype.ACLItemArray{Status: pgtype.Null},
|
||||
dst: &stringSlice,
|
||||
expected: (([]string)(nil)),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
src pgtype.ACLItemArray
|
||||
dst interface{}
|
||||
}{
|
||||
{
|
||||
src: pgtype.ACLItemArray{
|
||||
Elements: []pgtype.ACLItem{{Status: pgtype.Null}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &stringSlice,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err == nil {
|
||||
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
)
|
||||
|
||||
func TestACLItemTranscode(t *testing.T) {
|
||||
testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{
|
||||
&pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present},
|
||||
&pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present},
|
||||
&pgtype.ACLItem{Status: pgtype.Null},
|
||||
})
|
||||
}
|
||||
|
||||
func TestACLItemSet(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.ACLItem
|
||||
}{
|
||||
{source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}},
|
||||
{source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var d pgtype.ACLItem
|
||||
err := d.Set(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if d != tt.result {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestACLItemAssignTo(t *testing.T) {
|
||||
var s string
|
||||
var ps *string
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.ACLItem
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"},
|
||||
{src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
pointerAllocTests := []struct {
|
||||
src pgtype.ACLItem
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"},
|
||||
}
|
||||
|
||||
for i, tt := range pointerAllocTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
src pgtype.ACLItem
|
||||
dst interface{}
|
||||
}{
|
||||
{src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err == nil {
|
||||
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,352 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Information on the internals of PostgreSQL arrays can be found in
|
||||
// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of
|
||||
// particular interest is the array_send function.
|
||||
|
||||
type ArrayHeader struct {
|
||||
ContainsNull bool
|
||||
ElementOID int32
|
||||
Dimensions []ArrayDimension
|
||||
}
|
||||
|
||||
type ArrayDimension struct {
|
||||
Length int32
|
||||
LowerBound int32
|
||||
}
|
||||
|
||||
func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) {
|
||||
if len(src) < 12 {
|
||||
return 0, errors.Errorf("array header too short: %d", len(src))
|
||||
}
|
||||
|
||||
rp := 0
|
||||
|
||||
numDims := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1
|
||||
rp += 4
|
||||
|
||||
dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
if numDims > 0 {
|
||||
dst.Dimensions = make([]ArrayDimension, numDims)
|
||||
}
|
||||
if len(src) < 12+numDims*8 {
|
||||
return 0, errors.Errorf("array header too short for %d dimensions: %d", numDims, len(src))
|
||||
}
|
||||
for i := range dst.Dimensions {
|
||||
dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
}
|
||||
|
||||
return rp, nil
|
||||
}
|
||||
|
||||
func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte {
|
||||
buf = pgio.AppendInt32(buf, int32(len(src.Dimensions)))
|
||||
|
||||
var containsNull int32
|
||||
if src.ContainsNull {
|
||||
containsNull = 1
|
||||
}
|
||||
buf = pgio.AppendInt32(buf, containsNull)
|
||||
|
||||
buf = pgio.AppendInt32(buf, src.ElementOID)
|
||||
|
||||
for i := range src.Dimensions {
|
||||
buf = pgio.AppendInt32(buf, src.Dimensions[i].Length)
|
||||
buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound)
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
type UntypedTextArray struct {
|
||||
Elements []string
|
||||
Dimensions []ArrayDimension
|
||||
}
|
||||
|
||||
func ParseUntypedTextArray(src string) (*UntypedTextArray, error) {
|
||||
dst := &UntypedTextArray{}
|
||||
|
||||
buf := bytes.NewBufferString(src)
|
||||
|
||||
skipWhitespace(buf)
|
||||
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array: %v", err)
|
||||
}
|
||||
|
||||
var explicitDimensions []ArrayDimension
|
||||
|
||||
// Array has explicit dimensions
|
||||
if r == '[' {
|
||||
buf.UnreadRune()
|
||||
|
||||
for {
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array: %v", err)
|
||||
}
|
||||
|
||||
if r == '=' {
|
||||
break
|
||||
} else if r != '[' {
|
||||
return nil, errors.Errorf("invalid array, expected '[' or '=' got %v", r)
|
||||
}
|
||||
|
||||
lower, err := arrayParseInteger(buf)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array: %v", err)
|
||||
}
|
||||
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array: %v", err)
|
||||
}
|
||||
|
||||
if r != ':' {
|
||||
return nil, errors.Errorf("invalid array, expected ':' got %v", r)
|
||||
}
|
||||
|
||||
upper, err := arrayParseInteger(buf)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array: %v", err)
|
||||
}
|
||||
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array: %v", err)
|
||||
}
|
||||
|
||||
if r != ']' {
|
||||
return nil, errors.Errorf("invalid array, expected ']' got %v", r)
|
||||
}
|
||||
|
||||
explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1})
|
||||
}
|
||||
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if r != '{' {
|
||||
return nil, errors.Errorf("invalid array, expected '{': %v", err)
|
||||
}
|
||||
|
||||
implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}}
|
||||
|
||||
// Consume all initial opening brackets. This provides number of dimensions.
|
||||
for {
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array: %v", err)
|
||||
}
|
||||
|
||||
if r == '{' {
|
||||
implicitDimensions[len(implicitDimensions)-1].Length = 1
|
||||
implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1})
|
||||
} else {
|
||||
buf.UnreadRune()
|
||||
break
|
||||
}
|
||||
}
|
||||
currentDim := len(implicitDimensions) - 1
|
||||
counterDim := currentDim
|
||||
|
||||
for {
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array: %v", err)
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '{':
|
||||
if currentDim == counterDim {
|
||||
implicitDimensions[currentDim].Length++
|
||||
}
|
||||
currentDim++
|
||||
case ',':
|
||||
case '}':
|
||||
currentDim--
|
||||
if currentDim < counterDim {
|
||||
counterDim = currentDim
|
||||
}
|
||||
default:
|
||||
buf.UnreadRune()
|
||||
value, err := arrayParseValue(buf)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid array value: %v", err)
|
||||
}
|
||||
if currentDim == counterDim {
|
||||
implicitDimensions[currentDim].Length++
|
||||
}
|
||||
dst.Elements = append(dst.Elements, value)
|
||||
}
|
||||
|
||||
if currentDim < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
skipWhitespace(buf)
|
||||
|
||||
if buf.Len() > 0 {
|
||||
return nil, errors.Errorf("unexpected trailing data: %v", buf.String())
|
||||
}
|
||||
|
||||
if len(dst.Elements) == 0 {
|
||||
dst.Dimensions = nil
|
||||
} else if len(explicitDimensions) > 0 {
|
||||
dst.Dimensions = explicitDimensions
|
||||
} else {
|
||||
dst.Dimensions = implicitDimensions
|
||||
}
|
||||
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func skipWhitespace(buf *bytes.Buffer) {
|
||||
var r rune
|
||||
var err error
|
||||
for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() {
|
||||
}
|
||||
|
||||
if err != io.EOF {
|
||||
buf.UnreadRune()
|
||||
}
|
||||
}
|
||||
|
||||
func arrayParseValue(buf *bytes.Buffer) (string, error) {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if r == '"' {
|
||||
return arrayParseQuotedValue(buf)
|
||||
}
|
||||
buf.UnreadRune()
|
||||
|
||||
s := &bytes.Buffer{}
|
||||
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch r {
|
||||
case ',', '}':
|
||||
buf.UnreadRune()
|
||||
return s.String(), nil
|
||||
}
|
||||
|
||||
s.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) {
|
||||
s := &bytes.Buffer{}
|
||||
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
case '"':
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf.UnreadRune()
|
||||
return s.String(), nil
|
||||
}
|
||||
s.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
func arrayParseInteger(buf *bytes.Buffer) (int32, error) {
|
||||
s := &bytes.Buffer{}
|
||||
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if '0' <= r && r <= '9' {
|
||||
s.WriteRune(r)
|
||||
} else {
|
||||
buf.UnreadRune()
|
||||
n, err := strconv.ParseInt(s.String(), 10, 32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int32(n), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte {
|
||||
var customDimensions bool
|
||||
for _, dim := range dimensions {
|
||||
if dim.LowerBound != 1 {
|
||||
customDimensions = true
|
||||
}
|
||||
}
|
||||
|
||||
if !customDimensions {
|
||||
return buf
|
||||
}
|
||||
|
||||
for _, dim := range dimensions {
|
||||
buf = append(buf, '[')
|
||||
buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...)
|
||||
buf = append(buf, ':')
|
||||
buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...)
|
||||
buf = append(buf, ']')
|
||||
}
|
||||
|
||||
return append(buf, '=')
|
||||
}
|
||||
|
||||
var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
|
||||
|
||||
func quoteArrayElement(src string) string {
|
||||
return `"` + quoteArrayReplacer.Replace(src) + `"`
|
||||
}
|
||||
|
||||
func QuoteArrayElementIfNeeded(src string) string {
|
||||
if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `{},"\`) {
|
||||
return quoteArrayElement(src)
|
||||
}
|
||||
return src
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func TestParseUntypedTextArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
source string
|
||||
result pgtype.UntypedTextArray
|
||||
}{
|
||||
{
|
||||
source: "{}",
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: nil,
|
||||
Dimensions: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
source: "{1}",
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: []string{"1"},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
source: "{a,b}",
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: []string{"a", "b"},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
source: `{"NULL"}`,
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: []string{"NULL"},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
source: `{""}`,
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: []string{""},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
source: `{"He said, \"Hello.\""}`,
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: []string{`He said, "Hello."`},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
source: "{{a,b},{c,d},{e,f}}",
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: []string{"a", "b", "c", "d", "e", "f"},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}",
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"},
|
||||
Dimensions: []pgtype.ArrayDimension{
|
||||
{Length: 2, LowerBound: 1},
|
||||
{Length: 3, LowerBound: 1},
|
||||
{Length: 2, LowerBound: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
source: "[4:4]={1}",
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: []string{"1"},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}},
|
||||
},
|
||||
},
|
||||
{
|
||||
source: "[4:5][2:3]={{a,b},{c,d}}",
|
||||
result: pgtype.UntypedTextArray{
|
||||
Elements: []string{"a", "b", "c", "d"},
|
||||
Dimensions: []pgtype.ArrayDimension{
|
||||
{Length: 2, LowerBound: 4},
|
||||
{Length: 2, LowerBound: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, err := pgtype.ParseUntypedTextArray(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(*r, tt.result) {
|
||||
t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Bool struct {
|
||||
Bool bool
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *Bool) Set(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
case bool:
|
||||
*dst = Bool{Bool: value, Status: Present}
|
||||
case string:
|
||||
bb, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dst = Bool{Bool: bb, Status: Present}
|
||||
default:
|
||||
if originalSrc, ok := underlyingBoolType(src); ok {
|
||||
return dst.Set(originalSrc)
|
||||
}
|
||||
return errors.Errorf("cannot convert %v to Bool", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Bool) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst.Bool
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Bool) AssignTo(dst interface{}) error {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
switch v := dst.(type) {
|
||||
case *bool:
|
||||
*v = src.Bool
|
||||
return nil
|
||||
default:
|
||||
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
}
|
||||
case Null:
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot decode %v into %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Bool{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src) != 1 {
|
||||
return errors.Errorf("invalid length for bool: %v", len(src))
|
||||
}
|
||||
|
||||
*dst = Bool{Bool: src[0] == 't', Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Bool{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src) != 1 {
|
||||
return errors.Errorf("invalid length for bool: %v", len(src))
|
||||
}
|
||||
|
||||
*dst = Bool{Bool: src[0] == 1, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
if src.Bool {
|
||||
buf = append(buf, 't')
|
||||
} else {
|
||||
buf = append(buf, 'f')
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (src *Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
if src.Bool {
|
||||
buf = append(buf, 1)
|
||||
} else {
|
||||
buf = append(buf, 0)
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *Bool) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = Bool{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case bool:
|
||||
*dst = Bool{Bool: src, Status: Present}
|
||||
return nil
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
srcCopy := make([]byte, len(src))
|
||||
copy(srcCopy, src)
|
||||
return dst.DecodeText(nil, srcCopy)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *Bool) Value() (driver.Value, error) {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
return src.Bool, nil
|
||||
case Null:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, errUndefined
|
||||
}
|
||||
}
|
|
@ -0,0 +1,294 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type BoolArray struct {
|
||||
Elements []Bool
|
||||
Dimensions []ArrayDimension
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *BoolArray) Set(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
|
||||
case []bool:
|
||||
if value == nil {
|
||||
*dst = BoolArray{Status: Null}
|
||||
} else if len(value) == 0 {
|
||||
*dst = BoolArray{Status: Present}
|
||||
} else {
|
||||
elements := make([]Bool, len(value))
|
||||
for i := range value {
|
||||
if err := elements[i].Set(value[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*dst = BoolArray{
|
||||
Elements: elements,
|
||||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||
Status: Present,
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||
return dst.Set(originalSrc)
|
||||
}
|
||||
return errors.Errorf("cannot convert %v to Bool", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *BoolArray) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *BoolArray) AssignTo(dst interface{}) error {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
switch v := dst.(type) {
|
||||
|
||||
case *[]bool:
|
||||
*v = make([]bool, len(src.Elements))
|
||||
for i := range src.Elements {
|
||||
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
}
|
||||
case Null:
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot decode %v into %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = BoolArray{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
uta, err := ParseUntypedTextArray(string(src))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var elements []Bool
|
||||
|
||||
if len(uta.Elements) > 0 {
|
||||
elements = make([]Bool, len(uta.Elements))
|
||||
|
||||
for i, s := range uta.Elements {
|
||||
var elem Bool
|
||||
var elemSrc []byte
|
||||
if s != "NULL" {
|
||||
elemSrc = []byte(s)
|
||||
}
|
||||
err = elem.DecodeText(ci, elemSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elements[i] = elem
|
||||
}
|
||||
}
|
||||
|
||||
*dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = BoolArray{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
var arrayHeader ArrayHeader
|
||||
rp, err := arrayHeader.DecodeBinary(ci, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(arrayHeader.Dimensions) == 0 {
|
||||
*dst = BoolArray{Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
elementCount := arrayHeader.Dimensions[0].Length
|
||||
for _, d := range arrayHeader.Dimensions[1:] {
|
||||
elementCount *= d.Length
|
||||
}
|
||||
|
||||
elements := make([]Bool, elementCount)
|
||||
|
||||
for i := range elements {
|
||||
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
var elemSrc []byte
|
||||
if elemLen >= 0 {
|
||||
elemSrc = src[rp : rp+elemLen]
|
||||
rp += elemLen
|
||||
}
|
||||
err = elements[i].DecodeBinary(ci, elemSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
*dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
if len(src.Dimensions) == 0 {
|
||||
return append(buf, '{', '}'), nil
|
||||
}
|
||||
|
||||
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
|
||||
|
||||
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||
// or '}'.
|
||||
dimElemCounts := make([]int, len(src.Dimensions))
|
||||
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
|
||||
for i := len(src.Dimensions) - 2; i > -1; i-- {
|
||||
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
|
||||
}
|
||||
|
||||
inElemBuf := make([]byte, 0, 32)
|
||||
for i, elem := range src.Elements {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if i%dec == 0 {
|
||||
buf = append(buf, '{')
|
||||
}
|
||||
}
|
||||
|
||||
elemBuf, err := elem.EncodeText(ci, inElemBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if elemBuf == nil {
|
||||
buf = append(buf, `NULL`...)
|
||||
} else {
|
||||
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if (i+1)%dec == 0 {
|
||||
buf = append(buf, '}')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
arrayHeader := ArrayHeader{
|
||||
Dimensions: src.Dimensions,
|
||||
}
|
||||
|
||||
if dt, ok := ci.DataTypeForName("bool"); ok {
|
||||
arrayHeader.ElementOID = int32(dt.OID)
|
||||
} else {
|
||||
return nil, errors.Errorf("unable to find oid for type name %v", "bool")
|
||||
}
|
||||
|
||||
for i := range src.Elements {
|
||||
if src.Elements[i].Status == Null {
|
||||
arrayHeader.ContainsNull = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
buf = arrayHeader.EncodeBinary(ci, buf)
|
||||
|
||||
for i := range src.Elements {
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if elemBuf != nil {
|
||||
buf = elemBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *BoolArray) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
return dst.DecodeText(nil, nil)
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
srcCopy := make([]byte, len(src))
|
||||
copy(srcCopy, src)
|
||||
return dst.DecodeText(nil, srcCopy)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *BoolArray) Value() (driver.Value, error) {
|
||||
buf, err := src.EncodeText(nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return string(buf), nil
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
)
|
||||
|
||||
func TestBoolArrayTranscode(t *testing.T) {
|
||||
testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{
|
||||
&pgtype.BoolArray{
|
||||
Elements: nil,
|
||||
Dimensions: nil,
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.BoolArray{
|
||||
Elements: []pgtype.Bool{
|
||||
pgtype.Bool{Bool: true, Status: pgtype.Present},
|
||||
pgtype.Bool{Status: pgtype.Null},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.BoolArray{Status: pgtype.Null},
|
||||
&pgtype.BoolArray{
|
||||
Elements: []pgtype.Bool{
|
||||
pgtype.Bool{Bool: true, Status: pgtype.Present},
|
||||
pgtype.Bool{Bool: true, Status: pgtype.Present},
|
||||
pgtype.Bool{Bool: false, Status: pgtype.Present},
|
||||
pgtype.Bool{Bool: true, Status: pgtype.Present},
|
||||
pgtype.Bool{Status: pgtype.Null},
|
||||
pgtype.Bool{Bool: false, Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.BoolArray{
|
||||
Elements: []pgtype.Bool{
|
||||
pgtype.Bool{Bool: true, Status: pgtype.Present},
|
||||
pgtype.Bool{Bool: false, Status: pgtype.Present},
|
||||
pgtype.Bool{Bool: true, Status: pgtype.Present},
|
||||
pgtype.Bool{Bool: false, Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{
|
||||
{Length: 2, LowerBound: 4},
|
||||
{Length: 2, LowerBound: 2},
|
||||
},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBoolArraySet(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.BoolArray
|
||||
}{
|
||||
{
|
||||
source: []bool{true},
|
||||
result: pgtype.BoolArray{
|
||||
Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
source: (([]bool)(nil)),
|
||||
result: pgtype.BoolArray{Status: pgtype.Null},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var r pgtype.BoolArray
|
||||
err := r.Set(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(r, tt.result) {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolArrayAssignTo(t *testing.T) {
|
||||
var boolSlice []bool
|
||||
type _boolSlice []bool
|
||||
var namedBoolSlice _boolSlice
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.BoolArray
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
src: pgtype.BoolArray{
|
||||
Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &boolSlice,
|
||||
expected: []bool{true},
|
||||
},
|
||||
{
|
||||
src: pgtype.BoolArray{
|
||||
Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &namedBoolSlice,
|
||||
expected: _boolSlice{true},
|
||||
},
|
||||
{
|
||||
src: pgtype.BoolArray{Status: pgtype.Null},
|
||||
dst: &boolSlice,
|
||||
expected: (([]bool)(nil)),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
src pgtype.BoolArray
|
||||
dst interface{}
|
||||
}{
|
||||
{
|
||||
src: pgtype.BoolArray{
|
||||
Elements: []pgtype.Bool{{Status: pgtype.Null}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &boolSlice,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err == nil {
|
||||
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
)
|
||||
|
||||
func TestBoolTranscode(t *testing.T) {
|
||||
testutil.TestSuccessfulTranscode(t, "bool", []interface{}{
|
||||
&pgtype.Bool{Bool: false, Status: pgtype.Present},
|
||||
&pgtype.Bool{Bool: true, Status: pgtype.Present},
|
||||
&pgtype.Bool{Bool: false, Status: pgtype.Null},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBoolSet(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.Bool
|
||||
}{
|
||||
{source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
|
||||
{source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
|
||||
{source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
|
||||
{source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
|
||||
{source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var r pgtype.Bool
|
||||
err := r.Set(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if r != tt.result {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolAssignTo(t *testing.T) {
|
||||
var b bool
|
||||
var _b _bool
|
||||
var pb *bool
|
||||
var _pb *_bool
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.Bool
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &b, expected: false},
|
||||
{src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &b, expected: true},
|
||||
{src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &_b, expected: _bool(false)},
|
||||
{src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_b, expected: _bool(true)},
|
||||
{src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &pb, expected: ((*bool)(nil))},
|
||||
{src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &_pb, expected: ((*_bool)(nil))},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
pointerAllocTests := []struct {
|
||||
src pgtype.Bool
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &pb, expected: true},
|
||||
{src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_pb, expected: _bool(true)},
|
||||
}
|
||||
|
||||
for i, tt := range pointerAllocTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Box struct {
|
||||
P [2]Vec2
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *Box) Set(src interface{}) error {
|
||||
return errors.Errorf("cannot convert %v to Box", src)
|
||||
}
|
||||
|
||||
func (dst *Box) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Box) AssignTo(dst interface{}) error {
|
||||
return errors.Errorf("cannot assign %v to %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Box{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src) < 11 {
|
||||
return errors.Errorf("invalid length for Box: %v", len(src))
|
||||
}
|
||||
|
||||
str := string(src[1:])
|
||||
|
||||
var end int
|
||||
end = strings.IndexByte(str, ',')
|
||||
|
||||
x1, err := strconv.ParseFloat(str[:end], 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
str = str[end+1:]
|
||||
end = strings.IndexByte(str, ')')
|
||||
|
||||
y1, err := strconv.ParseFloat(str[:end], 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
str = str[end+3:]
|
||||
end = strings.IndexByte(str, ',')
|
||||
|
||||
x2, err := strconv.ParseFloat(str[:end], 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
str = str[end+1 : len(str)-1]
|
||||
|
||||
y2, err := strconv.ParseFloat(str, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Box{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src) != 32 {
|
||||
return errors.Errorf("invalid length for Box: %v", len(src))
|
||||
}
|
||||
|
||||
x1 := binary.BigEndian.Uint64(src)
|
||||
y1 := binary.BigEndian.Uint64(src[8:])
|
||||
x2 := binary.BigEndian.Uint64(src[16:])
|
||||
y2 := binary.BigEndian.Uint64(src[24:])
|
||||
|
||||
*dst = Box{
|
||||
P: [2]Vec2{
|
||||
{math.Float64frombits(x1), math.Float64frombits(y1)},
|
||||
{math.Float64frombits(x2), math.Float64frombits(y2)},
|
||||
},
|
||||
Status: Present,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`,
|
||||
src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...)
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (src *Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X))
|
||||
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y))
|
||||
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X))
|
||||
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y))
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *Box) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = Box{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
srcCopy := make([]byte, len(src))
|
||||
copy(srcCopy, src)
|
||||
return dst.DecodeText(nil, srcCopy)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *Box) Value() (driver.Value, error) {
|
||||
return EncodeValueText(src)
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
)
|
||||
|
||||
func TestBoxTranscode(t *testing.T) {
|
||||
testutil.TestSuccessfulTranscode(t, "box", []interface{}{
|
||||
&pgtype.Box{
|
||||
P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.Box{
|
||||
P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.Box{Status: pgtype.Null},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBoxNormalize(t *testing.T) {
|
||||
testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{
|
||||
{
|
||||
SQL: "select '3.14, 1.678, 7.1, 5.234'::box",
|
||||
Value: &pgtype.Box{
|
||||
P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Bytea struct {
|
||||
Bytes []byte
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *Bytea) Set(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = Bytea{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch value := src.(type) {
|
||||
case []byte:
|
||||
if value != nil {
|
||||
*dst = Bytea{Bytes: value, Status: Present}
|
||||
} else {
|
||||
*dst = Bytea{Status: Null}
|
||||
}
|
||||
default:
|
||||
if originalSrc, ok := underlyingBytesType(src); ok {
|
||||
return dst.Set(originalSrc)
|
||||
}
|
||||
return errors.Errorf("cannot convert %v to Bytea", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Bytea) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst.Bytes
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Bytea) AssignTo(dst interface{}) error {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
switch v := dst.(type) {
|
||||
case *[]byte:
|
||||
buf := make([]byte, len(src.Bytes))
|
||||
copy(buf, src.Bytes)
|
||||
*v = buf
|
||||
return nil
|
||||
default:
|
||||
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
}
|
||||
case Null:
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot decode %v into %T", src, dst)
|
||||
}
|
||||
|
||||
// DecodeText only supports the hex format. This has been the default since
|
||||
// PostgreSQL 9.0.
|
||||
func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Bytea{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src) < 2 || src[0] != '\\' || src[1] != 'x' {
|
||||
return errors.Errorf("invalid hex format")
|
||||
}
|
||||
|
||||
buf := make([]byte, (len(src)-2)/2)
|
||||
_, err := hex.Decode(buf, src[2:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*dst = Bytea{Bytes: buf, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Bytea{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
*dst = Bytea{Bytes: src, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
buf = append(buf, `\x`...)
|
||||
buf = append(buf, hex.EncodeToString(src.Bytes)...)
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (src *Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
return append(buf, src.Bytes...), nil
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *Bytea) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = Bytea{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
buf := make([]byte, len(src))
|
||||
copy(buf, src)
|
||||
*dst = Bytea{Bytes: buf, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *Bytea) Value() (driver.Value, error) {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
return src.Bytes, nil
|
||||
case Null:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, errUndefined
|
||||
}
|
||||
}
|
|
@ -0,0 +1,294 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type ByteaArray struct {
|
||||
Elements []Bytea
|
||||
Dimensions []ArrayDimension
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *ByteaArray) Set(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
|
||||
case [][]byte:
|
||||
if value == nil {
|
||||
*dst = ByteaArray{Status: Null}
|
||||
} else if len(value) == 0 {
|
||||
*dst = ByteaArray{Status: Present}
|
||||
} else {
|
||||
elements := make([]Bytea, len(value))
|
||||
for i := range value {
|
||||
if err := elements[i].Set(value[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*dst = ByteaArray{
|
||||
Elements: elements,
|
||||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||
Status: Present,
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||
return dst.Set(originalSrc)
|
||||
}
|
||||
return errors.Errorf("cannot convert %v to Bytea", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *ByteaArray) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *ByteaArray) AssignTo(dst interface{}) error {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
switch v := dst.(type) {
|
||||
|
||||
case *[][]byte:
|
||||
*v = make([][]byte, len(src.Elements))
|
||||
for i := range src.Elements {
|
||||
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
}
|
||||
case Null:
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot decode %v into %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = ByteaArray{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
uta, err := ParseUntypedTextArray(string(src))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var elements []Bytea
|
||||
|
||||
if len(uta.Elements) > 0 {
|
||||
elements = make([]Bytea, len(uta.Elements))
|
||||
|
||||
for i, s := range uta.Elements {
|
||||
var elem Bytea
|
||||
var elemSrc []byte
|
||||
if s != "NULL" {
|
||||
elemSrc = []byte(s)
|
||||
}
|
||||
err = elem.DecodeText(ci, elemSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elements[i] = elem
|
||||
}
|
||||
}
|
||||
|
||||
*dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = ByteaArray{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
var arrayHeader ArrayHeader
|
||||
rp, err := arrayHeader.DecodeBinary(ci, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(arrayHeader.Dimensions) == 0 {
|
||||
*dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
elementCount := arrayHeader.Dimensions[0].Length
|
||||
for _, d := range arrayHeader.Dimensions[1:] {
|
||||
elementCount *= d.Length
|
||||
}
|
||||
|
||||
elements := make([]Bytea, elementCount)
|
||||
|
||||
for i := range elements {
|
||||
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
var elemSrc []byte
|
||||
if elemLen >= 0 {
|
||||
elemSrc = src[rp : rp+elemLen]
|
||||
rp += elemLen
|
||||
}
|
||||
err = elements[i].DecodeBinary(ci, elemSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
*dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
if len(src.Dimensions) == 0 {
|
||||
return append(buf, '{', '}'), nil
|
||||
}
|
||||
|
||||
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
|
||||
|
||||
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||
// or '}'.
|
||||
dimElemCounts := make([]int, len(src.Dimensions))
|
||||
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
|
||||
for i := len(src.Dimensions) - 2; i > -1; i-- {
|
||||
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
|
||||
}
|
||||
|
||||
inElemBuf := make([]byte, 0, 32)
|
||||
for i, elem := range src.Elements {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if i%dec == 0 {
|
||||
buf = append(buf, '{')
|
||||
}
|
||||
}
|
||||
|
||||
elemBuf, err := elem.EncodeText(ci, inElemBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if elemBuf == nil {
|
||||
buf = append(buf, `NULL`...)
|
||||
} else {
|
||||
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if (i+1)%dec == 0 {
|
||||
buf = append(buf, '}')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
arrayHeader := ArrayHeader{
|
||||
Dimensions: src.Dimensions,
|
||||
}
|
||||
|
||||
if dt, ok := ci.DataTypeForName("bytea"); ok {
|
||||
arrayHeader.ElementOID = int32(dt.OID)
|
||||
} else {
|
||||
return nil, errors.Errorf("unable to find oid for type name %v", "bytea")
|
||||
}
|
||||
|
||||
for i := range src.Elements {
|
||||
if src.Elements[i].Status == Null {
|
||||
arrayHeader.ContainsNull = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
buf = arrayHeader.EncodeBinary(ci, buf)
|
||||
|
||||
for i := range src.Elements {
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if elemBuf != nil {
|
||||
buf = elemBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *ByteaArray) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
return dst.DecodeText(nil, nil)
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
srcCopy := make([]byte, len(src))
|
||||
copy(srcCopy, src)
|
||||
return dst.DecodeText(nil, srcCopy)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *ByteaArray) Value() (driver.Value, error) {
|
||||
buf, err := src.EncodeText(nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return string(buf), nil
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
)
|
||||
|
||||
func TestByteaArrayTranscode(t *testing.T) {
|
||||
testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{
|
||||
&pgtype.ByteaArray{
|
||||
Elements: nil,
|
||||
Dimensions: nil,
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.ByteaArray{
|
||||
Elements: []pgtype.Bytea{
|
||||
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
|
||||
pgtype.Bytea{Status: pgtype.Null},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.ByteaArray{Status: pgtype.Null},
|
||||
&pgtype.ByteaArray{
|
||||
Elements: []pgtype.Bytea{
|
||||
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
|
||||
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
|
||||
pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present},
|
||||
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
|
||||
pgtype.Bytea{Status: pgtype.Null},
|
||||
pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.ByteaArray{
|
||||
Elements: []pgtype.Bytea{
|
||||
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
|
||||
pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present},
|
||||
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
|
||||
pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{
|
||||
{Length: 2, LowerBound: 4},
|
||||
{Length: 2, LowerBound: 2},
|
||||
},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestByteaArraySet(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.ByteaArray
|
||||
}{
|
||||
{
|
||||
source: [][]byte{{1, 2, 3}},
|
||||
result: pgtype.ByteaArray{
|
||||
Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
source: (([][]byte)(nil)),
|
||||
result: pgtype.ByteaArray{Status: pgtype.Null},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var r pgtype.ByteaArray
|
||||
err := r.Set(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(r, tt.result) {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteaArrayAssignTo(t *testing.T) {
|
||||
var byteByteSlice [][]byte
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.ByteaArray
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
src: pgtype.ByteaArray{
|
||||
Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &byteByteSlice,
|
||||
expected: [][]byte{{1, 2, 3}},
|
||||
},
|
||||
{
|
||||
src: pgtype.ByteaArray{Status: pgtype.Null},
|
||||
dst: &byteByteSlice,
|
||||
expected: (([][]byte)(nil)),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
)
|
||||
|
||||
func TestByteaTranscode(t *testing.T) {
|
||||
testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{
|
||||
&pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
|
||||
&pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present},
|
||||
&pgtype.Bytea{Bytes: nil, Status: pgtype.Null},
|
||||
})
|
||||
}
|
||||
|
||||
func TestByteaSet(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.Bytea
|
||||
}{
|
||||
{source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
|
||||
{source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}},
|
||||
{source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}},
|
||||
{source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
|
||||
{source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var r pgtype.Bytea
|
||||
err := r.Set(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(r, tt.result) {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteaAssignTo(t *testing.T) {
|
||||
var buf []byte
|
||||
var _buf _byteSlice
|
||||
var pbuf *[]byte
|
||||
var _pbuf *_byteSlice
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.Bytea
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}},
|
||||
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}},
|
||||
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}},
|
||||
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}},
|
||||
{src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))},
|
||||
{src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// CID is PostgreSQL's Command Identifier type.
|
||||
//
|
||||
// When one does
|
||||
//
|
||||
// select cmin, cmax, * from some_table;
|
||||
//
|
||||
// it is the data type of the cmin and cmax hidden system columns.
|
||||
//
|
||||
// It is currently implemented as an unsigned four byte integer.
|
||||
// Its definition can be found in src/include/c.h as CommandId
|
||||
// in the PostgreSQL sources.
|
||||
type CID pguint32
|
||||
|
||||
// Set converts from src to dst. Note that as CID is not a general
|
||||
// number type Set does not do automatic type conversion as other number
|
||||
// types do.
|
||||
func (dst *CID) Set(src interface{}) error {
|
||||
return (*pguint32)(dst).Set(src)
|
||||
}
|
||||
|
||||
func (dst *CID) Get() interface{} {
|
||||
return (*pguint32)(dst).Get()
|
||||
}
|
||||
|
||||
// AssignTo assigns from src to dst. Note that as CID is not a general number
|
||||
// type AssignTo does not do automatic type conversion as other number types do.
|
||||
func (src *CID) AssignTo(dst interface{}) error {
|
||||
return (*pguint32)(src).AssignTo(dst)
|
||||
}
|
||||
|
||||
func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
return (*pguint32)(dst).DecodeText(ci, src)
|
||||
}
|
||||
|
||||
func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
return (*pguint32)(dst).DecodeBinary(ci, src)
|
||||
}
|
||||
|
||||
func (src *CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
return (*pguint32)(src).EncodeText(ci, buf)
|
||||
}
|
||||
|
||||
func (src *CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
return (*pguint32)(src).EncodeBinary(ci, buf)
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *CID) Scan(src interface{}) error {
|
||||
return (*pguint32)(dst).Scan(src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *CID) Value() (driver.Value, error) {
|
||||
return (*pguint32)(src).Value()
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
)
|
||||
|
||||
func TestCIDTranscode(t *testing.T) {
|
||||
pgTypeName := "cid"
|
||||
values := []interface{}{
|
||||
&pgtype.CID{Uint: 42, Status: pgtype.Present},
|
||||
&pgtype.CID{Status: pgtype.Null},
|
||||
}
|
||||
eqFunc := func(a, b interface{}) bool {
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
|
||||
// No direct conversion from int to cid, convert through text
|
||||
testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
|
||||
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCIDSet(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.CID
|
||||
}{
|
||||
{source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var r pgtype.CID
|
||||
err := r.Set(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if r != tt.result {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCIDAssignTo(t *testing.T) {
|
||||
var ui32 uint32
|
||||
var pui32 *uint32
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.CID
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)},
|
||||
{src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
pointerAllocTests := []struct {
|
||||
src pgtype.CID
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)},
|
||||
}
|
||||
|
||||
for i, tt := range pointerAllocTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
src pgtype.CID
|
||||
dst interface{}
|
||||
}{
|
||||
{src: pgtype.CID{Status: pgtype.Null}, dst: &ui32},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err == nil {
|
||||
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package pgtype
|
||||
|
||||
type CIDR Inet
|
||||
|
||||
func (dst *CIDR) Set(src interface{}) error {
|
||||
return (*Inet)(dst).Set(src)
|
||||
}
|
||||
|
||||
func (dst *CIDR) Get() interface{} {
|
||||
return (*Inet)(dst).Get()
|
||||
}
|
||||
|
||||
func (src *CIDR) AssignTo(dst interface{}) error {
|
||||
return (*Inet)(src).AssignTo(dst)
|
||||
}
|
||||
|
||||
func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
return (*Inet)(dst).DecodeText(ci, src)
|
||||
}
|
||||
|
||||
func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
return (*Inet)(dst).DecodeBinary(ci, src)
|
||||
}
|
||||
|
||||
func (src *CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
return (*Inet)(src).EncodeText(ci, buf)
|
||||
}
|
||||
|
||||
func (src *CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
return (*Inet)(src).EncodeBinary(ci, buf)
|
||||
}
|
|
@ -0,0 +1,323 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type CIDRArray struct {
|
||||
Elements []CIDR
|
||||
Dimensions []ArrayDimension
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *CIDRArray) Set(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
|
||||
case []*net.IPNet:
|
||||
if value == nil {
|
||||
*dst = CIDRArray{Status: Null}
|
||||
} else if len(value) == 0 {
|
||||
*dst = CIDRArray{Status: Present}
|
||||
} else {
|
||||
elements := make([]CIDR, len(value))
|
||||
for i := range value {
|
||||
if err := elements[i].Set(value[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*dst = CIDRArray{
|
||||
Elements: elements,
|
||||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||
Status: Present,
|
||||
}
|
||||
}
|
||||
|
||||
case []net.IP:
|
||||
if value == nil {
|
||||
*dst = CIDRArray{Status: Null}
|
||||
} else if len(value) == 0 {
|
||||
*dst = CIDRArray{Status: Present}
|
||||
} else {
|
||||
elements := make([]CIDR, len(value))
|
||||
for i := range value {
|
||||
if err := elements[i].Set(value[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*dst = CIDRArray{
|
||||
Elements: elements,
|
||||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||
Status: Present,
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||
return dst.Set(originalSrc)
|
||||
}
|
||||
return errors.Errorf("cannot convert %v to CIDR", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *CIDRArray) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *CIDRArray) AssignTo(dst interface{}) error {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
switch v := dst.(type) {
|
||||
|
||||
case *[]*net.IPNet:
|
||||
*v = make([]*net.IPNet, len(src.Elements))
|
||||
for i := range src.Elements {
|
||||
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case *[]net.IP:
|
||||
*v = make([]net.IP, len(src.Elements))
|
||||
for i := range src.Elements {
|
||||
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
}
|
||||
case Null:
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot decode %v into %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = CIDRArray{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
uta, err := ParseUntypedTextArray(string(src))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var elements []CIDR
|
||||
|
||||
if len(uta.Elements) > 0 {
|
||||
elements = make([]CIDR, len(uta.Elements))
|
||||
|
||||
for i, s := range uta.Elements {
|
||||
var elem CIDR
|
||||
var elemSrc []byte
|
||||
if s != "NULL" {
|
||||
elemSrc = []byte(s)
|
||||
}
|
||||
err = elem.DecodeText(ci, elemSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elements[i] = elem
|
||||
}
|
||||
}
|
||||
|
||||
*dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = CIDRArray{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
var arrayHeader ArrayHeader
|
||||
rp, err := arrayHeader.DecodeBinary(ci, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(arrayHeader.Dimensions) == 0 {
|
||||
*dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
elementCount := arrayHeader.Dimensions[0].Length
|
||||
for _, d := range arrayHeader.Dimensions[1:] {
|
||||
elementCount *= d.Length
|
||||
}
|
||||
|
||||
elements := make([]CIDR, elementCount)
|
||||
|
||||
for i := range elements {
|
||||
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
var elemSrc []byte
|
||||
if elemLen >= 0 {
|
||||
elemSrc = src[rp : rp+elemLen]
|
||||
rp += elemLen
|
||||
}
|
||||
err = elements[i].DecodeBinary(ci, elemSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
*dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
if len(src.Dimensions) == 0 {
|
||||
return append(buf, '{', '}'), nil
|
||||
}
|
||||
|
||||
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
|
||||
|
||||
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||
// or '}'.
|
||||
dimElemCounts := make([]int, len(src.Dimensions))
|
||||
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
|
||||
for i := len(src.Dimensions) - 2; i > -1; i-- {
|
||||
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
|
||||
}
|
||||
|
||||
inElemBuf := make([]byte, 0, 32)
|
||||
for i, elem := range src.Elements {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if i%dec == 0 {
|
||||
buf = append(buf, '{')
|
||||
}
|
||||
}
|
||||
|
||||
elemBuf, err := elem.EncodeText(ci, inElemBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if elemBuf == nil {
|
||||
buf = append(buf, `NULL`...)
|
||||
} else {
|
||||
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if (i+1)%dec == 0 {
|
||||
buf = append(buf, '}')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
arrayHeader := ArrayHeader{
|
||||
Dimensions: src.Dimensions,
|
||||
}
|
||||
|
||||
if dt, ok := ci.DataTypeForName("cidr"); ok {
|
||||
arrayHeader.ElementOID = int32(dt.OID)
|
||||
} else {
|
||||
return nil, errors.Errorf("unable to find oid for type name %v", "cidr")
|
||||
}
|
||||
|
||||
for i := range src.Elements {
|
||||
if src.Elements[i].Status == Null {
|
||||
arrayHeader.ContainsNull = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
buf = arrayHeader.EncodeBinary(ci, buf)
|
||||
|
||||
for i := range src.Elements {
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if elemBuf != nil {
|
||||
buf = elemBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *CIDRArray) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
return dst.DecodeText(nil, nil)
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
srcCopy := make([]byte, len(src))
|
||||
copy(srcCopy, src)
|
||||
return dst.DecodeText(nil, srcCopy)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *CIDRArray) Value() (driver.Value, error) {
|
||||
buf, err := src.EncodeText(nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return string(buf), nil
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
)
|
||||
|
||||
func TestCIDRArrayTranscode(t *testing.T) {
|
||||
testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{
|
||||
&pgtype.CIDRArray{
|
||||
Elements: nil,
|
||||
Dimensions: nil,
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.CIDRArray{
|
||||
Elements: []pgtype.CIDR{
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
|
||||
pgtype.CIDR{Status: pgtype.Null},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.CIDRArray{Status: pgtype.Null},
|
||||
&pgtype.CIDRArray{
|
||||
Elements: []pgtype.CIDR{
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present},
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present},
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present},
|
||||
pgtype.CIDR{Status: pgtype.Null},
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.CIDRArray{
|
||||
Elements: []pgtype.CIDR{
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present},
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present},
|
||||
pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{
|
||||
{Length: 2, LowerBound: 4},
|
||||
{Length: 2, LowerBound: 2},
|
||||
},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestCIDRArraySet(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.CIDRArray
|
||||
}{
|
||||
{
|
||||
source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")},
|
||||
result: pgtype.CIDRArray{
|
||||
Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
source: (([]*net.IPNet)(nil)),
|
||||
result: pgtype.CIDRArray{Status: pgtype.Null},
|
||||
},
|
||||
{
|
||||
source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP},
|
||||
result: pgtype.CIDRArray{
|
||||
Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
source: (([]net.IP)(nil)),
|
||||
result: pgtype.CIDRArray{Status: pgtype.Null},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var r pgtype.CIDRArray
|
||||
err := r.Set(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(r, tt.result) {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCIDRArrayAssignTo(t *testing.T) {
|
||||
var ipnetSlice []*net.IPNet
|
||||
var ipSlice []net.IP
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.CIDRArray
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
src: pgtype.CIDRArray{
|
||||
Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &ipnetSlice,
|
||||
expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")},
|
||||
},
|
||||
{
|
||||
src: pgtype.CIDRArray{
|
||||
Elements: []pgtype.CIDR{{Status: pgtype.Null}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &ipnetSlice,
|
||||
expected: []*net.IPNet{nil},
|
||||
},
|
||||
{
|
||||
src: pgtype.CIDRArray{
|
||||
Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &ipSlice,
|
||||
expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP},
|
||||
},
|
||||
{
|
||||
src: pgtype.CIDRArray{
|
||||
Elements: []pgtype.CIDR{{Status: pgtype.Null}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &ipSlice,
|
||||
expected: []net.IP{nil},
|
||||
},
|
||||
{
|
||||
src: pgtype.CIDRArray{Status: pgtype.Null},
|
||||
dst: &ipnetSlice,
|
||||
expected: (([]*net.IPNet)(nil)),
|
||||
},
|
||||
{
|
||||
src: pgtype.CIDRArray{Status: pgtype.Null},
|
||||
dst: &ipSlice,
|
||||
expected: (([]net.IP)(nil)),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Circle struct {
|
||||
P Vec2
|
||||
R float64
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *Circle) Set(src interface{}) error {
|
||||
return errors.Errorf("cannot convert %v to Circle", src)
|
||||
}
|
||||
|
||||
func (dst *Circle) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case Present:
|
||||
return dst
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Circle) AssignTo(dst interface{}) error {
|
||||
return errors.Errorf("cannot assign %v to %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Circle{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src) < 9 {
|
||||
return errors.Errorf("invalid length for Circle: %v", len(src))
|
||||
}
|
||||
|
||||
str := string(src[2:])
|
||||
end := strings.IndexByte(str, ',')
|
||||
x, err := strconv.ParseFloat(str[:end], 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
str = str[end+1:]
|
||||
end = strings.IndexByte(str, ')')
|
||||
|
||||
y, err := strconv.ParseFloat(str[:end], 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
str = str[end+2 : len(str)-1]
|
||||
|
||||
r, err := strconv.ParseFloat(str, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*dst = Circle{P: Vec2{x, y}, R: r, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Circle{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src) != 24 {
|
||||
return errors.Errorf("invalid length for Circle: %v", len(src))
|
||||
}
|
||||
|
||||
x := binary.BigEndian.Uint64(src)
|
||||
y := binary.BigEndian.Uint64(src[8:])
|
||||
r := binary.BigEndian.Uint64(src[16:])
|
||||
|
||||
*dst = Circle{
|
||||
P: Vec2{math.Float64frombits(x), math.Float64frombits(y)},
|
||||
R: math.Float64frombits(r),
|
||||
Status: Present,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
buf = append(buf, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)...)
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (src *Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Null:
|
||||
return nil, nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
|
||||
buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X))
|
||||
buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y))
|
||||
buf = pgio.AppendUint64(buf, math.Float64bits(src.R))
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *Circle) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = Circle{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
srcCopy := make([]byte, len(src))
|
||||
copy(srcCopy, src)
|
||||
return dst.DecodeText(nil, srcCopy)
|
||||
}
|
||||
|
||||
return errors.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *Circle) Value() (driver.Value, error) {
|
||||
return EncodeValueText(src)
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue