Merge branch 'pgconnimport' into v5-dev

query-exec-mode
Jack Christensen 2021-12-04 13:51:50 -06:00
commit 5c36639f09
29 changed files with 7822 additions and 0 deletions

81
pgconn/.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,81 @@
name: CI
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
test:
name: Test
runs-on: ubuntu-18.04
strategy:
matrix:
go-version: [1.15, 1.16]
pg-version: [9.6, 10, 11, 12, 13, cockroachdb]
include:
- pg-version: 9.6
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 10
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 11
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 12
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 13
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: cockroachdb
pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: Setup database server for testing
run: ci/setup_test.bash
env:
PGVERSION: ${{ matrix.pg-version }}
- name: Test
run: go test -v -race ./...
env:
PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }}
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}

3
pgconn/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
.envrc
vendor/
.vscode

129
pgconn/CHANGELOG.md Normal file
View File

@ -0,0 +1,129 @@
# 1.10.1 (November 20, 2021)
* Close without waiting for response (Kei Kamikawa)
* Save waiting for network round-trip in CopyFrom (Rueian)
* Fix concurrency issue with ContextWatcher
* LRU.Get always checks context for cancellation / expiration (Georges Varouchas)
# 1.10.0 (July 24, 2021)
* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned.
# 1.9.0 (July 10, 2021)
* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr)
* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle)
* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard)
* Fix default host when parsing URL without host but with port
* Allow dbname query parameter in URL conn string
* Update underlying dependencies
# 1.8.1 (March 25, 2021)
* Better connection string sanitization (ip.novikov)
* Use proper pgpass location on Windows (Moshe Katz)
* Use errors instead of golang.org/x/xerrors
* Resume fallback on server error in Connect (Andrey Borodin)
# 1.8.0 (December 3, 2020)
* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes)
# 1.7.2 (November 3, 2020)
* Fix data value slices into work buffer with capacities larger than length.
# 1.7.1 (October 31, 2020)
* Do not asyncClose after receiving FATAL error from PostgreSQL server
# 1.7.0 (September 26, 2020)
* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded
* Add ReceiveResults (Sebastiaan Mannem)
* Fix parsing DSN connection with bad backslash
* Add PgConn.CleanupDone so connection pools can determine when async close is complete
# 1.6.4 (July 29, 2020)
* Fix deadlock on error after CommandComplete but before ReadyForQuery
* Fix panic on parsing DSN with trailing '='
# 1.6.3 (July 22, 2020)
* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo)
# 1.6.2 (July 14, 2020)
* Update pgservicefile library
# 1.6.1 (June 27, 2020)
* Update golang.org/x/crypto to latest
* Update golang.org/x/text to 0.3.3
* Fix error handling for bad PGSERVICE definition
* Redact passwords in ParseConfig errors (Lukas Vogel)
# 1.6.0 (June 6, 2020)
* Fix panic when closing conn during cancellable query
* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný)
* Fix field descriptions available after command concluded (Tobias Salzmann)
* Support connect_timeout (georgysavva)
* Handle IPv6 in connection URLs (Lukas Vogel)
* Fix ValidateConnect with cancelable context
* Improve CopyFrom performance
* Add Config.Copy (georgysavva)
# 1.5.0 (March 30, 2020)
* Update golang.org/x/crypto for security fix
* Implement "verify-ca" SSL mode (Greg Curtis)
# 1.4.0 (March 7, 2020)
* Fix ExecParams and ExecPrepared handling of empty query.
* Support reading config from PostgreSQL service files.
# 1.3.2 (February 14, 2020)
* Update chunkreader to v2.0.1 for optimized default buffer size.
# 1.3.1 (February 5, 2020)
* Fix CopyFrom deadlock when multiple NoticeResponse received during copy
# 1.3.0 (January 23, 2020)
* Add Hijack and Construct.
* Update pgproto3 to v2.0.1.
# 1.2.1 (January 13, 2020)
* Fix data race in context cancellation introduced in v1.2.0.
# 1.2.0 (January 11, 2020)
## Features
* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag.
* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases.
## Performance
* Improve performance when context.Background() is used. (bakape)
* CommandTag.RowsAffected is faster and does not allocate.
## Fixes
* Try to cancel any in-progress query when a conn is closed by ctx cancel.
* Handle NoticeResponse during CopyFrom.
* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish.
# 1.1.0 (October 12, 2019)
* Add PgConn.IsBusy() method.
# 1.0.1 (September 19, 2019)
* Fix statement cache not properly cleaning discarded statements.

22
pgconn/LICENSE Normal file
View File

@ -0,0 +1,22 @@
Copyright (c) 2019-2021 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

56
pgconn/README.md Normal file
View File

@ -0,0 +1,56 @@
[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn)
![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg)
# pgconn
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
low-level access to PostgreSQL functionality.
## Example Usage
```go
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatalln("pgconn failed to connect:", err)
}
defer pgConn.Close(context.Background())
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
for result.NextRow() {
fmt.Println("User 123 has email:", string(result.Values()[0]))
}
_, err = result.Close()
if err != nil {
log.Fatalln("failed reading result:", err)
}
```
## Testing
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
environment variable handling.
### Example Test Environment
Connect to your PostgreSQL server and run:
```
create database pgx_test;
```
Now you can run the tests:
```bash
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
```
### Connection and Authentication Tests
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
authentication code.

266
pgconn/auth_scram.go Normal file
View File

