Make v3 main release

pull/291/merge
Jack Christensen 2017-07-24 08:51:34 -05:00
commit f79e52f1ee
232 changed files with 31747 additions and 7798 deletions

1
.gitignore vendored
View File

@ -22,3 +22,4 @@ _testmain.go
*.exe
conn_config_test.go
.envrc

View File

@ -1,8 +1,7 @@
language: go
go:
- 1.7.4
- 1.6.4
- 1.8
- tip
# Derived from https://github.com/lib/pq/blob/master/.travis.yml
@ -28,6 +27,8 @@ before_install:
- sudo /etc/init.d/postgresql restart
env:
global:
- PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test
matrix:
- PGVERSION=9.6
- PGVERSION=9.5
@ -40,7 +41,7 @@ env:
before_script:
- mv conn_config_test.go.travis conn_config_test.go
- psql -U postgres -c 'create database pgx_test'
- "[[ \"${PGVERSION}\" = '9.0' ]] && psql -U postgres -f /usr/share/postgresql/9.0/contrib/hstore.sql pgx_test || psql -U postgres pgx_test -c 'create extension hstore'"
- psql -U postgres pgx_test -c 'create extension hstore'
- psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
- psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
- psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
@ -51,9 +52,14 @@ install:
- go get -u github.com/shopspring/decimal
- go get -u gopkg.in/inconshreveable/log15.v2
- go get -u github.com/jackc/fake
- go get -u github.com/lib/pq
- go get -u github.com/hashicorp/go-version
- go get -u github.com/satori/go.uuid
- go get -u github.com/Sirupsen/logrus
- go get -u github.com/pkg/errors
script:
- go test -v -race -short ./...
- go test -v -race ./...
matrix:
allow_failures:

View File

@ -1,4 +1,62 @@
# Unreleased
# Unreleased V3
## Changes
* Pid to PID in accordance with Go naming conventions.
* Conn.Pid changed to accessor method Conn.PID()
* Conn.SecretKey removed
* Remove Conn.TxStatus
* Logger interface reduced to single Log method.
* Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode.
* Transaction isolation level constants are now typed strings instead of bare strings.
* Conn.WaitForNotification now takes context.Context instead of time.Duration for cancellation support.
* Conn.WaitForNotification no longer automatically pings internally every 15 seconds.
* ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support.
* Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228
* No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed.
* Remove CopyTo (functionality is now in CopyFrom)
* OID constants moved from pgx to pgtype package
* Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system
* Removed ValueReader
* ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset.
* Removed Rows.Fatal(error)
* Removed Rows.AfterClose()
* Removed Rows.Conn()
* Removed Tx.AfterClose()
* Removed Tx.Conn()
* Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR
* Replaced stdlib.OpenFromConnPool with DriverConfig system
## Features
* Entirely revamped pluggable type system that supports approximately 60 PostgreSQL types.
* Types support database/sql interfaces and therefore can be used with other drivers
* Added context methods supporting cancellation where appropriate
* Added simple query protocol support
* Added single round-trip query mode
* Added batch query operations
* Added OnNotice
* github.com/pkg/errors used where possible for errors
* Added stdlib.DriverConfig which directly allows full configuration of underlying pgx connections without needing to use a pgx.ConnPool
* Added AcquireConn and ReleaseConn to stdlib to allow acquiring a connection from a database/sql connection.
# 2.11.0 (June 5, 2017)
## Fixes
* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock)
## Features
* .pgpass support (j7b)
* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen)
* Add ParseConnectionString (James Lawrence)
## Performance
* Optimize HStore encoding (René Kroon)
# 2.10.0 (March 17, 2017)
## Fixes

View File

@ -1,63 +1,57 @@
[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://godoc.org/github.com/jackc/pgx)
# Pgx
# pgx - PostgreSQL Driver and Toolkit
## Master Branch
This is the `master` branch which tracks the stable release of the current
version. At the moment this is `v2`. The `v3` branch which is currently in beta.
General release is planned for July. `v3` is considered to be stable in the
sense of lack of known bugs, but the API is not considered stable until general
release. No further changes are planned, but the beta process may surface
desirable changes. If possible API changes are acceptable, then `v3` is the
recommended branch for new development. Regardless, please lock to the `v2` or
`v3` branch as when `v3` is released breaking changes will be applied to the
master branch.
Pgx is a pure Go database connection library designed specifically for
PostgreSQL. Pgx is different from other drivers such as
[pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a
database/sql compatible driver, pgx is primarily intended to be used directly.
It offers a native interface similar to database/sql that offers better
performance and more features.
pgx is a pure Go driver and toolkit for PostgreSQL. pgx is different from other drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a database/sql compatible driver, pgx is also usable directly. It offers a native interface similar to database/sql that offers better performance and more features.
## Features
Pgx supports many additional features beyond what is available through database/sql.
pgx supports many additional features beyond what is available through database/sql.
* Listen / notify
* Transaction isolation level control
* Support for approximately 60 different PostgreSQL types
* Batch queries
* Single-round trip query mode
* Full TLS connection control
* Binary format support for custom types (can be much faster)
* Copy from protocol support for faster bulk data loads
* Logging support
* Configurable connection pool with after connect hooks to do arbitrary connection setup
* Copy protocol support for faster bulk data loads
* Extendable logging support including built-in support for log15 and logrus
* Connection pool with after connect hook to do arbitrary connection setup
* Listen / notify
* PostgreSQL array to Go slice mapping for integers, floats, and strings
* Hstore support
* JSON and JSONB support
* Maps inet and cidr PostgreSQL types to net.IPNet and net.IP
* Large object support
* Null mapping to Null* struct or pointer to pointer.
* NULL mapping to Null* struct or pointer to pointer.
* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types
* Logical replication connections, including receiving WAL and sending standby status updates
* Notice response handling (this is different than listen / notify)
## Performance
Pgx performs roughly equivalent to [pq](http://godoc.org/github.com/lib/pq) and
[go-pg](https://github.com/go-pg/pg) for selecting a single column from a single
row, but it is substantially faster when selecting multiple entire rows (6893
queries/sec for pgx vs. 3968 queries/sec for pq -- 73% faster).
pgx performs roughly equivalent to [go-pg](https://github.com/go-pg/pg) and is almost always faster than [pq](http://godoc.org/github.com/lib/pq). When parsing large result sets the percentage difference can be significant (16483 queries/sec for pgx vs. 10106 queries/sec for pq -- 63% faster).
See this [gist](https://gist.github.com/jackc/d282f39e088b495fba3e) for the
underlying benchmark results or checkout
[go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself.
In many use cases a significant cause of latency is network round trips between the application and the server. pgx supports query batching to bundle multiple queries into a single round trip. Even in the case of a connection with the lowest possible latency, a local Unix domain socket, batching as few as three queries together can yield an improvement of 57%. With a typical network connection the results can be even more substantial.
## database/sql
See this [gist](https://gist.github.com/jackc/4996e8648a0c59839bff644f49d6e434) for the underlying benchmark results or checkout [go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself.
Import the ```github.com/jackc/pgx/stdlib``` package to use pgx as a driver for
database/sql. It is possible to retrieve a pgx connection from database/sql on
demand. This allows using the database/sql interface in most places, but using
pgx directly when more performance or PostgreSQL specific features are needed.
In addition to the native driver, pgx also includes a number of packages that provide additional functionality.
## github.com/jackc/pgxstdlib
database/sql compatibility layer for pgx. pgx can be used as a normal database/sql driver, but at any time the native interface may be acquired for more performance or PostgreSQL specific functionality.
## github.com/jackc/pgx/pgtype
Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver.
## github.com/jackc/pgx/pgproto3
pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling.
## github.com/jackc/pgx/pgmock
pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler).
## Documentation
@ -132,6 +126,8 @@ Change the following settings in your postgresql.conf:
max_wal_senders=5
max_replication_slots=5
Set `replicationConnConfig` appropriately in `conn_config_test.go`.
## Version Policy
pgx follows semantic versioning for the documented public API on stable releases. Branch `v2` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v2` branch (in practice, this occurs very rarely).
pgx follows semantic versioning for the documented public API on stable releases. Branch `v3` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v3` branch (in practice, this occurs very rarely). `v2` is the previous stable release.

View File

@ -1,126 +0,0 @@
package pgx
import (
"reflect"
"testing"
)
func TestEscapeAclItem(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
"foo",
"foo",
},
{
`foo, "\}`,
`foo\, \"\\\}`,
},
}
for i, tt := range tests {
actual, err := escapeAclItem(tt.input)
if err != nil {
t.Errorf("%d. Unexpected error %v", i, err)
}
if actual != tt.expected {
t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual)
}
}
}
func TestParseAclItemArray(t *testing.T) {
tests := []struct {
input string
expected []AclItem
errMsg string
}{
{
"",
[]AclItem{},
"",
},
{
"one",
[]AclItem{"one"},
"",
},
{
`"one"`,
[]AclItem{"one"},
"",
},
{
"one,two,three",
[]AclItem{"one", "two", "three"},
"",
},
{
`"one","two","three"`,
[]AclItem{"one", "two", "three"},
"",
},
{
`"one",two,"three"`,
[]AclItem{"one", "two", "three"},
"",
},
{
`one,two,"three"`,
[]AclItem{"one", "two", "three"},
"",
},
{
`"one","two",three`,
[]AclItem{"one", "two", "three"},
"",
},
{
`"one","t w o",three`,
[]AclItem{"one", "t w o", "three"},
"",
},
{
`"one","t, w o\"\}\\",three`,
[]AclItem{"one", `t, w o"}\`, "three"},
"",
},
{
`"one","two",three"`,
[]AclItem{"one", "two", `three"`},
"",
},
{
`"one","two,"three"`,
nil,
"unexpected rune after quoted value",
},
{
`"one","two","three`,
nil,
"unexpected end of quoted value",
},
}
for i, tt := range tests {
actual, err := parseAclItemArray(tt.input)
if err != nil {
if tt.errMsg == "" {
t.Errorf("%d. Unexpected error %v", i, err)
} else if err.Error() != tt.errMsg {
t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error())
}
} else if tt.errMsg != "" {
t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg)
}
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual)
}
}
}

246
batch.go Normal file
View File

@ -0,0 +1,246 @@
package pgx
import (
"context"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
type batchItem struct {
query string
arguments []interface{}
parameterOIDs []pgtype.OID
resultFormatCodes []int16
}
// Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips.
type Batch struct {
conn *Conn
connPool *ConnPool
items []*batchItem
resultsRead int
sent bool
ctx context.Context
err error
}
// BeginBatch returns a *Batch query for c.
func (c *Conn) BeginBatch() *Batch {
return &Batch{conn: c}
}
// Conn returns the underlying connection that b will or was performed on.
func (b *Batch) Conn() *Conn {
return b.conn
}
// Queue queues a query to batch b. parameterOIDs are required if there are
// parameters and query is not the name of a prepared statement.
// resultFormatCodes are required if there is a result.
func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgtype.OID, resultFormatCodes []int16) {
b.items = append(b.items, &batchItem{
query: query,
arguments: arguments,
parameterOIDs: parameterOIDs,
resultFormatCodes: resultFormatCodes,
})
}
// Send sends all queued queries to the server at once. All queries are wrapped
// in a transaction. The transaction can optionally be configured with
// txOptions. The context is in effect until the Batch is closed.
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
if b.err != nil {
return b.err
}
b.ctx = ctx
err := b.conn.waitForPreviousCancelQuery(ctx)
if err != nil {
return err
}
if err := b.conn.ensureConnectionReadyForQuery(); err != nil {
return err
}
err = b.conn.initContext(ctx)
if err != nil {
return err
}
buf := appendQuery(b.conn.wbuf, txOptions.beginSQL())
for _, bi := range b.items {
var psName string
var psParameterOIDs []pgtype.OID
if ps, ok := b.conn.preparedStatements[bi.query]; ok {
psName = ps.Name
psParameterOIDs = ps.ParameterOIDs
} else {
psParameterOIDs = bi.parameterOIDs
buf = appendParse(buf, "", bi.query, psParameterOIDs)
}
var err error
buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes)
if err != nil {
return err
}
buf = appendDescribe(buf, 'P', "")
buf = appendExecute(buf, "", 0)
}
buf = appendSync(buf)
buf = appendQuery(buf, "commit")
n, err := b.conn.conn.Write(buf)
if err != nil {
if fatalWriteErr(n, err) {
b.conn.die(err)
}
return err
}
// expect ReadyForQuery from sync and from commit
b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2
b.sent = true
for {
msg, err := b.conn.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
return nil
default:
if err := b.conn.processContextFreeMsg(msg); err != nil {
return err
}
}
}
return nil
}
// ExecResults reads the results from the next query in the batch as if the
// query has been sent with Exec.
func (b *Batch) ExecResults() (CommandTag, error) {
if b.err != nil {
return "", b.err
}
select {
case <-b.ctx.Done():
b.die(b.ctx.Err())
return "", b.ctx.Err()
default:
}
b.resultsRead++
for {
msg, err := b.conn.rxMsg()
if err != nil {
return "", err
}
switch msg := msg.(type) {
case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil
default:
if err := b.conn.processContextFreeMsg(msg); err != nil {
return "", err
}
}
}
}
// QueryResults reads the results from the next query in the batch as if the
// query has been sent with Query.
func (b *Batch) QueryResults() (*Rows, error) {
rows := b.conn.getRows("batch query", nil)
if b.err != nil {
rows.fatal(b.err)
return rows, b.err
}
select {
case <-b.ctx.Done():
b.die(b.ctx.Err())
rows.fatal(b.err)
return rows, b.ctx.Err()
default:
}
b.resultsRead++
fieldDescriptions, err := b.conn.readUntilRowDescription()
if err != nil {
b.die(err)
rows.fatal(b.err)
return rows, err
}
rows.batch = b
rows.fields = fieldDescriptions
return rows, nil
}
// QueryRowResults reads the results from the next query in the batch as if the
// query has been sent with QueryRow.
func (b *Batch) QueryRowResults() *Row {
rows, _ := b.QueryResults()
return (*Row)(rows)
}
// Close closes the batch operation. Any error that occured during a batch
// operation may have made it impossible to resyncronize the connection with the
// server. In this case the underlying connection will have been closed.
func (b *Batch) Close() (err error) {
if b.err != nil {
return b.err
}
defer func() {
err = b.conn.termContext(err)
if b.conn != nil && b.connPool != nil {
b.connPool.Release(b.conn)
}
}()
for i := b.resultsRead; i < len(b.items); i++ {
if _, err = b.ExecResults(); err != nil {
return err
}
}
if err = b.conn.ensureConnectionReadyForQuery(); err != nil {
return err
}
return nil
}
func (b *Batch) die(err error) {
if b.err != nil {
return
}
b.err = err
b.conn.die(err)
if b.conn != nil && b.connPool != nil {
b.connPool.Release(b.conn)
}
}

478
batch_test.go Normal file
View File

@ -0,0 +1,478 @@
package pgx_test
import (
"context"
"testing"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
)
func TestConnBeginBatch(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q2", 2},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q3", 3},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
batch.Queue("select id, description, amount from ledger order by id",
nil,
nil,
[]int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode},
)
batch.Queue("select sum(amount) from ledger",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
var id int32
var description string
var amount int32
if !rows.Next() {
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 1 {
t.Errorf("id => %v, want %v", id, 1)
}
if description != "q1" {
t.Errorf("description => %v, want %v", description, "q1")
}
if amount != 1 {
t.Errorf("amount => %v, want %v", amount, 1)
}
if !rows.Next() {
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 2 {
t.Errorf("id => %v, want %v", id, 2)
}
if description != "q2" {
t.Errorf("description => %v, want %v", description, "q2")
}
if amount != 2 {
t.Errorf("amount => %v, want %v", amount, 2)
}
if !rows.Next() {
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 3 {
t.Errorf("id => %v, want %v", id, 3)
}
if description != "q3" {
t.Errorf("description => %v, want %v", description, "q3")
}
if amount != 3 {
t.Errorf("amount => %v, want %v", amount, 3)
}
if rows.Next() {
t.Fatal("did not expect a row to be available")
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
err = batch.QueryRowResults().Scan(&amount)
if err != nil {
t.Error(err)
}
if amount != 6 {
t.Errorf("amount => %v, want %v", amount, 6)
}
err = batch.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
_, err := conn.Prepare("ps1", "select n from generate_series(0,$1::int) n")
if err != nil {
t.Fatal(err)
}
batch := conn.BeginBatch()
queryCount := 3
for i := 0; i < queryCount; i++ {
batch.Queue("ps1",
[]interface{}{5},
nil,
[]int16{pgx.BinaryFormatCode},
)
}
err = batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
for i := 0; i < queryCount; i++ {
rows, err := batch.QueryResults()
if err != nil {
t.Fatal(err)
}
for k := 0; rows.Next(); k++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Fatal(err)
}
if n != k {
t.Fatalf("n => %v, want %v", n, k)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
}
err = batch.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
batch.Queue("select pg_sleep(2)",
nil,
nil,
nil,
)
ctx, cancelFn := context.WithCancel(context.Background())
err := batch.Send(ctx, nil)
if err != nil {
t.Fatal(err)
}
cancelFn()
_, err = batch.ExecResults()
if err != context.Canceled {
t.Errorf("err => %v, want %v", err, context.Canceled)
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
}
}
func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
batch := conn.BeginBatch()
batch.Queue("select pg_sleep(2)",
nil,
nil,
nil,
)
batch.Queue("select pg_sleep(2)",
nil,
nil,
nil,
)
ctx, cancelFn := context.WithCancel(context.Background())
err := batch.Send(ctx, nil)
if err != nil {
t.Fatal(err)
}
cancelFn()
_, err = batch.QueryResults()
if err != context.Canceled {
t.Errorf("err => %v, want %v", err, context.Canceled)
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
}
}
func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
batch := conn.BeginBatch()
batch.Queue("select pg_sleep(2)",
nil,
nil,
nil,
)
batch.Queue("select pg_sleep(2)",
nil,
nil,
nil,
)
ctx, cancelFn := context.WithCancel(context.Background())
err := batch.Send(ctx, nil)
if err != nil {
t.Fatal(err)
}
cancelFn()
err = batch.Close()
if err != context.Canceled {
t.Errorf("err => %v, want %v", err, context.Canceled)
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
}
}
func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
batch := conn.BeginBatch()
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
for i := 0; i < 3; i++ {
if !rows.Next() {
t.Error("expected a row to be available")
}
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
rows.Close()
rows, err = batch.QueryResults()
if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
if rows.Err() != nil {
t.Error(rows.Err())
}
err = batch.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchQueryError(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
batch := conn.BeginBatch()
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
if pgErr, ok := rows.Err().(pgx.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
}
err = batch.Close()
if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
}
}
func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
batch := conn.BeginBatch()
batch.Queue("select 1 1",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
var n int32
err = batch.QueryRowResults().Scan(&n)
if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "42601") {
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
}
err = batch.Close()
if err == nil {
t.Error("Expected error")
}
if conn.IsAlive() {
t.Error("conn should be dead, but was alive")
}
}

55
bench-tmp_test.go Normal file
View File

@ -0,0 +1,55 @@
package pgx_test
import (
"testing"
)
func BenchmarkPgtypeInt4ParseBinary(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
_, err := conn.Prepare("selectBinary", "select n::int4 from generate_series(1, 100) n")
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
var n int32
rows, err := conn.Query("selectBinary")
if err != nil {
b.Fatal(err)
}
for rows.Next() {
err := rows.Scan(&n)
if err != nil {
b.Fatal(err)
}
}
if rows.Err() != nil {
b.Fatal(rows.Err())
}
}
}
func BenchmarkPgtypeInt4EncodeBinary(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
_, err := conn.Prepare("encodeBinary", "select $1::int4, $2::int4, $3::int4, $4::int4, $5::int4, $6::int4, $7::int4")
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
rows, err := conn.Query("encodeBinary", int32(i), int32(i), int32(i), int32(i), int32(i), int32(i), int32(i))
if err != nil {
b.Fatal(err)
}
rows.Close()
}
}

View File

@ -2,13 +2,14 @@ package pgx_test
import (
"bytes"
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/jackc/pgx"
log "gopkg.in/inconshreveable/log15.v2"
"github.com/jackc/pgx/pgtype"
)
func BenchmarkConnPool(b *testing.B) {
@ -50,126 +51,6 @@ func BenchmarkConnPoolQueryRow(b *testing.B) {
}
}
func BenchmarkNullXWithNullValues(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
_, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz")
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
var record struct {
id int32
userName string
email pgx.NullString
name pgx.NullString
sex pgx.NullString
birthDate pgx.NullTime
lastLoginTime pgx.NullTime
}
err = conn.QueryRow("selectNulls").Scan(
&record.id,
&record.userName,
&record.email,
&record.name,
&record.sex,
&record.birthDate,
&record.lastLoginTime,
)
if err != nil {
b.Fatal(err)
}
// These checks both ensure that the correct data was returned
// and provide a benchmark of accessing the returned values.
if record.id != 1 {
b.Fatalf("bad value for id: %v", record.id)
}
if record.userName != "johnsmith" {
b.Fatalf("bad value for userName: %v", record.userName)
}
if record.email.Valid {
b.Fatalf("bad value for email: %v", record.email)
}
if record.name.Valid {
b.Fatalf("bad value for name: %v", record.name)
}
if record.sex.Valid {
b.Fatalf("bad value for sex: %v", record.sex)
}
if record.birthDate.Valid {
b.Fatalf("bad value for birthDate: %v", record.birthDate)
}
if record.lastLoginTime.Valid {
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
}
}
}
func BenchmarkNullXWithPresentValues(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
_, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
var record struct {
id int32
userName string
email pgx.NullString
name pgx.NullString
sex pgx.NullString
birthDate pgx.NullTime
lastLoginTime pgx.NullTime
}
err = conn.QueryRow("selectNulls").Scan(
&record.id,
&record.userName,
&record.email,
&record.name,
&record.sex,
&record.birthDate,
&record.lastLoginTime,
)
if err != nil {
b.Fatal(err)
}
// These checks both ensure that the correct data was returned
// and provide a benchmark of accessing the returned values.
if record.id != 1 {
b.Fatalf("bad value for id: %v", record.id)
}
if record.userName != "johnsmith" {
b.Fatalf("bad value for userName: %v", record.userName)
}
if !record.email.Valid || record.email.String != "johnsmith@example.com" {
b.Fatalf("bad value for email: %v", record.email)
}
if !record.name.Valid || record.name.String != "John Smith" {
b.Fatalf("bad value for name: %v", record.name)
}
if !record.sex.Valid || record.sex.String != "male" {
b.Fatalf("bad value for sex: %v", record.sex)
}
if !record.birthDate.Valid || record.birthDate.Time != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) {
b.Fatalf("bad value for birthDate: %v", record.birthDate)
}
if !record.lastLoginTime.Valid || record.lastLoginTime.Time != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
}
}
}
func BenchmarkPointerPointerWithNullValues(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
@ -297,71 +178,51 @@ func BenchmarkSelectWithoutLogging(b *testing.B) {
benchmarkSelectWithLog(b, conn)
}
func BenchmarkSelectWithLoggingTraceWithLog15(b *testing.B) {
connConfig := *defaultConnConfig
type discardLogger struct{}
logger := log.New()
lvl, err := log.LvlFromString("debug")
if err != nil {
b.Fatal(err)
}
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
connConfig.Logger = logger
connConfig.LogLevel = pgx.LogLevelTrace
conn := mustConnect(b, connConfig)
func (dl discardLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {}
func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
var logger discardLogger
conn.SetLogger(logger)
conn.SetLogLevel(pgx.LogLevelTrace)
benchmarkSelectWithLog(b, conn)
}
func BenchmarkSelectWithLoggingDebugWithLog15(b *testing.B) {
connConfig := *defaultConnConfig
logger := log.New()
lvl, err := log.LvlFromString("debug")
if err != nil {
b.Fatal(err)
}
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
connConfig.Logger = logger
connConfig.LogLevel = pgx.LogLevelDebug
conn := mustConnect(b, connConfig)
func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
var logger discardLogger
conn.SetLogger(logger)
conn.SetLogLevel(pgx.LogLevelDebug)
benchmarkSelectWithLog(b, conn)
}
func BenchmarkSelectWithLoggingInfoWithLog15(b *testing.B) {
connConfig := *defaultConnConfig
logger := log.New()
lvl, err := log.LvlFromString("info")
if err != nil {
b.Fatal(err)
}
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
connConfig.Logger = logger
connConfig.LogLevel = pgx.LogLevelInfo
conn := mustConnect(b, connConfig)
func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
var logger discardLogger
conn.SetLogger(logger)
conn.SetLogLevel(pgx.LogLevelInfo)
benchmarkSelectWithLog(b, conn)
}
func BenchmarkSelectWithLoggingErrorWithLog15(b *testing.B) {
connConfig := *defaultConnConfig
logger := log.New()
lvl, err := log.LvlFromString("error")
if err != nil {
b.Fatal(err)
}
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
connConfig.Logger = logger
connConfig.LogLevel = pgx.LogLevelError
conn := mustConnect(b, connConfig)
func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
var logger discardLogger
conn.SetLogger(logger)
conn.SetLogLevel(pgx.LogLevelError)
benchmarkSelectWithLog(b, conn)
}
@ -422,20 +283,6 @@ func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
}
}
func BenchmarkLog15Discard(b *testing.B) {
logger := log.New()
lvl, err := log.LvlFromString("error")
if err != nil {
b.Fatal(err)
}
logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Debug("benchmark", "i", i, "b.N", b.N)
}
}
const benchmarkWriteTableCreateSQL = `drop table if exists t;
create table t(
@ -510,12 +357,12 @@ func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource {
row: []interface{}{
"varchar_1",
"varchar_2",
pgx.NullString{},
pgtype.Text{},
time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
pgx.NullTime{},
pgtype.Date{},
1,
2,
pgx.NullInt32{},
pgtype.Int4{},
time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
true,
@ -763,3 +610,92 @@ func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 10000)
}
func BenchmarkMultipleQueriesNonBatch(b *testing.B) {
config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5}
pool, err := pgx.NewConnPool(config)
if err != nil {
b.Fatalf("Unable to create connection pool: %v", err)
}
defer pool.Close()
queryCount := 3
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < queryCount; j++ {
rows, err := pool.Query("select n from generate_series(0, 5) n")
if err != nil {
b.Fatal(err)
}
for k := 0; rows.Next(); k++ {
var n int
if err := rows.Scan(&n); err != nil {
b.Fatal(err)
}
if n != k {
b.Fatalf("n => %v, want %v", n, k)
}
}
if rows.Err() != nil {
b.Fatal(rows.Err())
}
}
}
}
func BenchmarkMultipleQueriesBatch(b *testing.B) {
config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5}
pool, err := pgx.NewConnPool(config)
if err != nil {
b.Fatalf("Unable to create connection pool: %v", err)
}
defer pool.Close()
queryCount := 3
b.ResetTimer()
for i := 0; i < b.N; i++ {
batch := pool.BeginBatch()
for j := 0; j < queryCount; j++ {
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
}
err := batch.Send(context.Background(), nil)
if err != nil {
b.Fatal(err)
}
for j := 0; j < queryCount; j++ {
rows, err := batch.QueryResults()
if err != nil {
b.Fatal(err)
}
for k := 0; rows.Next(); k++ {
var n int
if err := rows.Scan(&n); err != nil {
b.Fatal(err)
}
if n != k {
b.Fatalf("n => %v, want %v", n, k)
}
}
if rows.Err() != nil {
b.Fatal(rows.Err())
}
}
err = batch.Close()
if err != nil {
b.Fatal(err)
}
}
}

View File

@ -0,0 +1,89 @@
package chunkreader
import (
"io"
)
type ChunkReader struct {
r io.Reader
buf []byte
rp, wp int // buf read position and write position
options Options
}
type Options struct {
MinBufLen int // Minimum buffer length
}
func NewChunkReader(r io.Reader) *ChunkReader {
cr, err := NewChunkReaderEx(r, Options{})
if err != nil {
panic("default options can't be bad")
}
return cr
}
func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) {
if options.MinBufLen == 0 {
options.MinBufLen = 4096
}
return &ChunkReader{
r: r,
buf: make([]byte, options.MinBufLen),
options: options,
}, nil
}
// Next returns buf filled with the next n bytes. If an error occurs, buf will
// be nil.
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
// n bytes already in buf
if (r.wp - r.rp) >= n {
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, err
}
// available space in buf is less than n
if len(r.buf) < n {
r.copyBufContents(r.newBuf(n))
}
// buf is large enough, but need to shift filled area to start to make enough contiguous space
minReadCount := n - (r.wp - r.rp)
if (len(r.buf) - r.wp) < minReadCount {
newBuf := r.newBuf(n)
r.copyBufContents(newBuf)
}
if err := r.appendAtLeast(minReadCount); err != nil {
return nil, err
}
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, nil
}
func (r *ChunkReader) appendAtLeast(fillLen int) error {
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
r.wp += n
return err
}
func (r *ChunkReader) newBuf(size int) []byte {
if size < r.options.MinBufLen {
size = r.options.MinBufLen
}
return make([]byte, size)
}
func (r *ChunkReader) copyBufContents(dest []byte) {
r.wp = copy(dest, r.buf[r.rp:r.wp])
r.rp = 0
r.buf = dest
}

View File

@ -0,0 +1,96 @@
package chunkreader
import (
"bytes"
"testing"
)
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
server := &bytes.Buffer{}
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
if err != nil {
t.Fatal(err)
}
src := []byte{1, 2, 3, 4}
server.Write(src)
n1, err := r.Next(2)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n1, src[0:2]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
}
n2, err := r.Next(2)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n2, src[2:4]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
}
if bytes.Compare(r.buf, src) != 0 {
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
}
if r.rp != 4 {
t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp)
}
if r.wp != 4 {
t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp)
}
}
func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) {
server := &bytes.Buffer{}
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
if err != nil {
t.Fatal(err)
}
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
server.Write(src)
n1, err := r.Next(5)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n1, src[0:5]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1)
}
if len(r.buf) != 5 {
t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf))
}
}
func TestChunkReaderDoesNotReuseBuf(t *testing.T) {
server := &bytes.Buffer{}
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
if err != nil {
t.Fatal(err)
}
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
server.Write(src)
n1, err := r.Next(4)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n1, src[0:4]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1)
}
n2, err := r.Next(4)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n2, src[4:8]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2)
}
if bytes.Compare(n1, src[0:4]) != 0 {
t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1)
}
}

1342
conn.go

File diff suppressed because it is too large Load Diff

View File

@ -23,3 +23,5 @@ var replicationConnConfig *pgx.ConnConfig = nil
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
// var replicationConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"}

View File

@ -1,9 +1,13 @@
package pgx
import (
"errors"
"context"
"sync"
"time"
"github.com/pkg/errors"
"github.com/jackc/pgx/pgtype"
)
type ConnPoolConfig struct {
@ -27,11 +31,7 @@ type ConnPool struct {
closed bool
preparedStatements map[string]*PreparedStatement
acquireTimeout time.Duration
pgTypes map[Oid]PgType
pgsqlAfInet *byte
pgsqlAfInet6 *byte
txAfterClose func(tx *Tx)
rowsAfterClose func(rows *Rows)
connInfo *pgtype.ConnInfo
}
type ConnPoolStat struct {
@ -48,6 +48,7 @@ var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool")
func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
p = new(ConnPool)
p.config = config.ConnConfig
p.connInfo = minimalConnInfo
p.maxConnections = config.MaxConnections
if p.maxConnections == 0 {
p.maxConnections = 5
@ -73,14 +74,6 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
p.logLevel = LogLevelNone
}
p.txAfterClose = func(tx *Tx) {
p.Release(tx.Conn())
}
p.rowsAfterClose = func(rows *Rows) {
p.Release(rows.Conn())
}
p.allConnections = make([]*Conn, 0, p.maxConnections)
p.availableConnections = make([]*Conn, 0, p.maxConnections)
p.preparedStatements = make(map[string]*PreparedStatement)
@ -94,6 +87,7 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
}
p.allConnections = append(p.allConnections, c)
p.availableConnections = append(p.availableConnections, c)
p.connInfo = c.ConnInfo.DeepCopy()
return
}
@ -161,7 +155,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
}
// All connections are in use and we cannot create more
if p.logLevel >= LogLevelWarn {
p.logger.Warn("All connections in pool are busy - waiting...")
p.logger.Log(LogLevelWarn, "waiting for available connection", nil)
}
// Wait until there is an available connection OR room to create a new connection
@ -181,7 +175,11 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
// Release gives up use of a connection.
func (p *ConnPool) Release(conn *Conn) {
if conn.TxStatus != 'I' {
if conn.ctxInProgress {
panic("should never release when context is in progress")
}
if conn.txStatus != 'I' {
conn.Exec("rollback")
}
@ -223,25 +221,21 @@ func (p *ConnPool) removeFromAllConnections(conn *Conn) bool {
return false
}
// Close ends the use of a connection pool. It prevents any new connections
// from being acquired, waits until all acquired connections are released,
// then closes all underlying connections.
// Close ends the use of a connection pool. It prevents any new connections from
// being acquired and closes available underlying connections. Any acquired
// connections will be closed when they are released.
func (p *ConnPool) Close() {
p.cond.L.Lock()
defer p.cond.L.Unlock()
p.closed = true
// Wait until all connections are released
if len(p.availableConnections) != len(p.allConnections) {
for len(p.availableConnections) != len(p.allConnections) {
p.cond.Wait()
}
}
for _, c := range p.allConnections {
for _, c := range p.availableConnections {
_ = c.Close()
}
// This will cause any checked out connections to be closed on release
p.resetCount++
}
// Reset closes all open connections, but leaves the pool open. It is intended
@ -289,7 +283,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) {
}
func (p *ConnPool) createConnection() (*Conn, error) {
c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6)
c, err := connect(p.config, p.connInfo)
if err != nil {
return nil, err
}
@ -324,10 +318,6 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
// afterConnectionCreated executes (if it is) afterConnect() callback and prepares
// all the known statements for the new connection.
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
p.pgTypes = c.PgTypes
p.pgsqlAfInet = c.pgsqlAfInet
p.pgsqlAfInet6 = c.pgsqlAfInet6
if p.afterConnect != nil {
err := p.afterConnect(c)
if err != nil {
@ -357,6 +347,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
return c.Exec(sql, arguments...)
}
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
var c *Conn
if c, err = p.Acquire(); err != nil {
return
}
defer p.Release(c)
return c.ExecEx(ctx, sql, options, arguments...)
}
// Query acquires a connection and delegates the call to that connection. When
// *Rows are closed, the connection is released automatically.
func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
@ -372,7 +372,25 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
return rows, err
}
rows.AfterClose(p.rowsAfterClose)
rows.connPool = p
return rows, nil
}
func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
c, err := p.Acquire()
if err != nil {
// Because checking for errors can be deferred to the *Rows, build one with the error
return &Rows{closed: true, err: err}, err
}
rows, err := c.QueryEx(ctx, sql, options, args...)
if err != nil {
p.Release(c)
return rows, err
}
rows.connPool = p
return rows, nil
}
@ -385,10 +403,15 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
return (*Row)(rows)
}
func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
rows, _ := p.QueryEx(ctx, sql, options, args...)
return (*Row)(rows)
}
// Begin acquires a connection and begins a transaction on it. When the
// transaction is closed the connection will be automatically released.
func (p *ConnPool) Begin() (*Tx, error) {
return p.BeginIso("")
return p.BeginEx(context.Background(), nil)
}
// Prepare creates a prepared statement on a connection in the pool to test the
@ -403,7 +426,7 @@ func (p *ConnPool) Begin() (*Tx, error) {
// the same name and sql arguments. This allows a code path to Prepare and
// Query/Exec/PrepareEx without concern for if the statement has already been prepared.
func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) {
return p.PrepareEx(name, sql, nil)
return p.PrepareEx(context.Background(), name, sql, nil)
}
// PrepareEx creates a prepared statement on a connection in the pool to test the
@ -417,7 +440,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) {
// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without
// concern for if the statement has already been prepared.
func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
func (p *ConnPool) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
p.cond.L.Lock()
defer p.cond.L.Unlock()
@ -439,13 +462,13 @@ func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*Prepare
return ps, nil
}
ps, err := c.PrepareEx(name, sql, opts)
ps, err := c.PrepareEx(ctx, name, sql, opts)
if err != nil {
return nil, err
}
for _, c := range p.availableConnections {
_, err := c.PrepareEx(name, sql, opts)
_, err := c.PrepareEx(ctx, name, sql, opts)
if err != nil {
return nil, err
}
@ -474,17 +497,17 @@ func (p *ConnPool) Deallocate(name string) (err error) {
return nil
}
// BeginIso acquires a connection and begins a transaction in isolation mode iso
// on it. When the transaction is closed the connection will be automatically
// released.
func (p *ConnPool) BeginIso(iso string) (*Tx, error) {
// BeginEx acquires a connection and starts a transaction with txOptions
// determining the transaction mode. When the transaction is closed the
// connection will be automatically released.
func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
for {
c, err := p.Acquire()
if err != nil {
return nil, err
}
tx, err := c.BeginIso(iso)
tx, err := c.BeginEx(ctx, txOptions)
if err != nil {
alive := c.IsAlive()
p.Release(c)
@ -499,25 +522,12 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) {
continue
}
tx.AfterClose(p.txAfterClose)
tx.connPool = p
return tx, nil
}
}
// Deprecated. Use CopyFrom instead. CopyTo acquires a connection, delegates the
// call to that connection, and releases the connection.
func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
c, err := p.Acquire()
if err != nil {
return 0, err
}
defer p.Release(c)
return c.CopyTo(tableName, columnNames, rowSrc)
}
// CopyFrom acquires a connection, delegates the call to that connection, and
// releases the connection.
// CopyFrom acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
c, err := p.Acquire()
if err != nil {
@ -527,3 +537,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
return c.CopyFrom(tableName, columnNames, rowSrc)
}
// BeginBatch acquires a connection and begins a batch on that connection. When
// *Batch is finished, the connection is released automatically.
func (p *ConnPool) BeginBatch() *Batch {
c, err := p.Acquire()
return &Batch{conn: c, connPool: p, err: err}
}

View File

@ -1,13 +1,15 @@
package pgx_test
import (
"errors"
"context"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/pkg/errors"
"github.com/jackc/pgx"
)
@ -297,7 +299,7 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) {
// ... then try to consume 1 more. It should hang forever.
// To unblock it we release the previously taken connection in a goroutine.
stopDeadWaitTimeout := 5 * time.Second
timer := time.AfterFunc(stopDeadWaitTimeout, func() {
timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() {
releaseAllConnections(pool, allConnections)
})
defer timer.Stop()
@ -328,14 +330,14 @@ func TestPoolReleaseWithTransactions(t *testing.T) {
t.Fatal("Did not receive expected error")
}
if conn.TxStatus != 'E' {
t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus)
if conn.TxStatus() != 'E' {
t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus())
}
pool.Release(conn)
if conn.TxStatus != 'I' {
t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus)
if conn.TxStatus() != 'I' {
t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus())
}
conn, err = pool.Acquire()
@ -343,14 +345,14 @@ func TestPoolReleaseWithTransactions(t *testing.T) {
t.Fatalf("Unable to acquire connection: %v", err)
}
mustExec(t, conn, "begin")
if conn.TxStatus != 'T' {
t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus)
if conn.TxStatus() != 'T' {
t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus())
}
pool.Release(conn)
if conn.TxStatus != 'I' {
t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus)
if conn.TxStatus() != 'I' {
t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus())
}
}
@ -428,7 +430,7 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) {
}
}()
if _, err = c2.Exec("select pg_terminate_backend($1)", c1.Pid); err != nil {
if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID()); err != nil {
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
}
@ -635,9 +637,9 @@ func TestConnPoolTransactionIso(t *testing.T) {
pool := createConnPool(t, 2)
defer pool.Close()
tx, err := pool.BeginIso(pgx.Serializable)
tx, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
if err != nil {
t.Fatalf("pool.Begin failed: %v", err)
t.Fatalf("pool.BeginEx failed: %v", err)
}
defer tx.Rollback()
@ -674,7 +676,7 @@ func TestConnPoolBeginRetry(t *testing.T) {
pool.Release(victimConn)
// Terminate connection that was released to pool
if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.Pid); err != nil {
if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID()); err != nil {
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
}
@ -686,13 +688,13 @@ func TestConnPoolBeginRetry(t *testing.T) {
}
defer tx.Rollback()
var txPid int32
err = tx.QueryRow("select pg_backend_pid()").Scan(&txPid)
var txPID uint32
err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID)
if err != nil {
t.Fatalf("tx.QueryRow Scan failed: %v", err)
}
if txPid == victimConn.Pid {
t.Error("Expected txPid to defer from killed conn pid, but it didn't")
if txPID == victimConn.PID() {
t.Error("Expected txPID to defer from killed conn pid, but it didn't")
}
}()
}
@ -980,3 +982,70 @@ func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) {
t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err)
}
}
func TestConnPoolBeginBatch(t *testing.T) {
t.Parallel()
pool := createConnPool(t, 2)
defer pool.Close()
batch := pool.BeginBatch()
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
if rows.Err() != nil {
t.Error(rows.Err())
}
rows, err = batch.QueryResults()
if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
if rows.Err() != nil {
t.Error(rows.Err())
}
err = batch.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@ -1,6 +1,7 @@
package pgx_test
import (
"context"
"crypto/tls"
"fmt"
"net"
@ -13,6 +14,7 @@ import (
"time"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
)
func TestConnect(t *testing.T) {
@ -27,14 +29,10 @@ func TestConnect(t *testing.T) {
t.Error("Runtime parameters not stored")
}
if conn.Pid == 0 {
if conn.PID() == 0 {
t.Error("Backend PID not stored")
}
if conn.SecretKey == 0 {
t.Error("Backend secret key not stored")
}
var currentDB string
err = conn.QueryRow("select current_database()").Scan(&currentDB)
if err != nil {
@ -752,7 +750,7 @@ func TestParseConnectionString(t *testing.T) {
}
func TestParseEnvLibpq(t *testing.T) {
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME"}
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"}
savedEnv := make(map[string]string)
for _, n := range pgEnvvars {
@ -1035,6 +1033,169 @@ func TestExecFailure(t *testing.T) {
}
}
func TestExecExContextWithoutCancelation(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil)
if err != nil {
t.Fatal(err)
}
if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
}
func TestExecExContextFailureWithoutCancelation(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil {
t.Fatal("Expected SQL syntax error")
}
rows, _ := conn.Query("select 1")
rows.Close()
if rows.Err() != nil {
t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err())
}
}
func TestExecExContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(500 * time.Millisecond)
cancelFunc()
}()
_, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil)
if err != context.Canceled {
t.Fatalf("Expected context.Canceled err, got %v", err)
}
ensureConnValid(t, conn)
}
func TestExecExExtendedProtocol(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
if err != nil {
t.Fatal(err)
}
if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
commandTag, err = conn.ExecEx(
ctx,
"insert into foo(name) values($1);",
nil,
"bar",
)
if err != nil {
t.Fatal(err)
}
if commandTag != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
ensureConnValid(t, conn)
}
func TestExecExSimpleProtocol(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
if err != nil {
t.Fatal(err)
}
if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
commandTag, err = conn.ExecEx(
ctx,
"insert into foo(name) values($1);",
&pgx.QueryExOptions{SimpleProtocol: true},
"bar'; drop table foo;--",
)
if err != nil {
t.Fatal(err)
}
if commandTag != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
}
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
commandTag, err := conn.ExecEx(
context.Background(),
"insert into foo(name) values($1);",
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.VarcharOID}},
"bar'; drop table foo;--",
)
if err != nil {
t.Fatal(err)
}
if commandTag != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
}
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
_, err := conn.ExecEx(
context.Background(),
"insert into foo(name) values($1);",
&pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}},
"bar'; drop table foo;--",
)
if err == nil {
t.Fatal("expected error but got none")
}
}
func TestPrepare(t *testing.T) {
t.Parallel()
@ -1206,7 +1367,7 @@ func TestPrepareEx(t *testing.T) {
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
_, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}})
_, err := conn.PrepareEx(context.Background(), "test", "select $1", &pgx.PrepareExOptions{ParameterOIDs: []pgtype.OID{pgtype.TextOID}})
if err != nil {
t.Errorf("Unable to prepare statement: %v", err)
return
@ -1244,7 +1405,7 @@ func TestListenNotify(t *testing.T) {
mustExec(t, notifier, "notify chat")
// when notification is waiting on the socket to be read
notification, err := listener.WaitForNotification(time.Second)
notification, err := listener.WaitForNotification(context.Background())
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1259,7 +1420,10 @@ func TestListenNotify(t *testing.T) {
if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", rows.Err())
}
notification, err = listener.WaitForNotification(0)
ctx, cancelFn := context.WithCancel(context.Background())
cancelFn()
notification, err = listener.WaitForNotification(ctx)
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1268,8 +1432,9 @@ func TestListenNotify(t *testing.T) {
}
// when timeout occurs
notification, err = listener.WaitForNotification(time.Millisecond)
if err != pgx.ErrNotificationTimeout {
ctx, _ = context.WithTimeout(context.Background(), time.Millisecond)
notification, err = listener.WaitForNotification(ctx)
if err != context.DeadlineExceeded {
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
}
if notification != nil {
@ -1278,7 +1443,7 @@ func TestListenNotify(t *testing.T) {
// listener can listen again after a timeout
mustExec(t, notifier, "notify chat")
notification, err = listener.WaitForNotification(time.Second)
notification, err = listener.WaitForNotification(context.Background())
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1303,7 +1468,7 @@ func TestUnlistenSpecificChannel(t *testing.T) {
mustExec(t, notifier, "notify unlisten_test")
// when notification is waiting on the socket to be read
notification, err := listener.WaitForNotification(time.Second)
notification, err := listener.WaitForNotification(context.Background())
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1323,8 +1488,10 @@ func TestUnlistenSpecificChannel(t *testing.T) {
if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", rows.Err())
}
notification, err = listener.WaitForNotification(100 * time.Millisecond)
if err != pgx.ErrNotificationTimeout {
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
notification, err = listener.WaitForNotification(ctx)
if err != context.DeadlineExceeded {
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
}
}
@ -1376,13 +1543,9 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
}
}()
notifierDone := make(chan bool)
go func() {
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
defer func() {
notifierDone <- true
}()
for i := 0; i < 100000; i++ {
mustExec(t, conn, "notify busysafe, 'hello'")
@ -1406,7 +1569,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
// Notify self and WaitForNotification immediately
mustExec(t, conn, "notify self")
notification, err := conn.WaitForNotification(time.Second)
ctx, _ := context.WithTimeout(context.Background(), time.Second)
notification, err := conn.WaitForNotification(ctx)
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1423,7 +1587,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
t.Fatalf("Unexpected error on Query: %v", rows.Err())
}
notification, err = conn.WaitForNotification(time.Second)
ctx, _ = context.WithTimeout(context.Background(), time.Second)
notification, err = conn.WaitForNotification(ctx)
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1474,7 +1639,7 @@ func TestFatalRxError(t *testing.T) {
}
defer otherConn.Close()
if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.Pid); err != nil {
if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID()); err != nil {
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
}
@ -1500,7 +1665,7 @@ func TestFatalTxError(t *testing.T) {
}
defer otherConn.Close()
_, err = otherConn.Exec("select pg_terminate_backend($1)", conn.Pid)
_, err = otherConn.Exec("select pg_terminate_backend($1)", conn.PID())
if err != nil {
t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
}
@ -1611,26 +1776,17 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
}
type testLog struct {
lvl int
lvl pgx.LogLevel
msg string
ctx []interface{}
data map[string]interface{}
}
type testLogger struct {
logs []testLog
}
func (l *testLogger) Debug(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx})
}
func (l *testLogger) Info(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx})
}
func (l *testLogger) Warn(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx})
}
func (l *testLogger) Error(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx})
func (l *testLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
}
func TestSetLogger(t *testing.T) {
@ -1742,3 +1898,30 @@ func TestIdentifierSanitize(t *testing.T) {
}
}
}
func TestConnOnNotice(t *testing.T) {
t.Parallel()
var msg string
connConfig := *defaultConnConfig
connConfig.OnNotice = func(c *pgx.Conn, notice *pgx.Notice) {
msg = notice.Message
}
conn := mustConnect(t, connConfig)
defer closeConn(t, conn)
_, err := conn.Exec(`do $$
begin
raise notice 'hello, world';
end$$;`)
if err != nil {
t.Fatal(err)
}
if msg != "hello, world" {
t.Errorf("msg => %v, want %v", msg, "hello, world")
}
ensureConnValid(t, conn)
}

View File

@ -3,6 +3,10 @@ package pgx
import (
"bytes"
"fmt"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/pkg/errors"
)
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
@ -54,25 +58,25 @@ type copyFrom struct {
func (ct *copyFrom) readUntilReadyForQuery() {
for {
t, r, err := ct.conn.rxMsg()
msg, err := ct.conn.rxMsg()
if err != nil {
ct.readerErrChan <- err
close(ct.readerErrChan)
return
}
switch t {
case readyForQuery:
ct.conn.rxReadyForQuery(r)
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
ct.conn.rxReadyForQuery(msg)
close(ct.readerErrChan)
return
case commandComplete:
case errorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
case *pgproto3.CommandComplete:
case *pgproto3.ErrorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
default:
err = ct.conn.processContextFreeMsg(t, r)
err = ct.conn.processContextFreeMsg(msg)
if err != nil {
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
}
}
}
@ -87,14 +91,14 @@ func (ct *copyFrom) waitForReaderDone() error {
func (ct *copyFrom) run() (int, error) {
quotedTableName := ct.tableName.Sanitize()
buf := &bytes.Buffer{}
cbuf := &bytes.Buffer{}
for i, cn := range ct.columnNames {
if i != 0 {
buf.WriteString(", ")
cbuf.WriteString(", ")
}
buf.WriteString(quoteIdentifier(cn))
cbuf.WriteString(quoteIdentifier(cn))
}
quotedColumnNames := buf.String()
quotedColumnNames := cbuf.String()
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
if err != nil {
@ -114,11 +118,14 @@ func (ct *copyFrom) run() (int, error) {
go ct.readUntilReadyForQuery()
defer ct.waitForReaderDone()
wbuf := newWriteBuf(ct.conn, copyData)
buf := ct.conn.wbuf
buf = append(buf, copyData)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
wbuf.WriteInt32(0)
wbuf.WriteInt32(0)
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
var sentCount int
@ -129,18 +136,16 @@ func (ct *copyFrom) run() (int, error) {
default:
}
if len(wbuf.buf) > 65536 {
wbuf.closeMsg()
_, err = ct.conn.conn.Write(wbuf.buf)
if len(buf) > 65536 {
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = ct.conn.conn.Write(buf)
if err != nil {
ct.conn.die(err)
return 0, err
}
// Directly manipulate wbuf to reset to reuse the same buffer
wbuf.buf = wbuf.buf[0:5]
wbuf.buf[0] = copyData
wbuf.sizeIdx = 1
buf = buf[0:5]
}
sentCount++
@ -152,12 +157,12 @@ func (ct *copyFrom) run() (int, error) {
}
if len(values) != len(ct.columnNames) {
ct.cancelCopyIn()
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}
wbuf.WriteInt16(int16(len(ct.columnNames)))
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
ct.cancelCopyIn()
return 0, err
@ -171,11 +176,13 @@ func (ct *copyFrom) run() (int, error) {
return 0, ct.rowSrc.Err()
}
wbuf.WriteInt16(-1) // terminate the copy stream
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
wbuf.startMsg(copyDone)
wbuf.closeMsg()
_, err = ct.conn.conn.Write(wbuf.buf)
buf = append(buf, copyDone)
buf = pgio.AppendInt32(buf, 4)
_, err = ct.conn.conn.Write(buf)
if err != nil {
ct.conn.die(err)
return 0, err
@ -190,18 +197,16 @@ func (ct *copyFrom) run() (int, error) {
func (c *Conn) readUntilCopyInResponse() error {
for {
var t byte
var r *msgReader
t, r, err := c.rxMsg()
msg, err := c.rxMsg()
if err != nil {
return err
}
switch t {
case copyInResponse:
switch msg := msg.(type) {
case *pgproto3.CopyInResponse:
return nil
default:
err = c.processContextFreeMsg(t, r)
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
@ -210,10 +215,15 @@ func (c *Conn) readUntilCopyInResponse() error {
}
func (ct *copyFrom) cancelCopyIn() error {
wbuf := newWriteBuf(ct.conn, copyFail)
wbuf.WriteCString("client error: abort")
wbuf.closeMsg()
_, err := ct.conn.conn.Write(wbuf.buf)
buf := ct.conn.wbuf
buf = append(buf, copyFail)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, "client error: abort"...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err := ct.conn.conn.Write(buf)
if err != nil {
ct.conn.die(err)
return err

View File

@ -1,12 +1,12 @@
package pgx_test
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/jackc/pgx"
"github.com/pkg/errors"
)
func TestConnCopyFromSmall(t *testing.T) {
@ -26,7 +26,7 @@ func TestConnCopyFromSmall(t *testing.T) {
)`)
inputRows := [][]interface{}{
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
{nil, nil, nil, nil, nil, nil, nil},
}
@ -83,7 +83,7 @@ func TestConnCopyFromLarge(t *testing.T) {
inputRows := [][]interface{}{}
for i := 0; i < 10000; i++ {
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
}
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
@ -125,8 +125,8 @@ func TestConnCopyFromJSON(t *testing.T) {
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
if _, ok := conn.PgTypes[oid]; !ok {
for _, typeName := range []string{"json", "jsonb"} {
if _, ok := conn.ConnInfo.DataTypeForName(typeName); !ok {
return // No JSON/JSONB type -- must be running against old PostgreSQL
}
}
@ -174,6 +174,28 @@ func TestConnCopyFromJSON(t *testing.T) {
ensureConnValid(t, conn)
}
type clientFailSource struct {
count int
err error
}
func (cfs *clientFailSource) Next() bool {
cfs.count++
return cfs.count < 100
}
func (cfs *clientFailSource) Values() ([]interface{}, error) {
if cfs.count == 3 {
cfs.err = errors.Errorf("client error")
return nil, cfs.err
}
return []interface{}{make([]byte, 100000)}, nil
}
func (cfs *clientFailSource) Err() error {
return cfs.err
}
func TestConnCopyFromFailServerSideMidway(t *testing.T) {
t.Parallel()
@ -302,28 +324,6 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
ensureConnValid(t, conn)
}
type clientFailSource struct {
count int
err error
}
func (cfs *clientFailSource) Next() bool {
cfs.count++
return cfs.count < 100
}
func (cfs *clientFailSource) Values() ([]interface{}, error) {
if cfs.count == 3 {
cfs.err = fmt.Errorf("client error")
return nil, cfs.err
}
return []interface{}{make([]byte, 100000)}, nil
}
func (cfs *clientFailSource) Err() error {
return cfs.err
}
func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
t.Parallel()
@ -381,7 +381,7 @@ func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
}
func (cfs *clientFinalErrSource) Err() error {
return fmt.Errorf("final error")
return errors.Errorf("final error")
}
func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {

View File

@ -1,222 +0,0 @@
package pgx
import (
"bytes"
"fmt"
)
// Deprecated. Use CopyFromRows instead. CopyToRows returns a CopyToSource
// interface over the provided rows slice making it usable by *Conn.CopyTo.
func CopyToRows(rows [][]interface{}) CopyToSource {
return &copyToRows{rows: rows, idx: -1}
}
type copyToRows struct {
rows [][]interface{}
idx int
}
func (ctr *copyToRows) Next() bool {
ctr.idx++
return ctr.idx < len(ctr.rows)
}
func (ctr *copyToRows) Values() ([]interface{}, error) {
return ctr.rows[ctr.idx], nil
}
func (ctr *copyToRows) Err() error {
return nil
}
// Deprecated. Use CopyFromSource instead. CopyToSource is the interface used by
// *Conn.CopyTo as the source for copy data.
type CopyToSource interface {
// Next returns true if there is another row and makes the next row data
// available to Values(). When there are no more rows available or an error
// has occurred it returns false.
Next() bool
// Values returns the values for the current row.
Values() ([]interface{}, error)
// Err returns any error that has been encountered by the CopyToSource. If
// this is not nil *Conn.CopyTo will abort the copy.
Err() error
}
type copyTo struct {
conn *Conn
tableName string
columnNames []string
rowSrc CopyToSource
readerErrChan chan error
}
func (ct *copyTo) readUntilReadyForQuery() {
for {
t, r, err := ct.conn.rxMsg()
if err != nil {
ct.readerErrChan <- err
close(ct.readerErrChan)
return
}
switch t {
case readyForQuery:
ct.conn.rxReadyForQuery(r)
close(ct.readerErrChan)
return
case commandComplete:
case errorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
default:
err = ct.conn.processContextFreeMsg(t, r)
if err != nil {
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
}
}
}
}
func (ct *copyTo) waitForReaderDone() error {
var err error
for err = range ct.readerErrChan {
}
return err
}
func (ct *copyTo) run() (int, error) {
quotedTableName := quoteIdentifier(ct.tableName)
buf := &bytes.Buffer{}
for i, cn := range ct.columnNames {
if i != 0 {
buf.WriteString(", ")
}
buf.WriteString(quoteIdentifier(cn))
}
quotedColumnNames := buf.String()
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
if err != nil {
return 0, err
}
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
if err != nil {
return 0, err
}
err = ct.conn.readUntilCopyInResponse()
if err != nil {
return 0, err
}
go ct.readUntilReadyForQuery()
defer ct.waitForReaderDone()
wbuf := newWriteBuf(ct.conn, copyData)
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
wbuf.WriteInt32(0)
wbuf.WriteInt32(0)
var sentCount int
for ct.rowSrc.Next() {
select {
case err = <-ct.readerErrChan:
return 0, err
default:
}
if len(wbuf.buf) > 65536 {
wbuf.closeMsg()
_, err = ct.conn.conn.Write(wbuf.buf)
if err != nil {
ct.conn.die(err)
return 0, err
}
// Directly manipulate wbuf to reset to reuse the same buffer
wbuf.buf = wbuf.buf[0:5]
wbuf.buf[0] = copyData
wbuf.sizeIdx = 1
}
sentCount++
values, err := ct.rowSrc.Values()
if err != nil {
ct.cancelCopyIn()
return 0, err
}
if len(values) != len(ct.columnNames) {
ct.cancelCopyIn()
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}
wbuf.WriteInt16(int16(len(ct.columnNames)))
for i, val := range values {
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
ct.cancelCopyIn()
return 0, err
}
}
}
if ct.rowSrc.Err() != nil {
ct.cancelCopyIn()
return 0, ct.rowSrc.Err()
}
wbuf.WriteInt16(-1) // terminate the copy stream
wbuf.startMsg(copyDone)
wbuf.closeMsg()
_, err = ct.conn.conn.Write(wbuf.buf)
if err != nil {
ct.conn.die(err)
return 0, err
}
err = ct.waitForReaderDone()
if err != nil {
return 0, err
}
return sentCount, nil
}
func (ct *copyTo) cancelCopyIn() error {
wbuf := newWriteBuf(ct.conn, copyFail)
wbuf.WriteCString("client error: abort")
wbuf.closeMsg()
_, err := ct.conn.conn.Write(wbuf.buf)
if err != nil {
ct.conn.die(err)
return err
}
return nil
}
// Deprecated. Use CopyFrom instead. CopyTo uses the PostgreSQL copy protocol to
// perform bulk data insertion. It returns the number of rows copied and an
// error.
//
// CopyTo requires all values use the binary format. Almost all types
// implemented by pgx use the binary format by default. Types implementing
// Encoder can only be used if they encode to the binary format.
func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
ct := &copyTo{
conn: c,
tableName: tableName,
columnNames: columnNames,
rowSrc: rowSrc,
readerErrChan: make(chan error),
}
return ct.run()
}

View File

@ -1,367 +0,0 @@
package pgx_test
import (
"reflect"
"testing"
"time"
"github.com/jackc/pgx"
)
func TestConnCopyToSmall(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g timestamptz
)`)
inputRows := [][]interface{}{
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
{nil, nil, nil, nil, nil, nil, nil},
}
copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyTo: %v", err)
}
if copyCount != len(inputRows) {
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
ensureConnValid(t, conn)
}
func TestConnCopyToLarge(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g timestamptz,
h bytea
)`)
inputRows := [][]interface{}{}
for i := 0; i < 10000; i++ {
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
}
copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyTo: %v", err)
}
if copyCount != len(inputRows) {
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal")
}
ensureConnValid(t, conn)
}
func TestConnCopyToJSON(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
if _, ok := conn.PgTypes[oid]; !ok {
return // No JSON/JSONB type -- must be running against old PostgreSQL
}
}
mustExec(t, conn, `create temporary table foo(
a json,
b jsonb
)`)
inputRows := [][]interface{}{
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
{nil, nil},
}
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyTo: %v", err)
}
if copyCount != len(inputRows) {
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
ensureConnValid(t, conn)
}
func TestConnCopyToFailServerSideMidway(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int4,
b varchar not null
)`)
inputRows := [][]interface{}{
{int32(1), "abc"},
{int32(2), nil}, // this row should trigger a failure
{int32(3), "def"},
}
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
if err == nil {
t.Errorf("Expected CopyTo return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
}
if copyCount != 0 {
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
}
ensureConnValid(t, conn)
}
func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a bytea not null
)`)
startTime := time.Now()
copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{})
if err == nil {
t.Errorf("Expected CopyTo return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
}
if copyCount != 0 {
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
}
endTime := time.Now()
copyTime := endTime.Sub(startTime)
if copyTime > time.Second {
t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
}
ensureConnValid(t, conn)
}
func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a bytea not null
)`)
copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{})
if err == nil {
t.Errorf("Expected CopyTo return error, but it did not")
}
if copyCount != 0 {
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
}
ensureConnValid(t, conn)
}
func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a bytea not null
)`)
copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{})
if err == nil {
t.Errorf("Expected CopyTo return error, but it did not")
}
if copyCount != 0 {
t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if len(outputRows) != 0 {
t.Errorf("Expected 0 rows, but got %v", outputRows)
}
ensureConnValid(t, conn)
}