@ -0,0 +1,266 @@
// SCRAM-SHA-256 authentication
//
// Resources:
// https://tools.ietf.org/html/rfc5802
// https://tools.ietf.org/html/rfc8265
// https://www.postgresql.org/docs/current/sasl-authentication.html
//
// Inspiration drawn from other implementations:
// https://github.com/lib/pq/pull/608
// https://github.com/lib/pq/pull/788
// https://github.com/lib/pq/pull/833
package pgconn
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"strconv"
"github.com/jackc/pgproto3/v2"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/text/secure/precis"
)
const clientNonceLen = 18
// Perform SCRAM authentication.
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
if err != nil {
return err
}
// Send client-first-message in a SASLInitialResponse
saslInitialResponse := &pgproto3.SASLInitialResponse{
AuthMechanism: "SCRAM-SHA-256",
Data: sc.clientFirstMessage(),
}
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
if err != nil {
return err
}
// Receive server-first-message payload in a AuthenticationSASLContinue.
saslContinue, err := c.rxSASLContinue()
if err != nil {
return err
}
err = sc.recvServerFirstMessage(saslContinue.Data)
if err != nil {
return err
}
// Send client-final-message in a SASLResponse
saslResponse := &pgproto3.SASLResponse{
Data: []byte(sc.clientFinalMessage()),
}
_, err = c.conn.Write(saslResponse.Encode(nil))
if err != nil {
return err
}
// Receive server-final-message payload in a AuthenticationSASLFinal.
saslFinal, err := c.rxSASLFinal()
if err != nil {
return err
}
return sc.recvServerFinalMessage(saslFinal.Data)
}
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
if ok {
return saslContinue, nil
}
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
if ok {
return saslFinal, nil
}
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
}
type scramClient struct {
serverAuthMechanisms []string
password []byte
clientNonce []byte
clientFirstMessageBare []byte
serverFirstMessage []byte
clientAndServerNonce []byte
salt []byte
iterations int
saltedPassword []byte
authMessage []byte
}
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
sc := &scramClient{
serverAuthMechanisms: serverAuthMechanisms,
}
// Ensure server supports SCRAM-SHA-256
hasScramSHA256 := false
for _, mech := range sc.serverAuthMechanisms {
if mech == "SCRAM-SHA-256" {
hasScramSHA256 = true
break
}
}
if !hasScramSHA256 {
return nil, errors.New("server does not support SCRAM-SHA-256")
}
// precis.OpaqueString is equivalent to SASLprep for password.
var err error
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
if err != nil {
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
sc.password = []byte(password)
}
buf := make([]byte, clientNonceLen)
_, err = rand.Read(buf)
if err != nil {
return nil, err
}
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
return sc, nil
}
func (sc *scramClient) clientFirstMessage() []byte {
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
}
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
sc.serverFirstMessage = serverFirstMessage
buf := serverFirstMessage
if !bytes.HasPrefix(buf, []byte("r=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
}
buf = buf[2:]
idx := bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
sc.clientAndServerNonce = buf[:idx]
buf = buf[idx+1:]
if !bytes.HasPrefix(buf, []byte("s=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
buf = buf[2:]
idx = bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
saltStr := buf[:idx]
buf = buf[idx+1:]
if !bytes.HasPrefix(buf, []byte("i=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
buf = buf[2:]
iterationsStr := buf
var err error
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
if err != nil {
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
}
sc.iterations, err = strconv.Atoi(string(iterationsStr))
if err != nil || sc.iterations <= 0 {
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
}
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not start with client nonce")
}
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not include server nonce")
}
return nil
}
func (sc *scramClient) clientFinalMessage() string {
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
}
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
return errors.New("invalid SCRAM server-final-message received from server")
}
serverSignature := serverFinalMessage[2:]
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
return errors.New("invalid SCRAM ServerSignature received from server")
}
return nil
}
func computeHMAC(key, msg []byte) []byte {
mac := hmac.New(sha256.New, key)
mac.Write(msg)
return mac.Sum(nil)
}
func computeClientProof(saltedPassword, authMessage []byte) []byte {
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
storedKey := sha256.Sum256(clientKey)
clientSignature := computeHMAC(storedKey[:], authMessage)
clientProof := make([]byte, len(clientSignature))
for i := 0; i < len(clientSignature); i++ {
clientProof[i] = clientKey[i] ^ clientSignature[i]
}
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
base64.StdEncoding.Encode(buf, clientProof)
return buf
}
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
serverSignature := computeHMAC(serverKey, authMessage)
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
base64.StdEncoding.Encode(buf, serverSignature)
return buf
}

322
pgconn/benchmark_test.go Normal file
View File

@ -0,0 +1,322 @@
package pgconn_test
import (
"bytes"
"context"
"os"
"strings"
"testing"
"github.com/jackc/pgconn"
"github.com/stretchr/testify/require"
)
func BenchmarkConnect(b *testing.B) {
benchmarks := []struct {
name string
env string
}{
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
}
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
connString := os.Getenv(bm.env)
if connString == "" {
b.Skipf("Skipping due to missing environment variable %v", bm.env)
}
for i := 0; i < b.N; i++ {
conn, err := pgconn.Connect(context.Background(), connString)
require.Nil(b, err)
err = conn.Close(context.Background())
require.Nil(b, err)
}
})
}
}
func BenchmarkExec(b *testing.B) {
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
benchmarks := []struct {
name string
ctx context.Context
}{
// Using an empty context other than context.Background() to compare
// performance
{"background context", context.Background()},
{"empty context", context.TODO()},
}
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err)
defer closeConn(b, conn)
b.ResetTimer()
for i := 0; i < b.N; i++ {
mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
for mrr.NextResult() {
rr := mrr.ResultReader()
rowCount := 0
for rr.NextRow() {
rowCount++
if len(rr.Values()) != len(expectedValues) {
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
}
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
}
}
}
_, err = rr.Close()
if err != nil {
b.Fatal(err)
}
if rowCount != 1 {
b.Fatalf("unexpected rowCount: %d", rowCount)
}
}
err := mrr.Close()
if err != nil {
b.Fatal(err)
}
}
})
}
}
func BenchmarkExecPossibleToCancel(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err)
defer closeConn(b, conn)
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
b.ResetTimer()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for i := 0; i < b.N; i++ {
mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
for mrr.NextResult() {
rr := mrr.ResultReader()
rowCount := 0
for rr.NextRow() {
rowCount++
if len(rr.Values()) != len(expectedValues) {
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
}
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
}
}
}
_, err = rr.Close()
if err != nil {
b.Fatal(err)
}
if rowCount != 1 {
b.Fatalf("unexpected rowCount: %d", rowCount)
}
}
err := mrr.Close()
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkExecPrepared(b *testing.B) {
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
benchmarks := []struct {
name string
ctx context.Context
}{
// Using an empty context other than context.Background() to compare
// performance
{"background context", context.Background()},
{"empty context", context.TODO()},
}
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err)
defer closeConn(b, conn)
_, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
require.Nil(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil)
rowCount := 0
for rr.NextRow() {
rowCount++
if len(rr.Values()) != len(expectedValues) {
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
}
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
}
}
}
_, err = rr.Close()
if err != nil {
b.Fatal(err)
}
if rowCount != 1 {
b.Fatalf("unexpected rowCount: %d", rowCount)
}
}
})
}
}
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err)
defer closeConn(b, conn)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
require.Nil(b, err)
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil)
rowCount := 0
for rr.NextRow() {
rowCount += 1
if len(rr.Values()) != len(expectedValues) {
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
}
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
}
}
}
_, err = rr.Close()
if err != nil {
b.Fatal(err)
}
if rowCount != 1 {
b.Fatalf("unexpected rowCount: %d", rowCount)
}
}
}
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
// require.Nil(b, err)
// defer closeConn(b, conn)
// ctx, cancel := context.WithCancel(context.Background())
// defer cancel()
// b.ResetTimer()
// for i := 0; i < b.N; i++ {
// conn.ChanToSetDeadline().Watch(ctx)
// conn.ChanToSetDeadline().Ignore()
// }
// }
func BenchmarkCommandTagRowsAffected(b *testing.B) {
benchmarks := []struct {
commandTag string
rowsAffected int64
}{
{"UPDATE 1", 1},
{"UPDATE 123456789", 123456789},
{"INSERT 0 1", 1},
{"INSERT 0 123456789", 123456789},
}
for _, bm := range benchmarks {
ct := pgconn.CommandTag(bm.commandTag)
b.Run(bm.commandTag, func(b *testing.B) {
var n int64
for i := 0; i < b.N; i++ {
n = ct.RowsAffected()
}
if n != bm.rowsAffected {
b.Errorf("expected %d got %d", bm.rowsAffected, n)
}
})
}
}
func BenchmarkCommandTagTypeFromString(b *testing.B) {
ct := pgconn.CommandTag("UPDATE 1")
var update bool
for i := 0; i < b.N; i++ {
update = strings.HasPrefix(ct.String(), "UPDATE")
}
if !update {
b.Error("expected update")
}
}
func BenchmarkCommandTagInsert(b *testing.B) {
benchmarks := []struct {
commandTag string
is bool
}{
{"INSERT 1", true},
{"INSERT 1234567890", true},
{"UPDATE 1", false},
{"UPDATE 1234567890", false},
{"DELETE 1", false},
{"DELETE 1234567890", false},
{"SELECT 1", false},
{"SELECT 1234567890", false},
{"UNKNOWN 1234567890", false},
}
for _, bm := range benchmarks {
ct := pgconn.CommandTag(bm.commandTag)
b.Run(bm.commandTag, func(b *testing.B) {
var is bool
for i := 0; i < b.N; i++ {
is = ct.Insert()
}
if is != bm.is {
b.Errorf("expected %v got %v", bm.is, is)
}
})
}
}