76
doc.go
View File

@ -62,17 +62,15 @@ Use Exec to execute a query that does not return a result set.
Connection Pool
Connection pool usage is explicit and configurable. In pgx, a connection can
be created and managed directly, or a connection pool with a configurable
maximum connections can be used. Also, the connection pool offers an after
connect hook that allows every connection to be automatically setup before
being made available in the connection pool. This is especially useful to
ensure all connections have the same prepared statements available or to
change any other connection settings.
Connection pool usage is explicit and configurable. In pgx, a connection can be
created and managed directly, or a connection pool with a configurable maximum
connections can be used. The connection pool offers an after connect hook that
allows every connection to be automatically setup before being made available in
the connection pool.
It delegates Query, QueryRow, Exec, and Begin functions to an automatically
checked out and released connection so you can avoid manually acquiring and
releasing connections when you do not need that level of control.
It delegates methods such as QueryRow to an automatically checked out and
released connection so you can avoid manually acquiring and releasing
connections when you do not need that level of control.
var name string
var weight int64
@ -117,11 +115,11 @@ particular:
Null Mapping
pgx can map nulls in two ways. The first is Null* types that have a data field
and a valid field. They work in a similar fashion to database/sql. The second
is to use a pointer to a pointer.
pgx can map nulls in two ways. The first is package pgtype provides types that
have a data field and a status field. They work in a similar fashion to
database/sql. The second is to use a pointer to a pointer.
var foo pgx.NullString
var foo pgtype.Varchar
var bar *string
err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&a, &b)
if err != nil {
@ -133,20 +131,15 @@ Array Mapping
pgx maps between int16, int32, int64, float32, float64, and string Go slices
and the equivalent PostgreSQL array type. Go slices of native types do not
support nulls, so if a PostgreSQL array that contains a null is read into a
native Go slice an error will occur.
Hstore Mapping
pgx includes an Hstore type and a NullHstore type. Hstore is simply a
map[string]string and is preferred when the hstore contains no nulls. NullHstore
follows the Null* pattern and supports null values.
native Go slice an error will occur. The pgtype package includes many more
array types for PostgreSQL types that do not directly map to native Go types.
JSON and JSONB Mapping
pgx includes built-in support to marshal and unmarshal between Go types and
the PostgreSQL JSON and JSONB.
Inet and Cidr Mapping
Inet and CIDR Mapping
pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In
addition, as a convenience pgx will encode from a net.IP; it will assume a /32
@ -155,25 +148,10 @@ netmask for IPv4 and a /128 for IPv6.
Custom Type Support
pgx includes support for the common data types like integers, floats, strings,
dates, and times that have direct mappings between Go and SQL. Support can be
added for additional types like point, hstore, numeric, etc. that do not have
direct mappings in Go by the types implementing ScannerPgx and Encoder.
Custom types can support text or binary formats. Binary format can provide a
large performance increase. The natural place for deciding the format for a
value would be in ScannerPgx as it is responsible for decoding the returned
data. However, that is impossible as the query has already been sent by the time
the ScannerPgx is invoked. The solution to this is the global
DefaultTypeFormats. If a custom type prefers binary format it should register it
there.
pgx.DefaultTypeFormats["point"] = pgx.BinaryFormatCode
Note that the type is referred to by name, not by OID. This is because custom
PostgreSQL types like hstore will have different OIDs on different servers. When
pgx establishes a connection it queries the pg_type table for all types. It then
matches the names in DefaultTypeFormats with the returned OIDs and stores it in
Conn.PgTypes.
dates, and times that have direct mappings between Go and SQL. In addition,
pgx uses the github.com/jackc/pgx/pgtype library to support more types. See
documention for that library for instructions on how to implement custom
types.
See example_custom_type_test.go for an example of a custom type for the
PostgreSQL point type.
@ -184,15 +162,12 @@ and database/sql/driver.Valuer interfaces.
Raw Bytes Mapping
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified
to PostgreSQL. In like manner, a *[]byte passed to Scan will be filled with
the raw bytes returned by PostgreSQL. This can be especially useful for reading
varchar, text, json, and jsonb values directly into a []byte and avoiding the
type conversion from string.
to PostgreSQL.
Transactions
Transactions are started by calling Begin or BeginIso. The BeginIso variant
creates a transaction with a specified isolation level.
Transactions are started by calling Begin or BeginEx. The BeginEx variant
can create a transaction with a specified isolation level.
tx, err := conn.Begin()
if err != nil {
@ -257,9 +232,8 @@ connection.
Logging
pgx defines a simple logger interface. Connections optionally accept a logger
that satisfies this interface. The log15 package
(http://gopkg.in/inconshreveable/log15.v2) satisfies this interface and it is
simple to define adapters for other loggers. Set LogLevel to control logging
verbosity.
that satisfies this interface. Set LogLevel to control logging verbosity.
Adapters for github.com/inconshreveable/log15, github.com/Sirupsen/logrus, and
the testing log are provided in the log directory.
*/
package pgx

View File

@ -1,81 +1,74 @@
package pgx_test
import (
"errors"
"fmt"
"regexp"
"strconv"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
"github.com/pkg/errors"
)
var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`)
// NullPoint represents a point that may be null.
//
// If Valid is false then the value is NULL.
type NullPoint struct {
// Point represents a point that may be null.
type Point struct {
X, Y float64 // Coordinates of point
Valid bool // Valid is true if not NULL
Status pgtype.Status
}
func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error {
if vr.Type().DataTypeName != "point" {
return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (OID %d)", vr.Type().DataTypeName, vr.Type().DataType))
}
func (dst *Point) Set(src interface{}) error {
return errors.Errorf("cannot convert %v to Point", src)
}
if vr.Len() == -1 {
p.X, p.Y, p.Valid = 0, 0, false
func (dst *Point) Get() interface{} {
switch dst.Status {
case pgtype.Present:
return dst
case pgtype.Null:
return nil
default:
return dst.Status
}
}
func (src *Point) AssignTo(dst interface{}) error {
return errors.Errorf("cannot assign %v to %T", src, dst)
}
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
if src == nil {
*dst = Point{Status: pgtype.Null}
return nil
}
switch vr.Type().FormatCode {
case pgx.TextFormatCode:
s := vr.ReadString(vr.Len())
s := string(src)
match := pointRegexp.FindStringSubmatch(s)
if match == nil {
return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
return errors.Errorf("Received invalid point: %v", s)
}
var err error
p.X, err = strconv.ParseFloat(match[1], 64)
x, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
return errors.Errorf("Received invalid point: %v", s)
}
p.Y, err = strconv.ParseFloat(match[2], 64)
y, err := strconv.ParseFloat(match[2], 64)
if err != nil {
return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
}
case pgx.BinaryFormatCode:
return errors.New("binary format not implemented")
default:
return fmt.Errorf("unknown format %v", vr.Type().FormatCode)
return errors.Errorf("Received invalid point: %v", s)
}
p.Valid = true
return vr.Err()
}
func (p NullPoint) FormatCode() int16 { return pgx.TextFormatCode }
func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error {
if !p.Valid {
w.WriteInt32(-1)
return nil
}
s := fmt.Sprintf("(%v,%v)", p.X, p.Y)
w.WriteInt32(int32(len(s)))
w.WriteBytes([]byte(s))
*dst = Point{X: x, Y: y, Status: pgtype.Present}
return nil
}
func (p NullPoint) String() string {
if p.Valid {
return fmt.Sprintf("%v, %v", p.X, p.Y)
}
func (src *Point) String() string {
if src.Status == pgtype.Null {
return "null point"
}
return fmt.Sprintf("%.1f, %.1f", src.X, src.Y)
}
func Example_CustomType() {
@ -85,22 +78,22 @@ func Example_CustomType() {
return
}
var p NullPoint
err = conn.QueryRow("select null::point").Scan(&p)
// Override registered handler for point
conn.ConnInfo.RegisterDataType(pgtype.DataType{
Value: &Point{},
Name: "point",
OID: 600,
})
p := &Point{}
err = conn.QueryRow("select null::point").Scan(p)
if err != nil {
fmt.Println(err)
return
}
fmt.Println(p)
err = conn.QueryRow("select point(1.5,2.5)").Scan(&p)
if err != nil {
fmt.Println(err)
return
}
fmt.Println(p)
err = conn.QueryRow("select $1::point", &NullPoint{X: 0.5, Y: 0.75, Valid: true}).Scan(&p)
err = conn.QueryRow("select point(1.5,2.5)").Scan(p)
if err != nil {
fmt.Println(err)
return
@ -109,5 +102,4 @@ func Example_CustomType() {
// Output:
// null point
// 1.5, 2.5
// 0.5, 0.75
}

View File

@ -2,6 +2,7 @@ package pgx_test
import (
"fmt"
"github.com/jackc/pgx"
)
@ -12,13 +13,6 @@ func Example_JSON() {
return
}
if _, ok := conn.PgTypes[pgx.JsonOid]; !ok {
// No JSON type -- must be running against very old PostgreSQL
// Pretend it works
fmt.Println("John", 42)
return
}
type person struct {
Name string `json:"name"`
Age int `json:"age"`

View File

@ -2,10 +2,11 @@ package main
import (
"bufio"
"context"
"fmt"
"github.com/jackc/pgx"
"os"
"time"
"github.com/jackc/pgx"
)
var pool *pgx.ConnPool
@ -58,16 +59,13 @@ func listen() {
conn.Listen("chat")
for {
notification, err := conn.WaitForNotification(time.Second)
if err == pgx.ErrNotificationTimeout {
continue
}
notification, err := conn.WaitForNotification(context.Background())
if err != nil {
fmt.Fprintln(os.Stderr, "Error waiting for notification:", err)
os.Exit(1)
}
fmt.Println("PID:", notification.Pid, "Channel:", notification.Channel, "Payload:", notification.Payload)
fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload)
}
}

View File

@ -1,11 +1,13 @@
package main
import (
"github.com/jackc/pgx"
log "gopkg.in/inconshreveable/log15.v2"
"io/ioutil"
"net/http"
"os"
"github.com/jackc/pgx"
"github.com/jackc/pgx/log/log15adapter"
log "gopkg.in/inconshreveable/log15.v2"
)
var pool *pgx.ConnPool
@ -89,6 +91,8 @@ func urlHandler(w http.ResponseWriter, req *http.Request) {
}
func main() {
logger := log15adapter.NewLogger(log.New("module", "pgx"))
var err error
connPoolConfig := pgx.ConnPoolConfig{
ConnConfig: pgx.ConnConfig{
@ -96,7 +100,7 @@ func main() {
User: "jack",
Password: "jack",
Database: "url_shortener",
Logger: log.New("module", "pgx"),
Logger: logger,
},
MaxConnections: 5,
AfterConnect: afterConnect,

View File

@ -2,29 +2,33 @@ package pgx
import (
"encoding/binary"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
func newFastpath(cn *Conn) *fastpath {
return &fastpath{cn: cn, fns: make(map[string]Oid)}
return &fastpath{cn: cn, fns: make(map[string]pgtype.OID)}
}
type fastpath struct {
cn *Conn
fns map[string]Oid
fns map[string]pgtype.OID
}
func (f *fastpath) functionOID(name string) Oid {
func (f *fastpath) functionOID(name string) pgtype.OID {
return f.fns[name]
}
func (f *fastpath) addFunction(name string, oid Oid) {
func (f *fastpath) addFunction(name string, oid pgtype.OID) {
f.fns[name] = oid
}
func (f *fastpath) addFunctions(rows *Rows) error {
for rows.Next() {
var name string
var oid Oid
var oid pgtype.OID
if err := rows.Scan(&name, &oid); err != nil {
return err
}
@ -47,41 +51,46 @@ func fpInt64Arg(n int64) fpArg {
return res
}
func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) {
wbuf := newWriteBuf(f.cn, 'F') // function call
wbuf.WriteInt32(int32(oid)) // function object id
wbuf.WriteInt16(1) // # of argument format codes
wbuf.WriteInt16(1) // format code: binary
wbuf.WriteInt16(int16(len(args))) // # of arguments
for _, arg := range args {
wbuf.WriteInt32(int32(len(arg))) // length of argument
wbuf.WriteBytes(arg) // argument value
func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) {
if err := f.cn.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
wbuf.WriteInt16(1) // response format code (binary)
wbuf.closeMsg()
if _, err := f.cn.conn.Write(wbuf.buf); err != nil {
buf := f.cn.wbuf
buf = append(buf, 'F') // function call
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = pgio.AppendInt32(buf, int32(oid)) // function object id
buf = pgio.AppendInt16(buf, 1) // # of argument format codes
buf = pgio.AppendInt16(buf, 1) // format code: binary
buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments
for _, arg := range args {
buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument
buf = append(buf, arg...) // argument value
}
buf = pgio.AppendInt16(buf, 1) // response format code (binary)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
if _, err := f.cn.conn.Write(buf); err != nil {
return nil, err
}
for {
var t byte
var r *msgReader
t, r, err = f.cn.rxMsg()
msg, err := f.cn.rxMsg()
if err != nil {
return nil, err
}
switch t {
case 'V': // FunctionCallResponse
data := r.readBytes(r.readInt32())
res = make([]byte, len(data))
copy(res, data)
case 'Z': // Ready for query
f.cn.rxReadyForQuery(r)
switch msg := msg.(type) {
case *pgproto3.FunctionCallResponse:
res = make([]byte, len(msg.Result))
copy(res, msg.Result)
case *pgproto3.ReadyForQuery:
f.cn.rxReadyForQuery(msg)
// done
return
return res, err
default:
if err := f.cn.processContextFreeMsg(t, r); err != nil {
if err := f.cn.processContextFreeMsg(msg); err != nil {
return nil, err
}
}

View File

@ -1,8 +1,9 @@
package pgx_test
import (
"github.com/jackc/pgx"
"testing"
"github.com/jackc/pgx"
)
func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
@ -21,7 +22,6 @@ func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.Replicatio
return conn
}
func closeConn(t testing.TB, conn *pgx.Conn) {
err := conn.Close()
if err != nil {

222
hstore.go
View File

@ -1,222 +0,0 @@
package pgx
import (
"bytes"
"errors"
"fmt"
"unicode"
"unicode/utf8"
)
const (
hsPre = iota
hsKey
hsSep
hsVal
hsNul
hsNext
)
type hstoreParser struct {
str string
pos int
}
func newHSP(in string) *hstoreParser {
return &hstoreParser{
pos: 0,
str: in,
}
}
func (p *hstoreParser) Consume() (r rune, end bool) {
if p.pos >= len(p.str) {
end = true
return
}
r, w := utf8.DecodeRuneInString(p.str[p.pos:])
p.pos += w
return
}
func (p *hstoreParser) Peek() (r rune, end bool) {
if p.pos >= len(p.str) {
end = true
return
}
r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
return
}
func parseHstoreToMap(s string) (m map[string]string, err error) {
keys, values, err := ParseHstore(s)
if err != nil {
return
}
m = make(map[string]string, len(keys))
for i, key := range keys {
if !values[i].Valid {
err = fmt.Errorf("key '%s' has NULL value", key)
m = nil
return
}
m[key] = values[i].String
}
return
}
func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) {
keys, values, err := ParseHstore(s)
if err != nil {
return
}
store = make(map[string]NullString, len(keys))
for i, key := range keys {
store[key] = values[i]
}
return
}
// ParseHstore parses the string representation of an hstore column (the same
// you would get from an ordinary SELECT) into two slices of keys and values. it
// is used internally in the default parsing of hstores, but is exported for use
// in handling custom data structures backed by an hstore column without the
// overhead of creating a map[string]string
func ParseHstore(s string) (k []string, v []NullString, err error) {
if s == "" {
return
}
buf := bytes.Buffer{}
keys := []string{}
values := []NullString{}
p := newHSP(s)
r, end := p.Consume()
state := hsPre
for !end {
switch state {
case hsPre:
if r == '"' {
state = hsKey
} else {
err = errors.New("String does not begin with \"")
}
case hsKey:
switch r {
case '"': //End of the key
if buf.Len() == 0 {
err = errors.New("Empty Key is invalid")
} else {
keys = append(keys, buf.String())
buf = bytes.Buffer{}
state = hsSep
}
case '\\': //Potential escaped character
n, end := p.Consume()
switch {
case end:
err = errors.New("Found EOS in key, expecting character or \"")
case n == '"', n == '\\':
buf.WriteRune(n)
default:
buf.WriteRune(r)
buf.WriteRune(n)
}
default: //Any other character
buf.WriteRune(r)
}
case hsSep:
if r == '=' {
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after '=', expecting '>'")
case r == '>':
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
case r == '"':
state = hsVal
case r == 'N':
state = hsNul
default:
err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r)
}
default:
err = fmt.Errorf("Invalid character after '=', expecting '>'")
}
} else {
err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r)
}
case hsVal:
switch r {
case '"': //End of the value
values = append(values, NullString{String: buf.String(), Valid: true})
buf = bytes.Buffer{}
state = hsNext
case '\\': //Potential escaped character
n, end := p.Consume()
switch {
case end:
err = errors.New("Found EOS in key, expecting character or \"")
case n == '"', n == '\\':
buf.WriteRune(n)
default:
buf.WriteRune(r)
buf.WriteRune(n)
}
default: //Any other character
buf.WriteRune(r)
}
case hsNul:
nulBuf := make([]rune, 3)
nulBuf[0] = r
for i := 1; i < 3; i++ {
r, end = p.Consume()
if end {
err = errors.New("Found EOS in NULL value")
return
}
nulBuf[i] = r
}
if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
values = append(values, NullString{String: "", Valid: false})
state = hsNext
} else {
err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
}
case hsNext:
if r == ',' {
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after ',', expcting space")
case (unicode.IsSpace(r)):
r, end = p.Consume()
state = hsKey
default:
err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r)
}
} else {
err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r)
}
}
if err != nil {
return
}
r, end = p.Consume()
}
if state != hsNext {
err = errors.New("Improperly formatted hstore")
return
}
k = keys
v = values
return
}

View File

@ -1,181 +0,0 @@
package pgx_test
import (
"github.com/jackc/pgx"
"testing"
)
func TestHstoreTranscode(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
type test struct {
hstore pgx.Hstore
description string
}
tests := []test{
{pgx.Hstore{}, "empty"},
{pgx.Hstore{"foo": "bar"}, "single key/value"},
{pgx.Hstore{"foo": "bar", "baz": "quz"}, "multiple key/values"},
{pgx.Hstore{"NULL": "bar"}, `string "NULL" key`},
{pgx.Hstore{"foo": "NULL"}, `string "NULL" value`},
}
specialStringTests := []struct {
input string
description string
}{
{`"`, `double quote (")`},
{`'`, `single quote (')`},
{`\`, `backslash (\)`},
{`\\`, `multiple backslashes (\\)`},
{`=>`, `separator (=>)`},
{` `, `space`},
{`\ / / \\ => " ' " '`, `multiple special characters`},
}
for _, sst := range specialStringTests {
tests = append(tests, test{pgx.Hstore{sst.input + "foo": "bar"}, "key with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.Hstore{"foo" + sst.input + "foo": "bar"}, "key with " + sst.description + " in middle"})
tests = append(tests, test{pgx.Hstore{"foo" + sst.input: "bar"}, "key with " + sst.description + " at end"})
tests = append(tests, test{pgx.Hstore{sst.input: "bar"}, "key is " + sst.description})
tests = append(tests, test{pgx.Hstore{"foo": sst.input + "bar"}, "value with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input + "bar"}, "value with " + sst.description + " in middle"})
tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input}, "value with " + sst.description + " at end"})
tests = append(tests, test{pgx.Hstore{"foo": sst.input}, "value is " + sst.description})
}
for _, tt := range tests {
var result pgx.Hstore
err := conn.QueryRow("select $1::hstore", tt.hstore).Scan(&result)
if err != nil {
t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err)
}
for key, inValue := range tt.hstore {
outValue, ok := result[key]
if ok {
if inValue != outValue {
t.Errorf(`%s: Key %s mismatch - expected %s, received %s`, tt.description, key, inValue, outValue)
}
} else {
t.Errorf(`%s: Missing key %s`, tt.description, key)
}
}
ensureConnValid(t, conn)
}
}
func TestNullHstoreTranscode(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
type test struct {
nullHstore pgx.NullHstore
description string
}
tests := []test{
{pgx.NullHstore{}, "null"},
{pgx.NullHstore{Valid: true}, "empty"},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}},
Valid: true},
"single key/value"},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}},
Valid: true},
"multiple key/values"},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}},
Valid: true},
`string "NULL" key`},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}},
Valid: true},
`string "NULL" value`},
{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}},
Valid: true},
`NULL value`},
}
specialStringTests := []struct {
input string
description string
}{
{`"`, `double quote (")`},
{`'`, `single quote (')`},
{`\`, `backslash (\)`},
{`\\`, `multiple backslashes (\\)`},
{`=>`, `separator (=>)`},
{` `, `space`},
{`\ / / \\ => " ' " '`, `multiple special characters`},
}
for _, sst := range specialStringTests {
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}},
Valid: true},
"key with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}},
Valid: true},
"key with " + sst.description + " in middle"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}},
Valid: true},
"key with " + sst.description + " at end"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}},
Valid: true},
"key is " + sst.description})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}},
Valid: true},
"value with " + sst.description + " at beginning"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}},
Valid: true},
"value with " + sst.description + " in middle"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}},
Valid: true},
"value with " + sst.description + " at end"})
tests = append(tests, test{pgx.NullHstore{
Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}},
Valid: true},
"value is " + sst.description})
}
for _, tt := range tests {
var result pgx.NullHstore
err := conn.QueryRow("select $1::hstore", tt.nullHstore).Scan(&result)
if err != nil {
t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err)
}
if result.Valid != tt.nullHstore.Valid {
t.Errorf(`%s: Valid mismatch - expected %v, received %v`, tt.description, tt.nullHstore.Valid, result.Valid)
}
for key, inValue := range tt.nullHstore.Hstore {
outValue, ok := result.Hstore[key]
if ok {
if inValue != outValue {
t.Errorf(`%s: Key %s mismatch - expected %v, received %v`, tt.description, key, inValue, outValue)
}
} else {
t.Errorf(`%s: Missing key %s`, tt.description, key)
}
}
ensureConnValid(t, conn)
}
}

View File

@ -0,0 +1,237 @@
package sanitize
import (
"bytes"
"encoding/hex"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/pkg/errors"
)
// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part interface{}
type Query struct {
Parts []Part
}
func (q *Query) Sanitize(args ...interface{}) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
case int:
argIdx := part - 1
if argIdx >= len(args) {
return "", errors.Errorf("insufficient arguments")
}
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
case int64:
str = strconv.FormatInt(arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
case []byte:
str = QuoteBytes(arg)
case string:
str = QuoteString(arg)
case time.Time:
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", errors.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
default:
return "", errors.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
if !used {
return "", errors.Errorf("unused argument: %d", i)
}
}
return buf.String(), nil
}
func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
}
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
}
func QuoteString(str string) string {
return "'" + strings.Replace(str, "'", "''", -1) + "'"
}
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}
type sqlLexer struct {
src string
start int
pos int
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '$':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if '0' <= nextRune && nextRune <= '9' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return placeholderState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
num := 0
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
} else {
l.parts = append(l.parts, num)
l.pos -= width
l.start = l.pos
return rawState
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
return query.Sanitize(args...)
}

View File

@ -0,0 +1,175 @@
package sanitize_test
import (
"testing"
"github.com/jackc/pgx/internal/sanitize"
)
func TestNewQuery(t *testing.T) {
successTests := []struct {
sql string
expected sanitize.Query
}{
{
sql: "select 42",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
},
{
sql: "select $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
},
{
sql: "select 'quoted $42', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
},
{
sql: `select "doubled quoted $42", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
},
{
sql: "select 'foo''bar', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
},
{
sql: `select "foo""bar", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
},
{
sql: "select '''', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
},
{
sql: `select """", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
},
{
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
},
{
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
},
{
sql: `select E'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
},
{
sql: `select e'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
},
}
for i, tt := range successTests {
query, err := sanitize.NewQuery(tt.sql)
if err != nil {
t.Errorf("%d. %v", i, err)
}
if len(query.Parts) == len(tt.expected.Parts) {
for j := range query.Parts {
if query.Parts[j] != tt.expected.Parts[j] {
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
}
}
} else {
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
}
}
}
func TestQuerySanitize(t *testing.T) {
successfulTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
args: []interface{}{},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{int64(42)},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{float64(1.23)},
expected: `select 1.23`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{true},
expected: `select true`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
expected: `select '\x00010203ff'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{nil},
expected: `select null`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foobar"},
expected: `select 'foobar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foo'bar"},
expected: `select 'foo''bar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{`foo\'bar`},
expected: `select 'foo\''bar'`,
},
}
for i, tt := range successfulTests {
actual, err := tt.query.Sanitize(tt.args...)
if err != nil {
t.Errorf("%d. %v", i, err)
continue
}
if tt.expected != actual {
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
}
}
errorTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
args: []interface{}{int64(42)},
expected: `insufficient arguments`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
args: []interface{}{int64(42)},
expected: `unused argument: 0`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{42},
expected: `invalid arg type: int`,
},
}
for i, tt := range errorTests {
_, err := tt.query.Sanitize(tt.args...)
if err == nil || err.Error() != tt.expected {
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
}
}
}

View File

@ -2,6 +2,8 @@ package pgx
import (
"io"
"github.com/jackc/pgx/pgtype"
)
// LargeObjects is a structure used to access the large objects API. It is only
@ -60,19 +62,19 @@ const (
// Create creates a new large object. If id is zero, the server assigns an
// unused OID.
func (o *LargeObjects) Create(id Oid) (Oid, error) {
newOid, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))}))
return Oid(newOid), err
func (o *LargeObjects) Create(id pgtype.OID) (pgtype.OID, error) {
newOID, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))}))
return pgtype.OID(newOID), err
}
// Open opens an existing large object with the given mode.
func (o *LargeObjects) Open(oid Oid, mode LargeObjectMode) (*LargeObject, error) {
func (o *LargeObjects) Open(oid pgtype.OID, mode LargeObjectMode) (*LargeObject, error) {
fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))}))
return &LargeObject{fd: fd, lo: o}, err
}
// Unlink removes a large object from the database.
func (o *LargeObjects) Unlink(oid Oid) error {
func (o *LargeObjects) Unlink(oid pgtype.OID) error {
_, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))})
return err
}

View File

@ -0,0 +1,47 @@
// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger
// log.
package log15adapter
import (
"github.com/jackc/pgx"
)
// Log15Logger interface defines the subset of
// github.com/inconshreveable/log15.Logger that this adapter uses.
type Log15Logger interface {
Debug(msg string, ctx ...interface{})
Info(msg string, ctx ...interface{})
Warn(msg string, ctx ...interface{})
Error(msg string, ctx ...interface{})
Crit(msg string, ctx ...interface{})
}
type Logger struct {
l Log15Logger
}
func NewLogger(l Log15Logger) *Logger {
return &Logger{l: l}
}
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
logArgs := make([]interface{}, 0, len(data))
for k, v := range data {
logArgs = append(logArgs, k, v)
}
switch level {
case pgx.LogLevelTrace:
l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...)
case pgx.LogLevelDebug:
l.l.Debug(msg, logArgs...)
case pgx.LogLevelInfo:
l.l.Info(msg, logArgs...)
case pgx.LogLevelWarn:
l.l.Warn(msg, logArgs...)
case pgx.LogLevelError:
l.l.Error(msg, logArgs...)
default:
l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...)
}
}

View File

@ -0,0 +1,40 @@
// Package logrusadapter provides a logger that writes to a github.com/Sirupsen/logrus.Logger
// log.
package logrusadapter
import (
"github.com/Sirupsen/logrus"
"github.com/jackc/pgx"
)
type Logger struct {
l *logrus.Logger
}
func NewLogger(l *logrus.Logger) *Logger {
return &Logger{l: l}
}
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
var logger logrus.FieldLogger
if data != nil {
logger = l.l.WithFields(data)
} else {
logger = l.l
}
switch level {
case pgx.LogLevelTrace:
logger.WithField("PGX_LOG_LEVEL", level).Debug(msg)
case pgx.LogLevelDebug:
logger.Debug(msg)
case pgx.LogLevelInfo:
logger.Info(msg)
case pgx.LogLevelWarn:
logger.Warn(msg)
case pgx.LogLevelError:
logger.Error(msg)
default:
logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg)
}
}

View File

@ -0,0 +1,32 @@
// Package testingadapter provides a logger that writes to a test or benchmark
// log.
package testingadapter
import (
"fmt"
"github.com/jackc/pgx"
)
// TestingLogger interface defines the subset of testing.TB methods used by this
// adapter.
type TestingLogger interface {
Log(args ...interface{})
}
type Logger struct {
l TestingLogger
}
func NewLogger(l TestingLogger) *Logger {
return &Logger{l: l}
}
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
logArgs := make([]interface{}, 0, 2+len(data))
logArgs = append(logArgs, level, msg)
for k, v := range data {
logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v))
}
l.l.Log(logArgs...)
}

View File

@ -2,13 +2,13 @@ package pgx
import (
"encoding/hex"
"errors"
"fmt"
"github.com/pkg/errors"
)
// The values for log levels are chosen such that the zero value means that no
// log level was specified and we can default to LogLevelDebug to preserve
// the behavior that existed prior to log level introduction.
// log level was specified.
const (
LogLevelTrace = 6
LogLevelDebug = 5
@ -18,16 +18,33 @@ const (
LogLevelNone = 1
)
// LogLevel represents the pgx logging level. See LogLevel* constants for
// possible values.
type LogLevel int
func (ll LogLevel) String() string {
switch ll {
case LogLevelTrace:
return "trace"
case LogLevelDebug:
return "debug"
case LogLevelInfo:
return "info"
case LogLevelWarn:
return "warn"
case LogLevelError:
return "error"
case LogLevelNone:
return "none"
default:
return fmt.Sprintf("invalid level %d", ll)
}
}
// Logger is the interface used to get logging from pgx internals.
// https://github.com/inconshreveable/log15 is the recommended logging package.
// This logging interface was extracted from there. However, it should be simple
// to adapt any logger to this interface.
type Logger interface {
// Log a message at the given level with context key/value pairs
Debug(msg string, ctx ...interface{})
Info(msg string, ctx ...interface{})
Warn(msg string, ctx ...interface{})
Error(msg string, ctx ...interface{})
// Log a message at the given level with data key/value pairs. data may be nil.
Log(level LogLevel, msg string, data map[string]interface{})
}
// LogLevelFromString converts log level string to constant
@ -39,7 +56,7 @@ type Logger interface {
// warn
// error
// none
func LogLevelFromString(s string) (int, error) {
func LogLevelFromString(s string) (LogLevel, error) {
switch s {
case "trace":
return LogLevelTrace, nil

View File

@ -1,66 +1,24 @@
package pgx
import (
"encoding/binary"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgtype"
)
const (
protocolVersionNumber = 196608 // 3.0
)
const (
backendKeyData = 'K'
authenticationX = 'R'
readyForQuery = 'Z'
rowDescription = 'T'
dataRow = 'D'
commandComplete = 'C'
errorResponse = 'E'
noticeResponse = 'N'
parseComplete = '1'
parameterDescription = 't'
bindComplete = '2'
notificationResponse = 'A'
emptyQueryResponse = 'I'
noData = 'n'
closeComplete = '3'
flush = 'H'
copyInResponse = 'G'
copyData = 'd'
copyFail = 'f'
copyDone = 'c'
)
type startupMessage struct {
options map[string]string
}
func newStartupMessage() *startupMessage {
return &startupMessage{map[string]string{}}
}
func (s *startupMessage) Bytes() (buf []byte) {
buf = make([]byte, 8, 128)
binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber))
for key, value := range s.options {
buf = append(buf, key...)
buf = append(buf, 0)
buf = append(buf, value...)
buf = append(buf, 0)
}
buf = append(buf, ("\000")...)
binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf)))
return buf
}
type FieldDescription struct {
Name string
Table Oid
AttributeNumber int16
DataType Oid
Table pgtype.OID
AttributeNumber uint16
DataType pgtype.OID
DataTypeSize int16
DataTypeName string
Modifier int32
Modifier uint32
FormatCode int16
}
@ -91,69 +49,114 @@ func (pe PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
func newWriteBuf(c *Conn, t byte) *WriteBuf {
buf := append(c.wbuf[0:0], t, 0, 0, 0, 0)
c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c}
return &c.writeBuf
// Notice represents a notice response message reported by the PostgreSQL
// server. Be aware that this is distinct from LISTEN/NOTIFY notification.
type Notice PgError
// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it.
func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.OID) []byte {
buf = append(buf, 'P')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, name...)
buf = append(buf, 0)
buf = append(buf, query...)
buf = append(buf, 0)
buf = pgio.AppendInt16(buf, int16(len(parameterOIDs)))
for _, oid := range parameterOIDs {
buf = pgio.AppendUint32(buf, uint32(oid))
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
// WriteBuf is used build messages to send to the PostgreSQL server. It is used
// by the Encoder interface when implementing custom encoders.
type WriteBuf struct {
buf []byte
sizeIdx int
conn *Conn
// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it.
func appendDescribe(buf []byte, objectType byte, name string) []byte {
buf = append(buf, 'D')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, objectType)
buf = append(buf, name...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
func (wb *WriteBuf) startMsg(t byte) {
wb.closeMsg()
wb.buf = append(wb.buf, t, 0, 0, 0, 0)
wb.sizeIdx = len(wb.buf) - 4
// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it.
func appendSync(buf []byte) []byte {
buf = append(buf, 'S')
buf = pgio.AppendInt32(buf, 4)
return buf
}
func (wb *WriteBuf) closeMsg() {
binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx))
// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it.
func appendBind(
buf []byte,
destinationPortal,
preparedStatement string,
connInfo *pgtype.ConnInfo,
parameterOIDs []pgtype.OID,
arguments []interface{},
resultFormatCodes []int16,
) ([]byte, error) {
buf = append(buf, 'B')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, destinationPortal...)
buf = append(buf, 0)
buf = append(buf, preparedStatement...)
buf = append(buf, 0)
buf = pgio.AppendInt16(buf, int16(len(parameterOIDs)))
for i, oid := range parameterOIDs {
buf = pgio.AppendInt16(buf, chooseParameterFormatCode(connInfo, oid, arguments[i]))
}
buf = pgio.AppendInt16(buf, int16(len(arguments)))
for i, oid := range parameterOIDs {
var err error
buf, err = encodePreparedStatementArgument(connInfo, buf, oid, arguments[i])
if err != nil {
return nil, err
}
}
buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes)))
for _, fc := range resultFormatCodes {
buf = pgio.AppendInt16(buf, fc)
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf, nil
}
func (wb *WriteBuf) WriteByte(b byte) {
wb.buf = append(wb.buf, b)
// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it.
func appendExecute(buf []byte, portal string, maxRows uint32) []byte {
buf = append(buf, 'E')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, portal...)
buf = append(buf, 0)
buf = pgio.AppendUint32(buf, maxRows)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
func (wb *WriteBuf) WriteCString(s string) {
wb.buf = append(wb.buf, []byte(s)...)
wb.buf = append(wb.buf, 0)
}
// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it.
func appendQuery(buf []byte, query string) []byte {
buf = append(buf, 'Q')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, query...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
func (wb *WriteBuf) WriteInt16(n int16) {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, uint16(n))
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) WriteUint16(n uint16) {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, n)
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) WriteInt32(n int32) {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(n))
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) WriteUint32(n uint32) {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, n)
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) WriteInt64(n int64) {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(n))
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) WriteBytes(b []byte) {
wb.buf = append(wb.buf, b...)
return buf
}