10
pgconn/ci/script.bash Executable file
View File

@ -0,0 +1,10 @@
#!/usr/bin/env bash
set -eux
if [ "${PGVERSION-}" != "" ]
then
go test -v -race ./...
elif [ "${CRATEVERSION-}" != "" ]
then
go test -v -race -run 'TestCrateDBConnect'
fi

59
pgconn/ci/setup_test.bash Executable file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env bash
set -eux
if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]]
then
sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common
sudo rm -rf /var/lib/postgresql
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list"
sudo apt-get update -qq
sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION
sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf
if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then
echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
fi
sudo /etc/init.d/postgresql restart
# The tricky test user, below, has to actually exist so that it can be used in a test
# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
psql -U postgres -c 'create database pgx_test'
psql -U postgres pgx_test -c 'create extension hstore'
psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)'
psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
psql -U postgres -c "create user `whoami`"
psql -U postgres -c "create user pgx_replication with replication password 'secret'"
psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
fi
if [[ "${PGVERSION-}" =~ ^cockroach ]]
then
wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz
sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/
cockroach start-single-node --insecure --background --listen-addr=localhost
cockroach sql --insecure -e 'create database pgx_test'
fi
if [ "${CRATEVERSION-}" != "" ]
then
docker run \
-p "6543:5432" \
-d \
crate:"$CRATEVERSION" \
crate \
-Cnetwork.host=0.0.0.0 \
-Ctransport.host=localhost \
-Clicense.enterprise=false
fi

729
pgconn/config.go Normal file
View File