View File

@ -1,316 +0,0 @@
package pgx
import (
"bufio"
"encoding/binary"
"errors"
"io"
)
// msgReader is a helper that reads values from a PostgreSQL message.
type msgReader struct {
reader *bufio.Reader
msgBytesRemaining int32
err error
log func(lvl int, msg string, ctx ...interface{})
shouldLog func(lvl int) bool
}
// Err returns any error that the msgReader has experienced
func (r *msgReader) Err() error {
return r.err
}
// fatal tells rc that a Fatal error has occurred
func (r *msgReader) fatal(err error) {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
}
r.err = err
}
// rxMsg reads the type and size of the next message.
func (r *msgReader) rxMsg() (byte, error) {
if r.err != nil {
return 0, r.err
}
if r.msgBytesRemaining > 0 {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
}
_, err := r.reader.Discard(int(r.msgBytesRemaining))
if err != nil {
return 0, err
}
}
b, err := r.reader.Peek(5)
if err != nil {
r.fatal(err)
return 0, err
}
msgType := b[0]
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4
r.reader.Discard(5)
return msgType, nil
}
func (r *msgReader) readByte() byte {
if r.err != nil {
return 0
}
r.msgBytesRemaining--
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.ReadByte()
if err != nil {
r.fatal(err)
return 0
}
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
}
return b
}
func (r *msgReader) readInt16() int16 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(2)
if err != nil {
r.fatal(err)
return 0
}
n := int16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
return n
}
func (r *msgReader) readInt32() int32 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(4)
if err != nil {
r.fatal(err)
return 0
}
n := int32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
return n
}
func (r *msgReader) readUint16() uint16 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(2)
if err != nil {
r.fatal(err)
return 0
}
n := uint16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
return n
}
func (r *msgReader) readUint32() uint32 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(4)
if err != nil {
r.fatal(err)
return 0
}
n := uint32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
return n
}
func (r *msgReader) readInt64() int64 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 8
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(8)
if err != nil {
r.fatal(err)
return 0
}
n := int64(binary.BigEndian.Uint64(b))
r.reader.Discard(8)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
return n
}
func (r *msgReader) readOid() Oid {
return Oid(r.readInt32())
}
// readCString reads a null terminated string
func (r *msgReader) readCString() string {
if r.err != nil {
return ""
}
b, err := r.reader.ReadBytes(0)
if err != nil {
r.fatal(err)
return ""
}
r.msgBytesRemaining -= int32(len(b))
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return ""
}
s := string(b[0 : len(b)-1])
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
}
return s
}
// readString reads count bytes and returns as string
func (r *msgReader) readString(countI32 int32) string {
if r.err != nil {
return ""
}
r.msgBytesRemaining -= countI32
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return ""
}
count := int(countI32)
var s string
if r.reader.Buffered() >= count {
buf, _ := r.reader.Peek(count)
s = string(buf)
r.reader.Discard(count)
} else {
buf := make([]byte, count)
_, err := io.ReadFull(r.reader, buf)
if err != nil {
r.fatal(err)
return ""
}
s = string(buf)
}
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
}
return s
}
// readBytes reads count bytes and returns as []byte
func (r *msgReader) readBytes(count int32) []byte {
if r.err != nil {
return nil
}
r.msgBytesRemaining -= count
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return nil
}
b := make([]byte, int(count))
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.fatal(err)
return nil
}
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
}
return b
}

6
pgio/doc.go Normal file
View File

@ -0,0 +1,6 @@
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
/*
pgio provides functions for appending integers to a []byte while doing byte
order conversion.
*/
package pgio

40
pgio/write.go Normal file
View File

@ -0,0 +1,40 @@
package pgio
import "encoding/binary"
func AppendUint16(buf []byte, n uint16) []byte {
wp := len(buf)
buf = append(buf, 0, 0)
binary.BigEndian.PutUint16(buf[wp:], n)
return buf
}
func AppendUint32(buf []byte, n uint32) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0)
binary.BigEndian.PutUint32(buf[wp:], n)
return buf
}
func AppendUint64(buf []byte, n uint64) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
binary.BigEndian.PutUint64(buf[wp:], n)
return buf
}
func AppendInt16(buf []byte, n int16) []byte {
return AppendUint16(buf, uint16(n))
}
func AppendInt32(buf []byte, n int32) []byte {
return AppendUint32(buf, uint32(n))
}
func AppendInt64(buf []byte, n int64) []byte {
return AppendUint64(buf, uint64(n))
}
func SetInt32(buf []byte, n int32) {
binary.BigEndian.PutUint32(buf, uint32(n))
}

78
pgio/write_test.go Normal file
View File

@ -0,0 +1,78 @@
package pgio
import (
"reflect"
"testing"
)
func TestAppendUint16NilBuf(t *testing.T) {
buf := AppendUint16(nil, 1)
if !reflect.DeepEqual(buf, []byte{0, 1}) {
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
}
}
func TestAppendUint16EmptyBuf(t *testing.T) {
buf := []byte{}
buf = AppendUint16(buf, 1)
if !reflect.DeepEqual(buf, []byte{0, 1}) {
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
}
}
func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) {
buf := make([]byte, 0, 4)
AppendUint16(buf, 1)
buf = buf[0:2]
if !reflect.DeepEqual(buf, []byte{0, 1}) {
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
}
}
func TestAppendUint32NilBuf(t *testing.T) {
buf := AppendUint32(nil, 1)
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
}
}
func TestAppendUint32EmptyBuf(t *testing.T) {
buf := []byte{}
buf = AppendUint32(buf, 1)
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
}
}
func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) {
buf := make([]byte, 0, 4)
AppendUint32(buf, 1)
buf = buf[0:4]
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
}
}
func TestAppendUint64NilBuf(t *testing.T) {
buf := AppendUint64(nil, 1)
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
}
}
func TestAppendUint64EmptyBuf(t *testing.T) {
buf := []byte{}
buf = AppendUint64(buf, 1)
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
}
}
func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) {
buf := make([]byte, 0, 8)
AppendUint64(buf, 1)
buf = buf[0:8]
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
}
}

501
pgmock/pgmock.go Normal file
View File

@ -0,0 +1,501 @@
package pgmock
import (
"io"
"net"
"reflect"
"github.com/pkg/errors"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
type Server struct {
ln net.Listener
controller Controller
}
func NewServer(controller Controller) (*Server, error) {
ln, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
return nil, err
}
server := &Server{
ln: ln,
controller: controller,
}
return server, nil
}
func (s *Server) Addr() net.Addr {
return s.ln.Addr()
}
func (s *Server) ServeOne() error {
conn, err := s.ln.Accept()
if err != nil {
return err
}
defer conn.Close()
s.Close()
backend, err := pgproto3.NewBackend(conn, conn)
if err != nil {
conn.Close()
return err
}
return s.controller.Serve(backend)
}
func (s *Server) Close() error {
err := s.ln.Close()
if err != nil {
return err
}
return nil
}
type Controller interface {
Serve(backend *pgproto3.Backend) error
}
type Step interface {
Step(*pgproto3.Backend) error
}
type Script struct {
Steps []Step
}
func (s *Script) Run(backend *pgproto3.Backend) error {
for _, step := range s.Steps {
err := step.Step(backend)
if err != nil {
return err
}
}
return nil
}
func (s *Script) Serve(backend *pgproto3.Backend) error {
for _, step := range s.Steps {
err := step.Step(backend)
if err != nil {
return err
}
}
return nil
}
func (s *Script) Step(backend *pgproto3.Backend) error {
return s.Serve(backend)
}
type expectMessageStep struct {
want pgproto3.FrontendMessage
any bool
}
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
msg, err := backend.Receive()
if err != nil {
return err
}
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
return nil
}
if !reflect.DeepEqual(msg, e.want) {
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
}
return nil
}
type expectStartupMessageStep struct {
want *pgproto3.StartupMessage
any bool
}
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
msg, err := backend.ReceiveStartupMessage()
if err != nil {
return err
}
if e.any {
return nil
}
if !reflect.DeepEqual(msg, e.want) {
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
}
return nil
}
func ExpectMessage(want pgproto3.FrontendMessage) Step {
return expectMessage(want, false)
}
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
return expectMessage(want, true)
}
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
if want, ok := want.(*pgproto3.StartupMessage); ok {
return &expectStartupMessageStep{want: want, any: any}
}
return &expectMessageStep{want: want, any: any}
}
type sendMessageStep struct {
msg pgproto3.BackendMessage
}
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
return backend.Send(e.msg)
}
func SendMessage(msg pgproto3.BackendMessage) Step {
return &sendMessageStep{msg: msg}
}
type waitForCloseMessageStep struct{}
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
for {
msg, err := backend.Receive()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
if _, ok := msg.(*pgproto3.Terminate); ok {
return nil
}
}
}
func WaitForClose() Step {
return &waitForCloseMessageStep{}
}
func AcceptUnauthenticatedConnRequestSteps() []Step {
return []Step{
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}),
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
}
}
func PgxInitSteps() []Step {
steps := []Step{
ExpectMessage(&pgproto3.Parse{
Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r', 'e')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)",
}),
ExpectMessage(&pgproto3.Describe{
ObjectType: 'S',
}),
ExpectMessage(&pgproto3.Sync{}),
SendMessage(&pgproto3.ParseComplete{}),
SendMessage(&pgproto3.ParameterDescription{}),
SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
{Name: "oid",
TableOID: 1247,
TableAttributeNumber: 65534,
DataTypeOID: 26,
DataTypeSize: 4,
TypeModifier: 4294967295,
Format: 0,
},
{Name: "typname",
TableOID: 1247,
TableAttributeNumber: 1,
DataTypeOID: 19,
DataTypeSize: 64,
TypeModifier: 4294967295,
Format: 0,
},
},
}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
ExpectMessage(&pgproto3.Bind{
ResultFormatCodes: []int16{1, 1},
}),
ExpectMessage(&pgproto3.Execute{}),
ExpectMessage(&pgproto3.Sync{}),
SendMessage(&pgproto3.BindComplete{}),
}
rowVals := []struct {
oid pgtype.OID
name string
}{
{16, "bool"},
{17, "bytea"},
{18, "char"},
{19, "name"},
{20, "int8"},
{21, "int2"},
{22, "int2vector"},
{23, "int4"},
{24, "regproc"},
{25, "text"},
{26, "oid"},
{27, "tid"},
{28, "xid"},
{29, "cid"},
{30, "oidvector"},
{114, "json"},
{142, "xml"},
{143, "_xml"},
{199, "_json"},
{194, "pg_node_tree"},
{32, "pg_ddl_command"},
{210, "smgr"},
{600, "point"},
{601, "lseg"},
{602, "path"},
{603, "box"},
{604, "polygon"},
{628, "line"},
{629, "_line"},
{700, "float4"},
{701, "float8"},
{702, "abstime"},
{703, "reltime"},
{704, "tinterval"},
{705, "unknown"},
{718, "circle"},
{719, "_circle"},
{790, "money"},
{791, "_money"},
{829, "macaddr"},
{869, "inet"},
{650, "cidr"},
{1000, "_bool"},
{1001, "_bytea"},
{1002, "_char"},
{1003, "_name"},
{1005, "_int2"},
{1006, "_int2vector"},
{1007, "_int4"},
{1008, "_regproc"},
{1009, "_text"},
{1028, "_oid"},
{1010, "_tid"},
{1011, "_xid"},
{1012, "_cid"},
{1013, "_oidvector"},
{1014, "_bpchar"},
{1015, "_varchar"},
{1016, "_int8"},
{1017, "_point"},
{1018, "_lseg"},
{1019, "_path"},
{1020, "_box"},
{1021, "_float4"},
{1022, "_float8"},
{1023, "_abstime"},
{1024, "_reltime"},
{1025, "_tinterval"},
{1027, "_polygon"},
{1033, "aclitem"},
{1034, "_aclitem"},
{1040, "_macaddr"},
{1041, "_inet"},
{651, "_cidr"},
{1263, "_cstring"},
{1042, "bpchar"},
{1043, "varchar"},
{1082, "date"},
{1083, "time"},
{1114, "timestamp"},
{1115, "_timestamp"},
{1182, "_date"},
{1183, "_time"},
{1184, "timestamptz"},
{1185, "_timestamptz"},
{1186, "interval"},
{1187, "_interval"},
{1231, "_numeric"},
{1266, "timetz"},
{1270, "_timetz"},
{1560, "bit"},
{1561, "_bit"},
{1562, "varbit"},
{1563, "_varbit"},
{1700, "numeric"},
{1790, "refcursor"},
{2201, "_refcursor"},
{2202, "regprocedure"},
{2203, "regoper"},
{2204, "regoperator"},
{2205, "regclass"},
{2206, "regtype"},
{4096, "regrole"},
{4089, "regnamespace"},
{2207, "_regprocedure"},
{2208, "_regoper"},
{2209, "_regoperator"},
{2210, "_regclass"},
{2211, "_regtype"},
{4097, "_regrole"},
{4090, "_regnamespace"},
{2950, "uuid"},
{2951, "_uuid"},
{3220, "pg_lsn"},
{3221, "_pg_lsn"},
{3614, "tsvector"},
{3642, "gtsvector"},
{3615, "tsquery"},
{3734, "regconfig"},
{3769, "regdictionary"},
{3643, "_tsvector"},
{3644, "_gtsvector"},
{3645, "_tsquery"},
{3735, "_regconfig"},
{3770, "_regdictionary"},
{3802, "jsonb"},
{3807, "_jsonb"},
{2970, "txid_snapshot"},
{2949, "_txid_snapshot"},
{3904, "int4range"},
{3905, "_int4range"},
{3906, "numrange"},
{3907, "_numrange"},
{3908, "tsrange"},
{3909, "_tsrange"},
{3910, "tstzrange"},
{3911, "_tstzrange"},
{3912, "daterange"},
{3913, "_daterange"},
{3926, "int8range"},
{3927, "_int8range"},
{2249, "record"},
{2287, "_record"},
{2275, "cstring"},
{2276, "any"},
{2277, "anyarray"},
{2278, "void"},
{2279, "trigger"},
{3838, "event_trigger"},
{2280, "language_handler"},
{2281, "internal"},
{2282, "opaque"},
{2283, "anyelement"},
{2776, "anynonarray"},
{3500, "anyenum"},
{3115, "fdw_handler"},
{325, "index_am_handler"},
{3310, "tsm_handler"},
{3831, "anyrange"},
{51367, "gbtreekey4"},
{51370, "_gbtreekey4"},
{51371, "gbtreekey8"},
{51374, "_gbtreekey8"},
{51375, "gbtreekey16"},
{51378, "_gbtreekey16"},
{51379, "gbtreekey32"},
{51382, "_gbtreekey32"},
{51383, "gbtreekey_var"},
{51386, "_gbtreekey_var"},
{51921, "hstore"},
{51926, "_hstore"},
{52005, "ghstore"},
{52008, "_ghstore"},
}
for _, rv := range rowVals {
step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat}))
steps = append(steps, step)
}
steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"}))
steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
return steps
}
type dataRowValue struct {
Value interface{}
FormatCode int16
}
func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow {
dr, err := buildDataRow(values, formatCodes)
if err != nil {
panic(err)
}
return dr
}
func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) {
dr := &pgproto3.DataRow{
Values: make([][]byte, len(values)),
}
if len(formatCodes) == 1 {
for i := 1; i < len(values); i++ {
formatCodes = append(formatCodes, formatCodes[0])
}
}
for i := range values {
switch v := values[i].(type) {
case string:
values[i] = &pgtype.Text{String: v, Status: pgtype.Present}
case int16:
values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present}
case int32:
values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present}
case int64:
values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present}
}
}
for i := range values {
switch formatCodes[i] {
case pgproto3.TextFormat:
if e, ok := values[i].(pgtype.TextEncoder); ok {
buf, err := e.EncodeText(nil, nil)
if err != nil {
return nil, errors.Errorf("failed to encode values[%d]", i)
}
dr.Values[i] = buf
} else {
return nil, errors.Errorf("values[%d] does not implement TextExcoder", i)
}
case pgproto3.BinaryFormat:
if e, ok := values[i].(pgtype.BinaryEncoder); ok {
buf, err := e.EncodeBinary(nil, nil)
if err != nil {
return nil, errors.Errorf("failed to encode values[%d]", i)
}
dr.Values[i] = buf
} else {
return nil, errors.Errorf("values[%d] does not implement BinaryEncoder", i)
}
default:
return nil, errors.New("unknown FormatCode")
}
}
return dr, nil
}

View File

@ -0,0 +1,54 @@
package pgproto3
import (
"encoding/binary"
"github.com/jackc/pgx/pgio"
"github.com/pkg/errors"
)
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
)
type Authentication struct {
Type uint32
// MD5Password fields
Salt [4]byte
}
func (*Authentication) Backend() {}
func (dst *Authentication) Decode(src []byte) error {
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
switch dst.Type {
case AuthTypeOk:
case AuthTypeCleartextPassword:
case AuthTypeMD5Password:
copy(dst.Salt[:], src[4:8])
default:
return errors.Errorf("unknown authentication type: %d", dst.Type)
}
return nil
}
func (src *Authentication) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, src.Type)
switch src.Type {
case AuthTypeMD5Password:
dst = append(dst, src.Salt[:]...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}

101
pgproto3/backend.go Normal file
View File

@ -0,0 +1,101 @@
package pgproto3
import (
"encoding/binary"
"io"
"github.com/jackc/pgx/chunkreader"
"github.com/pkg/errors"
)
type Backend struct {
cr *chunkreader.ChunkReader
w io.Writer
// Frontend message flyweights
bind Bind
_close Close
describe Describe
execute Execute
flush Flush
parse Parse
passwordMessage PasswordMessage
query Query
startupMessage StartupMessage
sync Sync
terminate Terminate
}
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
cr := chunkreader.NewChunkReader(r)
return &Backend{cr: cr, w: w}, nil
}
func (b *Backend) Send(msg BackendMessage) error {
_, err := b.w.Write(msg.Encode(nil))
return err
}
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
buf, err := b.cr.Next(4)
if err != nil {
return nil, err
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
buf, err = b.cr.Next(msgSize)
if err != nil {
return nil, err
}
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
}
return &b.startupMessage, nil
}
func (b *Backend) Receive() (FrontendMessage, error) {
header, err := b.cr.Next(5)
if err != nil {
return nil, err
}
msgType := header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
var msg FrontendMessage
switch msgType {
case 'B':
msg = &b.bind
case 'C':
msg = &b._close
case 'D':
msg = &b.describe
case 'E':
msg = &b.execute
case 'H':
msg = &b.flush
case 'P':
msg = &b.parse
case 'p':
msg = &b.passwordMessage
case 'Q':
msg = &b.query
case 'S':
msg = &b.sync
case 'X':
msg = &b.terminate
default:
return nil, errors.Errorf("unknown message type: %c", msgType)
}
msgBody, err := b.cr.Next(bodyLen)
if err != nil {
return nil, err
}
err = msg.Decode(msgBody)
return msg, err
}

View File

@ -0,0 +1,46 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type BackendKeyData struct {
ProcessID uint32
SecretKey uint32
}
func (*BackendKeyData) Backend() {}
func (dst *BackendKeyData) Decode(src []byte) error {
if len(src) != 8 {
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
}
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
return nil
}
func (src *BackendKeyData) Encode(dst []byte) []byte {
dst = append(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
}
func (src *BackendKeyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "BackendKeyData",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

37
pgproto3/big_endian.go Normal file
View File

@ -0,0 +1,37 @@
package pgproto3
import (
"encoding/binary"
)
type BigEndianBuf [8]byte
func (b BigEndianBuf) Int16(n int16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, uint16(n))
return buf
}
func (b BigEndianBuf) Uint16(n uint16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, n)
return buf
}
func (b BigEndianBuf) Int32(n int32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, uint32(n))
return buf
}
func (b BigEndianBuf) Uint32(n uint32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, n)
return buf
}
func (b BigEndianBuf) Int64(n int64) []byte {
buf := b[0:8]
binary.BigEndian.PutUint64(buf, uint64(n))
return buf
}

171
pgproto3/bind.go Normal file
View File