@ -0,0 +1,729 @@
package pgconn
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/jackc/chunkreader/v2"
"github.com/jackc/pgpassfile"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgservicefile"
)
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
// manually initialized Config will cause ConnectConfig to panic.
type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
ConnectTimeout time.Duration
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Fallbacks []*FallbackConfig
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
ValidateConnect ValidateConnectFunc
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
// or prepare statements). If this returns an error the connection attempt fails.
AfterConnect AfterConnectFunc
// OnNotice is a callback function called when a notice response is received.
OnNotice NoticeHandler
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the TLSConfig field:
// according to the tls.Config docs it must not be modified after creation.
func (c *Config) Copy() *Config {
newConf := new(Config)
*newConf = *c
if newConf.TLSConfig != nil {
newConf.TLSConfig = c.TLSConfig.Clone()
}
if newConf.RuntimeParams != nil {
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
for k, v := range c.RuntimeParams {
newConf.RuntimeParams[k] = v
}
}
if newConf.Fallbacks != nil {
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
for i, fallback := range c.Fallbacks {
newFallback := new(FallbackConfig)
*newFallback = *fallback
if newFallback.TLSConfig != nil {
newFallback.TLSConfig = fallback.TLSConfig.Clone()
}
newConf.Fallbacks[i] = newFallback
}
}
return newConf
}
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
type FallbackConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16
TLSConfig *tls.Config // nil disables TLS
}
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
// net.Dial.
func NetworkAddress(host string, port uint16) (network, address string) {
if strings.HasPrefix(host, "/") {
network = "unix"
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
} else {
network = "tcp"
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
}
return network, address
}
// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same
// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches
// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
//
// # Example DSN
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
//
// # Example URL
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
//
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
// not be modified individually. They should all be modified or all left unchanged.
//
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
// values that will be tried in order. This can be used as part of a high availability system. See
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
//
// # Example URL
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
//
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or DSN:
//
// PGHOST
// PGPORT
// PGDATABASE
// PGUSER
// PGPASSWORD
// PGPASSFILE
// PGSERVICE
// PGSERVICEFILE
// PGSSLMODE
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
//
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
//
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
// usually but not always the environment variable name downcased and without the "PG" prefix.
//
// Important Security Notes:
//
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
// not set.
//
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
// security each sslmode provides.
//
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
// TLCConfig.
//
// Other known differences with libpq:
//
// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first.
//
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
// does not.
//
// In addition, ParseConfig accepts the following options:
//
// min_read_buffer_size
// The minimum size of the internal read buffer. Default 8192.
// servicefile
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
func ParseConfig(connString string) (*Config, error) {
defaultSettings := defaultSettings()
envSettings := parseEnvSettings()
connStringSettings := make(map[string]string)
if connString != "" {
var err error
// connString may be a database URL or a DSN
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseDSNSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
}
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
}
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
}
config := &Config{
createdByParseConfig: true,
Database: settings["database"],
User: settings["user"],
Password: settings["password"],
RuntimeParams: make(map[string]string),
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
}
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
} else {
defaultDialer := makeDefaultDialer()
config.DialFunc = defaultDialer.DialContext
}
config.LookupFunc = makeDefaultResolver().LookupHost
notRuntimeParams := map[string]struct{}{
"host": struct{}{},
"port": struct{}{},
"database": struct{}{},
"user": struct{}{},
"password": struct{}{},
"passfile": struct{}{},
"connect_timeout": struct{}{},
"sslmode": struct{}{},
"sslkey": struct{}{},
"sslcert": struct{}{},
"sslrootcert": struct{}{},
"target_session_attrs": struct{}{},
"min_read_buffer_size": struct{}{},
"service": struct{}{},
"servicefile": struct{}{},
}
for k, v := range settings {
if _, present := notRuntimeParams[k]; present {
continue
}
config.RuntimeParams[k] = v
}
fallbacks := []*FallbackConfig{}
hosts := strings.Split(settings["host"], ",")
ports := strings.Split(settings["port"], ",")
for i, host := range hosts {
var portStr string
if i < len(ports) {
portStr = ports[i]
} else {
portStr = ports[0]
}
port, err := parsePort(portStr)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
// Ignore TLS settings if Unix domain socket like libpq
if network, _ := NetworkAddress(host, port); network == "unix" {
tlsConfigs = append(tlsConfigs, nil)
} else {
var err error
tlsConfigs, err = configTLS(settings, host)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
}
}
for _, tlsConfig := range tlsConfigs {
fallbacks = append(fallbacks, &FallbackConfig{
Host: host,
Port: port,
TLSConfig: tlsConfig,
})
}
}
config.Host = fallbacks[0].Host
config.Port = fallbacks[0].Port
config.TLSConfig = fallbacks[0].TLSConfig
config.Fallbacks = fallbacks[1:]
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
if err == nil {
if config.Password == "" {
host := config.Host
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
host = "localhost"
}
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
}
}
if settings["target_session_attrs"] == "read-write" {
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
} else if settings["target_session_attrs"] != "any" {
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])}
}
return config, nil
}
func mergeSettings(settingSets ...map[string]string) map[string]string {
settings := make(map[string]string)
for _, s2 := range settingSets {
for k, v := range s2 {
settings[k] = v
}
}
return settings
}
func parseEnvSettings() map[string]string {
settings := make(map[string]string)
nameMap := map[string]string{
"PGHOST": "host",
"PGPORT": "port",
"PGDATABASE": "database",
"PGUSER": "user",
"PGPASSWORD": "password",
"PGPASSFILE": "passfile",
"PGAPPNAME": "application_name",
"PGCONNECT_TIMEOUT": "connect_timeout",
"PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLROOTCERT": "sslrootcert",
"PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service",
"PGSERVICEFILE": "servicefile",
}
for envname, realname := range nameMap {
value := os.Getenv(envname)
if value != "" {
settings[realname] = value
}
}
return settings
}
func parseURLSettings(connString string) (map[string]string, error) {
settings := make(map[string]string)
url, err := url.Parse(connString)
if err != nil {
return nil, err
}
if url.User != nil {
settings["user"] = url.User.Username()
if password, present := url.User.Password(); present {
settings["password"] = password
}
}
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
var hosts []string
var ports []string
for _, host := range strings.Split(url.Host, ",") {
if host == "" {
continue
}
if isIPOnly(host) {
hosts = append(hosts, strings.Trim(host, "[]"))
continue
}
h, p, err := net.SplitHostPort(host)
if err != nil {
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
}
if h != "" {
hosts = append(hosts, h)
}
if p != "" {
ports = append(ports, p)
}
}
if len(hosts) > 0 {
settings["host"] = strings.Join(hosts, ",")
}
if len(ports) > 0 {
settings["port"] = strings.Join(ports, ",")
}
database := strings.TrimLeft(url.Path, "/")
if database != "" {
settings["database"] = database
}
nameMap := map[string]string{
"dbname": "database",
}
for k, v := range url.Query() {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v[0]
}
return settings, nil
}
func isIPOnly(host string) bool {
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
}
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseDSNSettings(s string) (map[string]string, error) {
settings := make(map[string]string)
nameMap := map[string]string{
"dbname": "database",
}
for len(s) > 0 {
var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return nil, errors.New("invalid dsn")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
if len(s) == 0 {
} else if s[0] != '\'' {
end := 0
for ; end < len(s); end++ {
if asciiSpace[s[end]] == 1 {
break
}
if s[end] == '\\' {
end++
if end == len(s) {
return nil, errors.New("invalid backslash")
}
}
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
} else { // quoted string
s = s[1:]
end := 0
for ; end < len(s); end++ {
if s[end] == '\'' {
break
}
if s[end] == '\\' {
end++
}
}
if end == len(s) {
return nil, errors.New("unterminated quoted string in connection info string")
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
}
if k, ok := nameMap[key]; ok {
key = k
}
if key == "" {
return nil, errors.New("invalid dsn")
}
settings[key] = val
}
return settings, nil
}
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
if err != nil {
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
}
service, err := servicefile.GetService(serviceName)
if err != nil {
return nil, fmt.Errorf("unable to find service: %v", serviceName)
}
nameMap := map[string]string{
"dbname": "database",
}
settings := make(map[string]string, len(service.Settings))
for k, v := range service.Settings {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v
}
return settings, nil
}
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
// necessary to allow returning multiple TLS configs as sslmode "allow" and
// "prefer" allow fallback.
func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) {
host := thisHost
sslmode := settings["sslmode"]
sslrootcert := settings["sslrootcert"]
sslcert := settings["sslcert"]
sslkey := settings["sslkey"]
// Match libpq default behavior
if sslmode == "" {
sslmode = "prefer"
}
tlsConfig := &tls.Config{}
switch sslmode {
case "disable":
return []*tls.Config{nil}, nil
case "allow", "prefer":
tlsConfig.InsecureSkipVerify = true
case "require":
// According to PostgreSQL documentation, if a root CA file exists,
// the behavior of sslmode=require should be the same as that of verify-ca
//
// See https://www.postgresql.org/docs/12/libpq-ssl.html
if sslrootcert != "" {
goto nextCase
}
tlsConfig.InsecureSkipVerify = true
break
nextCase:
fallthrough
case "verify-ca":
// Don't perform the default certificate verification because it
// will verify the hostname. Instead, verify the server's
// certificate chain ourselves in VerifyPeerCertificate and
// ignore the server name. This emulates libpq's verify-ca
// behavior.
//
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
// for more info.
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, len(certificates))
for i, asn1Data := range certificates {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return errors.New("failed to parse certificate from server: " + err.Error())
}
certs[i] = cert
}
// Leave DNSName empty to skip hostname verification.
opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
Intermediates: x509.NewCertPool(),
}
// Skip the first cert because it's the leaf. All others
// are intermediates.
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
return err
}
case "verify-full":
tlsConfig.ServerName = host
default:
return nil, errors.New("sslmode is invalid")
}
if sslrootcert != "" {
caCertPool := x509.NewCertPool()
caPath := sslrootcert
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
}
if sslcert != "" && sslkey != "" {
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
switch sslmode {
case "allow":
return []*tls.Config{nil, tlsConfig}, nil
case "prefer":
return []*tls.Config{tlsConfig, nil}, nil
case "require", "verify-ca", "verify-full":
return []*tls.Config{tlsConfig}, nil
default:
panic("BUG: bad sslmode should already have been caught")
}
}
func parsePort(s string) (uint16, error) {
port, err := strconv.ParseUint(s, 10, 16)
if err != nil {
return 0, err
}
if port < 1 || port > math.MaxUint16 {
return 0, errors.New("outside range")
}
return uint16(port), nil
}
func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func makeDefaultResolver() *net.Resolver {
return net.DefaultResolver
}
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
return func(r io.Reader, w io.Writer) Frontend {
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
if err != nil {
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
}
frontend := pgproto3.NewFrontend(cr, w)
return frontend
}
}
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, err
}
if timeout < 0 {
return 0, errors.New("negative timeout")
}
return time.Duration(timeout) * time.Second, nil
}
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
d := makeDefaultDialer()
d.Timeout = timeout
return d.DialContext
}
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) == "on" {
return errors.New("read only connection")
}
return nil
}

900
pgconn/config_test.go Normal file
View File

@ -0,0 +1,900 @@
package pgconn_test
import (
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"os"
"os/user"
"runtime"
"strings"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseConfig(t *testing.T) {
t.Parallel()
var osUserName string
osUser, err := user.Current()
if err == nil {
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
// but the libpq default is just the `user` portion, so we strip off the first part.
if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") {
osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:]
} else {
osUserName = osUser.Username
}
}
config, err := pgconn.ParseConfig("")
require.NoError(t, err)
defaultHost := config.Host
tests := []struct {
name string
connString string
config *pgconn.Config
}{
// Test all sslmodes
{
name: "sslmode not set (prefer)",
connString: "postgres://jack:secret@localhost:5432/mydb",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "localhost",
Port: 5432,
TLSConfig: nil,
},
},
},
},
{
name: "sslmode disable",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "sslmode allow",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "localhost",
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
},
},
},
{
name: "sslmode prefer",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "localhost",
Port: 5432,
TLSConfig: nil,
},
},
},
},
{
name: "sslmode require",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "sslmode verify-ca",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "sslmode verify-full",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{ServerName: "localhost"},
RuntimeParams: map[string]string{},
},
},
{
name: "database url everything",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
ConnectTimeout: 5 * time.Second,
RuntimeParams: map[string]string{
"application_name": "pgxtest",
"search_path": "myschema",
},
},
},
{
name: "database url missing password",
connString: "postgres://jack@localhost:5432/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "database url missing user and password",
connString: "postgres://localhost:5432/mydb?sslmode=disable",
config: &pgconn.Config{
User: osUserName,
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "database url missing port",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "database url unix domain socket host",
connString: "postgres:///foo?host=/tmp",
config: &pgconn.Config{
User: osUserName,
Host: "/tmp",
Port: 5432,
Database: "foo",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "database url dbname",
connString: "postgres://localhost/?dbname=foo&sslmode=disable",
config: &pgconn.Config{
User: osUserName,
Host: "localhost",
Port: 5432,
Database: "foo",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "database url postgresql protocol",
connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "database url IPv4 with port",
connString: "postgresql://jack@127.0.0.1:5433/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Host: "127.0.0.1",
Port: 5433,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "database url IPv6 with port",
connString: "postgresql://jack@[2001:db8::1]:5433/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Host: "2001:db8::1",
Port: 5433,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "database url IPv6 no port",
connString: "postgresql://jack@[2001:db8::1]/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Host: "2001:db8::1",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "DSN everything",
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
ConnectTimeout: 5 * time.Second,
RuntimeParams: map[string]string{
"application_name": "pgxtest",
"search_path": "myschema",
},
},
},
{
name: "DSN with escaped single quote",
connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
config: &pgconn.Config{
User: "jack's",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "DSN with escaped backslash",
connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable",
config: &pgconn.Config{
User: "jack",
Password: "sooper\\secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "DSN with single quoted values",
connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "DSN with single quoted value with escaped single quote",
connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'",
config: &pgconn.Config{
User: "jack's",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "DSN with empty single quoted value",
connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "DSN with space between key and value",
connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "URL multiple hosts",
connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "foo",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "bar",
Port: 5432,
TLSConfig: nil,
},
&pgconn.FallbackConfig{
Host: "baz",
Port: 5432,
TLSConfig: nil,
},
},
},
},
{
name: "URL multiple hosts and ports",
connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "foo",
Port: 1,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "bar",
Port: 2,
TLSConfig: nil,
},
&pgconn.FallbackConfig{
Host: "baz",
Port: 3,
TLSConfig: nil,
},
},
},
},
// https://github.com/jackc/pgconn/issues/72
{
name: "URL without host but with port still uses default host",
connString: "postgres://jack:secret@:1/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: defaultHost,
Port: 1,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "DSN multiple hosts one port",
connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "foo",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "bar",
Port: 5432,
TLSConfig: nil,
},
&pgconn.FallbackConfig{
Host: "baz",
Port: 5432,
TLSConfig: nil,
},
},
},
},
{
name: "DSN multiple hosts multiple ports",
connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "foo",
Port: 1,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "bar",
Port: 2,
TLSConfig: nil,
},
&pgconn.FallbackConfig{
Host: "baz",
Port: 3,
TLSConfig: nil,
},
},
},
},
{
name: "multiple hosts and fallback tsl",
connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "foo",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "foo",
Port: 5432,
TLSConfig: nil,
},
&pgconn.FallbackConfig{
Host: "bar",
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
}},
&pgconn.FallbackConfig{
Host: "bar",
Port: 5432,
TLSConfig: nil,
},
&pgconn.FallbackConfig{
Host: "baz",
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
}},
&pgconn.FallbackConfig{
Host: "baz",
Port: 5432,
TLSConfig: nil,
},
},
},
},
{
name: "target_session_attrs",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
},
},
}
for i, tt := range tests {
config, err := pgconn.ParseConfig(tt.connString)
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
continue
}
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
}
}
// https://github.com/jackc/pgconn/issues/47
func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) {
_, err := pgconn.ParseConfig("host= user= password= port= database=")
require.NoError(t, err)
}
func TestParseConfigDSNLeadingEqual(t *testing.T) {
_, err := pgconn.ParseConfig("= user=jack")
require.Error(t, err)
}
// https://github.com/jackc/pgconn/issues/49
func TestParseConfigDSNTrailingBackslash(t *testing.T) {
_, err := pgconn.ParseConfig(`x=x\`)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid backslash")
}
func TestConfigCopyReturnsEqualConfig(t *testing.T) {
connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5"
original, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
copied := original.Copy()
assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config")
}
func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) {
connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5&sslmode=prefer"
original, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
copied := original.Copy()
assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config")
copied.Port = uint16(5433)
copied.RuntimeParams["foo"] = "bar"
copied.Fallbacks[0].Port = uint16(5433)
assert.Equal(t, uint16(5432), original.Port)
assert.Equal(t, "", original.RuntimeParams["foo"])
assert.Equal(t, uint16(5432), original.Fallbacks[0].Port)
}
func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
connString := os.Getenv("PGX_TEST_CONN_STRING")
original, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
copied := original.Copy()
assert.NotPanics(t, func() {
_, err = pgconn.ConnectConfig(context.Background(), copied)
})
assert.NoError(t, err)
}
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
if !assert.NotNil(t, expected) {
return
}
if !assert.NotNil(t, actual) {
return
}
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
// Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
if expected.TLSConfig != nil {
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
}
}
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
for i := range expected.Fallbacks {
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
if expected.Fallbacks[i].TLSConfig != nil {
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
}
}
}
}
}
func TestParseConfigEnvLibpq(t *testing.T) {
var osUserName string
osUser, err := user.Current()
if err == nil {
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
// but the libpq default is just the `user` portion, so we strip off the first part.
if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") {
osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:]
} else {
osUserName = osUser.Username
}
}
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
savedEnv := make(map[string]string)
for _, n := range pgEnvvars {
savedEnv[n] = os.Getenv(n)
}
defer func() {
for k, v := range savedEnv {
err := os.Setenv(k, v)
if err != nil {
t.Fatalf("Unable to restore environment: %v", err)
}
}
}()
tests := []struct {
name string
envvars map[string]string
config *pgconn.Config
}{
{
// not testing no environment at all as that would use default host and that can vary.
name: "PGHOST only",
envvars: map[string]string{"PGHOST": "123.123.123.123"},
config: &pgconn.Config{
User: osUserName,
Host: "123.123.123.123",
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "123.123.123.123",
Port: 5432,
TLSConfig: nil,
},
},
},
},
{
name: "All non-TLS environment",
envvars: map[string]string{
"PGHOST": "123.123.123.123",
"PGPORT": "7777",
"PGDATABASE": "foo",
"PGUSER": "bar",
"PGPASSWORD": "baz",
"PGCONNECT_TIMEOUT": "10",
"PGSSLMODE": "disable",
"PGAPPNAME": "pgxtest",
},
config: &pgconn.Config{
Host: "123.123.123.123",
Port: 7777,
Database: "foo",
User: "bar",
Password: "baz",
ConnectTimeout: 10 * time.Second,
TLSConfig: nil,
RuntimeParams: map[string]string{"application_name": "pgxtest"},
},
},
}
for i, tt := range tests {
for _, n := range pgEnvvars {
err := os.Unsetenv(n)
require.NoError(t, err)
}
for k, v := range tt.envvars {
err := os.Setenv(k, v)
require.NoError(t, err)
}
config, err := pgconn.ParseConfig("")
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
continue
}
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
}
}
func TestParseConfigReadsPgPassfile(t *testing.T) {
t.Parallel()
tf, err := ioutil.TempFile("", "")
require.NoError(t, err)
defer tf.Close()
defer os.Remove(tf.Name())
_, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk"))
require.NoError(t, err)
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name())
expected := &pgconn.Config{
User: "curly",
Password: "nyuknyuknyuk",
Host: "test1",
Port: 5432,
Database: "curlydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
}
actual, err := pgconn.ParseConfig(connString)
assert.NoError(t, err)
assertConfigsEqual(t, expected, actual, "passfile")
}
func TestParseConfigReadsPgServiceFile(t *testing.T) {
t.Parallel()
tf, err := ioutil.TempFile("", "")
require.NoError(t, err)
defer tf.Close()
defer os.Remove(tf.Name())
_, err = tf.Write([]byte(`
[abc]
host=abc.example.com
port=9999
dbname=abcdb
user=abcuser
[def]
host = def.example.com
dbname = defdb
user = defuser
application_name = spaced string
`))
require.NoError(t, err)
tests := []struct {
name string
connString string
config *pgconn.Config
}{
{
name: "abc",
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"),
config: &pgconn.Config{
Host: "abc.example.com",
Database: "abcdb",
User: "abcuser",
Port: 9999,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "abc.example.com",
Port: 9999,
TLSConfig: nil,
},
},
},
},
{
name: "def",
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"),
config: &pgconn.Config{
Host: "def.example.com",
Port: 5432,
Database: "defdb",
User: "defuser",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{"application_name": "spaced string"},
Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "def.example.com",
Port: 5432,
TLSConfig: nil,
},
},
},
},
{
name: "conn string has precedence",
connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"),
config: &pgconn.Config{
Host: "other.example.com",
Database: "abcdb",
User: "abcuser",
Port: 7777,
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
}
for i, tt := range tests {
config, err := pgconn.ParseConfig(tt.connString)
if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) {
continue
}
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
}
}
func TestParseConfigExtractsMinReadBufferSize(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig("min_read_buffer_size=0")
require.NoError(t, err)
_, present := config.RuntimeParams["min_read_buffer_size"]
require.False(t, present)
// The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param
// was removed.
}

64
pgconn/defaults.go Normal file
View File

@ -0,0 +1,64 @@
// +build !windows
package pgconn
import (
"os"
"os/user"
"path/filepath"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
if err == nil {
settings["user"] = user.Username
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
settings["min_read_buffer_size"] = "8192"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
candidatePaths := []string{
"/var/run/postgresql", // Debian
"/private/tmp", // OSX - homebrew
"/tmp", // standard PostgreSQL
}
for _, path := range candidatePaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return "localhost"
}

View File

@ -0,0 +1,59 @@
package pgconn
import (
"os"
"os/user"
"path/filepath"
"strings"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
appData := os.Getenv("APPDATA")
if err == nil {
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
// but the libpq default is just the `user` portion, so we strip off the first part.
username := user.Username
if strings.Contains(username, "\\") {
username = username[strings.LastIndex(username, "\\")+1:]
}
settings["user"] = username
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
settings["min_read_buffer_size"] = "8192"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
return "localhost"
}

29
pgconn/doc.go Normal file
View File

@ -0,0 +1,29 @@
// Package pgconn is a low-level PostgreSQL database driver.
/*
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
nearly the same level is the C library libpq.
Establishing a Connection
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
libpq style environment variables.
Executing a Query
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
reads all rows into memory.
Executing Multiple Queries in a Single Round Trip
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
result. The ReadAll method reads all query results into memory.
Context Support
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
method immediately returns. In most circumstances, this will close the underlying connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.
*/
package pgconn

221
pgconn/errors.go Normal file
View File

@ -0,0 +1,221 @@
package pgconn
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"regexp"
"strings"
)
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool {
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
return e.SafeToRetry()
}
return false
}
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
var timeoutErr *errTimeout
return errors.As(err, &timeoutErr)
}
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// SQLState returns the SQLState of the error.
func (pe *PgError) SQLState() string {
return pe.Code
}
type connectError struct {
config *Config
msg string
err error
}
func (e *connectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
return sb.String()
}
func (e *connectError) Unwrap() error {
return e.err
}
type connLockError struct {
status string
}
func (e *connLockError) SafeToRetry() bool {
return true // a lock failure by definition happens before the connection is used.
}
func (e *connLockError) Error() string {
return e.status
}
type parseConfigError struct {
connString string
msg string
err error
}
func (e *parseConfigError) Error() string {
connString := redactPW(e.connString)
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
}
func (e *parseConfigError) Unwrap() error {
return e.err
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return &errTimeout{err: ctx.Err()}
}
return err
}
type pgconnError struct {
msg string
err error
safeToRetry bool
}
func (e *pgconnError) Error() string {
if e.msg == "" {
return e.err.Error()
}
if e.err == nil {
return e.msg
}
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
}
func (e *pgconnError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *pgconnError) Unwrap() error {
return e.err
}
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
type errTimeout struct {
err error
}
func (e *errTimeout) Error() string {
return fmt.Sprintf("timeout: %s", e.err.Error())
}
func (e *errTimeout) SafeToRetry() bool {
return SafeToRetry(e.err)
}
func (e *errTimeout) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
func (e *contextAlreadyDoneError) Error() string {
return fmt.Sprintf("context already done: %s", e.err.Error())
}
func (e *contextAlreadyDoneError) SafeToRetry() bool {
return true
}
func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
func newContextAlreadyDoneError(ctx context.Context) (err error) {
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
}
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
}
func redactPW(connString string) string {
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
if u, err := url.Parse(connString); err == nil {
return redactURL(u)
}
}
quotedDSN := regexp.MustCompile(`password='[^']*'`)
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
plainDSN := regexp.MustCompile(`password=[^ ]*`)
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
return connString
}
func redactURL(u *url.URL) string {
if u == nil {
return ""
}
if _, pwSet := u.User.Password(); pwSet {
u.User = url.UserPassword(u.User.Username(), "xxxxx")
}
return u.String()
}