@ -0,0 +1,171 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type Bind struct {
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters [][]byte
ResultFormatCodes []int16
}
func (*Bind) Frontend() {}
func (dst *Bind) Decode(src []byte) error {
*dst = Bind{}
idx := bytes.IndexByte(src, 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.DestinationPortal = string(src[:idx])
rp := idx + 1
idx = bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.PreparedStatement = string(src[rp : rp+idx])
rp += idx + 1
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if parameterFormatCodeCount > 0 {
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
for i := 0; i < parameterFormatCodeCount; i++ {
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
}
}
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if parameterCount > 0 {
dst.Parameters = make([][]byte, parameterCount)
for i := 0; i < parameterCount; i++ {
if len(src[rp:]) < 4 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
// null
if msgSize == -1 {
continue
}
if len(src[rp:]) < msgSize {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.Parameters[i] = src[rp : rp+msgSize]
rp += msgSize
}
}
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
for i := 0; i < resultFormatCodeCount; i++ {
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
}
return nil
}
func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
dst = pgio.AppendInt32(dst, -1)
continue
}
dst = pgio.AppendInt32(dst, int32(len(p)))
dst = append(dst, p...)
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *Bind) MarshalJSON() ([]byte, error) {
formattedParameters := make([]map[string]string, len(src.Parameters))
for i, p := range src.Parameters {
if p == nil {
continue
}
if src.ParameterFormatCodes[i] == 0 {
formattedParameters[i] = map[string]string{"text": string(p)}
} else {
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
}
}
return json.Marshal(struct {
Type string
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters []map[string]string
ResultFormatCodes []int16
}{
Type: "Bind",
DestinationPortal: src.DestinationPortal,
PreparedStatement: src.PreparedStatement,
ParameterFormatCodes: src.ParameterFormatCodes,
Parameters: formattedParameters,
ResultFormatCodes: src.ResultFormatCodes,
})
}

29
pgproto3/bind_complete.go Normal file
View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type BindComplete struct{}
func (*BindComplete) Backend() {}
func (dst *BindComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *BindComplete) Encode(dst []byte) []byte {
return append(dst, '2', 0, 0, 0, 4)
}
func (src *BindComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "BindComplete",
})
}

59
pgproto3/close.go Normal file
View File

@ -0,0 +1,59 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type Close struct {
ObjectType byte // 'S' = prepared statement, 'P' = portal
Name string
}
func (*Close) Frontend() {}
func (dst *Close) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "Close"}
}
dst.ObjectType = src[0]
rp := 1
idx := bytes.IndexByte(src[rp:], 0)
if idx != len(src[rp:])-1 {
return &invalidMessageFormatErr{messageType: "Close"}
}
dst.Name = string(src[rp : len(src)-1])
return nil
}
func (src *Close) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *Close) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ObjectType string
Name string
}{
Type: "Close",
ObjectType: string(src.ObjectType),
Name: src.Name,
})
}

View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type CloseComplete struct{}
func (*CloseComplete) Backend() {}
func (dst *CloseComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *CloseComplete) Encode(dst []byte) []byte {
return append(dst, '3', 0, 0, 0, 4)
}
func (src *CloseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "CloseComplete",
})
}

View File

@ -0,0 +1,48 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type CommandComplete struct {
CommandTag string
}
func (*CommandComplete) Backend() {}
func (dst *CommandComplete) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0)
if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CommandComplete"}
}
dst.CommandTag = string(src[:idx])
return nil
}
func (src *CommandComplete) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.CommandTag...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *CommandComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
CommandTag string
}{
Type: "CommandComplete",
CommandTag: src.CommandTag,
})
}

View File

@ -0,0 +1,65 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type CopyBothResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyBothResponse) Backend() {}
func (dst *CopyBothResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
func (src *CopyBothResponse) Encode(dst []byte) []byte {
dst = append(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *CopyBothResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyBothResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}

37
pgproto3/copy_data.go Normal file
View File

@ -0,0 +1,37 @@
package pgproto3
import (
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type CopyData struct {
Data []byte
}
func (*CopyData) Backend() {}
func (*CopyData) Frontend() {}
func (dst *CopyData) Decode(src []byte) error {
dst.Data = src
return nil
}
func (src *CopyData) Encode(dst []byte) []byte {
dst = append(dst, 'd')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
dst = append(dst, src.Data...)
return dst
}
func (src *CopyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "CopyData",
Data: hex.EncodeToString(src.Data),
})
}

View File

@ -0,0 +1,65 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type CopyInResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyInResponse) Backend() {}
func (dst *CopyInResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
func (src *CopyInResponse) Encode(dst []byte) []byte {
dst = append(dst, 'G')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *CopyInResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyInResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}

View File

@ -0,0 +1,65 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type CopyOutResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyOutResponse) Backend() {}
func (dst *CopyOutResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
func (src *CopyOutResponse) Encode(dst []byte) []byte {
dst = append(dst, 'H')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *CopyOutResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyOutResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}

112
pgproto3/data_row.go Normal file
View File

@ -0,0 +1,112 @@
package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type DataRow struct {
Values [][]byte
}
func (*DataRow) Backend() {}
func (dst *DataRow) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
rp := 0
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
// If the capacity of the values slice is too small OR substantially too
// large reallocate. This is too avoid one row with many columns from
// permanently allocating memory.
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
newCap := 32
if newCap < fieldCount {
newCap = fieldCount
}
dst.Values = make([][]byte, fieldCount, newCap)
} else {
dst.Values = dst.Values[:fieldCount]
}
for i := 0; i < fieldCount; i++ {
if len(src[rp:]) < 4 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
// null
if msgSize == -1 {
dst.Values[i] = nil
} else {
if len(src[rp:]) < msgSize {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
dst.Values[i] = src[rp : rp+msgSize]
rp += msgSize
}
}
return nil
}
func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {
dst = pgio.AppendInt32(dst, -1)
continue
}
dst = pgio.AppendInt32(dst, int32(len(v)))
dst = append(dst, v...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *DataRow) MarshalJSON() ([]byte, error) {
formattedValues := make([]map[string]string, len(src.Values))
for i, v := range src.Values {
if v == nil {
continue
}
var hasNonPrintable bool
for _, b := range v {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
} else {
formattedValues[i] = map[string]string{"text": string(v)}
}
}
return json.Marshal(struct {
Type string
Values []map[string]string
}{
Type: "DataRow",
Values: formattedValues,
})
}

59
pgproto3/describe.go Normal file
View File

@ -0,0 +1,59 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type Describe struct {
ObjectType byte // 'S' = prepared statement, 'P' = portal
Name string
}
func (*Describe) Frontend() {}
func (dst *Describe) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "Describe"}
}
dst.ObjectType = src[0]
rp := 1
idx := bytes.IndexByte(src[rp:], 0)
if idx != len(src[rp:])-1 {
return &invalidMessageFormatErr{messageType: "Describe"}
}
dst.Name = string(src[rp : len(src)-1])
return nil
}
func (src *Describe) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *Describe) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ObjectType string
Name string
}{
Type: "Describe",
ObjectType: string(src.ObjectType),
Name: src.Name,
})
}

View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type EmptyQueryResponse struct{}
func (*EmptyQueryResponse) Backend() {}
func (dst *EmptyQueryResponse) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
return append(dst, 'I', 0, 0, 0, 4)
}
func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "EmptyQueryResponse",
})
}

197
pgproto3/error_response.go Normal file
View File

@ -0,0 +1,197 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"strconv"
)
type ErrorResponse struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
func (*ErrorResponse) Backend() {}
func (dst *ErrorResponse) Decode(src []byte) error {
*dst = ErrorResponse{}
buf := bytes.NewBuffer(src)
for {
k, err := buf.ReadByte()
if err != nil {
return err
}
if k == 0 {
break
}
vb, err := buf.ReadBytes(0)
if err != nil {
return err
}
v := string(vb[:len(vb)-1])
switch k {
case 'S':
dst.Severity = v
case 'C':
dst.Code = v
case 'M':
dst.Message = v
case 'D':
dst.Detail = v
case 'H':
dst.Hint = v
case 'P':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Position = int32(n)
case 'p':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.InternalPosition = int32(n)
case 'q':
dst.InternalQuery = v
case 'W':
dst.Where = v
case 's':
dst.SchemaName = v
case 't':
dst.TableName = v
case 'c':
dst.ColumnName = v
case 'd':
dst.DataTypeName = v
case 'n':
dst.ConstraintName = v
case 'F':
dst.File = v
case 'L':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Line = int32(n)
case 'R':
dst.Routine = v
default:
if dst.UnknownFields == nil {
dst.UnknownFields = make(map[byte]string)
}
dst.UnknownFields[k] = v
}
}
return nil
}
func (src *ErrorResponse) Encode(dst []byte) []byte {
return append(dst, src.marshalBinary('E')...)
}
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
if src.Severity != "" {
buf.WriteString(src.Severity)
buf.WriteByte(0)
}
if src.Code != "" {
buf.WriteString(src.Code)
buf.WriteByte(0)
}
if src.Message != "" {
buf.WriteString(src.Message)
buf.WriteByte(0)
}
if src.Detail != "" {
buf.WriteString(src.Detail)
buf.WriteByte(0)
}
if src.Hint != "" {
buf.WriteString(src.Hint)
buf.WriteByte(0)
}
if src.Position != 0 {
buf.WriteString(strconv.Itoa(int(src.Position)))
buf.WriteByte(0)
}
if src.InternalPosition != 0 {
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
buf.WriteByte(0)
}
if src.InternalQuery != "" {
buf.WriteString(src.InternalQuery)
buf.WriteByte(0)
}
if src.Where != "" {
buf.WriteString(src.Where)
buf.WriteByte(0)
}
if src.SchemaName != "" {
buf.WriteString(src.SchemaName)
buf.WriteByte(0)
}
if src.TableName != "" {
buf.WriteString(src.TableName)
buf.WriteByte(0)
}
if src.ColumnName != "" {
buf.WriteString(src.ColumnName)
buf.WriteByte(0)
}
if src.DataTypeName != "" {
buf.WriteString(src.DataTypeName)
buf.WriteByte(0)
}
if src.ConstraintName != "" {
buf.WriteString(src.ConstraintName)
buf.WriteByte(0)
}
if src.File != "" {
buf.WriteString(src.File)
buf.WriteByte(0)
}
if src.Line != 0 {
buf.WriteString(strconv.Itoa(int(src.Line)))
buf.WriteByte(0)
}
if src.Routine != "" {
buf.WriteString(src.Routine)
buf.WriteByte(0)
}
for k, v := range src.UnknownFields {
buf.WriteByte(k)
buf.WriteByte(0)
buf.WriteString(v)
buf.WriteByte(0)
}
buf.WriteByte(0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes()
}

60
pgproto3/execute.go Normal file
View File

@ -0,0 +1,60 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type Execute struct {
Portal string
MaxRows uint32
}
func (*Execute) Frontend() {}
func (dst *Execute) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Portal = string(b[:len(b)-1])
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "Execute"}
}
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
return nil
}
func (src *Execute) Encode(dst []byte) []byte {
dst = append(dst, 'E')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Portal...)
dst = append(dst, 0)
dst = pgio.AppendUint32(dst, src.MaxRows)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *Execute) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Portal string
MaxRows uint32
}{
Type: "Execute",
Portal: src.Portal,
MaxRows: src.MaxRows,
})
}

29
pgproto3/flush.go Normal file
View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type Flush struct{}
func (*Flush) Frontend() {}
func (dst *Flush) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *Flush) Encode(dst []byte) []byte {
return append(dst, 'H', 0, 0, 0, 4)
}
func (src *Flush) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "Flush",
})
}

113
pgproto3/frontend.go Normal file
View File

@ -0,0 +1,113 @@
package pgproto3
import (
"encoding/binary"
"io"
"github.com/jackc/pgx/chunkreader"
"github.com/pkg/errors"
)
type Frontend struct {
cr *chunkreader.ChunkReader
w io.Writer
// Backend message flyweights
authentication Authentication
backendKeyData BackendKeyData
bindComplete BindComplete
closeComplete CloseComplete
commandComplete CommandComplete
copyBothResponse CopyBothResponse
copyData CopyData
copyInResponse CopyInResponse
copyOutResponse CopyOutResponse
dataRow DataRow
emptyQueryResponse EmptyQueryResponse
errorResponse ErrorResponse
functionCallResponse FunctionCallResponse
noData NoData
noticeResponse NoticeResponse
notificationResponse NotificationResponse
parameterDescription ParameterDescription
parameterStatus ParameterStatus
parseComplete ParseComplete
readyForQuery ReadyForQuery
rowDescription RowDescription
}
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
cr := chunkreader.NewChunkReader(r)
return &Frontend{cr: cr, w: w}, nil
}
func (b *Frontend) Send(msg FrontendMessage) error {
_, err := b.w.Write(msg.Encode(nil))
return err
}
func (b *Frontend) Receive() (BackendMessage, error) {
header, err := b.cr.Next(5)
if err != nil {
return nil, err
}
msgType := header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
var msg BackendMessage
switch msgType {
case '1':
msg = &b.parseComplete
case '2':
msg = &b.bindComplete
case '3':
msg = &b.closeComplete
case 'A':
msg = &b.notificationResponse
case 'C':
msg = &b.commandComplete
case 'd':
msg = &b.copyData
case 'D':
msg = &b.dataRow
case 'E':
msg = &b.errorResponse
case 'G':
msg = &b.copyInResponse
case 'H':
msg = &b.copyOutResponse
case 'I':
msg = &b.emptyQueryResponse
case 'K':
msg = &b.backendKeyData
case 'n':
msg = &b.noData
case 'N':
msg = &b.noticeResponse
case 'R':
msg = &b.authentication
case 'S':
msg = &b.parameterStatus
case 't':
msg = &b.parameterDescription
case 'T':
msg = &b.rowDescription
case 'V':
msg = &b.functionCallResponse
case 'W':
msg = &b.copyBothResponse
case 'Z':
msg = &b.readyForQuery
default:
return nil, errors.Errorf("unknown message type: %c", msgType)
}
msgBody, err := b.cr.Next(bodyLen)
if err != nil {
return nil, err
}
err = msg.Decode(msgBody)
return msg, err
}

View File

@ -0,0 +1,78 @@
package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type FunctionCallResponse struct {
Result []byte
}
func (*FunctionCallResponse) Backend() {}
func (dst *FunctionCallResponse) Decode(src []byte) error {
if len(src) < 4 {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
rp := 0
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if resultSize == -1 {
dst.Result = nil
return nil
}
if len(src[rp:]) != resultSize {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
dst.Result = src[rp:]
return nil
}
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
dst = append(dst, 'V')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
if src.Result == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
dst = append(dst, src.Result...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) {
var formattedValue map[string]string
var hasNonPrintable bool
for _, b := range src.Result {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
} else {
formattedValue = map[string]string{"text": string(src.Result)}
}
return json.Marshal(struct {
Type string
Result map[string]string
}{
Type: "FunctionCallResponse",
Result: formattedValue,
})
}

29
pgproto3/no_data.go Normal file
View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type NoData struct{}
func (*NoData) Backend() {}
func (dst *NoData) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *NoData) Encode(dst []byte) []byte {
return append(dst, 'n', 0, 0, 0, 4)
}
func (src *NoData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "NoData",
})
}

View File

@ -0,0 +1,13 @@
package pgproto3
type NoticeResponse ErrorResponse
func (*NoticeResponse) Backend() {}
func (dst *NoticeResponse) Decode(src []byte) error {
return (*ErrorResponse)(dst).Decode(src)
}
func (src *NoticeResponse) Encode(dst []byte) []byte {
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
}

View File

@ -0,0 +1,67 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type NotificationResponse struct {
PID uint32
Channel string
Payload string
}
func (*NotificationResponse) Backend() {}
func (dst *NotificationResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
pid := binary.BigEndian.Uint32(buf.Next(4))
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
channel := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
payload := string(b[:len(b)-1])
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
return nil
}
func (src *NotificationResponse) Encode(dst []byte) []byte {
dst = append(dst, 'A')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Channel...)
dst = append(dst, 0)
dst = append(dst, src.Payload...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *NotificationResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
PID uint32
Channel string
Payload string
}{
Type: "NotificationResponse",
PID: src.PID,
Channel: src.Channel,
Payload: src.Payload,
})
}

View File

@ -0,0 +1,61 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type ParameterDescription struct {
ParameterOIDs []uint32
}
func (*ParameterDescription) Backend() {}
func (dst *ParameterDescription) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
}
// Reported parameter count will be incorrect when number of args is greater than uint16
buf.Next(2)
// Instead infer parameter count by remaining size of message
parameterCount := buf.Len() / 4
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
for i := 0; i < parameterCount; i++ {
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
}
return nil
}
func (src *ParameterDescription) Encode(dst []byte) []byte {
dst = append(dst, 't')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *ParameterDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ParameterOIDs []uint32
}{
Type: "ParameterDescription",
ParameterOIDs: src.ParameterOIDs,
})
}

View File

@ -0,0 +1,61 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type ParameterStatus struct {
Name string
Value string
}
func (*ParameterStatus) Backend() {}
func (dst *ParameterStatus) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
name := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
value := string(b[:len(b)-1])
*dst = ParameterStatus{Name: name, Value: value}
return nil
}
func (src *ParameterStatus) Encode(dst []byte) []byte {
dst = append(dst, 'S')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Value...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (ps *ParameterStatus) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Value string
}{
Type: "ParameterStatus",
Name: ps.Name,
Value: ps.Value,
})
}

83
pgproto3/parse.go Normal file
View File

@ -0,0 +1,83 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type Parse struct {
Name string
Query string
ParameterOIDs []uint32
}
func (*Parse) Frontend() {}
func (dst *Parse) Decode(src []byte) error {
*dst = Parse{}
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Name = string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
dst.Query = string(b[:len(b)-1])
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "Parse"}
}
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
for i := 0; i < parameterOIDCount; i++ {
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "Parse"}
}
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
}
return nil
}
func (src *Parse) Encode(dst []byte) []byte {
dst = append(dst, 'P')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Query...)
dst = append(dst, 0)
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *Parse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Query string
ParameterOIDs []uint32
}{
Type: "Parse",
Name: src.Name,
Query: src.Query,
ParameterOIDs: src.ParameterOIDs,
})
}

View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type ParseComplete struct{}
func (*ParseComplete) Backend() {}
func (dst *ParseComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *ParseComplete) Encode(dst []byte) []byte {
return append(dst, '1', 0, 0, 0, 4)
}
func (src *ParseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "ParseComplete",
})
}

View File

@ -0,0 +1,46 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type PasswordMessage struct {
Password string
}
func (*PasswordMessage) Frontend() {}
func (dst *PasswordMessage) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Password = string(b[:len(b)-1])
return nil
}
func (src *PasswordMessage) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
dst = append(dst, src.Password...)
dst = append(dst, 0)
return dst
}
func (src *PasswordMessage) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Password string
}{
Type: "PasswordMessage",
Password: src.Password,
})
}

42
pgproto3/pgproto3.go Normal file
View File

@ -0,0 +1,42 @@
package pgproto3
import "fmt"
// Message is the interface implemented by an object that can decode and encode
// a particular PostgreSQL message.
type Message interface {
// Decode is allowed and expected to retain a reference to data after
// returning (unlike encoding.BinaryUnmarshaler).
Decode(data []byte) error
// Encode appends itself to dst and returns the new buffer.
Encode(dst []byte) []byte
}
type FrontendMessage interface {
Message
Frontend() // no-op method to distinguish frontend from backend methods
}
type BackendMessage interface {
Message
Backend() // no-op method to distinguish frontend from backend methods
}
type invalidMessageLenErr struct {
messageType string
expectedLen int
actualLen int
}
func (e *invalidMessageLenErr) Error() string {
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
}
type invalidMessageFormatErr struct {
messageType string
}
func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType)
}

45
pgproto3/query.go Normal file
View File

@ -0,0 +1,45 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type Query struct {
String string
}
func (*Query) Frontend() {}
func (dst *Query) Decode(src []byte) error {
i := bytes.IndexByte(src, 0)
if i != len(src)-1 {
return &invalidMessageFormatErr{messageType: "Query"}
}
dst.String = string(src[:i])
return nil
}
func (src *Query) Encode(dst []byte) []byte {
dst = append(dst, 'Q')
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
dst = append(dst, src.String...)
dst = append(dst, 0)
return dst
}
func (src *Query) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
String string
}{
Type: "Query",
String: src.String,
})
}

View File

@ -0,0 +1,35 @@
package pgproto3
import (
"encoding/json"
)
type ReadyForQuery struct {
TxStatus byte
}
func (*ReadyForQuery) Backend() {}
func (dst *ReadyForQuery) Decode(src []byte) error {
if len(src) != 1 {
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
}
dst.TxStatus = src[0]
return nil
}
func (src *ReadyForQuery) Encode(dst []byte) []byte {
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
}
func (src *ReadyForQuery) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
TxStatus string
}{
Type: "ReadyForQuery",
TxStatus: string(src.TxStatus),
})
}

100
pgproto3/row_description.go Normal file
View File

@ -0,0 +1,100 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
const (
TextFormat = 0
BinaryFormat = 1
)
type FieldDescription struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier uint32
Format int16
}
type RowDescription struct {
Fields []FieldDescription
}
func (*RowDescription) Backend() {}
func (dst *RowDescription) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
*dst = RowDescription{Fields: make([]FieldDescription, fieldCount)}
for i := 0; i < fieldCount; i++ {
var fd FieldDescription
bName, err := buf.ReadBytes(0)
if err != nil {
return err
}
fd.Name = string(bName[:len(bName)-1])
// Since buf.Next() doesn't return an error if we hit the end of the buffer
// check Len ahead of time
if buf.Len() < 18 {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fd.TableOID = binary.BigEndian.Uint32(buf.Next(4))
fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2))
fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4))
fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2)))
fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4))
fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2)))
dst.Fields[i] = fd
}
return nil
}
func (src *RowDescription) Encode(dst []byte) []byte {
dst = append(dst, 'T')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
for _, fd := range src.Fields {
dst = append(dst, fd.Name...)
dst = append(dst, 0)
dst = pgio.AppendUint32(dst, fd.TableOID)
dst = pgio.AppendUint16(dst, fd.TableAttributeNumber)
dst = pgio.AppendUint32(dst, fd.DataTypeOID)
dst = pgio.AppendInt16(dst, fd.DataTypeSize)
dst = pgio.AppendUint32(dst, fd.TypeModifier)
dst = pgio.AppendInt16(dst, fd.Format)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *RowDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Fields []FieldDescription
}{
Type: "RowDescription",
Fields: src.Fields,
})
}

View File

@ -0,0 +1,97 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
"github.com/pkg/errors"
)
const (
ProtocolVersionNumber = 196608 // 3.0
sslRequestNumber = 80877103
)
type StartupMessage struct {
ProtocolVersion uint32
Parameters map[string]string
}
func (*StartupMessage) Frontend() {}
func (dst *StartupMessage) Decode(src []byte) error {
if len(src) < 4 {
return errors.Errorf("startup message too short")
}
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
rp := 4
if dst.ProtocolVersion == sslRequestNumber {
return errors.Errorf("can't handle ssl connection request")
}
if dst.ProtocolVersion != ProtocolVersionNumber {
return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
}
dst.Parameters = make(map[string]string)
for {
idx := bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "StartupMesage"}
}
key := string(src[rp : rp+idx])
rp += idx + 1
idx = bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "StartupMesage"}
}
value := string(src[rp : rp+idx])
rp += idx + 1
dst.Parameters[key] = value
if len(src[rp:]) == 1 {
if src[rp] != 0 {
return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
}
break
}
}
return nil
}
func (src *StartupMessage) Encode(dst []byte) []byte {
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, src.ProtocolVersion)
for k, v := range src.Parameters {
dst = append(dst, k...)
dst = append(dst, 0)
dst = append(dst, v...)
dst = append(dst, 0)
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *StartupMessage) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProtocolVersion uint32
Parameters map[string]string
}{
Type: "StartupMessage",
ProtocolVersion: src.ProtocolVersion,
Parameters: src.Parameters,
})
}

29
pgproto3/sync.go Normal file
View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type Sync struct{}
func (*Sync) Frontend() {}
func (dst *Sync) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *Sync) Encode(dst []byte) []byte {
return append(dst, 'S', 0, 0, 0, 4)
}
func (src *Sync) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "Sync",
})
}

29
pgproto3/terminate.go Normal file
View File