54
pgconn/errors_test.go Normal file
View File

@ -0,0 +1,54 @@
package pgconn_test
import (
"testing"
"github.com/jackc/pgconn"
"github.com/stretchr/testify/assert"
)
func TestConfigError(t *testing.T) {
tests := []struct {
name string
err error
expectedMsg string
}{
{
name: "url with password",
err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil),
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg",
},
{
name: "dsn with password unquoted",
err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil),
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
},
{
name: "dsn with password quoted",
err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil),
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
},
{
name: "weird url",
err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil),
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg",
},
{
name: "weird url with slash in password",
err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil),
expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg",
},
{
name: "url without password",
err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil),
expectedMsg: "cannot parse `postgresql://other@host/db`: msg",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.EqualError(t, tt.err, tt.expectedMsg)
})
}
}

11
pgconn/export_test.go Normal file
View File

@ -0,0 +1,11 @@
// File export_test exports some methods for better testing.
package pgconn
func NewParseConfigError(conn, msg string, err error) error {
return &parseConfigError{
connString: conn,
msg: msg,
err: err,
}
}

70
pgconn/frontend_test.go Normal file
View File

@ -0,0 +1,70 @@
package pgconn_test
import (
"context"
"io"
"os"
"testing"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// frontendWrapper allows to hijack a regular frontend, and inject a specific response
type frontendWrapper struct {
front pgconn.Frontend
msg pgproto3.BackendMessage
}
// frontendWrapper implements the pgconn.Frontend interface
var _ pgconn.Frontend = (*frontendWrapper)(nil)
func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) {
if f.msg != nil {
return f.msg, nil
}
return f.front.Receive()
}
func TestFrontendFatalErrExec(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
buildFrontend := config.BuildFrontend
var front *frontendWrapper
config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend {
wrapped := buildFrontend(r, w)
front = &frontendWrapper{wrapped, nil}
return front
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, conn)
require.NotNil(t, front)
// set frontend to return a "FATAL" message on next call
front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"}
_, err = conn.Exec(context.Background(), "SELECT 1").ReadAll()
assert.Error(t, err)
err = conn.Close(context.Background())
assert.NoError(t, err)
select {
case <-conn.CleanupDone():
t.Log("ok, CleanupDone() is not blocking")
default:
assert.Fail(t, "connection closed but CleanupDone() still blocking")
}
}