@ -0,0 +1,29 @@
package pgproto3
import (
"encoding/json"
)
type Terminate struct{}
func (*Terminate) Frontend() {}
func (dst *Terminate) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)}
}
return nil
}
func (src *Terminate) Encode(dst []byte) []byte {
return append(dst, 'X', 0, 0, 0, 4)
}
func (src *Terminate) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "Terminate",
})
}

126
pgtype/aclitem.go Normal file
View File

@ -0,0 +1,126 @@
package pgtype
import (
"database/sql/driver"
"github.com/pkg/errors"
)
// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem
// might look like this:
//
// postgres=arwdDxt/postgres
//
// Note, however, that because the user/role name part of an aclitem is
// an identifier, it follows all the usual formatting rules for SQL
// identifiers: if it contains spaces and other special characters,
// it should appear in double-quotes:
//
// postgres=arwdDxt/"role with spaces"
//
type ACLItem struct {
String string
Status Status
}
func (dst *ACLItem) Set(src interface{}) error {
switch value := src.(type) {
case string:
*dst = ACLItem{String: value, Status: Present}
case *string:
if value == nil {
*dst = ACLItem{Status: Null}
} else {
*dst = ACLItem{String: *value, Status: Present}
}
default:
if originalSrc, ok := underlyingStringType(src); ok {
return dst.Set(originalSrc)
}
return errors.Errorf("cannot convert %v to ACLItem", value)
}
return nil
}
func (dst *ACLItem) Get() interface{} {
switch dst.Status {
case Present:
return dst.String
case Null:
return nil
default:
return dst.Status
}
}
func (src *ACLItem) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
switch v := dst.(type) {
case *string:
*v = src.String
return nil
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case Null:
return NullAssignTo(dst)
}
return errors.Errorf("cannot decode %v into %T", src, dst)
}
func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = ACLItem{Status: Null}
return nil
}
*dst = ACLItem{String: string(src), Status: Present}
return nil
}
func (src *ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
return append(buf, src.String...), nil
}
// Scan implements the database/sql Scanner interface.
func (dst *ACLItem) Scan(src interface{}) error {
if src == nil {
*dst = ACLItem{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return errors.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *ACLItem) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.String, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

206
pgtype/aclitem_array.go Normal file
View File

@ -0,0 +1,206 @@
package pgtype
import (
"database/sql/driver"
"github.com/pkg/errors"
)
type ACLItemArray struct {
Elements []ACLItem
Dimensions []ArrayDimension
Status Status
}
func (dst *ACLItemArray) Set(src interface{}) error {
switch value := src.(type) {
case []string:
if value == nil {
*dst = ACLItemArray{Status: Null}
} else if len(value) == 0 {
*dst = ACLItemArray{Status: Present}
} else {
elements := make([]ACLItem, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = ACLItemArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
default:
if originalSrc, ok := underlyingSliceType(src); ok {
return dst.Set(originalSrc)
}
return errors.Errorf("cannot convert %v to ACLItem", value)
}
return nil
}
func (dst *ACLItemArray) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *ACLItemArray) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
switch v := dst.(type) {
case *[]string:
*v = make([]string, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case Null:
return NullAssignTo(dst)
}
return errors.Errorf("cannot decode %v into %T", src, dst)
}
func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = ACLItemArray{Status: Null}
return nil
}
uta, err := ParseUntypedTextArray(string(src))
if err != nil {
return err
}
var elements []ACLItem
if len(uta.Elements) > 0 {
elements = make([]ACLItem, len(uta.Elements))
for i, s := range uta.Elements {
var elem ACLItem
var elemSrc []byte
if s != "NULL" {
elemSrc = []byte(s)
}
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
return nil
}
func (src *ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
if len(src.Dimensions) == 0 {
return append(buf, '{', '}'), nil
}
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
// dimElemCounts is the multiples of elements that each array lies on. For
// example, a single dimension array of length 4 would have a dimElemCounts of
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
// or '}'.
dimElemCounts := make([]int, len(src.Dimensions))
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
for i := len(src.Dimensions) - 2; i > -1; i-- {
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
}
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Elements {
if i > 0 {
buf = append(buf, ',')
}
for _, dec := range dimElemCounts {
if i%dec == 0 {
buf = append(buf, '{')
}
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
buf = append(buf, `NULL`...)
} else {
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
}
for _, dec := range dimElemCounts {
if (i+1)%dec == 0 {
buf = append(buf, '}')
}
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *ACLItemArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return errors.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *ACLItemArray) Value() (driver.Value, error) {
buf, err := src.EncodeText(nil, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
}

View File

@ -0,0 +1,152 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/pgtype/testutil"
)
func TestACLItemArrayTranscode(t *testing.T) {
testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{
&pgtype.ACLItemArray{
Elements: nil,
Dimensions: nil,
Status: pgtype.Present,
},
&pgtype.ACLItemArray{
Elements: []pgtype.ACLItem{
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
pgtype.ACLItem{Status: pgtype.Null},
},
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.ACLItemArray{Status: pgtype.Null},
&pgtype.ACLItemArray{
Elements: []pgtype.ACLItem{
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present},
pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present},
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
pgtype.ACLItem{Status: pgtype.Null},
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.ACLItemArray{
Elements: []pgtype.ACLItem{
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present},
pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present},
pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 4},
{Length: 2, LowerBound: 2},
},
Status: pgtype.Present,
},
})
}
func TestACLItemArraySet(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.ACLItemArray
}{
{
source: []string{"=r/postgres"},
result: pgtype.ACLItemArray{
Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present},
},
{
source: (([]string)(nil)),
result: pgtype.ACLItemArray{Status: pgtype.Null},
},
}
for i, tt := range successfulTests {
var r pgtype.ACLItemArray
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !reflect.DeepEqual(r, tt.result) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestACLItemArrayAssignTo(t *testing.T) {
var stringSlice []string
type _stringSlice []string
var namedStringSlice _stringSlice
simpleTests := []struct {
src pgtype.ACLItemArray
dst interface{}
expected interface{}
}{
{
src: pgtype.ACLItemArray{
Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &stringSlice,
expected: []string{"=r/postgres"},
},
{
src: pgtype.ACLItemArray{
Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &namedStringSlice,
expected: _stringSlice{"=r/postgres"},
},
{
src: pgtype.ACLItemArray{Status: pgtype.Null},
dst: &stringSlice,
expected: (([]string)(nil)),
},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
errorTests := []struct {
src pgtype.ACLItemArray
dst interface{}
}{
{
src: pgtype.ACLItemArray{
Elements: []pgtype.ACLItem{{Status: pgtype.Null}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &stringSlice,
},
}
for i, tt := range errorTests {
err := tt.src.AssignTo(tt.dst)
if err == nil {
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
}
}
}

97
pgtype/aclitem_test.go Normal file
View File

@ -0,0 +1,97 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/pgtype/testutil"
)
func TestACLItemTranscode(t *testing.T) {
testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{
&pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present},
&pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present},
&pgtype.ACLItem{Status: pgtype.Null},
})
}
func TestACLItemSet(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.ACLItem
}{
{source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}},
{source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}},
}
for i, tt := range successfulTests {
var d pgtype.ACLItem
err := d.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if d != tt.result {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d)
}
}
}
func TestACLItemAssignTo(t *testing.T) {
var s string
var ps *string
simpleTests := []struct {
src pgtype.ACLItem
dst interface{}
expected interface{}
}{
{src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"},
{src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
pointerAllocTests := []struct {
src pgtype.ACLItem
dst interface{}
expected interface{}
}{
{src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"},
}
for i, tt := range pointerAllocTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
errorTests := []struct {
src pgtype.ACLItem
dst interface{}
}{
{src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s},
}
for i, tt := range errorTests {
err := tt.src.AssignTo(tt.dst)
if err == nil {
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
}
}
}

352
pgtype/array.go Normal file
View File

@ -0,0 +1,352 @@
package pgtype
import (
"bytes"
"encoding/binary"
"io"
"strconv"
"strings"
"unicode"
"github.com/jackc/pgx/pgio"
"github.com/pkg/errors"
)
// Information on the internals of PostgreSQL arrays can be found in
// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of
// particular interest is the array_send function.
type ArrayHeader struct {
ContainsNull bool
ElementOID int32
Dimensions []ArrayDimension
}
type ArrayDimension struct {
Length int32
LowerBound int32
}
func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) {
if len(src) < 12 {
return 0, errors.Errorf("array header too short: %d", len(src))
}
rp := 0
numDims := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1
rp += 4
dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if numDims > 0 {
dst.Dimensions = make([]ArrayDimension, numDims)
}
if len(src) < 12+numDims*8 {
return 0, errors.Errorf("array header too short for %d dimensions: %d", numDims, len(src))
}
for i := range dst.Dimensions {
dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4
dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4
}
return rp, nil
}
func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte {
buf = pgio.AppendInt32(buf, int32(len(src.Dimensions)))
var containsNull int32
if src.ContainsNull {
containsNull = 1
}
buf = pgio.AppendInt32(buf, containsNull)
buf = pgio.AppendInt32(buf, src.ElementOID)
for i := range src.Dimensions {
buf = pgio.AppendInt32(buf, src.Dimensions[i].Length)
buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound)
}
return buf
}
type UntypedTextArray struct {
Elements []string
Dimensions []ArrayDimension
}
func ParseUntypedTextArray(src string) (*UntypedTextArray, error) {
dst := &UntypedTextArray{}
buf := bytes.NewBufferString(src)
skipWhitespace(buf)
r, _, err := buf.ReadRune()
if err != nil {
return nil, errors.Errorf("invalid array: %v", err)
}
var explicitDimensions []ArrayDimension
// Array has explicit dimensions
if r == '[' {
buf.UnreadRune()
for {
r, _, err = buf.ReadRune()
if err != nil {
return nil, errors.Errorf("invalid array: %v", err)
}
if r == '=' {
break
} else if r != '[' {
return nil, errors.Errorf("invalid array, expected '[' or '=' got %v", r)
}
lower, err := arrayParseInteger(buf)
if err != nil {
return nil, errors.Errorf("invalid array: %v", err)
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, errors.Errorf("invalid array: %v", err)
}
if r != ':' {
return nil, errors.Errorf("invalid array, expected ':' got %v", r)
}
upper, err := arrayParseInteger(buf)
if err != nil {
return nil, errors.Errorf("invalid array: %v", err)
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, errors.Errorf("invalid array: %v", err)
}
if r != ']' {
return nil, errors.Errorf("invalid array, expected ']' got %v", r)
}
explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1})
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, errors.Errorf("invalid array: %v", err)
}
}
if r != '{' {
return nil, errors.Errorf("invalid array, expected '{': %v", err)
}
implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}}
// Consume all initial opening brackets. This provides number of dimensions.
for {
r, _, err = buf.ReadRune()
if err != nil {
return nil, errors.Errorf("invalid array: %v", err)
}
if r == '{' {
implicitDimensions[len(implicitDimensions)-1].Length = 1
implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1})
} else {
buf.UnreadRune()
break
}
}
currentDim := len(implicitDimensions) - 1
counterDim := currentDim
for {
r, _, err = buf.ReadRune()
if err != nil {
return nil, errors.Errorf("invalid array: %v", err)
}
switch r {
case '{':
if currentDim == counterDim {
implicitDimensions[currentDim].Length++
}
currentDim++
case ',':
case '}':
currentDim--
if currentDim < counterDim {
counterDim = currentDim
}
default:
buf.UnreadRune()
value, err := arrayParseValue(buf)
if err != nil {
return nil, errors.Errorf("invalid array value: %v", err)
}
if currentDim == counterDim {
implicitDimensions[currentDim].Length++
}
dst.Elements = append(dst.Elements, value)
}
if currentDim < 0 {
break
}
}
skipWhitespace(buf)
if buf.Len() > 0 {
return nil, errors.Errorf("unexpected trailing data: %v", buf.String())
}
if len(dst.Elements) == 0 {
dst.Dimensions = nil
} else if len(explicitDimensions) > 0 {
dst.Dimensions = explicitDimensions
} else {
dst.Dimensions = implicitDimensions
}
return dst, nil
}
func skipWhitespace(buf *bytes.Buffer) {
var r rune
var err error
for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() {
}
if err != io.EOF {
buf.UnreadRune()
}
}
func arrayParseValue(buf *bytes.Buffer) (string, error) {
r, _, err := buf.ReadRune()
if err != nil {
return "", err
}
if r == '"' {
return arrayParseQuotedValue(buf)
}
buf.UnreadRune()
s := &bytes.Buffer{}
for {
r, _, err := buf.ReadRune()
if err != nil {
return "", err
}
switch r {
case ',', '}':
buf.UnreadRune()
return s.String(), nil
}
s.WriteRune(r)
}
}
func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) {
s := &bytes.Buffer{}
for {
r, _, err := buf.ReadRune()
if err != nil {
return "", err
}
switch r {
case '\\':
r, _, err = buf.ReadRune()
if err != nil {
return "", err
}
case '"':
r, _, err = buf.ReadRune()
if err != nil {
return "", err
}
buf.UnreadRune()
return s.String(), nil
}
s.WriteRune(r)
}
}
func arrayParseInteger(buf *bytes.Buffer) (int32, error) {
s := &bytes.Buffer{}
for {
r, _, err := buf.ReadRune()
if err != nil {
return 0, err
}
if '0' <= r && r <= '9' {
s.WriteRune(r)
} else {
buf.UnreadRune()
n, err := strconv.ParseInt(s.String(), 10, 32)
if err != nil {
return 0, err
}
return int32(n), nil
}
}
}
func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte {
var customDimensions bool
for _, dim := range dimensions {
if dim.LowerBound != 1 {
customDimensions = true
}
}
if !customDimensions {
return buf
}
for _, dim := range dimensions {
buf = append(buf, '[')
buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...)
buf = append(buf, ':')
buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...)
buf = append(buf, ']')
}
return append(buf, '=')
}
var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
func quoteArrayElement(src string) string {
return `"` + quoteArrayReplacer.Replace(src) + `"`
}
func QuoteArrayElementIfNeeded(src string) string {
if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `{},"\`) {
return quoteArrayElement(src)
}
return src
}

105
pgtype/array_test.go Normal file
View File

@ -0,0 +1,105 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
)
func TestParseUntypedTextArray(t *testing.T) {
tests := []struct {
source string
result pgtype.UntypedTextArray
}{
{
source: "{}",
result: pgtype.UntypedTextArray{
Elements: nil,
Dimensions: nil,
},
},
{
source: "{1}",
result: pgtype.UntypedTextArray{
Elements: []string{"1"},
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}},
},
},
{
source: "{a,b}",
result: pgtype.UntypedTextArray{
Elements: []string{"a", "b"},
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
},
},
{
source: `{"NULL"}`,
result: pgtype.UntypedTextArray{
Elements: []string{"NULL"},
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}},
},
},
{
source: `{""}`,
result: pgtype.UntypedTextArray{
Elements: []string{""},
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}},
},
},
{
source: `{"He said, \"Hello.\""}`,
result: pgtype.UntypedTextArray{
Elements: []string{`He said, "Hello."`},
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}},
},
},
{
source: "{{a,b},{c,d},{e,f}}",
result: pgtype.UntypedTextArray{
Elements: []string{"a", "b", "c", "d", "e", "f"},
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
},
},
{
source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}",
result: pgtype.UntypedTextArray{
Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"},
Dimensions: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 1},
{Length: 3, LowerBound: 1},
{Length: 2, LowerBound: 1},
},
},
},
{
source: "[4:4]={1}",
result: pgtype.UntypedTextArray{
Elements: []string{"1"},
Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}},
},
},
{
source: "[4:5][2:3]={{a,b},{c,d}}",
result: pgtype.UntypedTextArray{
Elements: []string{"a", "b", "c", "d"},
Dimensions: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 4},
{Length: 2, LowerBound: 2},
},
},
},
}
for i, tt := range tests {
r, err := pgtype.ParseUntypedTextArray(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
continue
}
if !reflect.DeepEqual(*r, tt.result) {
t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r)
}
}
}

159
pgtype/bool.go Normal file
View File

@ -0,0 +1,159 @@
package pgtype
import (
"database/sql/driver"
"strconv"
"github.com/pkg/errors"
)
type Bool struct {
Bool bool
Status Status
}
func (dst *Bool) Set(src interface{}) error {
switch value := src.(type) {
case bool:
*dst = Bool{Bool: value, Status: Present}
case string:
bb, err := strconv.ParseBool(value)
if err != nil {
return err
}
*dst = Bool{Bool: bb, Status: Present}
default:
if originalSrc, ok := underlyingBoolType(src); ok {
return dst.Set(originalSrc)
}
return errors.Errorf("cannot convert %v to Bool", value)
}
return nil
}
func (dst *Bool) Get() interface{} {
switch dst.Status {
case Present:
return dst.Bool
case Null:
return nil
default:
return dst.Status
}
}
func (src *Bool) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
switch v := dst.(type) {
case *bool:
*v = src.Bool
return nil
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case Null:
return NullAssignTo(dst)
}
return errors.Errorf("cannot decode %v into %T", src, dst)
}
func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Bool{Status: Null}
return nil
}
if len(src) != 1 {
return errors.Errorf("invalid length for bool: %v", len(src))
}
*dst = Bool{Bool: src[0] == 't', Status: Present}
return nil
}
func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Bool{Status: Null}
return nil
}
if len(src) != 1 {
return errors.Errorf("invalid length for bool: %v", len(src))
}
*dst = Bool{Bool: src[0] == 1, Status: Present}
return nil
}
func (src *Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
if src.Bool {
buf = append(buf, 't')
} else {
buf = append(buf, 'f')
}
return buf, nil
}
func (src *Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
if src.Bool {
buf = append(buf, 1)
} else {
buf = append(buf, 0)
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Bool) Scan(src interface{}) error {
if src == nil {
*dst = Bool{Status: Null}
return nil
}
switch src := src.(type) {
case bool:
*dst = Bool{Bool: src, Status: Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return errors.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Bool) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.Bool, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

294
pgtype/bool_array.go Normal file
View File

@ -0,0 +1,294 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"github.com/jackc/pgx/pgio"
"github.com/pkg/errors"
)
type BoolArray struct {
Elements []Bool
Dimensions []ArrayDimension
Status Status
}
func (dst *BoolArray) Set(src interface{}) error {
switch value := src.(type) {
case []bool:
if value == nil {
*dst = BoolArray{Status: Null}
} else if len(value) == 0 {
*dst = BoolArray{Status: Present}
} else {
elements := make([]Bool, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = BoolArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
default:
if originalSrc, ok := underlyingSliceType(src); ok {
return dst.Set(originalSrc)
}
return errors.Errorf("cannot convert %v to Bool", value)
}
return nil
}
func (dst *BoolArray) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *BoolArray) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
switch v := dst.(type) {
case *[]bool:
*v = make([]bool, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case Null:
return NullAssignTo(dst)
}
return errors.Errorf("cannot decode %v into %T", src, dst)
}
func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = BoolArray{Status: Null}
return nil
}
uta, err := ParseUntypedTextArray(string(src))
if err != nil {
return err
}
var elements []Bool
if len(uta.Elements) > 0 {
elements = make([]Bool, len(uta.Elements))
for i, s := range uta.Elements {
var elem Bool
var elemSrc []byte
if s != "NULL" {
elemSrc = []byte(s)
}
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
return nil
}
func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = BoolArray{Status: Null}
return nil
}
var arrayHeader ArrayHeader
rp, err := arrayHeader.DecodeBinary(ci, src)
if err != nil {
return err
}
if len(arrayHeader.Dimensions) == 0 {
*dst = BoolArray{Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}
elementCount := arrayHeader.Dimensions[0].Length
for _, d := range arrayHeader.Dimensions[1:] {
elementCount *= d.Length
}
elements := make([]Bool, elementCount)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err = elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}
func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
if len(src.Dimensions) == 0 {
return append(buf, '{', '}'), nil
}
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
// dimElemCounts is the multiples of elements that each array lies on. For
// example, a single dimension array of length 4 would have a dimElemCounts of
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
// or '}'.
dimElemCounts := make([]int, len(src.Dimensions))
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
for i := len(src.Dimensions) - 2; i > -1; i-- {
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
}
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Elements {
if i > 0 {
buf = append(buf, ',')
}
for _, dec := range dimElemCounts {
if i%dec == 0 {
buf = append(buf, '{')
}
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
buf = append(buf, `NULL`...)
} else {
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
}
for _, dec := range dimElemCounts {
if (i+1)%dec == 0 {
buf = append(buf, '}')
}
}
}
return buf, nil
}
func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
arrayHeader := ArrayHeader{
Dimensions: src.Dimensions,
}
if dt, ok := ci.DataTypeForName("bool"); ok {
arrayHeader.ElementOID = int32(dt.OID)
} else {
return nil, errors.Errorf("unable to find oid for type name %v", "bool")
}
for i := range src.Elements {
if src.Elements[i].Status == Null {
arrayHeader.ContainsNull = true
break
}
}
buf = arrayHeader.EncodeBinary(ci, buf)
for i := range src.Elements {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if elemBuf != nil {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *BoolArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return errors.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *BoolArray) Value() (driver.Value, error) {
buf, err := src.EncodeText(nil, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
}

153
pgtype/bool_array_test.go Normal file
View File

@ -0,0 +1,153 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/pgtype/testutil"
)
func TestBoolArrayTranscode(t *testing.T) {
testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{
&pgtype.BoolArray{
Elements: nil,
Dimensions: nil,
Status: pgtype.Present,
},
&pgtype.BoolArray{
Elements: []pgtype.Bool{
pgtype.Bool{Bool: true, Status: pgtype.Present},
pgtype.Bool{Status: pgtype.Null},
},
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.BoolArray{Status: pgtype.Null},
&pgtype.BoolArray{
Elements: []pgtype.Bool{
pgtype.Bool{Bool: true, Status: pgtype.Present},
pgtype.Bool{Bool: true, Status: pgtype.Present},
pgtype.Bool{Bool: false, Status: pgtype.Present},
pgtype.Bool{Bool: true, Status: pgtype.Present},
pgtype.Bool{Status: pgtype.Null},
pgtype.Bool{Bool: false, Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.BoolArray{
Elements: []pgtype.Bool{
pgtype.Bool{Bool: true, Status: pgtype.Present},
pgtype.Bool{Bool: false, Status: pgtype.Present},
pgtype.Bool{Bool: true, Status: pgtype.Present},
pgtype.Bool{Bool: false, Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 4},
{Length: 2, LowerBound: 2},
},
Status: pgtype.Present,
},
})
}
func TestBoolArraySet(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.BoolArray
}{
{
source: []bool{true},
result: pgtype.BoolArray{
Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present},
},
{
source: (([]bool)(nil)),
result: pgtype.BoolArray{Status: pgtype.Null},
},
}
for i, tt := range successfulTests {
var r pgtype.BoolArray
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !reflect.DeepEqual(r, tt.result) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestBoolArrayAssignTo(t *testing.T) {
var boolSlice []bool
type _boolSlice []bool
var namedBoolSlice _boolSlice
simpleTests := []struct {
src pgtype.BoolArray
dst interface{}
expected interface{}
}{
{
src: pgtype.BoolArray{
Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &boolSlice,
expected: []bool{true},
},
{
src: pgtype.BoolArray{
Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &namedBoolSlice,
expected: _boolSlice{true},
},
{
src: pgtype.BoolArray{Status: pgtype.Null},
dst: &boolSlice,
expected: (([]bool)(nil)),
},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
errorTests := []struct {
src pgtype.BoolArray
dst interface{}
}{
{
src: pgtype.BoolArray{
Elements: []pgtype.Bool{{Status: pgtype.Null}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &boolSlice,
},
}
for i, tt := range errorTests {
err := tt.src.AssignTo(tt.dst)
if err == nil {
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
}
}
}

96
pgtype/bool_test.go Normal file
View File

@ -0,0 +1,96 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/pgtype/testutil"
)
func TestBoolTranscode(t *testing.T) {
testutil.TestSuccessfulTranscode(t, "bool", []interface{}{
&pgtype.Bool{Bool: false, Status: pgtype.Present},
&pgtype.Bool{Bool: true, Status: pgtype.Present},
&pgtype.Bool{Bool: false, Status: pgtype.Null},
})
}
func TestBoolSet(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.Bool
}{
{source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
{source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
{source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
{source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}},
{source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}},
}
for i, tt := range successfulTests {
var r pgtype.Bool
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if r != tt.result {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestBoolAssignTo(t *testing.T) {
var b bool
var _b _bool
var pb *bool
var _pb *_bool
simpleTests := []struct {
src pgtype.Bool
dst interface{}
expected interface{}
}{
{src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &b, expected: false},
{src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &b, expected: true},
{src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &_b, expected: _bool(false)},
{src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_b, expected: _bool(true)},
{src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &pb, expected: ((*bool)(nil))},
{src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &_pb, expected: ((*_bool)(nil))},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
pointerAllocTests := []struct {
src pgtype.Bool
dst interface{}
expected interface{}
}{
{src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &pb, expected: true},
{src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_pb, expected: _bool(true)},
}
for i, tt := range pointerAllocTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
}

162
pgtype/box.go Normal file
View File

@ -0,0 +1,162 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"math"
"strconv"
"strings"
"github.com/jackc/pgx/pgio"
"github.com/pkg/errors"
)
type Box struct {
P [2]Vec2
Status Status
}
func (dst *Box) Set(src interface{}) error {
return errors.Errorf("cannot convert %v to Box", src)
}
func (dst *Box) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *Box) AssignTo(dst interface{}) error {
return errors.Errorf("cannot assign %v to %T", src, dst)
}
func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Box{Status: Null}
return nil
}
if len(src) < 11 {
return errors.Errorf("invalid length for Box: %v", len(src))
}
str := string(src[1:])
var end int
end = strings.IndexByte(str, ',')
x1, err := strconv.ParseFloat(str[:end], 64)
if err != nil {
return err
}
str = str[end+1:]
end = strings.IndexByte(str, ')')
y1, err := strconv.ParseFloat(str[:end], 64)
if err != nil {
return err
}
str = str[end+3:]
end = strings.IndexByte(str, ',')
x2, err := strconv.ParseFloat(str[:end], 64)
if err != nil {
return err
}
str = str[end+1 : len(str)-1]
y2, err := strconv.ParseFloat(str, 64)
if err != nil {
return err
}
*dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present}
return nil
}
func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Box{Status: Null}
return nil
}
if len(src) != 32 {
return errors.Errorf("invalid length for Box: %v", len(src))
}
x1 := binary.BigEndian.Uint64(src)
y1 := binary.BigEndian.Uint64(src[8:])
x2 := binary.BigEndian.Uint64(src[16:])
y2 := binary.BigEndian.Uint64(src[24:])
*dst = Box{
P: [2]Vec2{
{math.Float64frombits(x1), math.Float64frombits(y1)},
{math.Float64frombits(x2), math.Float64frombits(y2)},
},
Status: Present,
}
return nil
}
func (src *Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`,
src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...)
return buf, nil
}
func (src *Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X))
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y))
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X))
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y))
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Box) Scan(src interface{}) error {
if src == nil {
*dst = Box{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return errors.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Box) Value() (driver.Value, error) {
return EncodeValueText(src)
}

34
pgtype/box_test.go Normal file
View File

@ -0,0 +1,34 @@
package pgtype_test
import (
"testing"
"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/pgtype/testutil"
)
func TestBoxTranscode(t *testing.T) {
testutil.TestSuccessfulTranscode(t, "box", []interface{}{
&pgtype.Box{
P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}},
Status: pgtype.Present,
},
&pgtype.Box{
P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}},
Status: pgtype.Present,
},
&pgtype.Box{Status: pgtype.Null},
})
}
func TestBoxNormalize(t *testing.T) {
testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{
{
SQL: "select '3.14, 1.678, 7.1, 5.234'::box",
Value: &pgtype.Box{
P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}},
Status: pgtype.Present,
},
},
})
}