15
pgconn/go.mod Normal file
View File

@ -0,0 +1,15 @@
module github.com/jackc/pgconn
go 1.12
require (
github.com/jackc/chunkreader/v2 v2.0.1
github.com/jackc/pgio v1.0.0
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65
github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgproto3/v2 v2.1.1
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b
github.com/stretchr/testify v1.7.0
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
golang.org/x/text v0.3.6
)

130
pgconn/go.sum Normal file
View File

@ -0,0 +1,130 @@
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o=
github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c=
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc=
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI=
github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

36
pgconn/helper_test.go Normal file
View File

@ -0,0 +1,36 @@
package pgconn_test
import (
"context"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func closeConn(t testing.TB, conn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
require.NoError(t, conn.Close(ctx))
select {
case <-conn.CleanupDone():
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
// Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
cancel()
require.Nil(t, result.Err)
assert.Equal(t, 3, len(result.Rows))
assert.Equal(t, "1", string(result.Rows[0][0]))
assert.Equal(t, "2", string(result.Rows[1][0]))
assert.Equal(t, "3", string(result.Rows[2][0]))
}

View File

@ -0,0 +1,73 @@
package ctxwatch
import (
"context"
"sync"
)
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
onCancel func()
onUnwatchAfterCancel func()
unwatchChan chan struct{}
lock sync.Mutex
watchInProgress bool
onCancelWasCalled bool
}
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
cw := &ContextWatcher{
onCancel: onCancel,
onUnwatchAfterCancel: onUnwatchAfterCancel,
unwatchChan: make(chan struct{}),
}
return cw
}
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
func (cw *ContextWatcher) Watch(ctx context.Context) {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
panic("Watch already in progress")
}
cw.onCancelWasCalled = false
if ctx.Done() != nil {
cw.watchInProgress = true
go func() {
select {
case <-ctx.Done():
cw.onCancel()
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
}
}()
} else {
cw.watchInProgress = false
}
}
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
// called then onUnwatchAfterCancel will also be called.
func (cw *ContextWatcher) Unwatch() {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
cw.onUnwatchAfterCancel()
}
cw.watchInProgress = false
}
}

View File

@ -0,0 +1,163 @@
package ctxwatch_test
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/jackc/pgconn/internal/ctxwatch"
"github.com/stretchr/testify/require"
)
func TestContextWatcherContextCancelled(t *testing.T) {
canceledChan := make(chan struct{})
cleanupCalled := false
cw := ctxwatch.NewContextWatcher(func() {
canceledChan <- struct{}{}
}, func() {
cleanupCalled = true
})
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
cancel()
select {
case <-canceledChan:
case <-time.NewTimer(time.Second).C:
t.Fatal("Timed out waiting for cancel func to be called")
}
cw.Unwatch()
require.True(t, cleanupCalled, "Cleanup func was not called")
}
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {
t.Error("cancel func should not have been called")
}, func() {
t.Error("cleanup func should not have been called")
})
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
cw.Unwatch()
cancel()
}
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cw.Watch(ctx)
ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
}
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
cw.Unwatch() // unwatch when not / never watching
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cw.Watch(ctx)
cw.Unwatch()
cw.Unwatch() // double unwatch
}
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
cw.Watch(ctx)
go cw.Unwatch()
go cw.Unwatch()
<-ctx.Done()
}
func TestContextWatcherStress(t *testing.T) {
var cancelFuncCalls int64
var cleanupFuncCalls int64
cw := ctxwatch.NewContextWatcher(func() {
atomic.AddInt64(&cancelFuncCalls, 1)
}, func() {
atomic.AddInt64(&cleanupFuncCalls, 1)
})
cycleCount := 100000
for i := 0; i < cycleCount; i++ {
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
if i%2 == 0 {
cancel()
}
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
if i%3 == 0 {
time.Sleep(time.Nanosecond)
}
cw.Unwatch()
if i%2 == 1 {
cancel()
}
}
actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)
if actualCancelFuncCalls == 0 {
t.Fatal("actualCancelFuncCalls == 0")
}
maxCancelFuncCalls := int64(cycleCount) / 2
if actualCancelFuncCalls > maxCancelFuncCalls {
t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
}
if actualCancelFuncCalls != actualCleanupFuncCalls {
t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
}
}
func BenchmarkContextWatcherUncancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
for i := 0; i < b.N; i++ {
cw.Watch(context.Background())
cw.Unwatch()
}
}
func BenchmarkContextWatcherCancelled(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
for i := 0; i < b.N; i++ {
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
cancel()
cw.Unwatch()
}
}
func BenchmarkContextWatcherCancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for i := 0; i < b.N; i++ {
cw.Watch(ctx)
cw.Unwatch()
}
}

1703
pgconn/pgconn.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,90 @@
package pgconn_test
import (
"context"
"math/rand"
"os"
"runtime"
"strconv"
"testing"
"github.com/jackc/pgconn"
"github.com/stretchr/testify/require"
)
func TestConnStress(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
actionCount := 10000
if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" {
stressFactor, err := strconv.ParseInt(s, 10, 64)
require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR")
actionCount *= int(stressFactor)
}
setupStressDB(t, pgConn)
actions := []struct {
name string
fn func(*pgconn.PgConn) error
}{
{"Exec Select", stressExecSelect},
{"ExecParams Select", stressExecParamsSelect},
{"Batch", stressBatch},
}
for i := 0; i < actionCount; i++ {
action := actions[rand.Intn(len(actions))]
err := action.fn(pgConn)
require.Nilf(t, err, "%d: %s", i, action.name)
}
// Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled.
numGoroutine := runtime.NumGoroutine()
require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine)
}
func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
_, err := pgConn.Exec(context.Background(), `
create temporary table widgets(
id serial primary key,
name varchar not null,
description text,
creation_time timestamptz default now()
);
insert into widgets(name, description) values
('Foo', 'bar'),
('baz', 'Something really long Something really long Something really long Something really long Something really long'),
('a', 'b')`).ReadAll()
require.NoError(t, err)
}
func stressExecSelect(pgConn *pgconn.PgConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := pgConn.Exec(ctx, "select * from widgets").ReadAll()
return err
}
func stressExecParamsSelect(pgConn *pgconn.PgConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
return result.Err
}
func stressBatch(pgConn *pgconn.PgConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
batch := &pgconn.Batch{}
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
return err
}

2012
pgconn/pgconn_test.go Normal file

File diff suppressed because it is too large Load Diff

165
pgconn/stmtcache/lru.go Normal file
View File

@ -0,0 +1,165 @@
package stmtcache
import (
"container/list"
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgconn"
)
var lruCount uint64
// LRU implements Cache with a Least Recently Used (LRU) cache.
type LRU struct {
conn *pgconn.PgConn
mode int
cap int
prepareCount int
m map[string]*list.Element
l *list.List
psNamePrefix string
stmtsToClear []string
}
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
mustBeValidMode(mode)
mustBeValidCap(cap)
n := atomic.AddUint64(&lruCount, 1)
return &LRU{
conn: conn,
mode: mode,
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
}
}
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// flush an outstanding bad statements
txStatus := c.conn.TxStatus()
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
for _, stmt := range c.stmtsToClear {
err := c.clearStmt(ctx, stmt)
if err != nil {
return nil, err
}
}
}
if el, ok := c.m[sql]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription), nil
}
if c.l.Len() == c.cap {
err := c.removeOldest(ctx)
if err != nil {
return nil, err
}
}
psd, err := c.prepare(ctx, sql)
if err != nil {
return nil, err
}
el := c.l.PushFront(psd)
c.m[sql] = el
return psd, nil
}
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
func (c *LRU) Clear(ctx context.Context) error {
for c.l.Len() > 0 {
err := c.removeOldest(ctx)
if err != nil {
return err
}
}
return nil
}
func (c *LRU) StatementErrored(sql string, err error) {
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return
}
isInvalidCachedPlanError := pgErr.Severity == "ERROR" &&
pgErr.Code == "0A000" &&
pgErr.Message == "cached plan must not change result type"
if isInvalidCachedPlanError {
c.stmtsToClear = append(c.stmtsToClear, sql)
}
}
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
elem, inMap := c.m[sql]
if !inMap {
// The statement probably fell off the back of the list. In that case, we've
// ensured that it isn't in the cache, so we can declare victory.
return nil
}
c.l.Remove(elem)
psd := elem.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRU) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRU) Cap() int {
return c.cap
}
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
func (c *LRU) Mode() int {
return c.mode
}
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
var name string
if c.mode == ModePrepare {
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
c.prepareCount += 1
}
return c.conn.Prepare(ctx, name, sql, nil)
}
func (c *LRU) removeOldest(ctx context.Context) error {
oldest := c.l.Back()
c.l.Remove(oldest)
psd := oldest.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}

View File

@ -0,0 +1,292 @@
package stmtcache_test
import (
"context"
"fmt"
"math/rand"
"os"
"regexp"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgconn/stmtcache"
"github.com/stretchr/testify/require"
)
func TestLRUModePrepare(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func TestLRUStmtInvalidation(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
// we construct a fake error because its not super straightforward to actually call
// a prepared statement from the LRU cache without the helper routines which live
// in pgx proper.
fakeInvalidCachePlanError := &pgconn.PgError{
Severity: "ERROR",
Code: "0A000",
Message: "cached plan must not change result type",
}
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
//
// outside of a transaction, we eagerly flush the statement
//
_, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
_, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
//
// within an errored transaction, we defer the flush to after the first get
// that happens after the transaction is rolled back
//
_, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
res := conn.Exec(ctx, "begin")
require.NoError(t, res.Close())
require.Equal(t, byte('T'), conn.TxStatus())
res = conn.Exec(ctx, "selec")
require.Error(t, res.Close())
require.Equal(t, byte('E'), conn.TxStatus())
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
require.EqualValues(t, 1, cache.Len())
res = conn.Exec(ctx, "rollback")
require.NoError(t, res.Close())
_, err = cache.Get(ctx, "select 2")
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
}
func TestLRUStmtInvalidationIntegration(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
sql := "select * from stmtcache_table"
sd1, err := cache.Get(ctx, sql)
require.NoError(t, err)
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
require.NoError(t, result.Err)
result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)")
cache.StatementErrored(sql, result.Err)
sd2, err := cache.Get(ctx, sql)
require.NoError(t, err)
require.NotEqual(t, sd1.Name, sd2.Name)
result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read()
require.NoError(t, result.Err)
}
func TestLRUModePrepareStress(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 8, cache.Cap())
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
for i := 0; i < 1000; i++ {
psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50)))
require.NoError(t, err)
require.NotNil(t, psd)
result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read()
require.NoError(t, result.Err)
}
}
func TestLRUModeDescribe(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func TestLRUContext(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
// test 1 : getting a value for the first time with a cancelled context returns an error
ctx1, cancel1 := context.WithCancel(ctx)
cancel1()
desc, err := cache.Get(ctx1, "SELECT 1")
require.Error(t, err)
require.Nil(t, desc)
// test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error
ctx2, cancel2 := context.WithCancel(ctx)
desc, err = cache.Get(ctx2, "SELECT 2")
require.NoError(t, err)
require.NotNil(t, desc)
cancel2()
desc, err = cache.Get(ctx2, "SELECT 2")
require.Error(t, err)
require.Nil(t, desc)
}
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
var statements []string
for _, r := range result.Rows {
statement := string(r[0])
if conn.ParameterStatus("crdb_version") != "" {
if statement == "PREPARE AS select statement from pg_prepared_statements" {
// CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it.
continue
}
// CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended
// protocol will PostgreSQL does not. Normalize the statement.
re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `)
statement = re.ReplaceAllString(statement, "")
}
statements = append(statements, statement)
}
return statements
}

View File

@ -0,0 +1,58 @@
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
package stmtcache
import (
"context"
"github.com/jackc/pgconn"
)
const (
ModePrepare = iota // Cache should prepare named statements.
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
)
// Cache prepares and caches prepared statement descriptions.
type Cache interface {
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
Clear(ctx context.Context) error
// StatementErrored informs the cache that the given statement resulted in an error when it
// was last used against the database. In some cases, this will cause the cache to maer that
// statement as bad. The bad statement will instead be flushed during the next call to Get
// that occurs outside of a failed transaction.
StatementErrored(sql string, err error)
// Len returns the number of cached prepared statement descriptions.
Len() int
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
Mode() int
}
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
// the maximum size of the cache.
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
mustBeValidMode(mode)
mustBeValidCap(cap)
return NewLRU(conn, mode, cap)
}
func mustBeValidMode(mode int) {
if mode != ModePrepare && mode != ModeDescribe {
panic("mode must be ModePrepare or ModeDescribe")
}
}
func mustBeValidCap(cap int) {
if cap < 1 {
panic("cache must have cap of >= 1")
}
}