156
pgtype/bytea.go Normal file
View File

@ -0,0 +1,156 @@
package pgtype
import (
"database/sql/driver"
"encoding/hex"
"github.com/pkg/errors"
)
type Bytea struct {
Bytes []byte
Status Status
}
func (dst *Bytea) Set(src interface{}) error {
if src == nil {
*dst = Bytea{Status: Null}
return nil
}
switch value := src.(type) {
case []byte:
if value != nil {
*dst = Bytea{Bytes: value, Status: Present}
} else {
*dst = Bytea{Status: Null}
}
default:
if originalSrc, ok := underlyingBytesType(src); ok {
return dst.Set(originalSrc)
}
return errors.Errorf("cannot convert %v to Bytea", value)
}
return nil
}
func (dst *Bytea) Get() interface{} {
switch dst.Status {
case Present:
return dst.Bytes
case Null:
return nil
default:
return dst.Status
}
}
func (src *Bytea) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
switch v := dst.(type) {
case *[]byte:
buf := make([]byte, len(src.Bytes))
copy(buf, src.Bytes)
*v = buf
return nil
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case Null:
return NullAssignTo(dst)
}
return errors.Errorf("cannot decode %v into %T", src, dst)
}
// DecodeText only supports the hex format. This has been the default since
// PostgreSQL 9.0.
func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Bytea{Status: Null}
return nil
}
if len(src) < 2 || src[0] != '\\' || src[1] != 'x' {
return errors.Errorf("invalid hex format")
}
buf := make([]byte, (len(src)-2)/2)
_, err := hex.Decode(buf, src[2:])
if err != nil {
return err
}
*dst = Bytea{Bytes: buf, Status: Present}
return nil
}
func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Bytea{Status: Null}
return nil
}
*dst = Bytea{Bytes: src, Status: Present}
return nil
}
func (src *Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = append(buf, `\x`...)
buf = append(buf, hex.EncodeToString(src.Bytes)...)
return buf, nil
}
func (src *Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
return append(buf, src.Bytes...), nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Bytea) Scan(src interface{}) error {
if src == nil {
*dst = Bytea{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
buf := make([]byte, len(src))
copy(buf, src)
*dst = Bytea{Bytes: buf, Status: Present}
return nil
}
return errors.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Bytea) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.Bytes, nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

294
pgtype/bytea_array.go Normal file
View File

@ -0,0 +1,294 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"github.com/jackc/pgx/pgio"
"github.com/pkg/errors"
)
type ByteaArray struct {
Elements []Bytea
Dimensions []ArrayDimension
Status Status
}
func (dst *ByteaArray) Set(src interface{}) error {
switch value := src.(type) {
case [][]byte:
if value == nil {
*dst = ByteaArray{Status: Null}
} else if len(value) == 0 {
*dst = ByteaArray{Status: Present}
} else {
elements := make([]Bytea, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = ByteaArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
default:
if originalSrc, ok := underlyingSliceType(src); ok {
return dst.Set(originalSrc)
}
return errors.Errorf("cannot convert %v to Bytea", value)
}
return nil
}
func (dst *ByteaArray) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *ByteaArray) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
switch v := dst.(type) {
case *[][]byte:
*v = make([][]byte, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case Null:
return NullAssignTo(dst)
}
return errors.Errorf("cannot decode %v into %T", src, dst)
}
func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = ByteaArray{Status: Null}
return nil
}
uta, err := ParseUntypedTextArray(string(src))
if err != nil {
return err
}
var elements []Bytea
if len(uta.Elements) > 0 {
elements = make([]Bytea, len(uta.Elements))
for i, s := range uta.Elements {
var elem Bytea
var elemSrc []byte
if s != "NULL" {
elemSrc = []byte(s)
}
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
return nil
}
func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = ByteaArray{Status: Null}
return nil
}
var arrayHeader ArrayHeader
rp, err := arrayHeader.DecodeBinary(ci, src)
if err != nil {
return err
}
if len(arrayHeader.Dimensions) == 0 {
*dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}
elementCount := arrayHeader.Dimensions[0].Length
for _, d := range arrayHeader.Dimensions[1:] {
elementCount *= d.Length
}
elements := make([]Bytea, elementCount)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err = elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}
func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
if len(src.Dimensions) == 0 {
return append(buf, '{', '}'), nil
}
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
// dimElemCounts is the multiples of elements that each array lies on. For
// example, a single dimension array of length 4 would have a dimElemCounts of
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
// or '}'.
dimElemCounts := make([]int, len(src.Dimensions))
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
for i := len(src.Dimensions) - 2; i > -1; i-- {
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
}
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Elements {
if i > 0 {
buf = append(buf, ',')
}
for _, dec := range dimElemCounts {
if i%dec == 0 {
buf = append(buf, '{')
}
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
buf = append(buf, `NULL`...)
} else {
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
}
for _, dec := range dimElemCounts {
if (i+1)%dec == 0 {
buf = append(buf, '}')
}
}
}
return buf, nil
}
func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
arrayHeader := ArrayHeader{
Dimensions: src.Dimensions,
}
if dt, ok := ci.DataTypeForName("bytea"); ok {
arrayHeader.ElementOID = int32(dt.OID)
} else {
return nil, errors.Errorf("unable to find oid for type name %v", "bytea")
}
for i := range src.Elements {
if src.Elements[i].Status == Null {
arrayHeader.ContainsNull = true
break
}
}
buf = arrayHeader.EncodeBinary(ci, buf)
for i := range src.Elements {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if elemBuf != nil {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *ByteaArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return errors.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *ByteaArray) Value() (driver.Value, error) {
buf, err := src.EncodeText(nil, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
}

120
pgtype/bytea_array_test.go Normal file
View File

@ -0,0 +1,120 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/pgtype/testutil"
)
func TestByteaArrayTranscode(t *testing.T) {
testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{
&pgtype.ByteaArray{
Elements: nil,
Dimensions: nil,
Status: pgtype.Present,
},
&pgtype.ByteaArray{
Elements: []pgtype.Bytea{
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
pgtype.Bytea{Status: pgtype.Null},
},
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.ByteaArray{Status: pgtype.Null},
&pgtype.ByteaArray{
Elements: []pgtype.Bytea{
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present},
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
pgtype.Bytea{Status: pgtype.Null},
pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.ByteaArray{
Elements: []pgtype.Bytea{
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present},
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 4},
{Length: 2, LowerBound: 2},
},
Status: pgtype.Present,
},
})
}
func TestByteaArraySet(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.ByteaArray
}{
{
source: [][]byte{{1, 2, 3}},
result: pgtype.ByteaArray{
Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present},
},
{
source: (([][]byte)(nil)),
result: pgtype.ByteaArray{Status: pgtype.Null},
},
}
for i, tt := range successfulTests {
var r pgtype.ByteaArray
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !reflect.DeepEqual(r, tt.result) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestByteaArrayAssignTo(t *testing.T) {
var byteByteSlice [][]byte
simpleTests := []struct {
src pgtype.ByteaArray
dst interface{}
expected interface{}
}{
{
src: pgtype.ByteaArray{
Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &byteByteSlice,
expected: [][]byte{{1, 2, 3}},
},
{
src: pgtype.ByteaArray{Status: pgtype.Null},
dst: &byteByteSlice,
expected: (([][]byte)(nil)),
},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
}

73
pgtype/bytea_test.go Normal file
View File

@ -0,0 +1,73 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/pgtype/testutil"
)
func TestByteaTranscode(t *testing.T) {
testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{
&pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
&pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present},
&pgtype.Bytea{Bytes: nil, Status: pgtype.Null},
})
}
func TestByteaSet(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.Bytea
}{
{source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
{source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}},
{source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}},
{source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
{source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}},
}
for i, tt := range successfulTests {
var r pgtype.Bytea
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !reflect.DeepEqual(r, tt.result) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestByteaAssignTo(t *testing.T) {
var buf []byte
var _buf _byteSlice
var pbuf *[]byte
var _pbuf *_byteSlice
simpleTests := []struct {
src pgtype.Bytea
dst interface{}
expected interface{}
}{
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}},
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}},
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}},
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}},
{src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))},
{src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
}

61
pgtype/cid.go Normal file
View File

@ -0,0 +1,61 @@
package pgtype
import (
"database/sql/driver"
)
// CID is PostgreSQL's Command Identifier type.
//
// When one does
//
// select cmin, cmax, * from some_table;
//
// it is the data type of the cmin and cmax hidden system columns.
//
// It is currently implemented as an unsigned four byte integer.
// Its definition can be found in src/include/c.h as CommandId
// in the PostgreSQL sources.
type CID pguint32
// Set converts from src to dst. Note that as CID is not a general
// number type Set does not do automatic type conversion as other number
// types do.
func (dst *CID) Set(src interface{}) error {
return (*pguint32)(dst).Set(src)
}
func (dst *CID) Get() interface{} {
return (*pguint32)(dst).Get()
}
// AssignTo assigns from src to dst. Note that as CID is not a general number
// type AssignTo does not do automatic type conversion as other number types do.
func (src *CID) AssignTo(dst interface{}) error {
return (*pguint32)(src).AssignTo(dst)
}
func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error {
return (*pguint32)(dst).DecodeText(ci, src)
}
func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error {
return (*pguint32)(dst).DecodeBinary(ci, src)
}
func (src *CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
return (*pguint32)(src).EncodeText(ci, buf)
}
func (src *CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
return (*pguint32)(src).EncodeBinary(ci, buf)
}
// Scan implements the database/sql Scanner interface.
func (dst *CID) Scan(src interface{}) error {
return (*pguint32)(dst).Scan(src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *CID) Value() (driver.Value, error) {
return (*pguint32)(src).Value()
}

108
pgtype/cid_test.go Normal file
View File

@ -0,0 +1,108 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/pgtype/testutil"
)
func TestCIDTranscode(t *testing.T) {
pgTypeName := "cid"
values := []interface{}{
&pgtype.CID{Uint: 42, Status: pgtype.Present},
&pgtype.CID{Status: pgtype.Null},
}
eqFunc := func(a, b interface{}) bool {
return reflect.DeepEqual(a, b)
}
testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
// No direct conversion from int to cid, convert through text
testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
}
}
func TestCIDSet(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.CID
}{
{source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}},
}
for i, tt := range successfulTests {
var r pgtype.CID
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if r != tt.result {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestCIDAssignTo(t *testing.T) {
var ui32 uint32
var pui32 *uint32
simpleTests := []struct {
src pgtype.CID
dst interface{}
expected interface{}
}{
{src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)},
{src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
pointerAllocTests := []struct {
src pgtype.CID
dst interface{}
expected interface{}
}{
{src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)},
}
for i, tt := range pointerAllocTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
errorTests := []struct {
src pgtype.CID
dst interface{}
}{
{src: pgtype.CID{Status: pgtype.Null}, dst: &ui32},
}
for i, tt := range errorTests {
err := tt.src.AssignTo(tt.dst)
if err == nil {
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
}
}
}

31
pgtype/cidr.go Normal file
View File

@ -0,0 +1,31 @@
package pgtype
type CIDR Inet
func (dst *CIDR) Set(src interface{}) error {
return (*Inet)(dst).Set(src)
}
func (dst *CIDR) Get() interface{} {
return (*Inet)(dst).Get()
}
func (src *CIDR) AssignTo(dst interface{}) error {
return (*Inet)(src).AssignTo(dst)
}
func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error {
return (*Inet)(dst).DecodeText(ci, src)
}
func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error {
return (*Inet)(dst).DecodeBinary(ci, src)
}
func (src *CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
return (*Inet)(src).EncodeText(ci, buf)
}
func (src *CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
return (*Inet)(src).EncodeBinary(ci, buf)
}

323
pgtype/cidr_array.go Normal file
View File

@ -0,0 +1,323 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"net"
"github.com/jackc/pgx/pgio"
"github.com/pkg/errors"
)
type CIDRArray struct {
Elements []CIDR
Dimensions []ArrayDimension
Status Status
}
func (dst *CIDRArray) Set(src interface{}) error {
switch value := src.(type) {
case []*net.IPNet:
if value == nil {
*dst = CIDRArray{Status: Null}
} else if len(value) == 0 {
*dst = CIDRArray{Status: Present}
} else {
elements := make([]CIDR, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = CIDRArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
case []net.IP:
if value == nil {
*dst = CIDRArray{Status: Null}
} else if len(value) == 0 {
*dst = CIDRArray{Status: Present}
} else {
elements := make([]CIDR, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = CIDRArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
default:
if originalSrc, ok := underlyingSliceType(src); ok {
return dst.Set(originalSrc)
}
return errors.Errorf("cannot convert %v to CIDR", value)
}
return nil
}
func (dst *CIDRArray) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *CIDRArray) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
switch v := dst.(type) {
case *[]*net.IPNet:
*v = make([]*net.IPNet, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]net.IP:
*v = make([]net.IP, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case Null:
return NullAssignTo(dst)
}
return errors.Errorf("cannot decode %v into %T", src, dst)
}
func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = CIDRArray{Status: Null}
return nil
}
uta, err := ParseUntypedTextArray(string(src))
if err != nil {
return err
}
var elements []CIDR
if len(uta.Elements) > 0 {
elements = make([]CIDR, len(uta.Elements))
for i, s := range uta.Elements {
var elem CIDR
var elemSrc []byte
if s != "NULL" {
elemSrc = []byte(s)
}
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
return nil
}
func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = CIDRArray{Status: Null}
return nil
}
var arrayHeader ArrayHeader
rp, err := arrayHeader.DecodeBinary(ci, src)
if err != nil {
return err
}
if len(arrayHeader.Dimensions) == 0 {
*dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}
elementCount := arrayHeader.Dimensions[0].Length
for _, d := range arrayHeader.Dimensions[1:] {
elementCount *= d.Length
}
elements := make([]CIDR, elementCount)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err = elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}
func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
if len(src.Dimensions) == 0 {
return append(buf, '{', '}'), nil
}
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
// dimElemCounts is the multiples of elements that each array lies on. For
// example, a single dimension array of length 4 would have a dimElemCounts of
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
// or '}'.
dimElemCounts := make([]int, len(src.Dimensions))
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
for i := len(src.Dimensions) - 2; i > -1; i-- {
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
}
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Elements {
if i > 0 {
buf = append(buf, ',')
}
for _, dec := range dimElemCounts {
if i%dec == 0 {
buf = append(buf, '{')
}
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
buf = append(buf, `NULL`...)
} else {
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
}
for _, dec := range dimElemCounts {
if (i+1)%dec == 0 {
buf = append(buf, '}')
}
}
}
return buf, nil
}
func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
arrayHeader := ArrayHeader{
Dimensions: src.Dimensions,
}
if dt, ok := ci.DataTypeForName("cidr"); ok {
arrayHeader.ElementOID = int32(dt.OID)
} else {
return nil, errors.Errorf("unable to find oid for type name %v", "cidr")
}
for i := range src.Elements {
if src.Elements[i].Status == Null {
arrayHeader.ContainsNull = true
break
}
}
buf = arrayHeader.EncodeBinary(ci, buf)
for i := range src.Elements {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if elemBuf != nil {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *CIDRArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return errors.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *CIDRArray) Value() (driver.Value, error) {
buf, err := src.EncodeText(nil, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
}

165
pgtype/cidr_array_test.go Normal file
View File

@ -0,0 +1,165 @@
package pgtype_test
import (
"net"
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/pgtype/testutil"
)
func TestCIDRArrayTranscode(t *testing.T) {
testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{
&pgtype.CIDRArray{
Elements: nil,
Dimensions: nil,
Status: pgtype.Present,
},
&pgtype.CIDRArray{
Elements: []pgtype.CIDR{
pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
pgtype.CIDR{Status: pgtype.Null},
},
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.CIDRArray{Status: pgtype.Null},
&pgtype.CIDRArray{
Elements: []pgtype.CIDR{
pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present},
pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present},
pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present},
pgtype.CIDR{Status: pgtype.Null},
pgtype.CIDR{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.CIDRArray{
Elements: []pgtype.CIDR{
pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present},
pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present},
pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 4},
{Length: 2, LowerBound: 2},
},
Status: pgtype.Present,
},
})
}
func TestCIDRArraySet(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.CIDRArray
}{
{
source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")},
result: pgtype.CIDRArray{
Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present},
},
{
source: (([]*net.IPNet)(nil)),
result: pgtype.CIDRArray{Status: pgtype.Null},
},
{
source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP},
result: pgtype.CIDRArray{
Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present},
},
{
source: (([]net.IP)(nil)),
result: pgtype.CIDRArray{Status: pgtype.Null},
},
}
for i, tt := range successfulTests {
var r pgtype.CIDRArray
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !reflect.DeepEqual(r, tt.result) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestCIDRArrayAssignTo(t *testing.T) {
var ipnetSlice []*net.IPNet
var ipSlice []net.IP
simpleTests := []struct {
src pgtype.CIDRArray
dst interface{}
expected interface{}
}{
{
src: pgtype.CIDRArray{
Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &ipnetSlice,
expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")},
},
{
src: pgtype.CIDRArray{
Elements: []pgtype.CIDR{{Status: pgtype.Null}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &ipnetSlice,
expected: []*net.IPNet{nil},
},
{
src: pgtype.CIDRArray{
Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &ipSlice,
expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP},
},
{
src: pgtype.CIDRArray{
Elements: []pgtype.CIDR{{Status: pgtype.Null}},
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
Status: pgtype.Present,
},
dst: &ipSlice,
expected: []net.IP{nil},
},
{
src: pgtype.CIDRArray{Status: pgtype.Null},
dst: &ipnetSlice,
expected: (([]*net.IPNet)(nil)),
},
{
src: pgtype.CIDRArray{Status: pgtype.Null},
dst: &ipSlice,
expected: (([]net.IP)(nil)),
},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
}

146
pgtype/circle.go Normal file
View File

@ -0,0 +1,146 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"math"
"strconv"
"strings"
"github.com/jackc/pgx/pgio"
"github.com/pkg/errors"
)
type Circle struct {
P Vec2
R float64
Status Status
}
func (dst *Circle) Set(src interface{}) error {
return errors.Errorf("cannot convert %v to Circle", src)
}
func (dst *Circle) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *Circle) AssignTo(dst interface{}) error {
return errors.Errorf("cannot assign %v to %T", src, dst)
}
func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Circle{Status: Null}
return nil
}
if len(src) < 9 {
return errors.Errorf("invalid length for Circle: %v", len(src))
}
str := string(src[2:])
end := strings.IndexByte(str, ',')
x, err := strconv.ParseFloat(str[:end], 64)
if err != nil {
return err
}
str = str[end+1:]
end = strings.IndexByte(str, ')')
y, err := strconv.ParseFloat(str[:end], 64)
if err != nil {
return err
}
str = str[end+2 : len(str)-1]
r, err := strconv.ParseFloat(str, 64)
if err != nil {
return err
}
*dst = Circle{P: Vec2{x, y}, R: r, Status: Present}
return nil
}
func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Circle{Status: Null}
return nil
}
if len(src) != 24 {
return errors.Errorf("invalid length for Circle: %v", len(src))
}
x := binary.BigEndian.Uint64(src)
y := binary.BigEndian.Uint64(src[8:])
r := binary.BigEndian.Uint64(src[16:])
*dst = Circle{
P: Vec2{math.Float64frombits(x), math.Float64frombits(y)},
R: math.Float64frombits(r),
Status: Present,
}
return nil
}
func (src *Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = append(buf, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)...)
return buf, nil
}
func (src *Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X))
buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y))
buf = pgio.AppendUint64(buf, math.Float64bits(src.R))
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Circle) Scan(src interface{}) error {
if src == nil {
*dst = Circle{Status: Null}
return nil
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return errors.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Circle) Value() (driver.Value, error) {
return EncodeValueText(src)
}

Some files were not shown because too many files have changed in this diff Show More