Compare commits

..

No commits in common. "master" and "v3.4.0" have entirely different histories.

442 changed files with 38805 additions and 59575 deletions

View File

@ -1,54 +0,0 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
If possible, please provide runnable example such as:
```go
package main
import (
"context"
"log"
"os"
"github.com/jackc/pgx/v5"
)
func main() {
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatal(err)
}
defer conn.Close(context.Background())
// Your code here...
}
```
Please run your example with the race detector enabled. For example, `go run -race main.go` or `go test -race`.
**Expected behavior**
A clear and concise description of what you expected to happen.
**Actual behavior**
A clear and concise description of what actually happened.
**Version**
- Go: `$ go version` -> [e.g. go version go1.18.3 darwin/amd64]
- PostgreSQL: `$ psql --no-psqlrc --tuples-only -c 'select version()'` -> [e.g. PostgreSQL 14.4 on x86_64-apple-darwin21.5.0, compiled by Apple clang version 13.1.6 (clang-1316.0.21.2.5), 64-bit]
- pgx: `$ grep 'github.com/jackc/pgx/v[0-9]' go.mod` -> [e.g. v4.16.1]
**Additional context**
Add any other context about the problem here.

View File

@ -1,20 +0,0 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@ -1,10 +0,0 @@
---
name: Other issues
about: Any issue that is not a bug or a feature request
title: ''
labels: ''
assignees: ''
---
Please describe the issue in detail. If this is a question about how to use pgx please use discussions instead.

View File

@ -1,156 +0,0 @@
name: CI
on:
push:
branches: [master]
pull_request:
branches: [master]
jobs:
test:
name: Test
runs-on: ubuntu-22.04
strategy:
matrix:
go-version: ["1.23", "1.24"]
pg-version: [13, 14, 15, 16, 17, cockroachdb]
include:
- pg-version: 13
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 14
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 15
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 16
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 17
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: cockroachdb
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Set up Go ${{ matrix.go-version }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
- name: Setup database server for testing
run: ci/setup_test.bash
env:
PGVERSION: ${{ matrix.pg-version }}
# - name: Setup upterm session
# uses: lhotari/action-upterm@v1
# with:
# ## limits ssh access and adds the ssh public key for the user which triggered the workflow
# limit-access-to-actor: true
# env:
# PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
# PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
# PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
# PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
# PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
# PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
# PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
# PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
- name: Check formatting
run: |
gofmt -l -s -w .
git status
git diff --exit-code
- name: Test
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
run: go test -parallel=1 -race ./...
env:
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
# TestConnectTLS fails. However, it succeeds if I connect to the CI server with upterm and run it. Give up on that test for now.
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
test-windows:
name: Test Windows
runs-on: windows-latest
strategy:
matrix:
go-version: ["1.23", "1.24"]
steps:
- name: Setup PostgreSQL
id: postgres
uses: ikalnytskyi/action-setup-postgres@v4
with:
database: pgx_test
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Set up Go ${{ matrix.go-version }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
- name: Initialize test database
run: |
psql -f testsetup/postgresql_setup.sql pgx_test
env:
PGSERVICE: ${{ steps.postgres.outputs.service-name }}
shell: bash
- name: Test
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
run: go test -parallel=1 -race ./...
env:
PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }}

4
.gitignore vendored
View File

@ -21,7 +21,5 @@ _testmain.go
*.exe
conn_config_test.go
.envrc
/.testdb
.DS_Store

View File

@ -1,21 +0,0 @@
# See for configurations: https://golangci-lint.run/usage/configuration/
version: 2
# See: https://golangci-lint.run/usage/formatters/
formatters:
default: none
enable:
- gofmt # https://pkg.go.dev/cmd/gofmt
- gofumpt # https://github.com/mvdan/gofumpt
settings:
gofmt:
simplify: true # Simplify code: gofmt with `-s` option.
gofumpt:
# Module path which contains the source code being formatted.
# Default: ""
module-path: github.com/jackc/pgx/v5 # Should match with module in go.mod
# Choose whether to use the extra rules.
# Default: false
extra-rules: true

33
.travis.yml Normal file
View File

@ -0,0 +1,33 @@
language: go
go:
- 1.x
- tip
# Derived from https://github.com/lib/pq/blob/master/.travis.yml
before_install:
- ./travis/before_install.bash
env:
global:
- PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test
matrix:
- CRATEVERSION=2.1
- PGVERSION=10
- PGVERSION=9.6
- PGVERSION=9.5
- PGVERSION=9.4
- PGVERSION=9.3
before_script:
- ./travis/before_script.bash
install:
- ./travis/install.bash
script:
- ./travis/script.bash
matrix:
allow_failures:
- go: tip

View File

@ -1,462 +1,384 @@
# 5.7.5 (May 17, 2025)
* Support sslnegotiation connection option (divyam234)
* Update golang.org/x/crypto to v0.37.0. This placates security scanners that were unable to see that pgx did not use the behavior affected by https://pkg.go.dev/vuln/GO-2025-3487.
* TraceLog now logs Acquire and Release at the debug level (dave sinclair)
* Add support for PGTZ environment variable
* Add support for PGOPTIONS environment variable
* Unpin memory used by Rows quicker
* Remove PlanScan memoization. This resolves a rare issue where scanning could be broken for one type by first scanning another. The problem was in the memoization system and benchmarking revealed that memoization was not providing any meaningful benefit.
# 5.7.4 (March 24, 2025)
* Fix / revert change to scanning JSON `null` (Felix Röhrich)
# 5.7.3 (March 21, 2025)
* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32)
* Improve SQL sanitizer performance (ninedraft)
* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich)
* Fix Values() for xml type always returning nil instead of []byte
* Add ability to send Flush message in pipeline mode (zenkovev)
* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou)
* Better error messages when scanning structs (logicbomb)
* Fix handling of error on batch write (bonnefoa)
* Match libpq's connection fallback behavior more closely (felix-roehrich)
* Add MinIdleConns to pgxpool (djahandarie)
# 5.7.2 (December 21, 2024)
* Fix prepared statement already exists on batch prepare failure
* Add commit query to tx options (Lucas Hild)
* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels)
* Add message body size limits in frontend and backend (zene)
* Add xid8 type
* Ensure planning encodes and scans cannot infinitely recurse
* Implement pgtype.UUID.String() (Konstantin Grachev)
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
* Update golang.org/x/crypto
* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo)
# 5.7.1 (September 10, 2024)
* Fix data race in tracelog.TraceLog
* Update puddle to v2.2.2. This removes the import of nanotime via linkname.
* Update golang.org/x/crypto and golang.org/x/text
# 5.7.0 (September 7, 2024)
* Add support for sslrootcert=system (Yann Soubeyrand)
* Add LoadTypes to load multiple types in a single SQL query (Nick Farrell)
* Add XMLCodec supports encoding + scanning XML column type like json (nickcruess-soda)
* Add MultiTrace (Stepan Rabotkin)
* Add TraceLogConfig with customizable TimeKey (stringintech)
* pgx.ErrNoRows wraps sql.ErrNoRows to aid in database/sql compatibility with native pgx functions (merlin)
* Support scanning binary formatted uint32 into string / TextScanner (jennifersp)
* Fix interval encoding to allow 0s and avoid extra spaces (Carlos Pérez-Aradros Herce)
* Update pgservicefile - fixes panic when parsing invalid file
* Better error message when reading past end of batch
* Don't print url when url.Parse returns an error (Kevin Biju)
* Fix snake case name normalization collision in RowToStructByName with db tag (nolandseigler)
* Fix: Scan and encode types with underlying types of arrays
# 5.6.0 (May 25, 2024)
* Add StrictNamedArgs (Tomas Zahradnicek)
* Add support for macaddr8 type (Carlos Pérez-Aradros Herce)
* Add SeverityUnlocalized field to PgError / Notice
* Performance optimization of RowToStructByPos/Name (Zach Olstein)
* Allow customizing context canceled behavior for pgconn
* Add ScanLocation to pgtype.Timestamp[tz]Codec
* Add custom data to pgconn.PgConn
* Fix ResultReader.Read() to handle nil values
* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce)
* pgconn.SafeToRetry checks for wrapped errors (tjasko)
* Failed connection attempts include all errors
* Optimize LargeObject.Read (Mitar)
* Add tracing for connection acquire and release from pool (ngavinsir)
* Fix encode driver.Valuer not called when nil
* Add support for custom JSON marshal and unmarshal (Mitar)
* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck)
# 5.5.5 (March 9, 2024)
Use spaces instead of parentheses for SQL sanitization.
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
# 5.5.4 (March 4, 2024)
Fix CVE-2024-27304
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
attacker's control.
Thanks to Paul Gerste for reporting this issue.
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
* Fix simple protocol encoding of json.RawMessage
* Fix *Pipeline.getResults should close pipeline on error
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
* Fix deallocation of invalidated cached statements in a transaction
* Handle invalid sslkey file
* Fix scan float4 into sql.Scanner
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
# 5.5.3 (February 3, 2024)
* Fix: prepared statement already exists
* Improve CopyFrom auto-conversion of text-ish values
* Add ltree type support (Florent Viel)
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
* Add AppendRows function (Edoardo Spadolini)
* Optimize convert UUID [16]byte to string (Kirill Malikov)
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
# 5.5.2 (January 13, 2024)
* Allow NamedArgs to start with underscore
* pgproto3: Maximum message body length support (jeremy.spriet)
* Upgrade golang.org/x/crypto to v0.17.0
* Add snake_case support to RowToStructByName (Tikhon Fedulov)
* Fix: update description cache after exec prepare (James Hartig)
* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler)
* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer)
* Add OnPgError for easier centralized error handling (James Hartig)
# 5.5.1 (December 9, 2023)
* Add CopyFromFunc helper function. (robford)
* Add PgConn.Deallocate method that uses PostgreSQL protocol Close message.
* pgx uses new PgConn.Deallocate method. This allows deallocating statements to work in a failed transaction. This fixes a case where the prepared statement map could become invalid.
* Fix: Prefer driver.Valuer over json.Marshaler for json fields. (Jacopo)
* Fix: simple protocol SQL sanitizer previously panicked if an invalid $0 placeholder was used. This now returns an error instead. (maksymnevajdev)
* Add pgtype.Numeric.ScanScientific (Eshton Robateau)
# 5.5.0 (November 4, 2023)
* Add CollectExactlyOneRow. (Julien GOTTELAND)
* Add OpenDBFromPool to create *database/sql.DB from *pgxpool.Pool. (Lev Zakharov)
* Prepare can automatically choose statement name based on sql. This makes it easier to explicitly manage prepared statements.
* Statement cache now uses deterministic, stable statement names.
* database/sql prepared statement names are deterministically generated.
* Fix: SendBatch wasn't respecting context cancellation.
* Fix: Timeout error from pipeline is now normalized.
* Fix: database/sql encoding json.RawMessage to []byte.
* CancelRequest: Wait for the cancel request to be acknowledged by the server. This should improve PgBouncer compatibility. (Anton Levakin)
* stdlib: Use Ping instead of CheckConn in ResetSession
* Add json.Marshaler and json.Unmarshaler for Float4, Float8 (Kirill Mironov)
# 5.4.3 (August 5, 2023)
* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert)
* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb)
* Fix: pgxpool: background health check cannot overflow pool
* Fix: Check for nil in defer when sending batch (recover properly from panic)
* Fix: json scan of non-string pointer to pointer
* Fix: zeronull.Timestamptz should use pgtype.Timestamptz
* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig)
* RowTo(AddrOf)StructByPos ignores fields with "-" db tag
* Optimization: improve text format numeric parsing (horpto)
# 5.4.2 (July 11, 2023)
* Fix: RowScanner errors are fatal to Rows
* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman)
* Hstore text codec internal improvements (Evan Jones)
* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares)
* Fix: Stop background reader as soon as possible.
* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn.
# 5.4.1 (June 18, 2023)
* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov)
* Add TxOptions.BeginQuery to allow overriding the default BEGIN query
# 5.4.0 (June 14, 2023)
* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues.
* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov)
* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic
* CancelRequest: don't try to read the reply (Nicola Murino)
* Fix: correctly handle bool type aliases (Wichert Akkerman)
* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr()
* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones)
* Add BeforeClose to pgxpool.Pool (Evan Cordell)
* Fix: various hstore fixes and optimizations (Evan Jones)
* Fix: RowToStructByPos with embedded unexported struct
* Support different bool string representations (Lev Zakharov)
* Fix: error when using BatchResults.Exec on a select that returns an error after some rows.
* Fix: pipelineBatchResults.Exec() not returning error from ResultReader
* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using
a callback.
* Fix: scanning a table type into a struct
* Fix: scan array of record to pointer to slice of struct
* Fix: handle null for json (Cemre Mengu)
* Batch Query callback is called even when there is an error
* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P)
# 5.3.1 (February 27, 2023)
* Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka)
* Fix: sql.Scanner not being used in certain cases
* Add text format jsonpath support
* Fix: fake non-blocking read adaptive wait time
# 5.3.0 (February 11, 2023)
* Fix: json values work with sql.Scanner
* Fixed / improved error messages (Mark Chambers and Yevgeny Pats)
* Fix: support scan into single dimensional arrays
* Fix: MaxConnLifetimeJitter setting actually jitter (Ben Weintraub)
* Fix: driver.Value representation of bytea should be []byte not string
* Fix: better handling of unregistered OIDs
* CopyFrom can use query cache to avoid extra round trip to get OIDs (Alejandro Do Nascimento Mora)
* Fix: encode to json ignoring driver.Valuer
* Support sql.Scanner on renamed base type
* Fix: pgtype.Numeric text encoding of negative numbers (Mark Chambers)
* Fix: connect with multiple hostnames when one can't be resolved
* Upgrade puddle to remove dependency on uber/atomic and fix alignment issue on 32-bit platform
* Fix: scanning json column into **string
* Multiple reductions in memory allocations
* Fake non-blocking read adapts its max wait time
* Improve CopyFrom performance and reduce memory usage
* Fix: encode []any to array
* Fix: LoadType for composite with dropped attributes (Felix Röhrich)
* Support v4 and v5 stdlib in same program
* Fix: text format array decoding with string of "NULL"
* Prefer binary format for arrays
# 5.2.0 (December 5, 2022)
* `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov)
* Optimize creating begin transaction SQL string (Petr Evdokimov and ksco)
* `Conn.LoadType` supports range and multirange types (Vitalii Solodilov)
* Fix scan `uint` and `uint64` `ScanNumeric`. This resolves a PostgreSQL `numeric` being incorrectly scanned into `uint` and `uint64`.
# 3.4.0 (May 3, 2019)
## Features
* Improved .pgpass handling (Dmitry Smal)
* Adds RowsAffected for CopyToWriter and CopyFromReader (Nikolay Vorobev)
* Support binding of []int type to array integer (David Bariod)
* Expose registered driver instance to aid integration with other libraries (PLATEL Kévin)
* Allow normal queries on replication connections (Jan Vcelak)
* Add support for creating a DB from pgx.Pool (fzerorubigd)
* SCRAM authentication
* pgtype.Date JSON marshal/unmarshal (Andrey Kuzmin)
## Fixes
* Fix encoding of ErrorResponse (Josh Leverette)
* Use more detailed error output of unknown field (Ilya Sivanev)
* "Temporary" Write errors no longer silently break connections.
* Fix PreferSimpleProtocol overwrite (Ilya Sinelnikov)
* Fix enum handling (Robert Lin)
* Copy protocol fixes (Andrey)
## Changes
* Do not attempt recovery from any Write error.
* Use LogLevel type instead of int for conn config
# 3.3.0 (December 1, 2018)
## Features
* Add CopyFromReader and CopyToWriter (Murat Kabilov)
* Add MacaddrArray (Anthony Regeda)
* Add float types to FieldDescription.Type (David Yamnitsky)
* Add CheckedOutConnections helper method (MOZGIII)
* Add host query parameter to support Unix sockets (Jörg Thalheim)
* Custom cancelation hook for use with PostgreSQL-like databases (James Hartig)
* Added LastStmtSent for safe retry logic (James Hartig)
## Fixes
* Do not silently ignore assign NULL to \*string
* Fix issue with JSON and driver.Valuer conversion
* Fix race with stdlib Driver.configs Open (Greg Curtis)
## Changes
* Connection pool uses connections in queue order instead of stack. This
minimized the time any connection is idle vs. any other connection.
(Anthony Regeda)
* FieldDescription.Modifier is int32 instead of uint32
* tls: stop sending ssl_renegotiation_limit in startup message (Tejas Manohar)
# 3.2.0 (August 7, 2018)
## Features
* Support sslkey, sslcert, and sslrootcert URI params (Sean Chittenden)
* Allow any scheme in ParseURI (for convenience with cockroachdb) (Sean Chittenden)
* Add support for domain types
* Add zerolog logging adaptor (Justin Reagor)
* Add driver.Connector support / Go 1.10 support (James Lawrence)
* Allow nested database/sql/driver.Drivers (Jackson Owens)
* Support int64 and uint64 numeric array (Anthony Regeda)
* Add nul support to pgtype.Bool (Tarik Demirci)
* Add types to decode error messages (Damir Vandic)
## Fixes
* Fix Rows.Values returning same value for multiple columns of same complex type
* Fix StartReplication() syntax (steampunkcoder)
* Fix precision loss for test format geometric types
* Allows scanning jsonb column into `*json.RawMessage`
* Allow recovery to savepoint in failed transaction
* Fix deadlock when CopyFromSource panics
* Include PreferSimpleProtocol in config Merge (Murat Kabilov)
## Changes
* pgtype.JSON(B).Value now returns []byte instead of string. This allows
database/sql to scan json(b) into \*json.RawMessage. This is a tiny behavior
change, but database/sql Scan should automatically convert []byte to string, so
there shouldn't be any incompatibility.
# 3.1.0 (January 15, 2018)
## Features
* Add QueryEx, QueryRowEx, ExecEx, and RollbackEx to Tx
* Add more ColumnType support (Timothée Peignier)
* Add UUIDArray type (Kelsey Francis)
* Add zap log adapter (Kelsey Francis)
* Add CreateReplicationSlotEx that consistent_point and snapshot_name (Mark Fletcher)
* Add BeginBatch to Tx (Gaspard Douady)
* Support CrateDB (Felix Geisendörfer)
* Allow use of logrus logger with fields configured (André Bierlein)
* Add array of enum support
* Add support for bit type
* Handle timeout parameters (Timothée Peignier)
* Allow overriding connection info (James Lawrence)
* Add support for bpchar type (Iurii Krasnoshchok)
* Add ConnConfig.PreferSimpleProtocol
## Fixes
* Fix numeric EncodeBinary bug (Wei Congrui)
* Fix logrus updated package name (Damir Vandic)
* Fix some invalid one round trip execs failing to return non-nil error. (Kelsey Francis)
* Return ErrClosedPool when Acquire() with closed pool (Mike Graf)
* Fix decoding row with same type values
* Always return non-nil \*Rows from Query to fix QueryRow (Kelsey Francis)
* Fix pgtype types that can Set database/sql/driver.driver.Valuer
* Prefix types in namespaces other than pg_catalog or public (Kelsey Francis)
* Fix incomplete selects during batch (Gaspard Douady and Jack Christensen)
* Support nil pointers to value implementing driver.Valuer
* Fix time logging for QueryEx
* Fix ranges with text format where end is unbounded
* Detect erroneous JSON(B) encoding
* Fix missing interval mapping
* ConnPool begin should not retry if ctx is done (Gaspard Douady)
* Fix reading interrupted messages could break connection
* Return error on unknown oid while decoding record instead of panic (Iurii Krasnoshchok)
# 5.1.1 (November 17, 2022)
## Changes
* Fix simple query sanitizer where query text contains a Unicode replacement character.
* Remove erroneous `name` argument from `DeallocateAll()`. Technically, this is a breaking change, but given that method was only added 5 days ago this change was accepted. (Bodo Kaiser)
* Align sslmode "require" more closely to libpq (Johan Brandhorst)
# 5.1.0 (November 12, 2022)
# 3.0.1 (August 12, 2017)
* Update puddle to v2.1.2. This resolves a race condition and a deadlock in pgxpool.
* `QueryRewriter.RewriteQuery` now returns an error. Technically, this is a breaking change for any external implementers, but given the minimal likelihood that there are actually any external implementers this change was accepted.
* Expose `GetSSLPassword` support to pgx.
* Fix encode `ErrorResponse` unknown field handling. This would only affect pgproto3 being used directly as a proxy with a non-PostgreSQL server that included additional error fields.
* Fix date text format encoding with 5 digit years.
* Fix date values passed to a `sql.Scanner` as `string` instead of `time.Time`.
* DateCodec.DecodeValue can return `pgtype.InfinityModifier` instead of `string` for infinite values. This now matches the behavior of the timestamp types.
* Add domain type support to `Conn.LoadType()`.
* Add `RowToStructByName` and `RowToAddrOfStructByName`. (Pavlo Golub)
* Add `Conn.DeallocateAll()` to clear all prepared statements including the statement cache. (Bodo Kaiser)
## Fixes
* Fix compilation on 32-bit platform
* Fix invalid MarshalJSON of types with status Undefined
* Fix pid logging
# 5.0.4 (October 24, 2022)
# 3.0.0 (July 24, 2017)
* Fix: CollectOneRow prefers PostgreSQL error over pgx.ErrorNoRows
* Fix: some reflect Kind checks to first check for nil
* Bump golang.org/x/text dependency to placate snyk
* Fix: RowToStructByPos on structs with multiple anonymous sub-structs (Baptiste Fontaine)
* Fix: Exec checks if tx is closed
## Changes
# 5.0.3 (October 14, 2022)
* Pid to PID in accordance with Go naming conventions.
* Conn.Pid changed to accessor method Conn.PID()
* Conn.SecretKey removed
* Remove Conn.TxStatus
* Logger interface reduced to single Log method.
* Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode.
* Transaction isolation level constants are now typed strings instead of bare strings.
* Conn.WaitForNotification now takes context.Context instead of time.Duration for cancellation support.
* Conn.WaitForNotification no longer automatically pings internally every 15 seconds.
* ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support.
* Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228
* No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed.
* Remove CopyTo (functionality is now in CopyFrom)
* OID constants moved from pgx to pgtype package
* Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system
* Removed ValueReader
* ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset.
* Removed Rows.Fatal(error)
* Removed Rows.AfterClose()
* Removed Rows.Conn()
* Removed Tx.AfterClose()
* Removed Tx.Conn()
* Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR
* Replaced stdlib.OpenFromConnPool with DriverConfig system
* Fix `driver.Valuer` handling edge cases that could cause infinite loop or crash
## Features
# v5.0.2 (October 8, 2022)
* Entirely revamped pluggable type system that supports approximately 60 PostgreSQL types.
* Types support database/sql interfaces and therefore can be used with other drivers
* Added context methods supporting cancellation where appropriate
* Added simple query protocol support
* Added single round-trip query mode
* Added batch query operations
* Added OnNotice
* github.com/pkg/errors used where possible for errors
* Added stdlib.DriverConfig which directly allows full configuration of underlying pgx connections without needing to use a pgx.ConnPool
* Added AcquireConn and ReleaseConn to stdlib to allow acquiring a connection from a database/sql connection.
* Fix date encoding in text format to always use 2 digits for month and day
* Prefer driver.Valuer over wrap plans when encoding
* Fix scan to pointer to pointer to renamed type
* Allow scanning NULL even if PG and Go types are incompatible
# 2.11.0 (June 5, 2017)
# v5.0.1 (September 24, 2022)
## Fixes
* Fix 32-bit atomic usage
* Add MarshalJSON for Float8 (yogipristiawan)
* Add `[` and `]` to text encoding of `Lseg`
* Fix sqlScannerWrapper NULL handling
* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock)
# v5.0.0 (September 17, 2022)
## Features
## Merged Packages
* .pgpass support (j7b)
* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen)
* Add ParseConnectionString (James Lawrence)
`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main
`github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional
release work due to releasing multiple packages, and less clear changelogs.
## Performance
## pgconn
* Optimize HStore encoding (René Kroon)
`CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`.
# 2.10.0 (March 17, 2017)
The return value `ResultReader.Values()` is no longer safe to retain a reference to after a subsequent call to `NextRow()` or `Close()`.
## Fixes
`Trace()` method adds low level message tracing similar to the `PQtrace` function in `libpq`.
* database/sql driver created through stdlib.OpenFromConnPool closes connections when requested by database/sql rather than release to underlying connection pool.
pgconn now uses non-blocking IO. This is a significant internal restructuring, but it should not cause any visible changes on its own. However, it is important in implementing other new features.
# 2.11.0 (June 5, 2017)
`CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping.
## Fixes
pgconn now supports pipeline mode.
* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock)
`*PgConn.ReceiveResults` removed. Use pipeline mode instead.
## Features
`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error.
* .pgpass support (j7b)
* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen)
* Add ParseConnectionString (James Lawrence)
## pgxpool
## Performance
`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
* Optimize HStore encoding (René Kroon)
## pgtype
# 2.10.0 (March 17, 2017)
The `pgtype` package has been significantly changed.
## Fixes
### NULL Representation
* Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood)
* Explicitly close checked-in connections on ConnPool.Reset, previously they were closed by GC
Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a
`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable.
## Features
Previously, a type that implemented `driver.Valuer` would have the `Value` method called even on a nil pointer. All nils
whether typed or untyped now represent `NULL`.
* Add xid type support (Manni Wood)
* Add cid type support (Manni Wood)
* Add tid type support (Manni Wood)
* Add "char" type support (Manni Wood)
* Add NullOid type (Manni Wood)
* Add json/jsonb binary support to allow use with CopyTo
* Add named error ErrAcquireTimeout (Alexander Staubo)
* Add logical replication decoding (Kris Wehner)
* Add PgxScanner interface to allow types to simultaneously support database/sql and pgx (Jack Christensen)
* Add CopyFrom with schema support (Jack Christensen)
### Codec and Value Split
## Compatibility
Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled
encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when
there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a
PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This
concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are
generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and
`PointValuer` for the PostgreSQL `point` type).
* jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work.
* CopyTo is now deprecated but will continue to work.
### Array Types
# 2.9.0 (August 26, 2016)
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also
means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional
arrays.
## Fixes
### Composite Types
* Fix *ConnPool.Deallocate() not deleting prepared statement from map
* Fix stdlib not logging unprepared query SQL (Krzysztof Dryś)
* Fix Rows.Values() with varchar binary format
* Concurrent ConnPool.Acquire calls with Dialer timeouts now timeout in the expected amount of time (Konstantin Dzreev)
Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite
values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite.
## Features
### Range Types
* Add CopyTo
* Add PrepareEx
* Add basic record to []interface{} decoding
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
* Decode inet/cidr to net.IP
* Encode/decode [][]byte to/from bytea[]
* Encode/decode named types whose underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64
Range types are now handled with types `RangeCodec` and `Range[T]`. This allows additional user defined range types to
easily be handled. Multirange types are handled similarly with `MultirangeCodec` and `Multirange[T]`.
## Performance
### pgxtype
* Substantial reduction in memory allocations
`LoadDataType` moved to `*Conn` as `LoadType`.
# 2.8.1 (March 24, 2016)
### Bytea
## Features
The `Bytea` and `GenericBinary` types have been replaced. Use the following instead:
* Scan accepts nil argument to ignore a column
* `[]byte` - For normal usage directly use `[]byte`.
* `DriverBytes` - Uses driver memory only available until next database method call. Avoids a copy and an allocation.
* `PreallocBytes` - Uses preallocated byte slice to avoid an allocation.
* `UndecodedBytes` - Avoids any decoding. Allows working with raw bytes.
## Fixes
### Dropped lib/pq Support
* Fix compilation on 32-bit architecture
* Fix Tx.status not being set on error on Commit
* Fix Listen/Unlisten with special characters
`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work
in most cases this is no longer supported.
# 2.8.0 (March 18, 2016)
### database/sql Scan
## Fixes
Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now
only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by
considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with
`pgx`. The previous behavior was only necessary for `lib/pq` compatibility.
* Fix unrecognized commit failure
* Fix msgReader.rxMsg bug when msgReader already has error
* Go float64 can no longer be encoded to a PostgreSQL float4
* Fix connection corruption when query with error is closed early
Added `*Map.SQLScanner` to create a `sql.Scanner` for types such as `[]int32` and `Range[T]` that do not implement
`sql.Scanner` directly.
## Features
### Number Type Fields Include Bit size
This release adds multiple extension points helpful when wrapping pgx with
custom application behavior. pgx can now use custom types designed for the
standard database/sql package such as
[github.com/shopspring/decimal](https://github.com/shopspring/decimal).
`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`.
This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and
`sql.NullInt64` the structures are identical. This means they can be directly converted one to another.
### 3rd Party Type Integrations
* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to
https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims
the pgx dependency tree.
### Other Changes
* `Bit` and `Varbit` are both replaced by the `Bits` type.
* `CID`, `OID`, `OIDValue`, and `XID` are replaced by the `Uint32` type.
* `Hstore` is now defined as `map[string]*string`.
* `JSON` and `JSONB` types removed. Use `[]byte` or `string` directly.
* `QChar` type removed. Use `rune` or `byte` directly.
* `Inet` and `Cidr` types removed. Use `netip.Addr` and `netip.Prefix` directly. These types are more memory efficient than the previous `net.IPNet`.
* `Macaddr` type removed. Use `net.HardwareAddr` directly.
* Renamed `pgtype.ConnInfo` to `pgtype.Map`.
* Renamed `pgtype.DataType` to `pgtype.Type`.
* Renamed `pgtype.None` to `pgtype.Finite`.
* `RegisterType` now accepts a `*Type` instead of `Type`.
* Assorted array helper methods and types made private.
## stdlib
* Removed `AcquireConn` and `ReleaseConn` as that functionality has been built in since Go 1.13.
## Reduced Memory Usage by Reusing Read Buffers
Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed
transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy.
However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large
chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now
ownership remains with the read buffer and anything needing to retain a value must make a copy.
## Query Execution Modes
Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode.
See documentation for `QueryExecMode`.
## QueryRewriter Interface and NamedArgs
pgx now supports named arguments with the `NamedArgs` type. This is implemented via the new `QueryRewriter` interface which
allows arbitrary rewriting of query SQL and arguments.
## RowScanner Interface
The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row.
## Rows Result Helpers
* `CollectRows` and `RowTo*` functions simplify collecting results into a slice.
* `CollectOneRow` collects one row using `RowTo*` functions.
* `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`.
## Tx Helpers
Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and
`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`.
## Improved Batch Query Ergonomics
Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the
results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code
to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can
be used to register a callback function that will handle the result. Callback functions are called automatically when
`BatchResults.Close` is called.
## SendBatch Uses Pipeline Mode When Appropriate
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
in a single network round trip. So it would only take 2 round trips.
## Tracing and Logging
Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer.
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
tree.
* Add *Tx.AfterClose() hook
* Add *Tx.Conn()
* Add *Tx.Status()
* Add *Tx.Err()
* Add *Rows.AfterClose() hook
* Add *Rows.Conn()
* Add *Conn.SetLogger() to allow changing logger
* Add *Conn.SetLogLevel() to allow changing log level
* Add ConnPool.Reset method
* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces
* Rows.Scan errors now include which argument caused error
* Add Encode() to allow custom Encoders to reuse internal encoding functionality
* Add Decode() to allow customer Decoders to reuse internal decoding functionality
* Add ConnPool.Prepare method
* Add ConnPool.Deallocate method
* Add Scan to uint32 and uint64 (utrack)
* Add encode and decode to []uint16, []uint32, and []uint64 (Max Musatov)
## Performance
* []byte skips encoding/decoding
# 2.7.1 (October 26, 2015)
* Disable SSL renegotiation
# 2.7.0 (October 16, 2015)
* Add RuntimeParams to ConnConfig
* ParseURI extracts RuntimeParams
* ParseDSN extracts RuntimeParams
* ParseEnvLibpq extracts PGAPPNAME
* Prepare is now idempotent
* Rows.Values now supports oid type
* ConnPool.Release automatically unlistens connections (Joseph Glanville)
* Add trace log level
* Add more efficient log leveling
* Retry automatically on ConnPool.Begin (Joseph Glanville)
* Encode from net.IP to inet and cidr
* Generalize encoding pointer to string to any PostgreSQL type
* Add UUID encoding from pointer to string (Joseph Glanville)
* Add null mapping to pointer to pointer (Jonathan Rudenberg)
* Add JSON and JSONB type support (Joseph Glanville)
# 2.6.0 (September 3, 2015)
* Add inet and cidr type support
* Add binary decoding to TimestampOid in stdlib driver (Samuel Stauffer)
* Add support for specifying sslmode in connection strings (Rick Snyder)
* Allow ConnPool to have MaxConnections of 1
* Add basic PGSSLMODE to support to ParseEnvLibpq
* Add fallback TLS config
* Expose specific error for TSL refused
* More error details exposed in PgError
* Support custom dialer (Lewis Marshall)
# 2.5.0 (April 15, 2015)
* Fix stdlib nil support (Blaž Hrastnik)
* Support custom Scanner not reading entire value
* Fix empty array scanning (Laurent Debacker)
* Add ParseDSN (deoxxa)
* Add timestamp support to NullTime
* Remove unused text format scanners
* Return error when too many parameters on Prepare
* Add Travis CI integration (Jonathan Rudenberg)
* Large object support (Jonathan Rudenberg)
* Fix reading null byte arrays (Karl Seguin)
* Add timestamptz[] support
* Add timestamp[] support (Karl Seguin)
* Add bool[] support (Karl Seguin)
* Allow writing []byte into text and varchar columns without type conversion (Hari Bhaskaran)
* Fix ConnPool Close panic
* Add Listen / notify example
* Reduce memory allocations (Karl Seguin)
# 2.4.0 (October 3, 2014)
* Add per connection oid to name map
* Add Hstore support (Andy Walker)
* Move introductory docs to godoc from readme
* Fix documentation references to TextEncoder and BinaryEncoder
* Add keep-alive to TCP connections (Andy Walker)
* Add support for EmptyQueryResponse / Allow no-op Exec (Andy Walker)
* Allow reading any type into []byte
* WaitForNotification detects lost connections quicker
# 2.3.0 (September 16, 2014)
* Truncate logged strings and byte slices
* Extract more error information from PostgreSQL
* Fix data race with Rows and ConnPool

View File

@ -1,121 +0,0 @@
# Contributing
## Discuss Significant Changes
Before you invest a significant amount of time on a change, please create a discussion or issue describing your
proposal. This will help to ensure your proposed change has a reasonable chance of being merged.
## Avoid Dependencies
Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change
that adds a dependency is no.
## Development Environment Setup
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE`
environment variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or key-value pairs. In addition,
the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to
simplify environment variable handling.
### Using an Existing PostgreSQL Cluster
If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx
test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different
authentication methods).
Create and setup a test database:
```
export PGDATABASE=pgx_test
createdb
psql -c 'create extension hstore;'
psql -c 'create extension ltree;'
psql -c 'create domain uint64 as numeric(20,0);'
```
Ensure a `postgres` user exists. This happens by default in normal PostgreSQL installs, but some installation methods
such as Homebrew do not.
```
createuser -s postgres
```
Ensure your `PGX_TEST_DATABASE` environment variable points to the database you just created and run the tests.
```
export PGX_TEST_DATABASE="host=/private/tmp database=pgx_test"
go test ./...
```
This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods).
### Creating a New PostgreSQL Cluster Exclusively for Testing
The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is
highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`.
```
export PGPORT=5015
export PGUSER=postgres
export PGDATABASE=pgx_test
export POSTGRESQL_DATA_DIR=postgresql
export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test"
export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test"
export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret"
export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem"
export PGX_SSL_PASSWORD=certpw
export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key"
```
Create a new database cluster.
```
initdb --locale=en_US -E UTF-8 --username=postgres .testdb/$POSTGRESQL_DATA_DIR
echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
cd .testdb
# Generate CA, server, and encrypted client certificates.
go run ../testsetup/generate_certs.go
# Copy certificates to server directory and set permissions.
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
cp localhost.key $POSTGRESQL_DATA_DIR/server.key
chmod 600 $POSTGRESQL_DATA_DIR/server.key
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
cd ..
```
Start the new cluster. This will be necessary whenever you are running pgx tests.
```
postgres -D .testdb/$POSTGRESQL_DATA_DIR
```
Setup the test database in the new cluster.
```
createdb
psql --no-psqlrc -f testsetup/postgresql_setup.sql
```
### PgBouncer
There are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set.
### Optional Tests
pgx supports multiple connection types and means of authentication. These tests are optional. They will only run if the
appropriate environment variables are set. In addition, there may be tests specific to particular PostgreSQL versions,
non-PostgreSQL servers (e.g. CockroachDB), or connection poolers (e.g. PgBouncer). `go test ./... -v | grep SKIP` to see
if any tests are being skipped.

View File

@ -1,4 +1,4 @@
Copyright (c) 2013-2021 Jack Christensen
Copyright (c) 2013 Jack Christensen
MIT License
@ -19,4 +19,4 @@ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

223
README.md
View File

@ -1,191 +1,152 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgx/v5.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5)
[![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgx/actions/workflows/ci.yml)
[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://godoc.org/github.com/jackc/pgx)
[![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx)
# pgx - PostgreSQL Driver and Toolkit
pgx is a pure Go driver and toolkit for PostgreSQL.
pgx is a pure Go driver and toolkit for PostgreSQL. pgx is different from other drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a database/sql compatible driver, pgx is also usable directly. It offers a native interface similar to database/sql that offers better performance and more features.
The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` /
`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface.
The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol
and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers,
proxies, load balancers, logical replication clients, etc.
## Example Usage
```go
package main
import (
"context"
"fmt"
"os"
"github.com/jackc/pgx/v5"
)
func main() {
// urlExample := "postgres://username:password@localhost:5432/database_name"
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err)
os.Exit(1)
}
defer conn.Close(context.Background())
var name string
var weight int64
err = conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
if err != nil {
fmt.Fprintf(os.Stderr, "QueryRow failed: %v\n", err)
os.Exit(1)
}
fmt.Println(name, weight)
var name string
var weight int64
err := conn.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
if err != nil {
return err
}
```
See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information.
## Features
* Support for approximately 70 different PostgreSQL types
* Automatic statement preparation and caching
pgx supports many additional features beyond what is available through database/sql.
* Support for approximately 60 different PostgreSQL types
* Batch queries
* Single-round trip query mode
* Full TLS connection control
* Binary format support for custom types (allows for much quicker encoding/decoding)
* `COPY` protocol support for faster bulk data loads
* Tracing and logging support
* Connection pool with after-connect hook for arbitrary connection setup
* `LISTEN` / `NOTIFY`
* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings
* `hstore` support
* `json` and `jsonb` support
* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix`
* Binary format support for custom types (can be much faster)
* Copy protocol support for faster bulk data loads
* Extendable logging support including built-in support for log15 and logrus
* Connection pool with after connect hook to do arbitrary connection setup
* Listen / notify
* PostgreSQL array to Go slice mapping for integers, floats, and strings
* Hstore support
* JSON and JSONB support
* Maps inet and cidr PostgreSQL types to net.IPNet and net.IP
* Large object support
* NULL mapping to pointer to pointer
* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types
* Notice response handling
* Simulated nested transactions with savepoints
* NULL mapping to Null* struct or pointer to pointer.
* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types
* Logical replication connections, including receiving WAL and sending standby status updates
* Notice response handling (this is different than listen / notify)
## Choosing Between the pgx and database/sql Interfaces
## Performance
The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available
through the `database/sql` interface.
pgx performs roughly equivalent to [go-pg](https://github.com/go-pg/pg) and is almost always faster than [pq](http://godoc.org/github.com/lib/pq). When parsing large result sets the percentage difference can be significant (16483 queries/sec for pgx vs. 10106 queries/sec for pq -- 63% faster).
The pgx interface is recommended when:
In many use cases a significant cause of latency is network round trips between the application and the server. pgx supports query batching to bundle multiple queries into a single round trip. Even in the case of a connection with the lowest possible latency, a local Unix domain socket, batching as few as three queries together can yield an improvement of 57%. With a typical network connection the results can be even more substantial.
1. The application only targets PostgreSQL.
2. No other libraries that require `database/sql` are in use.
See this [gist](https://gist.github.com/jackc/4996e8648a0c59839bff644f49d6e434) for the underlying benchmark results or checkout [go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself.
It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed.
In addition to the native driver, pgx also includes a number of packages that provide additional functionality.
## Testing
## github.com/jackc/pgx/stdlib
See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions.
database/sql compatibility layer for pgx. pgx can be used as a normal database/sql driver, but at any time the native interface may be acquired for more performance or PostgreSQL specific functionality.
## Architecture
## github.com/jackc/pgx/pgtype
See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.com/watch?v=sXMSWhcHCf8) for a description of pgx architecture.
Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver.
## Supported Go and PostgreSQL Versions
## github.com/jackc/pgx/pgproto3
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.23 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling.
## Version Policy
pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version.
## PGX Family Libraries
### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl)
pglogrepl provides functionality to act as a client for PostgreSQL logical replication.
### [github.com/jackc/pgmock](https://github.com/jackc/pgmock)
## github.com/jackc/pgx/pgmock
pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler).
### [github.com/jackc/tern](https://github.com/jackc/tern)
## Documentation
tern is a stand-alone SQL migration system.
pgx includes extensive documentation in the godoc format. It is viewable online at [godoc.org](https://godoc.org/github.com/jackc/pgx).
### [github.com/jackc/pgerrcode](https://github.com/jackc/pgerrcode)
## Testing
pgerrcode contains constants for the PostgreSQL error codes.
pgx supports multiple connection and authentication types. Setting up a test
environment that can test all of them can be cumbersome. In particular,
Windows cannot test Unix domain socket connections. Because of this pgx will
skip tests for connection types that are not configured.
## Adapters for 3rd Party Types
### Normal Test Environment
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
To setup the normal test environment, first install these dependencies:
go get github.com/cockroachdb/apd
go get github.com/hashicorp/go-version
go get github.com/jackc/fake
go get github.com/lib/pq
go get github.com/pkg/errors
go get github.com/satori/go.uuid
go get github.com/shopspring/decimal
go get github.com/sirupsen/logrus
go get go.uber.org/zap
go get gopkg.in/inconshreveable/log15.v2
## Adapters for 3rd Party Tracers
Then run the following SQL:
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
* [github.com/exaring/otelpgx](https://github.com/exaring/otelpgx)
create user pgx_md5 password 'secret';
create user " tricky, ' } "" \ test user " password 'secret';
create database pgx_test;
create user pgx_replication with replication password 'secret';
## Adapters for 3rd Party Loggers
Connect to database pgx_test and run:
These adapters can be used with the tracelog package.
create extension hstore;
create domain uint64 as numeric(20,0);
* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log)
* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15)
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)
* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog)
* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog)
* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog)
Next open conn_config_test.go.example and make a copy without the
.example. If your PostgreSQL server is accepting connections on 127.0.0.1,
then you are done.
## 3rd Party Libraries with PGX Support
### Connection and Authentication Test Environment
### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock)
Complete the normal test environment setup and also do the following.
pgxmock is a mock library implementing pgx interfaces.
pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection.
Run the following SQL:
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
create user pgx_none;
create user pgx_pw password 'secret';
Library for scanning data from a database into Go structs and more.
Add the following to your pg_hba.conf:
### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql)
If you are developing on Unix with domain socket connections:
A carefully designed SQL client for making using SQL easier,
more productive, and less error-prone on Golang.
local pgx_test pgx_none trust
local pgx_test pgx_pw password
local pgx_test pgx_md5 md5
### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
If you are developing on Windows with TCP connections:
Adds GSSAPI / Kerberos authentication support.
host pgx_test pgx_none 127.0.0.1/32 trust
host pgx_test pgx_pw 127.0.0.1/32 password
host pgx_test pgx_md5 127.0.0.1/32 md5
### [github.com/wcamarao/pmx](https://github.com/wcamarao/pmx)
### Replication Test Environment
Explicit data mapping and scanning library for Go structs and slices.
Add a replication user:
### [github.com/stephenafamo/scan](https://github.com/stephenafamo/scan)
create user pgx_replication with replication password 'secret';
Type safe and flexible package for scanning database data into Go types.
Supports, structs, maps, slices and custom mapping functions.
Add a replication line to your pg_hba.conf:
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
host replication pgx_replication 127.0.0.1/32 md5
Code first migration library for native pgx (no database/sql abstraction).
Change the following settings in your postgresql.conf:
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
wal_level=logical
max_wal_senders=5
max_replication_slots=5
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.
Set `replicationConnConfig` appropriately in `conn_config_test.go`.
### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox)
## Version Policy
Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver.
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
Simplifies working with the pgx library, providing convenient scanning of nested structures.
## [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter)
Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters.
pgx follows semantic versioning for the documented public API on stable releases. Branch `v3` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v3` branch (in practice, this occurs very rarely). `v2` is the previous stable release.

View File

@ -1,18 +0,0 @@
require "erb"
rule '.go' => '.go.erb' do |task|
erb = ERB.new(File.read(task.source))
File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
sh "goimports", "-w", task.name
end
generated_code_files = [
"pgtype/int.go",
"pgtype/int_test.go",
"pgtype/integration_benchmark_test.go",
"pgtype/zeronull/int.go",
"pgtype/zeronull/int_test.go"
]
desc "Generate code"
task generate: generated_code_files

View File

@ -10,7 +10,7 @@
// https://github.com/lib/pq/pull/788
// https://github.com/lib/pq/pull/833
package pgconn
package pgx
import (
"bytes"
@ -22,7 +22,7 @@ import (
"fmt"
"strconv"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/pgproto3"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/text/secure/precis"
)
@ -30,7 +30,7 @@ import (
const clientNonceLen = 18
// Perform SCRAM authentication.
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
func (c *Conn) scramAuth(serverAuthMechanisms []string) error {
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
if err != nil {
return err
@ -41,18 +41,17 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
AuthMechanism: "SCRAM-SHA-256",
Data: sc.clientFirstMessage(),
}
c.frontend.Send(saslInitialResponse)
err = c.flushWithPotentialWriteReadDeadlock()
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
if err != nil {
return err
}
// Receive server-first-message payload in an AuthenticationSASLContinue.
saslContinue, err := c.rxSASLContinue()
// Receive server-first-message payload in a AuthenticationSASLContinue.
authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue)
if err != nil {
return err
}
err = sc.recvServerFirstMessage(saslContinue.Data)
err = sc.recvServerFirstMessage(authMsg.SASLData)
if err != nil {
return err
}
@ -61,48 +60,33 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
saslResponse := &pgproto3.SASLResponse{
Data: []byte(sc.clientFinalMessage()),
}
c.frontend.Send(saslResponse)
err = c.flushWithPotentialWriteReadDeadlock()
_, err = c.conn.Write(saslResponse.Encode(nil))
if err != nil {
return err
}
// Receive server-final-message payload in an AuthenticationSASLFinal.
saslFinal, err := c.rxSASLFinal()
// Receive server-final-message payload in a AuthenticationSASLFinal.
authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal)
if err != nil {
return err
}
return sc.recvServerFinalMessage(saslFinal.Data)
return sc.recvServerFinalMessage(authMsg.SASLData)
}
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
msg, err := c.receiveMessage()
func (c *Conn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
msg, err := c.rxMsg()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
authMsg, ok := msg.(*pgproto3.Authentication)
if !ok {
return nil, errors.New("unexpected message type")
}
if authMsg.Type != typ {
return nil, errors.New("unexpected auth type")
}
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLFinal:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
return authMsg, nil
}
type scramClient struct {
@ -198,12 +182,12 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
var err error
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
if err != nil {
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
return fmt.Errorf("invalid SCRAM salt received from server: %v", err)
}
sc.iterations, err = strconv.Atoi(string(iterationsStr))
if err != nil || sc.iterations <= 0 {
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr)
}
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
@ -263,9 +247,9 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte {
return buf
}
func computeServerSignature(saltedPassword, authMessage []byte) []byte {
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
serverSignature := computeHMAC(serverKey, authMessage)
serverSignature := computeHMAC(serverKey[:], authMessage)
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
base64.StdEncoding.Encode(buf, serverSignature)
return buf

660
batch.go
View File

@ -2,468 +2,310 @@ package pgx
import (
"context"
"errors"
"fmt"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
// QueuedQuery is a query that has been queued for execution via a Batch.
type QueuedQuery struct {
SQL string
Arguments []any
Fn batchItemFunc
sd *pgconn.StatementDescription
}
type batchItemFunc func(br BatchResults) error
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
qq.Fn = func(br BatchResults) error {
rows, _ := br.Query()
defer rows.Close()
err := fn(rows)
if err != nil {
return err
}
rows.Close()
return rows.Err()
}
}
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
qq.Fn = func(br BatchResults) error {
row := br.QueryRow()
return fn(row)
}
}
// Exec sets fn to be called when the response to qq is received.
//
// Note: for simple batch insert uses where it is not required to handle
// each potential error individually, it's sufficient to not set any callbacks,
// and just handle the return value of BatchResults.Close.
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
qq.Fn = func(br BatchResults) error {
ct, err := br.Exec()
if err != nil {
return err
}
return fn(ct)
}
type batchItem struct {
query string
arguments []interface{}
parameterOIDs []pgtype.OID
resultFormatCodes []int16
}
// Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips. A Batch must only be sent once.
// unnecessary network round trips.
type Batch struct {
QueuedQueries []*QueuedQuery
conn *Conn
connPool *ConnPool
items []*batchItem
resultsRead int
pendingCommandComplete bool
ctx context.Context
err error
inTx bool
}
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option
// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode.
// BeginBatch returns a *Batch query for c.
func (c *Conn) BeginBatch() *Batch {
return &Batch{conn: c}
}
// BeginBatch returns a *Batch query for tx. Since this *Batch is already part
// of a transaction it will not automatically be wrapped in a transaction.
func (tx *Tx) BeginBatch() *Batch {
return &Batch{conn: tx.conn, inTx: true}
}
// Conn returns the underlying connection that b will or was performed on.
func (b *Batch) Conn() *Conn {
return b.conn
}
// Queue queues a query to batch b. parameterOIDs are required if there are
// parameters and query is not the name of a prepared statement.
// resultFormatCodes are required if there is a result.
func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgtype.OID, resultFormatCodes []int16) {
b.items = append(b.items, &batchItem{
query: query,
arguments: arguments,
parameterOIDs: parameterOIDs,
resultFormatCodes: resultFormatCodes,
})
}
// Send sends all queued queries to the server at once.
// If the batch is created from a conn Object then All queries are wrapped
// in a transaction. The transaction can optionally be configured with
// txOptions. The context is in effect until the Batch is closed.
//
// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should
// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query,
// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that
// include the current query may reference the wrong query.
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
qq := &QueuedQuery{
SQL: query,
Arguments: arguments,
}
b.QueuedQueries = append(b.QueuedQueries, qq)
return qq
}
// Len returns number of queries that have been queued so far.
func (b *Batch) Len() int {
return len(b.QueuedQueries)
}
type BatchResults interface {
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
// calling Exec on the QueuedQuery, or just calling Close.
Exec() (pgconn.CommandTag, error)
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
// calling Query on the QueuedQuery.
Query() (Rows, error)
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
// Prefer calling QueryRow on the QueuedQuery.
QueryRow() Row
// Close closes the batch operation. All unread results are read and any callback functions registered with
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
// error or the batch encounters an error subsequent callback functions will not be called.
//
// For simple batch inserts inside a transaction or similar queries, it's sufficient to not set any callbacks,
// and just handle the return value of Close.
//
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
// connection will have been closed.
//
// Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback
// functions will not be rerun.
Close() error
}
type batchResults struct {
ctx context.Context
conn *Conn
mrr *pgconn.MultiResultReader
err error
b *Batch
qqIdx int
closed bool
endTraced bool
}
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
func (br *batchResults) Exec() (pgconn.CommandTag, error) {
if br.err != nil {
return pgconn.CommandTag{}, br.err
}
if br.closed {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
// Warning: Send writes all queued queries before reading any results. This can
// cause a deadlock if an excessive number of queries are queued. It is highly
// advisable to use a timeout context to protect against this possibility.
// Unfortunately, this excessive number can vary based on operating system,
// connection type (TCP or Unix domain socket), and type of query. Unix domain
// sockets seem to be much more susceptible to this issue than TCP connections.
// However, it usually is at least several thousand.
//
// The deadlock occurs when the batched queries to be sent are so large that the
// PostgreSQL server cannot receive it all at once. PostgreSQL received some of
// the queued queries and starts executing them. As PostgreSQL executes the
// queries it sends responses back. pgx will not read any of these responses
// until it has finished sending. Therefore, if all network buffers are full pgx
// will not be able to finish sending the queries and PostgreSQL will not be
// able to finish sending the responses.
//
// See https://github.com/jackc/pgx/issues/374.
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
if b.err != nil {
return b.err
}
query, arguments, _ := br.nextQueryAndArgs()
b.ctx = ctx
if !br.mrr.NextResult() {
err := br.mrr.Close()
if err == nil {
err = errors.New("no more results in batch")
}
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
Err: err,
})
}
return pgconn.CommandTag{}, err
}
commandTag, err := br.mrr.ResultReader().Close()
err := b.conn.waitForPreviousCancelQuery(ctx)
if err != nil {
br.err = err
br.mrr.Close()
return err
}
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
CommandTag: commandTag,
Err: br.err,
})
if err := b.conn.ensureConnectionReadyForQuery(); err != nil {
return err
}
return commandTag, br.err
buf := b.conn.wbuf
if !b.inTx {
buf = appendQuery(buf, txOptions.beginSQL())
}
err = b.conn.initContext(ctx)
if err != nil {
return err
}
for _, bi := range b.items {
var psName string
var psParameterOIDs []pgtype.OID
if ps, ok := b.conn.preparedStatements[bi.query]; ok {
psName = ps.Name
psParameterOIDs = ps.ParameterOIDs
} else {
psParameterOIDs = bi.parameterOIDs
buf = appendParse(buf, "", bi.query, psParameterOIDs)
}
var err error
buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes)
if err != nil {
return err
}
buf = appendDescribe(buf, 'P', "")
buf = appendExecute(buf, "", 0)
}
buf = appendSync(buf)
b.conn.pendingReadyForQueryCount++
if !b.inTx {
buf = appendQuery(buf, "commit")
b.conn.pendingReadyForQueryCount++
}
_, err = b.conn.conn.Write(buf)
if err != nil {
b.conn.die(err)
return err
}
for !b.inTx {
msg, err := b.conn.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
return nil
default:
if err := b.conn.processContextFreeMsg(msg); err != nil {
return err
}
}
}
return nil
}
// Query reads the results from the next query in the batch as if the query has been sent with Query.
func (br *batchResults) Query() (Rows, error) {
query, arguments, ok := br.nextQueryAndArgs()
if !ok {
query = "batch query"
// ExecResults reads the results from the next query in the batch as if the
// query has been sent with Exec.
func (b *Batch) ExecResults() (CommandTag, error) {
if b.err != nil {
return "", b.err
}
if br.err != nil {
return &baseRows{err: br.err, closed: true}, br.err
select {
case <-b.ctx.Done():
b.die(b.ctx.Err())
return "", b.ctx.Err()
default:
}
if br.closed {
alreadyClosedErr := fmt.Errorf("batch already closed")
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
if err := b.ensureCommandComplete(); err != nil {
b.die(err)
return "", err
}
rows := br.conn.getRows(br.ctx, query, arguments)
rows.batchTracer = br.conn.batchTracer
b.resultsRead++
if !br.mrr.NextResult() {
rows.err = br.mrr.Close()
if rows.err == nil {
rows.err = errors.New("no more results in batch")
}
rows.closed = true
b.pendingCommandComplete = true
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
Err: rows.err,
})
for {
msg, err := b.conn.rxMsg()
if err != nil {
return "", err
}
return rows, rows.err
switch msg := msg.(type) {
case *pgproto3.CommandComplete:
b.pendingCommandComplete = false
return CommandTag(msg.CommandTag), nil
default:
if err := b.conn.processContextFreeMsg(msg); err != nil {
return "", err
}
}
}
}
// QueryResults reads the results from the next query in the batch as if the
// query has been sent with Query.
func (b *Batch) QueryResults() (*Rows, error) {
rows := b.conn.getRows("batch query", nil)
if b.err != nil {
rows.fatal(b.err)
return rows, b.err
}
rows.resultReader = br.mrr.ResultReader()
select {
case <-b.ctx.Done():
b.die(b.ctx.Err())
rows.fatal(b.err)
return rows, b.ctx.Err()
default:
}
if err := b.ensureCommandComplete(); err != nil {
b.die(err)
rows.fatal(err)
return rows, err
}
b.resultsRead++
b.pendingCommandComplete = true
fieldDescriptions, err := b.conn.readUntilRowDescription()
if err != nil {
b.die(err)
rows.fatal(b.err)
return rows, err
}
rows.batch = b
rows.fields = fieldDescriptions
return rows, nil
}
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
func (br *batchResults) QueryRow() Row {
rows, _ := br.Query()
return (*connRow)(rows.(*baseRows))
// QueryRowResults reads the results from the next query in the batch as if the
// query has been sent with QueryRow.
func (b *Batch) QueryRowResults() *Row {
rows, _ := b.QueryResults()
return (*Row)(rows)
}
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
func (br *batchResults) Close() error {
defer func() {
if !br.endTraced {
if br.conn != nil && br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
}
br.endTraced = true
}
// Close closes the batch operation. Any error that occured during a batch
// operation may have made it impossible to resyncronize the connection with the
// server. In this case the underlying connection will have been closed.
func (b *Batch) Close() (err error) {
if b.err != nil {
return b.err
}
invalidateCachesOnBatchResultsError(br.conn, br.b, br.err)
defer func() {
err = b.conn.termContext(err)
if b.conn != nil && b.connPool != nil {
b.connPool.Release(b.conn)
}
}()
if br.err != nil {
return br.err
}
if br.closed {
return nil
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil {
br.err = err
}
} else {
br.Exec()
for i := b.resultsRead; i < len(b.items); i++ {
if _, err = b.ExecResults(); err != nil {
return err
}
}
br.closed = true
err := br.mrr.Close()
if br.err == nil {
br.err = err
if err = b.conn.ensureConnectionReadyForQuery(); err != nil {
return err
}
return br.err
return nil
}
func (br *batchResults) earlyError() error {
return br.err
func (b *Batch) die(err error) {
if b.err != nil {
return
}
b.err = err
b.conn.die(err)
if b.conn != nil && b.connPool != nil {
b.connPool.Release(b.conn)
}
}
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.QueuedQueries[br.qqIdx]
query = bi.SQL
args = bi.Arguments
ok = true
br.qqIdx++
}
return
}
type pipelineBatchResults struct {
ctx context.Context
conn *Conn
pipeline *pgconn.Pipeline
lastRows *baseRows
err error
b *Batch
qqIdx int
closed bool
endTraced bool
}
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
if br.err != nil {
return pgconn.CommandTag{}, br.err
}
if br.closed {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
if br.lastRows != nil && br.lastRows.err != nil {
return pgconn.CommandTag{}, br.err
}
query, arguments, err := br.nextQueryAndArgs()
if err != nil {
return pgconn.CommandTag{}, err
}
results, err := br.pipeline.GetResults()
if err != nil {
br.err = err
return pgconn.CommandTag{}, br.err
}
var commandTag pgconn.CommandTag
switch results := results.(type) {
case *pgconn.ResultReader:
commandTag, br.err = results.Close()
default:
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
}
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
CommandTag: commandTag,
Err: br.err,
})
}
return commandTag, br.err
}
// Query reads the results from the next query in the batch as if the query has been sent with Query.
func (br *pipelineBatchResults) Query() (Rows, error) {
if br.err != nil {
return &baseRows{err: br.err, closed: true}, br.err
}
if br.closed {
alreadyClosedErr := fmt.Errorf("batch already closed")
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return &baseRows{err: br.err, closed: true}, br.err
}
query, arguments, err := br.nextQueryAndArgs()
if err != nil {
return &baseRows{err: err, closed: true}, err
}
rows := br.conn.getRows(br.ctx, query, arguments)
rows.batchTracer = br.conn.batchTracer
br.lastRows = rows
results, err := br.pipeline.GetResults()
if err != nil {
br.err = err
rows.err = err
rows.closed = true
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
Err: err,
})
func (b *Batch) ensureCommandComplete() error {
for b.pendingCommandComplete {
msg, err := b.conn.rxMsg()
if err != nil {
return err
}
} else {
switch results := results.(type) {
case *pgconn.ResultReader:
rows.resultReader = results
switch msg := msg.(type) {
case *pgproto3.CommandComplete:
b.pendingCommandComplete = false
return nil
default:
err = fmt.Errorf("unexpected pipeline result: %T", results)
br.err = err
rows.err = err
rows.closed = true
}
}
return rows, rows.err
}
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
func (br *pipelineBatchResults) QueryRow() Row {
rows, _ := br.Query()
return (*connRow)(rows.(*baseRows))
}
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
func (br *pipelineBatchResults) Close() error {
defer func() {
if !br.endTraced {
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
}
br.endTraced = true
}
invalidateCachesOnBatchResultsError(br.conn, br.b, br.err)
}()
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return br.err
}
if br.closed {
return br.err
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
err = b.conn.processContextFreeMsg(msg)
if err != nil {
br.err = err
}
} else {
br.Exec()
}
}
br.closed = true
err := br.pipeline.Close()
if br.err == nil {
br.err = err
}
return br.err
}
func (br *pipelineBatchResults) earlyError() error {
return br.err
}
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, err error) {
if br.b == nil {
return "", nil, errors.New("no reference to batch")
}
if br.qqIdx >= len(br.b.QueuedQueries) {
return "", nil, errors.New("no more results in batch")
}
bi := br.b.QueuedQueries[br.qqIdx]
br.qqIdx++
return bi.SQL, bi.Arguments, nil
}
// invalidates statement and description caches on batch results error
func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) {
if err != nil && conn != nil && b != nil {
if sc := conn.statementCache; sc != nil {
for _, bi := range b.QueuedQueries {
sc.Invalidate(bi.SQL)
}
}
if sc := conn.descriptionCache; sc != nil {
for _, bi := range b.QueuedQueries {
sc.Invalidate(bi.SQL)
return err
}
}
}
return nil
}

File diff suppressed because it is too large Load Diff

55
bench-tmp_test.go Normal file
View File

@ -0,0 +1,55 @@
package pgx_test
import (
"testing"
)
func BenchmarkPgtypeInt4ParseBinary(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
_, err := conn.Prepare("selectBinary", "select n::int4 from generate_series(1, 100) n")
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
var n int32
rows, err := conn.Query("selectBinary")
if err != nil {
b.Fatal(err)
}
for rows.Next() {
err := rows.Scan(&n)
if err != nil {
b.Fatal(err)
}
}
if rows.Err() != nil {
b.Fatal(rows.Err())
}
}
}
func BenchmarkPgtypeInt4EncodeBinary(b *testing.B) {
conn := mustConnect(b, *defaultConnConfig)
defer closeConn(b, conn)
_, err := conn.Prepare("encodeBinary", "select $1::int4, $2::int4, $3::int4, $4::int4, $5::int4, $6::int4, $7::int4")
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
rows, err := conn.Query("encodeBinary", int32(i), int32(i), int32(i), int32(i), int32(i), int32(i), int32(i))
if err != nil {
b.Fatal(err)
}
rows.Close()
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,89 @@
package chunkreader
import (
"io"
)
type ChunkReader struct {
r io.Reader
buf []byte
rp, wp int // buf read position and write position
options Options
}
type Options struct {
MinBufLen int // Minimum buffer length
}
func NewChunkReader(r io.Reader) *ChunkReader {
cr, err := NewChunkReaderEx(r, Options{})
if err != nil {
panic("default options can't be bad")
}
return cr
}
func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) {
if options.MinBufLen == 0 {
options.MinBufLen = 4096
}
return &ChunkReader{
r: r,
buf: make([]byte, options.MinBufLen),
options: options,
}, nil
}
// Next returns buf filled with the next n bytes. If an error occurs, buf will
// be nil.
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
// n bytes already in buf
if (r.wp - r.rp) >= n {
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, err
}
// available space in buf is less than n
if len(r.buf) < n {
r.copyBufContents(r.newBuf(n))
}
// buf is large enough, but need to shift filled area to start to make enough contiguous space
minReadCount := n - (r.wp - r.rp)
if (len(r.buf) - r.wp) < minReadCount {
newBuf := r.newBuf(n)
r.copyBufContents(newBuf)
}
if err := r.appendAtLeast(minReadCount); err != nil {
return nil, err
}
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, nil
}
func (r *ChunkReader) appendAtLeast(fillLen int) error {
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
r.wp += n
return err
}
func (r *ChunkReader) newBuf(size int) []byte {
if size < r.options.MinBufLen {
size = r.options.MinBufLen
}
return make([]byte, size)
}
func (r *ChunkReader) copyBufContents(dest []byte) {
r.wp = copy(dest, r.buf[r.rp:r.wp])
r.rp = 0
r.buf = dest
}

View File

@ -0,0 +1,96 @@
package chunkreader
import (
"bytes"
"testing"
)
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
server := &bytes.Buffer{}
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
if err != nil {
t.Fatal(err)
}
src := []byte{1, 2, 3, 4}
server.Write(src)
n1, err := r.Next(2)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n1, src[0:2]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
}
n2, err := r.Next(2)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n2, src[2:4]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
}
if bytes.Compare(r.buf, src) != 0 {
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
}
if r.rp != 4 {
t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp)
}
if r.wp != 4 {
t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp)
}
}
func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) {
server := &bytes.Buffer{}
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
if err != nil {
t.Fatal(err)
}
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
server.Write(src)
n1, err := r.Next(5)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n1, src[0:5]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1)
}
if len(r.buf) != 5 {
t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf))
}
}
func TestChunkReaderDoesNotReuseBuf(t *testing.T) {
server := &bytes.Buffer{}
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
if err != nil {
t.Fatal(err)
}
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
server.Write(src)
n1, err := r.Next(4)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n1, src[0:4]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1)
}
n2, err := r.Next(4)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n2, src[4:8]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2)
}
if bytes.Compare(n1, src[0:4]) != 0 {
t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1)
}
}

View File

@ -1,61 +0,0 @@
#!/usr/bin/env bash
set -eux
if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]]
then
sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common
sudo rm -rf /var/lib/postgresql
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list"
sudo apt-get update -qq
sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION
sudo cp testsetup/pg_hba.conf /etc/postgresql/$PGVERSION/main/pg_hba.conf
sudo sh -c "echo \"listen_addresses = '127.0.0.1'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
sudo sh -c "cat testsetup/postgresql_ssl.conf >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
cd testsetup
# Generate CA, server, and encrypted client certificates.
go run generate_certs.go
# Copy certificates to server directory and set permissions.
sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/root.crt
sudo cp localhost.key /var/lib/postgresql/$PGVERSION/main/server.key
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.key
sudo chmod 600 /var/lib/postgresql/$PGVERSION/main/server.key
sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt
cp ca.pem /tmp
cp pgx_sslcert.key /tmp
cp pgx_sslcert.crt /tmp
cd ..
sudo /etc/init.d/postgresql restart
createdb -U postgres pgx_test
psql -U postgres -f testsetup/postgresql_setup.sql pgx_test
fi
if [[ "${PGVERSION-}" =~ ^cockroach ]]
then
wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz
sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/
cockroach start-single-node --insecure --background --listen-addr=localhost
cockroach sql --insecure -e 'create database pgx_test'
fi
if [ "${CRATEVERSION-}" != "" ]
then
docker run \
-p "6543:5432" \
-d \
crate:"$CRATEVERSION" \
crate \
-Cnetwork.host=0.0.0.0 \
-Ctransport.host=localhost \
-Clicense.enterprise=false
fi

3182
conn.go

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,79 @@
package pgx_test
import (
// "crypto/tls"
// "crypto/x509"
// "fmt"
// "go/build"
// "io/ioutil"
// "path"
"github.com/jackc/pgx"
)
var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
// To skip tests for specific connection / authentication types set that connection param to nil
var tcpConnConfig *pgx.ConnConfig = nil
var unixSocketConnConfig *pgx.ConnConfig = nil
var md5ConnConfig *pgx.ConnConfig = nil
var plainPasswordConnConfig *pgx.ConnConfig = nil
var invalidUserConnConfig *pgx.ConnConfig = nil
var tlsConnConfig *pgx.ConnConfig = nil
var customDialerConnConfig *pgx.ConnConfig = nil
var replicationConnConfig *pgx.ConnConfig = nil
var cratedbConnConfig *pgx.ConnConfig = nil
// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
// var replicationConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"}
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
//
//// or to test client certs:
//
// var tlsConnConfig *pgx.ConnConfig
//
// func init() {
// homeDir := build.Default.GOPATH
// tlsConnConfig = &pgx.ConnConfig{
// Host: "127.0.0.1",
// User: "pgx_md5",
// Password: "secret",
// Database: "pgx_test",
// TLSConfig: &tls.Config{
// InsecureSkipVerify: true,
// },
// }
// caCertPool := x509.NewCertPool()
//
// caPath := path.Join(homeDir, "/src/github.com/jackc/pgx/rootCA.pem")
// caCert, err := ioutil.ReadFile(caPath)
// if err != nil {
// panic(fmt.Sprintf("unable to read CA file: %v", err))
// }
//
// if !caCertPool.AppendCertsFromPEM(caCert) {
// panic("unable to add CA to cert pool")
// }
//
// tlsConnConfig.TLSConfig.RootCAs = caCertPool
// tlsConnConfig.TLSConfig.ClientCAs = caCertPool
//
// sslCert := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.crt")
// sslKey := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.key")
// if (sslCert != "" && sslKey == "") || (sslCert == "" && sslKey != "") {
// panic(`both "sslcert" and "sslkey" are required`)
// }
//
// cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
// if err != nil {
// panic(fmt.Sprintf("unable to read cert: %v", err))
// }
//
// tlsConnConfig.TLSConfig.Certificates = []tls.Certificate{cert}
// }

View File

@ -0,0 +1,36 @@
package pgx_test
import (
"crypto/tls"
"github.com/jackc/pgx"
"os"
"strconv"
)
var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"}
var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
var replicationConnConfig *pgx.ConnConfig = nil
var cratedbConnConfig *pgx.ConnConfig = nil
func init() {
pgVersion := os.Getenv("PGVERSION")
if len(pgVersion) > 0 {
v, err := strconv.ParseFloat(pgVersion, 64)
if err == nil && v >= 9.6 {
replicationConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"}
}
}
crateVersion := os.Getenv("CRATEVERSION")
if crateVersion != "" {
cratedbConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", Port: 6543, User: "pgx", Password: "", Database: "pgx_test"}
}
}

View File

@ -1,55 +0,0 @@
package pgx
import (
"context"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func mustParseConfig(t testing.TB, connString string) *ConnConfig {
config, err := ParseConfig(connString)
require.Nil(t, err)
return config
}
func mustConnect(t testing.TB, config *ConnConfig) *Conn {
conn, err := ConnectConfig(context.Background(), config)
if err != nil {
t.Fatalf("Unable to establish connection: %v", err)
}
return conn
}
// Ensures the connection limits the size of its cached objects.
// This test examines the internals of *Conn so must be in the same package.
func TestStmtCacheSizeLimit(t *testing.T) {
const cacheLimit = 16
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
connConfig.StatementCacheCapacity = cacheLimit
conn := mustConnect(t, connConfig)
defer func() {
err := conn.Close(context.Background())
if err != nil {
t.Fatal(err)
}
}()
// run a set of unique queries that should overflow the cache
ctx := context.Background()
for i := 0; i < cacheLimit*2; i++ {
uniqueString := fmt.Sprintf("unique %d", i)
uniqueSQL := fmt.Sprintf("select '%s'", uniqueString)
var output string
err := conn.QueryRow(ctx, uniqueSQL).Scan(&output)
require.NoError(t, err)
require.Equal(t, uniqueString, output)
}
// preparedStatements contains cacheLimit+1 because deallocation happens before the query
assert.Len(t, conn.preparedStatements, cacheLimit+1)
assert.Equal(t, cacheLimit, conn.statementCache.Len())
}

582
conn_pool.go Normal file
View File

@ -0,0 +1,582 @@
package pgx
import (
"context"
"io"
"sync"
"time"
"github.com/pkg/errors"
"github.com/jackc/pgx/pgtype"
)
type ConnPoolConfig struct {
ConnConfig
MaxConnections int // max simultaneous connections to use, default 5, must be at least 2
AfterConnect func(*Conn) error // function to call on every new connection
AcquireTimeout time.Duration // max wait time when all connections are busy (0 means no timeout)
}
type ConnPool struct {
allConnections []*Conn
availableConnections []*Conn
cond *sync.Cond
config ConnConfig // config used when establishing connection
inProgressConnects int
maxConnections int
resetCount int
afterConnect func(*Conn) error
logger Logger
logLevel LogLevel
closed bool
preparedStatements map[string]*PreparedStatement
acquireTimeout time.Duration
connInfo *pgtype.ConnInfo
}
type ConnPoolStat struct {
MaxConnections int // max simultaneous connections to use
CurrentConnections int // current live connections
AvailableConnections int // unused live connections
}
// CheckedOutConnections returns the amount of connections that are currently
// checked out from the pool.
func (stat *ConnPoolStat) CheckedOutConnections() int {
return stat.CurrentConnections - stat.AvailableConnections
}
// ErrAcquireTimeout occurs when an attempt to acquire a connection times out.
var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool")
// ErrClosedPool occurs on an attempt to acquire a connection from a closed pool.
var ErrClosedPool = errors.New("cannot acquire from closed pool")
// NewConnPool creates a new ConnPool. config.ConnConfig is passed through to
// Connect directly.
func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
p = new(ConnPool)
p.config = config.ConnConfig
p.connInfo = minimalConnInfo
p.maxConnections = config.MaxConnections
if p.maxConnections == 0 {
p.maxConnections = 5
}
if p.maxConnections < 1 {
return nil, errors.New("MaxConnections must be at least 1")
}
p.acquireTimeout = config.AcquireTimeout
if p.acquireTimeout < 0 {
return nil, errors.New("AcquireTimeout must be equal to or greater than 0")
}
p.afterConnect = config.AfterConnect
if config.LogLevel != 0 {
p.logLevel = config.LogLevel
} else {
// Preserve pre-LogLevel behavior by defaulting to LogLevelDebug
p.logLevel = LogLevelDebug
}
p.logger = config.Logger
if p.logger == nil {
p.logLevel = LogLevelNone
}
p.allConnections = make([]*Conn, 0, p.maxConnections)
p.availableConnections = make([]*Conn, 0, p.maxConnections)
p.preparedStatements = make(map[string]*PreparedStatement)
p.cond = sync.NewCond(new(sync.Mutex))
// Initially establish one connection
var c *Conn
c, err = p.createConnection()
if err != nil {
return
}
p.allConnections = append(p.allConnections, c)
p.availableConnections = append(p.availableConnections, c)
p.connInfo = c.ConnInfo.DeepCopy()
return
}
// Acquire takes exclusive use of a connection until it is released.
func (p *ConnPool) Acquire() (*Conn, error) {
p.cond.L.Lock()
c, err := p.acquire(nil)
p.cond.L.Unlock()
return c, err
}
// deadlinePassed returns true if the given deadline has passed.
func (p *ConnPool) deadlinePassed(deadline *time.Time) bool {
return deadline != nil && time.Now().After(*deadline)
}
// acquire performs acquision assuming pool is already locked
func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
if p.closed {
return nil, ErrClosedPool
}
// A connection is available
// The pool works like a queue. Available connection will be returned
// from the head. A new connection will be added to the tail.
numAvailable := len(p.availableConnections)
if numAvailable > 0 {
c := p.availableConnections[0]
c.poolResetCount = p.resetCount
copy(p.availableConnections, p.availableConnections[1:])
p.availableConnections = p.availableConnections[:numAvailable-1]
return c, nil
}
// Set initial timeout/deadline value. If the method (acquire) happens to
// recursively call itself the deadline should retain its value.
if deadline == nil && p.acquireTimeout > 0 {
tmp := time.Now().Add(p.acquireTimeout)
deadline = &tmp
}
// Make sure the deadline (if it is) has not passed yet
if p.deadlinePassed(deadline) {
return nil, ErrAcquireTimeout
}
// If there is a deadline then start a timeout timer
var timer *time.Timer
if deadline != nil {
timer = time.AfterFunc(deadline.Sub(time.Now()), func() {
p.cond.Broadcast()
})
defer timer.Stop()
}
// No connections are available, but we can create more
if len(p.allConnections)+p.inProgressConnects < p.maxConnections {
// Create a new connection.
// Careful here: createConnectionUnlocked() removes the current lock,
// creates a connection and then locks it back.
c, err := p.createConnectionUnlocked()
if err != nil {
return nil, err
}
c.poolResetCount = p.resetCount
p.allConnections = append(p.allConnections, c)
return c, nil
}
// All connections are in use and we cannot create more
if p.logLevel >= LogLevelWarn {
p.logger.Log(LogLevelWarn, "waiting for available connection", nil)
}
// Wait until there is an available connection OR room to create a new connection
for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections {
if p.deadlinePassed(deadline) {
return nil, ErrAcquireTimeout
}
p.cond.Wait()
}
// Stop the timer so that we do not spawn it on every acquire call.
if timer != nil {
timer.Stop()
}
return p.acquire(deadline)
}
// Release gives up use of a connection.
func (p *ConnPool) Release(conn *Conn) {
if conn.ctxInProgress {
panic("should never release when context is in progress")
}
if conn.txStatus != 'I' {
conn.Exec("rollback")
}
if len(conn.channels) > 0 {
if err := conn.Unlisten("*"); err != nil {
conn.die(err)
}
conn.channels = make(map[string]struct{})
}
conn.notifications = nil
p.cond.L.Lock()
if conn.poolResetCount != p.resetCount {
conn.Close()
p.cond.L.Unlock()
p.cond.Signal()
return
}
if conn.IsAlive() {
p.availableConnections = append(p.availableConnections, conn)
} else {
p.removeFromAllConnections(conn)
}
p.cond.L.Unlock()
p.cond.Signal()
}
// removeFromAllConnections Removes the given connection from the list.
// It returns true if the connection was found and removed or false otherwise.
func (p *ConnPool) removeFromAllConnections(conn *Conn) bool {
for i, c := range p.allConnections {
if conn == c {
p.allConnections = append(p.allConnections[:i], p.allConnections[i+1:]...)
return true
}
}
return false
}
// Close ends the use of a connection pool. It prevents any new connections from
// being acquired and closes available underlying connections. Any acquired
// connections will be closed when they are released.
func (p *ConnPool) Close() {
p.cond.L.Lock()
defer p.cond.L.Unlock()
p.closed = true
for _, c := range p.availableConnections {
_ = c.Close()
}
// This will cause any checked out connections to be closed on release
p.resetCount++
}
// Reset closes all open connections, but leaves the pool open. It is intended
// for use when an error is detected that would disrupt all connections (such as
// a network interruption or a server state change).
//
// It is safe to reset a pool while connections are checked out. Those
// connections will be closed when they are returned to the pool.
func (p *ConnPool) Reset() {
p.cond.L.Lock()
defer p.cond.L.Unlock()
p.resetCount++
p.allConnections = p.allConnections[0:0]
for _, conn := range p.availableConnections {
conn.Close()
}
p.availableConnections = p.availableConnections[0:0]
}
// invalidateAcquired causes all acquired connections to be closed when released.
// The pool must already be locked.
func (p *ConnPool) invalidateAcquired() {
p.resetCount++
for _, c := range p.availableConnections {
c.poolResetCount = p.resetCount
}
p.allConnections = p.allConnections[:len(p.availableConnections)]
copy(p.allConnections, p.availableConnections)
}
// Stat returns connection pool statistics
func (p *ConnPool) Stat() (s ConnPoolStat) {
p.cond.L.Lock()
defer p.cond.L.Unlock()
s.MaxConnections = p.maxConnections
s.CurrentConnections = len(p.allConnections)
s.AvailableConnections = len(p.availableConnections)
return
}
func (p *ConnPool) createConnection() (*Conn, error) {
c, err := connect(p.config, p.connInfo)
if err != nil {
return nil, err
}
return p.afterConnectionCreated(c)
}
// createConnectionUnlocked Removes the current lock, creates a new connection, and
// then locks it back.
// Here is the point: lets say our pool dialer's OpenTimeout is set to 3 seconds.
// And we have a pool with 20 connections in it, and we try to acquire them all at
// startup.
// If it happens that the remote server is not accessible, then the first connection
// in the pool blocks all the others for 3 secs, before it gets the timeout. Then
// connection #2 holds the lock and locks everything for the next 3 secs until it
// gets OpenTimeout err, etc. And the very last 20th connection will fail only after
// 3 * 20 = 60 secs.
// To avoid this we put Connect(p.config) outside of the lock (it is thread safe)
// what would allow us to make all the 20 connection in parallel (more or less).
func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
p.inProgressConnects++
p.cond.L.Unlock()
c, err := Connect(p.config)
p.cond.L.Lock()
p.inProgressConnects--
if err != nil {
return nil, err
}
return p.afterConnectionCreated(c)
}
// afterConnectionCreated executes (if it is) afterConnect() callback and prepares
// all the known statements for the new connection.
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
if p.afterConnect != nil {
err := p.afterConnect(c)
if err != nil {
c.die(err)
return nil, err
}
}
for _, ps := range p.preparedStatements {
if _, err := c.Prepare(ps.Name, ps.SQL); err != nil {
c.die(err)
return nil, err
}
}
return c, nil
}
// Exec acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
var c *Conn
if c, err = p.Acquire(); err != nil {
return
}
defer p.Release(c)
return c.Exec(sql, arguments...)
}
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
var c *Conn
if c, err = p.Acquire(); err != nil {
return
}
defer p.Release(c)
return c.ExecEx(ctx, sql, options, arguments...)
}
// Query acquires a connection and delegates the call to that connection. When
// *Rows are closed, the connection is released automatically.
func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
c, err := p.Acquire()
if err != nil {
// Because checking for errors can be deferred to the *Rows, build one with the error
return &Rows{closed: true, err: err}, err
}
rows, err := c.Query(sql, args...)
if err != nil {
p.Release(c)
return rows, err
}
rows.connPool = p
return rows, nil
}
func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
c, err := p.Acquire()
if err != nil {
// Because checking for errors can be deferred to the *Rows, build one with the error
return &Rows{closed: true, err: err}, err
}
rows, err := c.QueryEx(ctx, sql, options, args...)
if err != nil {
p.Release(c)
return rows, err
}
rows.connPool = p
return rows, nil
}
// QueryRow acquires a connection and delegates the call to that connection. The
// connection is released automatically after Scan is called on the returned
// *Row.
func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
rows, _ := p.Query(sql, args...)
return (*Row)(rows)
}
func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
rows, _ := p.QueryEx(ctx, sql, options, args...)
return (*Row)(rows)
}
// Begin acquires a connection and begins a transaction on it. When the
// transaction is closed the connection will be automatically released.
func (p *ConnPool) Begin() (*Tx, error) {
return p.BeginEx(context.Background(), nil)
}
// Prepare creates a prepared statement on a connection in the pool to test the
// statement is valid. If it succeeds all connections accessed through the pool
// will have the statement available.
//
// Prepare creates a prepared statement with name and sql. sql can contain
// placeholders for bound parameters. These placeholders are referenced
// positional as $1, $2, etc.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with
// the same name and sql arguments. This allows a code path to Prepare and
// Query/Exec/PrepareEx without concern for if the statement has already been prepared.
func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) {
return p.PrepareEx(context.Background(), name, sql, nil)
}
// PrepareEx creates a prepared statement on a connection in the pool to test the
// statement is valid. If it succeeds all connections accessed through the pool
// will have the statement available.
//
// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
// It differs from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct
//
// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without
// concern for if the statement has already been prepared.
func (p *ConnPool) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
p.cond.L.Lock()
defer p.cond.L.Unlock()
if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil
}
c, err := p.acquire(nil)
if err != nil {
return nil, err
}
p.availableConnections = append(p.availableConnections, c)
// Double check that the statement was not prepared by someone else
// while we were acquiring the connection (since acquire is not fully
// blocking now, see createConnectionUnlocked())
if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil
}
ps, err := c.PrepareEx(ctx, name, sql, opts)
if err != nil {
return nil, err
}
for _, c := range p.availableConnections {
_, err := c.PrepareEx(ctx, name, sql, opts)
if err != nil {
return nil, err
}
}
p.invalidateAcquired()
p.preparedStatements[name] = ps
return ps, err
}
// Deallocate releases a prepared statement from all connections in the pool.
func (p *ConnPool) Deallocate(name string) (err error) {
p.cond.L.Lock()
defer p.cond.L.Unlock()
for _, c := range p.availableConnections {
if err := c.Deallocate(name); err != nil {
return err
}
}
p.invalidateAcquired()
delete(p.preparedStatements, name)
return nil
}
// BeginEx acquires a connection and starts a transaction with txOptions
// determining the transaction mode. When the transaction is closed the
// connection will be automatically released.
func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
for {
c, err := p.Acquire()
if err != nil {
return nil, err
}
tx, err := c.BeginEx(ctx, txOptions)
if err != nil {
alive := c.IsAlive()
p.Release(c)
// If connection is still alive then the error is not something trying
// again on a new connection would fix, so just return the error. But
// if the connection is dead try to acquire a new connection and try
// again.
if alive || ctx.Err() != nil {
return nil, err
}
continue
}
tx.connPool = p
return tx, nil
}
}
// CopyFrom acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
c, err := p.Acquire()
if err != nil {
return 0, err
}
defer p.Release(c)
return c.CopyFrom(tableName, columnNames, rowSrc)
}
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
c, err := p.Acquire()
if err != nil {
return "", err
}
defer p.Release(c)
return c.CopyFromReader(r, sql)
}
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
c, err := p.Acquire()
if err != nil {
return "", err
}
defer p.Release(c)
return c.CopyToWriter(w, sql, args...)
}
// BeginBatch acquires a connection and begins a batch on that connection. When
// *Batch is finished, the connection is released automatically.
func (p *ConnPool) BeginBatch() *Batch {
c, err := p.Acquire()
return &Batch{conn: c, connPool: p, err: err}
}

44
conn_pool_private_test.go Normal file
View File

@ -0,0 +1,44 @@
package pgx
import (
"testing"
)
func compareConnSlices(slice1, slice2 []*Conn) bool {
if len(slice1) != len(slice2) {
return false
}
for i, c := range slice1 {
if c != slice2[i] {
return false
}
}
return true
}
func TestConnPoolRemoveFromAllConnections(t *testing.T) {
t.Parallel()
pool := ConnPool{}
conn1 := &Conn{}
conn2 := &Conn{}
conn3 := &Conn{}
// First element
pool.allConnections = []*Conn{conn1, conn2, conn3}
pool.removeFromAllConnections(conn1)
if !compareConnSlices(pool.allConnections, []*Conn{conn2, conn3}) {
t.Fatal("First element test failed")
}
// Element somewhere in the middle
pool.allConnections = []*Conn{conn1, conn2, conn3}
pool.removeFromAllConnections(conn2)
if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn3}) {
t.Fatal("Middle element test failed")
}
// Last element
pool.allConnections = []*Conn{conn1, conn2, conn3}
pool.removeFromAllConnections(conn3)
if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn2}) {
t.Fatal("Last element test failed")
}
}

1083
conn_pool_test.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -2,22 +2,22 @@ package pgx
import (
"bytes"
"context"
"fmt"
"io"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/pkg/errors"
)
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
// making it usable by *Conn.CopyFrom.
func CopyFromRows(rows [][]any) CopyFromSource {
func CopyFromRows(rows [][]interface{}) CopyFromSource {
return &copyFromRows{rows: rows, idx: -1}
}
type copyFromRows struct {
rows [][]any
rows [][]interface{}
idx int
}
@ -26,7 +26,7 @@ func (ctr *copyFromRows) Next() bool {
return ctr.idx < len(ctr.rows)
}
func (ctr *copyFromRows) Values() ([]any, error) {
func (ctr *copyFromRows) Values() ([]interface{}, error) {
return ctr.rows[ctr.idx], nil
}
@ -34,63 +34,6 @@ func (ctr *copyFromRows) Err() error {
return nil
}
// CopyFromSlice returns a CopyFromSource interface over a dynamic func
// making it usable by *Conn.CopyFrom.
func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource {
return &copyFromSlice{next: next, idx: -1, len: length}
}
type copyFromSlice struct {
next func(int) ([]any, error)
idx int
len int
err error
}
func (cts *copyFromSlice) Next() bool {
cts.idx++
return cts.idx < cts.len
}
func (cts *copyFromSlice) Values() ([]any, error) {
values, err := cts.next(cts.idx)
if err != nil {
cts.err = err
}
return values, err
}
func (cts *copyFromSlice) Err() error {
return cts.err
}
// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
// or it returns an error. If nxtf returns an error, the copy is aborted.
func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
return &copyFromFunc{next: nxtf}
}
type copyFromFunc struct {
next func() ([]any, error)
valueRow []any
err error
}
func (g *copyFromFunc) Next() bool {
g.valueRow, g.err = g.next()
// only return true if valueRow exists and no error
return g.valueRow != nil && g.err == nil
}
func (g *copyFromFunc) Values() ([]any, error) {
return g.valueRow, g.err
}
func (g *copyFromFunc) Err() error {
return g.err
}
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
type CopyFromSource interface {
// Next returns true if there is another row and makes the next row data
@ -99,7 +42,7 @@ type CopyFromSource interface {
Next() bool
// Values returns the values for the current row.
Values() ([]any, error)
Values() ([]interface{}, error)
// Err returns any error that has been encountered by the CopyFromSource. If
// this is not nil *Conn.CopyFrom will abort the copy.
@ -112,17 +55,42 @@ type copyFrom struct {
columnNames []string
rowSrc CopyFromSource
readerErrChan chan error
mode QueryExecMode
}
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
if ct.conn.copyFromTracer != nil {
ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
TableName: ct.tableName,
ColumnNames: ct.columnNames,
})
}
func (ct *copyFrom) readUntilReadyForQuery() {
for {
msg, err := ct.conn.rxMsg()
if err != nil {
ct.readerErrChan <- err
close(ct.readerErrChan)
return
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
ct.conn.rxReadyForQuery(msg)
close(ct.readerErrChan)
return
case *pgproto3.CommandComplete:
case *pgproto3.ErrorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
default:
err = ct.conn.processContextFreeMsg(msg)
if err != nil {
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
}
}
}
}
func (ct *copyFrom) waitForReaderDone() error {
var err error
for err = range ct.readerErrChan {
}
return err
}
func (ct *copyFrom) run() (int, error) {
quotedTableName := ct.tableName.Sanitize()
cbuf := &bytes.Buffer{}
for i, cn := range ct.columnNames {
@ -133,144 +101,238 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
}
quotedColumnNames := cbuf.String()
var sd *pgconn.StatementDescription
switch ct.mode {
case QueryExecModeExec, QueryExecModeSimpleProtocol:
// These modes don't support the binary format. Before the inclusion of the
// QueryExecModes, Conn.Prepare was called on every COPY operation to get
// the OIDs. These prepared statements were not cached.
//
// Since that's the same behavior provided by QueryExecModeDescribeExec,
// we'll default to that mode.
ct.mode = QueryExecModeDescribeExec
fallthrough
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
var err error
sd, err = ct.conn.getStatementDescription(
ctx,
ct.mode,
fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
)
if err != nil {
return 0, fmt.Errorf("statement description failed: %w", err)
}
default:
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
if err != nil {
return 0, err
}
r, w := io.Pipe()
doneChan := make(chan struct{})
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
if err != nil {
return 0, err
}
go func() {
defer close(doneChan)
err = ct.conn.readUntilCopyInResponse()
if err != nil {
return 0, err
}
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
buf := ct.conn.wbuf
panicked := true
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
moreRows := true
for moreRows {
var err error
moreRows, buf, err = ct.buildCopyBuf(buf, sd)
if err != nil {
w.CloseWithError(err)
return
}
if ct.rowSrc.Err() != nil {
w.CloseWithError(ct.rowSrc.Err())
return
}
if len(buf) > 0 {
_, err = w.Write(buf)
if err != nil {
w.Close()
return
}
}
buf = buf[:0]
go ct.readUntilReadyForQuery()
defer ct.waitForReaderDone()
defer func() {
if panicked {
ct.conn.die(errors.New("panic while in copy from"))
}
w.Close()
}()
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
buf := ct.conn.wbuf
buf = append(buf, copyData)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
r.Close()
<-doneChan
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
var sentCount int
moreRows := true
for moreRows {
select {
case err = <-ct.readerErrChan:
panicked = false
return 0, err
default:
}
var addedRows int
var err error
moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps)
if err != nil {
panicked = false
ct.cancelCopyIn()
return 0, err
}
sentCount += addedRows
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = ct.conn.conn.Write(buf)
if err != nil {
panicked = false
ct.conn.die(err)
return 0, err
}
// Directly manipulate wbuf to reset to reuse the same buffer
buf = buf[0:5]
if ct.conn.copyFromTracer != nil {
ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
CommandTag: commandTag,
Err: err,
})
}
return commandTag.RowsAffected(), err
if ct.rowSrc.Err() != nil {
panicked = false
ct.cancelCopyIn()
return 0, ct.rowSrc.Err()
}
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
buf = append(buf, copyDone)
buf = pgio.AppendInt32(buf, 4)
_, err = ct.conn.conn.Write(buf)
if err != nil {
panicked = false
ct.conn.die(err)
return 0, err
}
err = ct.waitForReaderDone()
if err != nil {
panicked = false
return 0, err
}
panicked = false
return sentCount, nil
}
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
const sendBufSize = 65536 - 5 // The packet has a 5-byte header
lastBufLen := 0
largestRowLen := 0
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) {
var rowCount int
for ct.rowSrc.Next() {
lastBufLen = len(buf)
values, err := ct.rowSrc.Values()
if err != nil {
return false, nil, err
return false, nil, 0, err
}
if len(values) != len(ct.columnNames) {
return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
return false, nil, err
return false, nil, 0, err
}
}
rowLen := len(buf) - lastBufLen
if rowLen > largestRowLen {
largestRowLen = rowLen
}
rowCount++
// Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of
// io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531
// 13, 65531, 13, 65531, 13.
if len(buf) > sendBufSize-largestRowLen {
return true, buf, nil
if len(buf) > 65536 {
return true, buf, rowCount, nil
}
}
return false, buf, nil
return false, buf, rowCount, nil
}
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and
// an error.
func (c *Conn) readUntilCopyInResponse() error {
for {
msg, err := c.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.CopyInResponse:
return nil
default:
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
}
func (ct *copyFrom) cancelCopyIn() error {
buf := ct.conn.wbuf
buf = append(buf, copyFail)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, "client error: abort"...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err := ct.conn.conn.Write(buf)
if err != nil {
ct.conn.die(err)
return err
}
return nil
}
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
// It returns the number of rows copied and an error.
//
// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
// for the type of each column. Almost all types implemented by pgx support the binary format.
//
// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
// Conn.LoadType and pgtype.Map.RegisterType.
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
// CopyFrom requires all values use the binary format. Almost all types
// implemented by pgx use the binary format by default. Types implementing
// Encoder can only be used if they encode to the binary format.
func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
ct := &copyFrom{
conn: c,
tableName: tableName,
columnNames: columnNames,
rowSrc: rowSrc,
readerErrChan: make(chan error),
mode: c.config.DefaultQueryExecMode,
}
return ct.run(ctx)
return ct.run()
}
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
if err := c.sendSimpleQuery(sql); err != nil {
return "", err
}
if err := c.readUntilCopyInResponse(); err != nil {
return "", err
}
buf := c.wbuf
buf = append(buf, copyData)
sp := len(buf)
for {
n, err := r.Read(buf[5:cap(buf)])
if err == io.EOF && n == 0 {
break
}
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
if _, err := c.conn.Write(buf); err != nil {
return "", err
}
}
buf = buf[:0]
buf = append(buf, copyDone)
buf = pgio.AppendInt32(buf, 4)
if _, err := c.conn.Write(buf); err != nil {
return "", err
}
for {
msg, err := c.rxMsg()
if err != nil {
return "", err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return "", err
case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse:
return "", c.rxErrorResponse(msg)
default:
return "", c.processContextFreeMsg(msg)
}
}
}

File diff suppressed because it is too large Load Diff

63
copy_to.go Normal file
View File

@ -0,0 +1,63 @@
package pgx
import (
"io"
"github.com/jackc/pgx/pgproto3"
)
func (c *Conn) readUntilCopyOutResponse() error {
for {
msg, err := c.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.CopyOutResponse:
return nil
default:
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
}
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
if err := c.sendSimpleQuery(sql, args...); err != nil {
return "", err
}
if err := c.readUntilCopyOutResponse(); err != nil {
return "", err
}
for {
msg, err := c.rxMsg()
if err != nil {
return "", err
}
switch msg := msg.(type) {
case *pgproto3.CopyDone:
break
case *pgproto3.CopyData:
_, err := w.Write(msg.Data)
if err != nil {
c.die(err)
return "", err
}
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return "", nil
case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse:
return "", c.rxErrorResponse(msg)
default:
return "", c.processContextFreeMsg(msg)
}
}
}

115
copy_to_test.go Normal file
View File

@ -0,0 +1,115 @@
package pgx_test
import (
"bytes"
"testing"
"github.com/jackc/pgx"
)
func TestConnCopyToWriterSmall(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g json
)`)
mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`)
mustExec(t, conn, `insert into foo values (null, null, null, null, null, null, null)`)
inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
if err != nil {
t.Errorf("Unexpected error for CopyToWriter: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 2 {
t.Errorf("Expected CopyToWriter to return 2 copied rows, but got %d", copyCount)
}
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes()))
}
ensureConnValid(t, conn)
}
func TestConnCopyToWriterLarge(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g json,
h bytea
)`)
inputBytes := make([]byte, 0)
for i := 0; i < 1000; i++ {
mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`)
inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
}
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 1000 {
t.Errorf("Expected CopyToWriter to return 1 copied rows, but got %d", copyCount)
}
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
t.Errorf("Input rows and output rows do not equal")
}
ensureConnValid(t, conn)
}
func TestConnCopyToWriterQueryError(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
outputWriter := bytes.NewBuffer(make([]byte, 0))
res, err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
if err == nil {
t.Errorf("Expected CopyToWriter return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyToWriter return pgx.PgError, but instead it returned: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyToWriter to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn)
}

View File

@ -1,256 +0,0 @@
package pgx
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/jackc/pgx/v5/pgtype"
)
/*
buildLoadDerivedTypesSQL generates the correct query for retrieving type information.
pgVersion: the major version of the PostgreSQL server
typeNames: the names of the types to load. If nil, load all types.
*/
func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string {
supportsMultirange := (pgVersion >= 14)
var typeNamesClause string
if typeNames == nil {
// This should not occur; this will not return any types
typeNamesClause = "= ''"
} else {
typeNamesClause = "= ANY($1)"
}
parts := make([]string, 0, 10)
// Each of the type names provided might be found in pg_class or pg_type.
// Additionally, it may or may not include a schema portion.
parts = append(parts, `
WITH RECURSIVE
-- find the OIDs in pg_class which match one of the provided type names
selected_classes(oid,reltype) AS (
-- this query uses the namespace search path, so will match type names without a schema prefix
SELECT pg_class.oid, pg_class.reltype
FROM pg_catalog.pg_class
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
AND relname `, typeNamesClause, `
UNION ALL
-- this query will only match type names which include the schema prefix
SELECT pg_class.oid, pg_class.reltype
FROM pg_class
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
WHERE nspname || '.' || relname `, typeNamesClause, `
),
selected_types(oid) AS (
-- collect the OIDs from pg_types which correspond to the selected classes
SELECT reltype AS oid
FROM selected_classes
UNION ALL
-- as well as any other type names which match our criteria
SELECT pg_type.oid
FROM pg_type
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
WHERE typname `, typeNamesClause, `
OR nspname || '.' || typname `, typeNamesClause, `
),
-- this builds a parent/child mapping of objects, allowing us to know
-- all the child (ie: dependent) types that a parent (type) requires
-- As can be seen, there are 3 ways this can occur (the last of which
-- is due to being a composite class, where the composite fields are children)
pc(parent, child) AS (
SELECT parent.oid, parent.typelem
FROM pg_type parent
WHERE parent.typtype = 'b' AND parent.typelem != 0
UNION ALL
SELECT parent.oid, parent.typbasetype
FROM pg_type parent
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
UNION ALL
SELECT pg_type.oid, atttypid
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
),
-- Now construct a recursive query which includes a 'depth' element.
-- This is used to ensure that the "youngest" children are registered before
-- their parents.
relationships(parent, child, depth) AS (
SELECT DISTINCT 0::OID, selected_types.oid, 0
FROM selected_types
UNION ALL
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
FROM selected_classes c
inner join pg_type ON (c.reltype = pg_type.oid)
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
UNION ALL
SELECT pc.parent, pc.child, relationships.depth + 1
FROM pc
INNER JOIN relationships ON (pc.parent = relationships.child)
),
-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration
composite AS (
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
GROUP BY pg_type.oid
)
-- Bring together this information, showing all the information which might possibly be required
-- to complete the registration, applying filters to only show the items which relate to the selected
-- types/classes.
SELECT typname,
pg_namespace.nspname,
typtype,
typbasetype,
typelem,
pg_type.oid,`)
if supportsMultirange {
parts = append(parts, `
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
} else {
parts = append(parts, `
0 AS rngtypid,`)
}
parts = append(parts, `
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
attnames, atttypids
FROM relationships
INNER JOIN pg_type ON (pg_type.oid = relationships.child)
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
if supportsMultirange {
parts = append(parts, `
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
}
parts = append(parts, `
LEFT OUTER JOIN composite USING (oid)
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
WHERE NOT (typtype = 'b' AND typelem = 0)`)
parts = append(parts, `
GROUP BY typname, pg_namespace.nspname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
if supportsMultirange {
parts = append(parts, `
multirange.rngtypid,`)
}
parts = append(parts, `
attnames, atttypids
ORDER BY MAX(depth) desc, typname;`)
return strings.Join(parts, "")
}
type derivedTypeInfo struct {
Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32
TypeName, Typtype, NspName string
Attnames []string
Atttypids []uint32
}
// LoadTypes performs a single (complex) query, returning all the required
// information to register the named types, as well as any other types directly
// or indirectly required to complete the registration.
// The result of this call can be passed into RegisterTypes to complete the process.
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
m := c.TypeMap()
if len(typeNames) == 0 {
return nil, fmt.Errorf("No type names were supplied.")
}
// Disregard server version errors. This will result in
// the SQL not support recent structures such as multirange
serverVersion, _ := serverVersion(c)
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
if err != nil {
return nil, fmt.Errorf("While generating load types query: %w", err)
}
defer rows.Close()
result := make([]*pgtype.Type, 0, 100)
for rows.Next() {
ti := derivedTypeInfo{}
err = rows.Scan(&ti.TypeName, &ti.NspName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids)
if err != nil {
return nil, fmt.Errorf("While scanning type information: %w", err)
}
var type_ *pgtype.Type
switch ti.Typtype {
case "b": // array
dt, ok := m.TypeForOID(ti.Typelem)
if !ok {
return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}
case "c": // composite
var fields []pgtype.CompositeCodecField
for i, fieldName := range ti.Attnames {
dt, ok := m.TypeForOID(ti.Atttypids[i])
if !ok {
return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i])
}
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}}
case "d": // domain
dt, ok := m.TypeForOID(ti.Typbasetype)
if !ok {
return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec}
case "e": // enum
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}}
case "r": // range
dt, ok := m.TypeForOID(ti.Rngsubtype)
if !ok {
return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}}
case "m": // multirange
dt, ok := m.TypeForOID(ti.Rngtypid)
if !ok {
return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}
default:
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
}
// the type_ is imposible to be null
m.RegisterType(type_)
if ti.NspName != "" {
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
m.RegisterType(nspType)
result = append(result, nspType)
}
result = append(result, type_)
}
return result, nil
}
// serverVersion returns the postgresql server version.
func serverVersion(c *Conn) (int64, error) {
serverVersionStr := c.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
}
version, err := strconv.ParseInt(serverVersionStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
}
return version, nil
}

View File

@ -1,40 +0,0 @@
package pgx_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/require"
)
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `
drop type if exists dtype_test;
drop domain if exists anotheruint64;
create domain anotheruint64 as numeric(20,0);
create type dtype_test as (
a text,
b int4,
c anotheruint64,
d anotheruint64[]
);`)
require.NoError(t, err)
defer conn.Exec(ctx, "drop type dtype_test")
defer conn.Exec(ctx, "drop domain anotheruint64")
types, err := conn.LoadTypes(ctx, []string{"dtype_test"})
require.NoError(t, err)
require.Len(t, types, 6)
require.Equal(t, types[0].Name, "public.anotheruint64")
require.Equal(t, types[1].Name, "anotheruint64")
require.Equal(t, types[2].Name, "public._anotheruint64")
require.Equal(t, types[3].Name, "_anotheruint64")
require.Equal(t, types[4].Name, "public.dtype_test")
require.Equal(t, types[5].Name, "dtype_test")
})
}

291
doc.go
View File

@ -1,66 +1,58 @@
// Package pgx is a PostgreSQL database driver.
/*
pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar
to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use
github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for
details.
Establishing a Connection
The primary way of establishing a connection is with [pgx.Connect]:
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be
specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the
connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection
string.
Connection Pool
[*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
pgx provides lower level access to PostgreSQL than the standard database/sql.
It remains as similar to the database/sql interface as possible while
providing better speed and access to PostgreSQL specific features. Import
github.com/jackc/pgx/stdlib to use pgx as a database/sql compatible driver.
Query Interface
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(),
rows.Scan, and rows.Err().
pgx implements Query and Scan in the familiar database/sql style.
CollectRows can be used collect all returned rows into a slice.
var sum int32
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5)
numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32])
// Send the query to the server. The returned rows MUST be closed
// before conn can be used again.
rows, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil {
return err
return err
}
// numbers => [1 2 3 4 5]
ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows
directly.
// rows.Close is called by rows.Next when all rows are read
// or an error occurs in Next or Scan. So it may optionally be
// omitted if nothing in the rows.Next loop can panic. It is
// safe to close rows multiple times.
defer rows.Close()
var sum, n int32
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
_, err := pgx.ForEachRow(rows, []any{&n}, func() error {
sum += n
return nil
})
if err != nil {
return err
// Iterate through the result set
for rows.Next() {
var n int32
err = rows.Scan(&n)
if err != nil {
return err
}
sum += n
}
// Any errors encountered by rows.Next or rows.Scan will be returned here
if rows.Err() != nil {
return err
}
// No errors found - do something with sum
pgx also implements QueryRow in the same style as database/sql.
var name string
var weight int64
err := conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
err := conn.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
if err != nil {
return err
}
Use Exec to execute a query that does not return a result set.
commandTag, err := conn.Exec(context.Background(), "delete from widgets where id=$1", 42)
commandTag, err := conn.Exec("delete from widgets where id=$1", 42)
if err != nil {
return err
}
@ -68,127 +60,194 @@ Use Exec to execute a query that does not return a result set.
return errors.New("No row found to delete")
}
PostgreSQL Data Types
Connection Pool
pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types
directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may
require type registration. See that package's documentation for details.
Connection pool usage is explicit and configurable. In pgx, a connection can be
created and managed directly, or a connection pool with a configurable maximum
connections can be used. The connection pool offers an after connect hook that
allows every connection to be automatically setup before being made available in
the connection pool.
It delegates methods such as QueryRow to an automatically checked out and
released connection so you can avoid manually acquiring and releasing
connections when you do not need that level of control.
var name string
var weight int64
err := pool.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
if err != nil {
return err
}
Base Type Mapping
pgx maps between all common base types directly between Go and PostgreSQL. In
particular:
Go PostgreSQL
-----------------------
string varchar
text
// Integers are automatically be converted to any other integer type if
// it can be done without overflow or underflow.
int8
int16 smallint
int32 int
int64 bigint
int
uint8
uint16
uint32
uint64
uint
// Floats are strict and do not automatically convert like integers.
float32 float4
float64 float8
time.Time date
timestamp
timestamptz
[]byte bytea
Null Mapping
pgx can map nulls in two ways. The first is package pgtype provides types that
have a data field and a status field. They work in a similar fashion to
database/sql. The second is to use a pointer to a pointer.
var foo pgtype.Varchar
var bar *string
err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar)
if err != nil {
return err
}
Array Mapping
pgx maps between int16, int32, int64, float32, float64, and string Go slices
and the equivalent PostgreSQL array type. Go slices of native types do not
support nulls, so if a PostgreSQL array that contains a null is read into a
native Go slice an error will occur. The pgtype package includes many more
array types for PostgreSQL types that do not directly map to native Go types.
JSON and JSONB Mapping
pgx includes built-in support to marshal and unmarshal between Go types and
the PostgreSQL JSON and JSONB.
Inet and CIDR Mapping
pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In
addition, as a convenience pgx will encode from a net.IP; it will assume a /32
netmask for IPv4 and a /128 for IPv6.
Custom Type Support
pgx includes support for the common data types like integers, floats, strings,
dates, and times that have direct mappings between Go and SQL. In addition,
pgx uses the github.com/jackc/pgx/pgtype library to support more types. See
documention for that library for instructions on how to implement custom
types.
See example_custom_type_test.go for an example of a custom type for the
PostgreSQL point type.
pgx also includes support for custom types implementing the database/sql.Scanner
and database/sql/driver.Valuer interfaces.
If pgx does cannot natively encode a type and that type is a renamed type (e.g.
type MyTime time.Time) pgx will attempt to encode the underlying type. While
this is usually desired behavior it can produce suprising behavior if one the
underlying type and the renamed type each implement database/sql interfaces and
the other implements pgx interfaces. It is recommended that this situation be
avoided by implementing pgx interfaces on the renamed type.
Raw Bytes Mapping
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified
to PostgreSQL.
Transactions
Transactions are started by calling Begin.
Transactions are started by calling Begin or BeginEx. The BeginEx variant
can create a transaction with a specified isolation level.
tx, err := conn.Begin(context.Background())
tx, err := conn.Begin()
if err != nil {
return err
}
// Rollback is safe to call even if the tx is already closed, so if
// the tx commits successfully, this is a no-op
defer tx.Rollback(context.Background())
defer tx.Rollback()
_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
_, err = tx.Exec("insert into foo(id) values (1)")
if err != nil {
return err
}
err = tx.Commit(context.Background())
err = tx.Commit()
if err != nil {
return err
}
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
These are internally implemented with savepoints.
Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of
a pseudo nested transaction.
BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
transaction depending on the return value of the function. These can be simpler and less error prone to use.
err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
return err
})
if err != nil {
return err
}
Prepared Statements
Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx
includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are
automatically prepared on first execution and the prepared statement is reused on subsequent executions. See ParseConfig
for information on how to customize or disable the statement cache.
Copy Protocol
Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a
CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface.
Or implement CopyFromSource to avoid buffering the entire data set in memory.
Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL
copy protocol. CopyFrom accepts a CopyFromSource interface. If the data is already
in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource interface. Or
implement CopyFromSource to avoid buffering the entire data set in memory.
rows := [][]any{
rows := [][]interface{}{
{"John", "Smith", int32(36)},
{"Jane", "Doe", int32(29)},
}
copyCount, err := conn.CopyFrom(
context.Background(),
pgx.Identifier{"people"},
[]string{"first_name", "last_name", "age"},
pgx.CopyFromRows(rows),
)
When you already have a typed array using CopyFromSlice can be more convenient.
rows := []User{
{"John", "Smith", 36},
{"Jane", "Doe", 29},
}
copyCount, err := conn.CopyFrom(
context.Background(),
pgx.Identifier{"people"},
[]string{"first_name", "last_name", "age"},
pgx.CopyFromSlice(len(rows), func(i int) ([]any, error) {
return []any{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil
}),
)
CopyFrom can be faster than an insert with as few as 5 rows.
Listen and Notify
pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a
notification is received or the context is canceled.
pgx can listen to the PostgreSQL notification system with the
WaitForNotification function. It takes a maximum time to wait for a
notification.
_, err := conn.Exec(context.Background(), "listen channelname")
err := conn.Listen("channelname")
if err != nil {
return err
return nil
}
notification, err := conn.WaitForNotification(context.Background())
if err != nil {
return err
if notification, err := conn.WaitForNotification(time.Second); err != nil {
// do something with notification
}
// do something with notification
TLS
Tracing and Logging
The pgx ConnConfig struct has a TLSConfig field. If this field is
nil, then TLS will be disabled. If it is present, then it will be used to
configure the TLS connection. This allows total configuration of the TLS
connection.
pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer.
pgx has never explicitly supported Postgres < 9.6's `ssl_renegotiation` option.
As of v3.3.0, it doesn't send `ssl_renegotiation: 0` either to support Redshift
(https://github.com/jackc/pgx/pull/476). If you need TLS Renegotiation,
consider supplying `ConnConfig.TLSConfig` with a non-zero `Renegotiation`
value and if it's not the default on your server, set `ssl_renegotiation`
via `ConnConfig.RuntimeParams`.
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
Logging
For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3.
Lower Level PostgreSQL Functionality
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn is
implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer.
PgBouncer
By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
pgx defines a simple logger interface. Connections optionally accept a logger
that satisfies this interface. Set LogLevel to control logging verbosity.
Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus, and
the testing log are provided in the log directory.
*/
package pgx

105
example_custom_type_test.go Normal file
View File

@ -0,0 +1,105 @@
package pgx_test
import (
"fmt"
"regexp"
"strconv"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
"github.com/pkg/errors"
)
var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`)
// Point represents a point that may be null.
type Point struct {
X, Y float64 // Coordinates of point
Status pgtype.Status
}
func (dst *Point) Set(src interface{}) error {
return errors.Errorf("cannot convert %v to Point", src)
}
func (dst *Point) Get() interface{} {
switch dst.Status {
case pgtype.Present:
return dst
case pgtype.Null:
return nil
default:
return dst.Status
}
}
func (src *Point) AssignTo(dst interface{}) error {
return errors.Errorf("cannot assign %v to %T", src, dst)
}
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
if src == nil {
*dst = Point{Status: pgtype.Null}
return nil
}
s := string(src)
match := pointRegexp.FindStringSubmatch(s)
if match == nil {
return errors.Errorf("Received invalid point: %v", s)
}
x, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return errors.Errorf("Received invalid point: %v", s)
}
y, err := strconv.ParseFloat(match[2], 64)
if err != nil {
return errors.Errorf("Received invalid point: %v", s)
}
*dst = Point{X: x, Y: y, Status: pgtype.Present}
return nil
}
func (src *Point) String() string {
if src.Status == pgtype.Null {
return "null point"
}
return fmt.Sprintf("%.1f, %.1f", src.X, src.Y)
}
func Example_CustomType() {
conn, err := pgx.Connect(*defaultConnConfig)
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
// Override registered handler for point
conn.ConnInfo.RegisterDataType(pgtype.DataType{
Value: &Point{},
Name: "point",
OID: 600,
})
p := &Point{}
err = conn.QueryRow("select null::point").Scan(p)
if err != nil {
fmt.Println(err)
return
}
fmt.Println(p)
err = conn.QueryRow("select point(1.5,2.5)").Scan(p)
if err != nil {
fmt.Println(err)
return
}
fmt.Println(p)
// Output:
// null point
// 1.5, 2.5
}

View File

@ -1,15 +1,13 @@
package pgtype_test
package pgx_test
import (
"context"
"fmt"
"os"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx"
)
func Example_json() {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
func Example_JSON() {
conn, err := pgx.Connect(*defaultConnConfig)
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
@ -27,7 +25,7 @@ func Example_json() {
var output person
err = conn.QueryRow(context.Background(), "select $1::json", input).Scan(&output)
err = conn.QueryRow("select $1::json", input).Scan(&output)
if err != nil {
fmt.Println(err)
return

View File

@ -3,5 +3,5 @@
* chat is a command line chat program using listen/notify.
* todo is a command line todo list that demonstrates basic CRUD actions.
* url_shortener contains a simple example of using pgx in a web context.
* [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx.
* [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx (uses v1 of pgx).
* [The Pithy Reader](https://github.com/jackc/tpr) is a RSS aggregator that uses pgx.

View File

@ -8,7 +8,12 @@ between them.
## Connection configuration
The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.)
The database connection is configured via the standard PostgreSQL environment variables.
* PGHOST - defaults to localhost
* PGUSER - defaults to current OS user
* PGPASSWORD - defaults to empty string
* PGDATABASE - defaults to user name
You can either export them then run chat:

View File

@ -6,14 +6,19 @@ import (
"fmt"
"os"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx"
)
var pool *pgxpool.Pool
var pool *pgx.ConnPool
func main() {
var err error
pool, err = pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
config, err := pgx.ParseEnvLibpq()
if err != nil {
fmt.Fprintln(os.Stderr, "Unable to parse environment:", err)
os.Exit(1)
}
pool, err = pgx.NewConnPool(pgx.ConnPoolConfig{ConnConfig: config})
if err != nil {
fmt.Fprintln(os.Stderr, "Unable to connect to database:", err)
os.Exit(1)
@ -35,7 +40,7 @@ Type "exit" to quit.`)
os.Exit(0)
}
_, err = pool.Exec(context.Background(), "select pg_notify('chat', $1)", msg)
_, err = pool.Exec("select pg_notify('chat', $1)", msg)
if err != nil {
fmt.Fprintln(os.Stderr, "Error sending notification:", err)
os.Exit(1)
@ -48,21 +53,17 @@ Type "exit" to quit.`)
}
func listen() {
conn, err := pool.Acquire(context.Background())
conn, err := pool.Acquire()
if err != nil {
fmt.Fprintln(os.Stderr, "Error acquiring connection:", err)
os.Exit(1)
}
defer conn.Release()
defer pool.Release(conn)
_, err = conn.Exec(context.Background(), "listen chat")
if err != nil {
fmt.Fprintln(os.Stderr, "Error listening to chat channel:", err)
os.Exit(1)
}
conn.Listen("chat")
for {
notification, err := conn.Conn().WaitForNotification(context.Background())
notification, err := conn.WaitForNotification(context.Background())
if err != nil {
fmt.Fprintln(os.Stderr, "Error waiting for notification:", err)
os.Exit(1)

View File

@ -19,7 +19,12 @@ Build todo:
## Connection configuration
The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.)
The database connection is configured via enviroment variables.
* PGHOST - defaults to localhost
* PGUSER - defaults to current OS user
* PGPASSWORD - defaults to empty string
* PGDATABASE - defaults to user name
You can either export them then run todo:
@ -40,7 +45,7 @@ Or you can prefix the todo execution with the environment variables:
## Update a task
./todo update 1 'Learn more go'
./todo add 1 'Learn more go'
## Delete a task

View File

@ -1,19 +1,22 @@
package main
import (
"context"
"fmt"
"github.com/jackc/pgx"
"os"
"strconv"
"github.com/jackc/pgx/v5"
)
var conn *pgx.Conn
func main() {
var err error
conn, err = pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
config, err := pgx.ParseEnvLibpq()
if err != nil {
fmt.Fprintln(os.Stderr, "Unable to parse environment:", err)
os.Exit(1)
}
conn, err = pgx.Connect(config)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to connection to database: %v\n", err)
os.Exit(1)
@ -71,7 +74,7 @@ func main() {
}
func listTasks() error {
rows, _ := conn.Query(context.Background(), "select * from tasks")
rows, _ := conn.Query("select * from tasks")
for rows.Next() {
var id int32
@ -87,17 +90,17 @@ func listTasks() error {
}
func addTask(description string) error {
_, err := conn.Exec(context.Background(), "insert into tasks(description) values($1)", description)
_, err := conn.Exec("insert into tasks(description) values($1)", description)
return err
}
func updateTask(itemNum int32, description string) error {
_, err := conn.Exec(context.Background(), "update tasks set description=$1 where id=$2", description, itemNum)
_, err := conn.Exec("update tasks set description=$1 where id=$2", description, itemNum)
return err
}
func removeTask(itemNum int32) error {
_, err := conn.Exec(context.Background(), "delete from tasks where id=$1", itemNum)
_, err := conn.Exec("delete from tasks where id=$1", itemNum)
return err
}

View File

@ -6,28 +6,20 @@ This is a sample REST URL shortener service implemented using pgx as the connect
Create a PostgreSQL database and run structure.sql into it to create the necessary data schema.
Configure the database connection with `DATABASE_URL` or standard PostgreSQL (`PG*`) environment variables or
Edit connectionOptions in main.go with the location and credentials for your database.
Run main.go:
```
go run main.go
```
go run main.go
## Create or Update a Shortened URL
```
curl -X PUT -d 'http://www.google.com' http://localhost:8080/google
```
curl -X PUT -d 'http://www.google.com' http://localhost:8080/google
## Get a Shortened URL
```
curl http://localhost:8080/google
```
curl http://localhost:8080/google
## Delete a Shortened URL
```
curl -X DELETE http://localhost:8080/google
```
curl -X DELETE http://localhost:8080/google

View File

@ -1,21 +1,43 @@
package main
import (
"context"
"io"
"log"
"io/ioutil"
"net/http"
"os"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx"
"github.com/jackc/pgx/log/log15adapter"
log "gopkg.in/inconshreveable/log15.v2"
)
var db *pgxpool.Pool
var pool *pgx.ConnPool
// afterConnect creates the prepared statements that this application uses
func afterConnect(conn *pgx.Conn) (err error) {
_, err = conn.Prepare("getUrl", `
select url from shortened_urls where id=$1
`)
if err != nil {
return
}
_, err = conn.Prepare("deleteUrl", `
delete from shortened_urls where id=$1
`)
if err != nil {
return
}
_, err = conn.Prepare("putUrl", `
insert into shortened_urls(id, url) values ($1, $2)
on conflict (id) do update set url=excluded.url
`)
return
}
func getUrlHandler(w http.ResponseWriter, req *http.Request) {
var url string
err := db.QueryRow(context.Background(), "select url from shortened_urls where id=$1", req.URL.Path).Scan(&url)
err := pool.QueryRow("getUrl", req.URL.Path).Scan(&url)
switch err {
case nil:
http.Redirect(w, req, url, http.StatusSeeOther)
@ -29,15 +51,14 @@ func getUrlHandler(w http.ResponseWriter, req *http.Request) {
func putUrlHandler(w http.ResponseWriter, req *http.Request) {
id := req.URL.Path
var url string
if body, err := io.ReadAll(req.Body); err == nil {
if body, err := ioutil.ReadAll(req.Body); err == nil {
url = string(body)
} else {
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
if _, err := db.Exec(context.Background(), `insert into shortened_urls(id, url) values ($1, $2)
on conflict (id) do update set url=excluded.url`, id, url); err == nil {
if _, err := pool.Exec("putUrl", id, url); err == nil {
w.WriteHeader(http.StatusOK)
} else {
http.Error(w, "Internal server error", http.StatusInternalServerError)
@ -45,7 +66,7 @@ func putUrlHandler(w http.ResponseWriter, req *http.Request) {
}
func deleteUrlHandler(w http.ResponseWriter, req *http.Request) {
if _, err := db.Exec(context.Background(), "delete from shortened_urls where id=$1", req.URL.Path); err == nil {
if _, err := pool.Exec("deleteUrl", req.URL.Path); err == nil {
w.WriteHeader(http.StatusOK)
} else {
http.Error(w, "Internal server error", http.StatusInternalServerError)
@ -70,21 +91,32 @@ func urlHandler(w http.ResponseWriter, req *http.Request) {
}
func main() {
poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatalln("Unable to parse DATABASE_URL:", err)
}
logger := log15adapter.NewLogger(log.New("module", "pgx"))
db, err = pgxpool.NewWithConfig(context.Background(), poolConfig)
var err error
connPoolConfig := pgx.ConnPoolConfig{
ConnConfig: pgx.ConnConfig{
Host: "127.0.0.1",
User: "jack",
Password: "jack",
Database: "url_shortener",
Logger: logger,
},
MaxConnections: 5,
AfterConnect: afterConnect,
}
pool, err = pgx.NewConnPool(connPoolConfig)
if err != nil {
log.Fatalln("Unable to create connection pool:", err)
log.Crit("Unable to create connection pool", "error", err)
os.Exit(1)
}
http.HandleFunc("/", urlHandler)
log.Println("Starting URL shortener on localhost:8080")
log.Info("Starting URL shortener on localhost:8080")
err = http.ListenAndServe("localhost:8080", nil)
if err != nil {
log.Fatalln("Unable to start web server:", err)
log.Crit("Unable to start web server", "error", err)
os.Exit(1)
}
}

View File

@ -1,146 +0,0 @@
package pgx
import (
"fmt"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result
// formats for an extended query.
type ExtendedQueryBuilder struct {
ParamValues [][]byte
paramValueBytes []byte
ParamFormats []int16
ResultFormats []int16
}
// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If
// sd is nil then QueryExecModeExec behavior will be used.
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
eqb.reset()
if sd == nil {
for i := range args {
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
return nil
}
if len(sd.ParamOIDs) != len(args) {
return fmt.Errorf("mismatched param and argument count")
}
for i := range args {
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
for i := range sd.Fields {
eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
return nil
}
// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it
// must be an untyped nil.
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
if format == -1 {
preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
if preferredErr == nil {
return nil
}
var otherFormat int16
if preferredFormat == TextFormatCode {
otherFormat = BinaryFormatCode
} else {
otherFormat = TextFormatCode
}
otherErr := eqb.appendParam(m, oid, otherFormat, arg)
if otherErr == nil {
return nil
}
return preferredErr // return the error from the preferred format
}
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
if err != nil {
return err
}
eqb.ParamFormats = append(eqb.ParamFormats, format)
eqb.ParamValues = append(eqb.ParamValues, v)
return nil
}
// appendResultFormat appends a result format to the query.
func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) {
eqb.ResultFormats = append(eqb.ResultFormats, format)
}
// reset readies eqb to build another query.
func (eqb *ExtendedQueryBuilder) reset() {
eqb.ParamValues = eqb.ParamValues[0:0]
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
eqb.ParamFormats = eqb.ParamFormats[0:0]
eqb.ResultFormats = eqb.ResultFormats[0:0]
if cap(eqb.ParamValues) > 64 {
eqb.ParamValues = make([][]byte, 0, 64)
}
if cap(eqb.paramValueBytes) > 256 {
eqb.paramValueBytes = make([]byte, 0, 256)
}
if cap(eqb.ParamFormats) > 64 {
eqb.ParamFormats = make([]int16, 0, 64)
}
if cap(eqb.ResultFormats) > 64 {
eqb.ResultFormats = make([]int16, 0, 64)
}
}
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128)
}
pos := len(eqb.paramValueBytes)
buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
// chooseParameterFormatCode determines the correct format code for an
// argument to a prepared statement. It defaults to TextFormatCode if no
// determination can be made.
func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 {
switch arg.(type) {
case string, *string:
return TextFormatCode
}
return m.FormatCodeForOID(oid)
}

119
fastpath.go Normal file
View File

@ -0,0 +1,119 @@
package pgx
import (
"encoding/binary"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
func newFastpath(cn *Conn) *fastpath {
return &fastpath{cn: cn, fns: make(map[string]pgtype.OID)}
}
type fastpath struct {
cn *Conn
fns map[string]pgtype.OID
}
func (f *fastpath) functionOID(name string) pgtype.OID {
return f.fns[name]
}
func (f *fastpath) addFunction(name string, oid pgtype.OID) {
f.fns[name] = oid
}
func (f *fastpath) addFunctions(rows *Rows) error {
for rows.Next() {
var name string
var oid pgtype.OID
if err := rows.Scan(&name, &oid); err != nil {
return err
}
f.addFunction(name, oid)
}
return rows.Err()
}
type fpArg []byte
func fpIntArg(n int32) fpArg {
res := make([]byte, 4)
binary.BigEndian.PutUint32(res, uint32(n))
return res
}
func fpInt64Arg(n int64) fpArg {
res := make([]byte, 8)
binary.BigEndian.PutUint64(res, uint64(n))
return res
}
func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) {
if err := f.cn.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
buf := f.cn.wbuf
buf = append(buf, 'F') // function call
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = pgio.AppendInt32(buf, int32(oid)) // function object id
buf = pgio.AppendInt16(buf, 1) // # of argument format codes
buf = pgio.AppendInt16(buf, 1) // format code: binary
buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments
for _, arg := range args {
buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument
buf = append(buf, arg...) // argument value
}
buf = pgio.AppendInt16(buf, 1) // response format code (binary)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
if _, err := f.cn.conn.Write(buf); err != nil {
return nil, err
}
f.cn.pendingReadyForQueryCount++
for {
msg, err := f.cn.rxMsg()
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *pgproto3.FunctionCallResponse:
res = make([]byte, len(msg.Result))
copy(res, msg.Result)
case *pgproto3.ReadyForQuery:
f.cn.rxReadyForQuery(msg)
// done
return res, err
default:
if err := f.cn.processContextFreeMsg(msg); err != nil {
return nil, err
}
}
}
}
func (f *fastpath) CallFn(fn string, args []fpArg) ([]byte, error) {
return f.Call(f.functionOID(fn), args)
}
func fpInt32(data []byte, err error) (int32, error) {
if err != nil {
return 0, err
}
n := int32(binary.BigEndian.Uint32(data))
return n, nil
}
func fpInt64(data []byte, err error) (int64, error) {
if err != nil {
return 0, err
}
return int64(binary.BigEndian.Uint64(data)), nil
}

21
go.mod
View File

@ -1,21 +0,0 @@
module github.com/jackc/pgx/v5
go 1.23.0
require (
github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761
github.com/jackc/puddle/v2 v2.2.2
github.com/stretchr/testify v1.8.1
golang.org/x/crypto v0.37.0
golang.org/x/sync v0.13.0
golang.org/x/text v0.24.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

45
go.sum
View File

@ -1,45 +0,0 @@
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

61
go_stdlib.go Normal file
View File

@ -0,0 +1,61 @@
package pgx
import (
"database/sql/driver"
"reflect"
)
// This file contains code copied from the Go standard library due to the
// required function not being public.
// Copyright (c) 2009 The Go Authors. All rights reserved.
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// From database/sql/convert.go
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
// pointer is nil, it would panic at runtime in the panicwrap
// method. Treat it like nil instead.
// Issue 8415.
//
// This is so people can implement driver.Value on value types and
// still use nil pointers to those types to mean nil/NULL, just like
// string/*string.
//
// This function is mirrored in the database/sql/driver package.
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
}

View File

@ -1,45 +1,21 @@
package pgx_test
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/require"
"github.com/jackc/pgx"
)
var defaultConnTestRunner pgxtest.ConnTestRunner
func init() {
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
return config
}
}
func mustConnectString(t testing.TB, connString string) *pgx.Conn {
conn, err := pgx.Connect(context.Background(), connString)
func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
conn, err := pgx.Connect(config)
if err != nil {
t.Fatalf("Unable to establish connection: %v", err)
}
return conn
}
func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig {
config, err := pgx.ParseConfig(connString)
require.Nil(t, err)
return config
}
func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn {
conn, err := pgx.ConnectConfig(context.Background(), config)
func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.ReplicationConn {
conn, err := pgx.ReplicationConnect(config)
if err != nil {
t.Fatalf("Unable to establish connection: %v", err)
}
@ -47,25 +23,32 @@ func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn {
}
func closeConn(t testing.TB, conn *pgx.Conn) {
err := conn.Close(context.Background())
err := conn.Close()
if err != nil {
t.Fatalf("conn.Close unexpectedly failed: %v", err)
}
}
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) {
func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) {
err := conn.Close()
if err != nil {
t.Fatalf("conn.Close unexpectedly failed: %v", err)
}
}
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) {
var err error
if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
if commandTag, err = conn.Exec(sql, arguments...); err != nil {
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
}
return
}
// Do a simple query to ensure the connection is still usable
func ensureConnValid(t testing.TB, conn *pgx.Conn) {
func ensureConnValid(t *testing.T, conn *pgx.Conn) {
var sum, rowCount int32
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
rows, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@ -79,7 +62,7 @@ func ensureConnValid(t testing.TB, conn *pgx.Conn) {
}
if rows.Err() != nil {
t.Fatalf("conn.Query failed: %v", rows.Err())
t.Fatalf("conn.Query failed: %v", err)
}
if rowCount != 10 {
@ -89,50 +72,3 @@ func ensureConnValid(t testing.TB, conn *pgx.Conn) {
t.Error("Wrong values returned")
}
}
func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
if !assert.NotNil(t, expected) {
return
}
if !assert.NotNil(t, actual) {
return
}
assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName)
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
// Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
if expected.TLSConfig != nil {
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
}
}
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
for i := range expected.Fallbacks {
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
if expected.Fallbacks[i].TLSConfig != nil {
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
}
}
}
}
}

View File

@ -1,70 +0,0 @@
// Package iobufpool implements a global segregated-fit pool of buffers for IO.
//
// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid
// an allocation is purposely not documented. https://github.com/golang/go/issues/16323
package iobufpool
import "sync"
const minPoolExpOf2 = 8
var pools [18]*sync.Pool
func init() {
for i := range pools {
bufLen := 1 << (minPoolExpOf2 + i)
pools[i] = &sync.Pool{
New: func() any {
buf := make([]byte, bufLen)
return &buf
},
}
}
}
// Get gets a []byte of len size with cap <= size*2.
func Get(size int) *[]byte {
i := getPoolIdx(size)
if i >= len(pools) {
buf := make([]byte, size)
return &buf
}
ptrBuf := (pools[i].Get().(*[]byte))
*ptrBuf = (*ptrBuf)[:size]
return ptrBuf
}
func getPoolIdx(size int) int {
size--
size >>= minPoolExpOf2
i := 0
for size > 0 {
size >>= 1
i++
}
return i
}
// Put returns buf to the pool.
func Put(buf *[]byte) {
i := putPoolIdx(cap(*buf))
if i < 0 {
return
}
pools[i].Put(buf)
}
func putPoolIdx(size int) int {
minPoolSize := 1 << minPoolExpOf2
for i := range pools {
if size == minPoolSize<<i {
return i
}
}
return -1
}

View File

@ -1,36 +0,0 @@
package iobufpool
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPoolIdx(t *testing.T) {
tests := []struct {
size int
expected int
}{
{size: 0, expected: 0},
{size: 1, expected: 0},
{size: 255, expected: 0},
{size: 256, expected: 0},
{size: 257, expected: 1},
{size: 511, expected: 1},
{size: 512, expected: 1},
{size: 513, expected: 2},
{size: 1023, expected: 2},
{size: 1024, expected: 2},
{size: 1025, expected: 3},
{size: 2047, expected: 3},
{size: 2048, expected: 3},
{size: 2049, expected: 4},
{size: 8388607, expected: 15},
{size: 8388608, expected: 15},
{size: 8388609, expected: 16},
}
for _, tt := range tests {
idx := getPoolIdx(tt.size)
assert.Equalf(t, tt.expected, idx, "size: %d", tt.size)
}
}

View File

@ -1,85 +0,0 @@
package iobufpool_test
import (
"testing"
"github.com/jackc/pgx/v5/internal/iobufpool"
"github.com/stretchr/testify/assert"
)
func TestGetCap(t *testing.T) {
tests := []struct {
requestedLen int
expectedCap int
}{
{requestedLen: 0, expectedCap: 256},
{requestedLen: 128, expectedCap: 256},
{requestedLen: 255, expectedCap: 256},
{requestedLen: 256, expectedCap: 256},
{requestedLen: 257, expectedCap: 512},
{requestedLen: 511, expectedCap: 512},
{requestedLen: 512, expectedCap: 512},
{requestedLen: 513, expectedCap: 1024},
{requestedLen: 1023, expectedCap: 1024},
{requestedLen: 1024, expectedCap: 1024},
{requestedLen: 33554431, expectedCap: 33554432},
{requestedLen: 33554432, expectedCap: 33554432},
// Above 32 MiB skip the pool and allocate exactly the requested size.
{requestedLen: 33554433, expectedCap: 33554433},
}
for _, tt := range tests {
buf := iobufpool.Get(tt.requestedLen)
assert.Equalf(t, tt.requestedLen, len(*buf), "bad len for requestedLen: %d", len(*buf), tt.requestedLen)
assert.Equalf(t, tt.expectedCap, cap(*buf), "bad cap for requestedLen: %d", tt.requestedLen)
}
}
func TestPutHandlesWrongSizedBuffers(t *testing.T) {
for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} {
putBuf := make([]byte, putBufSize)
iobufpool.Put(&putBuf)
tests := []struct {
requestedLen int
expectedCap int
}{
{requestedLen: 0, expectedCap: 256},
{requestedLen: 128, expectedCap: 256},
{requestedLen: 255, expectedCap: 256},
{requestedLen: 256, expectedCap: 256},
{requestedLen: 257, expectedCap: 512},
{requestedLen: 511, expectedCap: 512},
{requestedLen: 512, expectedCap: 512},
{requestedLen: 513, expectedCap: 1024},
{requestedLen: 1023, expectedCap: 1024},
{requestedLen: 1024, expectedCap: 1024},
{requestedLen: 33554431, expectedCap: 33554432},
{requestedLen: 33554432, expectedCap: 33554432},
// Above 32 MiB skip the pool and allocate exactly the requested size.
{requestedLen: 33554433, expectedCap: 33554433},
}
for _, tt := range tests {
getBuf := iobufpool.Get(tt.requestedLen)
assert.Equalf(t, tt.requestedLen, len(*getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
assert.Equalf(t, tt.expectedCap, cap(*getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
}
}
}
func TestPutGetBufferReuse(t *testing.T) {
// There is no way to guarantee a buffer will be reused. It should be, but a GC between the Put and the Get will cause
// it not to be. So try many times.
for i := 0; i < 100000; i++ {
buf := iobufpool.Get(4)
(*buf)[0] = 1
iobufpool.Put(buf)
buf = iobufpool.Get(4)
if (*buf)[0] == 1 {
return
}
}
t.Error("buffer was never reused")
}

View File

@ -1,6 +0,0 @@
# pgio
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
pgio provides functions for appending integers to a []byte while doing byte
order conversion.

View File

@ -1,136 +0,0 @@
// Package pgmock provides the ability to mock a PostgreSQL server.
package pgmock
import (
"fmt"
"io"
"reflect"
"github.com/jackc/pgx/v5/pgproto3"
)
type Step interface {
Step(*pgproto3.Backend) error
}
type Script struct {
Steps []Step
}
func (s *Script) Run(backend *pgproto3.Backend) error {
for _, step := range s.Steps {
err := step.Step(backend)
if err != nil {
return err
}
}
return nil
}
func (s *Script) Step(backend *pgproto3.Backend) error {
return s.Run(backend)
}
type expectMessageStep struct {
want pgproto3.FrontendMessage
any bool
}
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
msg, err := backend.Receive()
if err != nil {
return err
}
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
return nil
}
if !reflect.DeepEqual(msg, e.want) {
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
}
return nil
}
type expectStartupMessageStep struct {
want *pgproto3.StartupMessage
any bool
}
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
msg, err := backend.ReceiveStartupMessage()
if err != nil {
return err
}
if e.any {
return nil
}
if !reflect.DeepEqual(msg, e.want) {
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
}
return nil
}
func ExpectMessage(want pgproto3.FrontendMessage) Step {
return expectMessage(want, false)
}
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
return expectMessage(want, true)
}
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
if want, ok := want.(*pgproto3.StartupMessage); ok {
return &expectStartupMessageStep{want: want, any: any}
}
return &expectMessageStep{want: want, any: any}
}
type sendMessageStep struct {
msg pgproto3.BackendMessage
}
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
backend.Send(e.msg)
return backend.Flush()
}
func SendMessage(msg pgproto3.BackendMessage) Step {
return &sendMessageStep{msg: msg}
}
type waitForCloseMessageStep struct{}
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
for {
msg, err := backend.Receive()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
if _, ok := msg.(*pgproto3.Terminate); ok {
return nil
}
}
}
func WaitForClose() Step {
return &waitForCloseMessageStep{}
}
func AcceptUnauthenticatedConnRequestSteps() []Step {
return []Step{
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
SendMessage(&pgproto3.AuthenticationOk{}),
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
}
}

View File

@ -1,91 +0,0 @@
package pgmock_test
import (
"context"
"fmt"
"net"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5/internal/pgmock"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestScript(t *testing.T) {
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
{
Name: []byte("?column?"),
TableOID: 0,
TableAttributeNumber: 0,
DataTypeOID: 23,
DataTypeSize: 4,
TypeModifier: -1,
Format: 0,
},
},
}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{
Values: [][]byte{[]byte("42")},
}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{}))
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(time.Second))
if err != nil {
serverErrChan <- err
return
}
err = script.Run(pgproto3.NewBackend(conn, conn))
if err != nil {
serverErrChan <- err
return
}
}()
host, port, _ := strings.Cut(ln.Addr().String(), ":")
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pgConn, err := pgconn.Connect(ctx, connStr)
require.NoError(t, err)
results, err := pgConn.Exec(ctx, "select 42").ReadAll()
assert.NoError(t, err)
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "42", string(results[0].Rows[0][0]))
pgConn.Close(ctx)
assert.NoError(t, <-serverErrChan)
}

View File

@ -1,60 +0,0 @@
#!/usr/bin/env bash
current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi
restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT
# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi
# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi
commits=("$@")
benchmarks_dir=benchmarks
if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi
# Benchmark results
bench_files=()
# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}
# Sanitized commmit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi
bench_files+=("$bench_file")
done
# go install golang.org/x/perf/cmd/benchstat[@latest]
benchstat "${bench_files[@]}"

View File

@ -3,209 +3,97 @@ package sanitize
import (
"bytes"
"encoding/hex"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/pkg/errors"
)
// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part any
type Part interface{}
type Query struct {
Parts []Part
}
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
// character. utf8.RuneError is not an error if it is also width 3.
//
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3
const maxBufSize = 16384 // 16 Ki
var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}
var null = []byte("null")
func (q *Query) Sanitize(args ...any) (string, error) {
func (q *Query) Sanitize(args ...interface{}) (string, error) {
argUse := make([]bool, len(args))
buf := bufPool.get()
defer bufPool.put(buf)
buf := &bytes.Buffer{}
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
buf.WriteString(part)
str = part
case int:
argIdx := part - 1
var p []byte
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}
if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
return "", errors.Errorf("insufficient arguments")
}
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
p = null
str = "null"
case int64:
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
str = strconv.FormatInt(arg, 10)
case float64:
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
str = strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
str = strconv.FormatBool(arg)
case []byte:
p = QuoteBytes(buf.AvailableBuffer(), arg)
str = QuoteBytes(arg)
case string:
p = QuoteString(buf.AvailableBuffer(), arg)
str = QuoteString(arg)
case time.Time:
p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
return "", errors.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
buf.Write(p)
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
default:
return "", fmt.Errorf("invalid Part type: %T", part)
return "", errors.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
if !used {
return "", fmt.Errorf("unused argument: %d", i)
return "", errors.Errorf("unused argument: %d", i)
}
}
return buf.String(), nil
}
func NewQuery(sql string) (*Query, error) {
query := &Query{}
query.init(sql)
return query, nil
}
var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}
func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
l := &sqlLexer{
src: sql,
stateFn: rawState,
}
l := sqlLexerPool.get()
defer sqlLexerPool.put(l)
l.src = sql
l.stateFn = rawState
l.parts = parts
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
q.Parts = l.parts
query := &Query{Parts: l.parts}
return query, nil
}
func QuoteString(dst []byte, str string) []byte {
const quote = '\''
// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)
// Add opening quote
dst = append(dst, quote)
// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}
// Add closing quote
dst = append(dst, quote)
return dst
func QuoteString(str string) string {
return "'" + strings.Replace(str, "'", "''", -1) + "'"
}
func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}
// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}
// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]
// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'
// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)
// Add suffix
dst[len(dst)-1] = '\''
return dst
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []Part
}
@ -237,26 +125,12 @@ func rawState(l *sqlLexer) stateFn {
l.start = l.pos
return placeholderState
}
case '-':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '-' {
l.pos += width
return oneLineCommentState
}
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
return multilineCommentState
}
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
@ -274,13 +148,11 @@ func singleQuoteState(l *sqlLexer) stateFn {
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
@ -298,13 +170,11 @@ func doubleQuoteState(l *sqlLexer) stateFn {
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
@ -346,115 +216,22 @@ func escapeStringState(l *sqlLexer) stateFn {
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func oneLineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\n', '\r':
return rawState
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
func multilineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
l.nested++
}
case '*':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '/' {
continue
}
l.pos += width
if l.nested == 0 {
return rawState
}
l.nested--
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...any) (string, error) {
query := queryPool.get()
query.init(sql)
defer queryPool.put(query)
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
return query.Sanitize(args...)
}
type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}
func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}
return v
}
func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}

View File

@ -1,62 +0,0 @@
// sanitize_benchmark_test.go
package sanitize_test
import (
"testing"
"time"
"github.com/jackc/pgx/v5/internal/sanitize"
)
var benchmarkSanitizeResult string
const benchmarkQuery = "" +
`SELECT *
FROM "water_containers"
WHERE NOT "id" = $1 -- int64
AND "tags" NOT IN $2 -- nil
AND "volume" > $3 -- float64
AND "transportable" = $4 -- bool
AND position($5 IN "sign") -- bytes
AND "label" LIKE $6 -- string
AND "created_at" > $7; -- time.Time`
var benchmarkArgs = []any{
int64(12345),
nil,
float64(500),
true,
[]byte("8BADF00D"),
"kombucha's han'dy awokowa",
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
}
func BenchmarkSanitize(b *testing.B) {
query, err := sanitize.NewQuery(benchmarkQuery)
if err != nil {
b.Fatalf("failed to create query: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize query: %v", err)
}
}
}
var benchmarkNewSQLResult string
func BenchmarkSanitizeSQL(b *testing.B) {
b.ReportAllocs()
var err error
for i := 0; i < b.N; i++ {
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize SQL: %v", err)
}
}
}

View File

@ -1,55 +0,0 @@
package sanitize_test
import (
"strings"
"testing"
"github.com/jackc/pgx/v5/internal/sanitize"
)
func FuzzQuoteString(f *testing.F) {
const prefix = "prefix"
f.Add("new\nline")
f.Add("sample text")
f.Add("sample q'u'o't'e's")
f.Add("select 'quoted $42', $1")
f.Fuzz(func(t *testing.T, input string) {
got := string(sanitize.QuoteString([]byte(prefix), input))
want := oldQuoteString(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}
func FuzzQuoteBytes(f *testing.F) {
const prefix = "prefix"
f.Add([]byte(nil))
f.Add([]byte("\n"))
f.Add([]byte("sample text"))
f.Add([]byte("sample q'u'o't'e's"))
f.Add([]byte("select 'quoted $42', $1"))
f.Fuzz(func(t *testing.T, input []byte) {
got := string(sanitize.QuoteBytes([]byte(prefix), input))
want := oldQuoteBytes(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}

View File

@ -1,12 +1,9 @@
package sanitize_test
import (
"encoding/hex"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5/internal/sanitize"
"github.com/jackc/pgx/internal/sanitize"
)
func TestNewQuery(t *testing.T) {
@ -62,44 +59,6 @@ func TestNewQuery(t *testing.T) {
sql: `select e'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
},
{
sql: `select /* a baby's toy */ 'barbie', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}},
},
{
sql: `select /* *_* */ $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}},
},
{
sql: `select 42 /* /* /* 42 */ */ */, $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}},
},
{
sql: "select -- a baby's toy\n'barbie', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}},
},
{
sql: "select 42 -- is a Deep Thought's favorite number",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Deep Thought's favorite number"}},
},
{
sql: "select 42, -- \\nis a Deep Thought's favorite number\n$1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\n", 1}},
},
{
sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}},
},
{
// https://github.com/jackc/pgx/issues/1380
sql: "select 'hello w<>rld'",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello w<>rld'"}},
},
{
// Unterminated quoted string
sql: "select 'hello world",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}},
},
}
for i, tt := range successTests {
@ -123,68 +82,53 @@ func TestNewQuery(t *testing.T) {
func TestQuerySanitize(t *testing.T) {
successfulTests := []struct {
query sanitize.Query
args []any
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
args: []any{},
args: []interface{}{},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{int64(42)},
expected: `select 42 `,
args: []interface{}{int64(42)},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{float64(1.23)},
expected: `select 1.23 `,
args: []interface{}{float64(1.23)},
expected: `select 1.23`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{true},
expected: `select true `,
args: []interface{}{true},
expected: `select true`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{[]byte{0, 1, 2, 3, 255}},
expected: `select '\x00010203ff' `,
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
expected: `select '\x00010203ff'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{nil},
expected: `select null `,
args: []interface{}{nil},
expected: `select null`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{"foobar"},
expected: `select 'foobar' `,
args: []interface{}{"foobar"},
expected: `select 'foobar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{"foo'bar"},
expected: `select 'foo''bar' `,
args: []interface{}{"foo'bar"},
expected: `select 'foo''bar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{`foo\'bar`},
expected: `select 'foo\''bar' `,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
expected: `insert '2020-03-01 23:59:59.999999Z' `,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
args: []any{int64(-1)},
expected: `select 1- -1 `,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
args: []any{float64(-1)},
expected: `select 1- -1 `,
args: []interface{}{`foo\'bar`},
expected: `select 'foo\''bar'`,
},
}
@ -202,22 +146,22 @@ func TestQuerySanitize(t *testing.T) {
errorTests := []struct {
query sanitize.Query
args []any
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
args: []any{int64(42)},
args: []interface{}{int64(42)},
expected: `insufficient arguments`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
args: []any{int64(42)},
args: []interface{}{int64(42)},
expected: `unused argument: 0`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{42},
args: []interface{}{42},
expected: `invalid arg type: int`,
},
}
@ -229,55 +173,3 @@ func TestQuerySanitize(t *testing.T) {
}
}
}
func TestQuoteString(t *testing.T) {
tc := func(name, input string) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteString(nil, input))
want := oldQuoteString(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("empty", "")
tc("text", "abcd")
tc("with quotes", `one's hat is always a cat`)
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}
func TestQuoteBytes(t *testing.T) {
tc := func(name string, input []byte) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteBytes(nil, input))
want := oldQuoteBytes(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("nil", nil)
tc("empty", []byte{})
tc("text", []byte("abcd"))
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}

View File

@ -1,111 +0,0 @@
package stmtcache
import (
"container/list"
"github.com/jackc/pgx/v5/pgconn"
)
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
type LRUCache struct {
cap int
m map[string]*list.Element
l *list.List
invalidStmts []*pgconn.StatementDescription
}
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
func NewLRUCache(cap int) *LRUCache {
return &LRUCache{
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
if el, ok := c.m[key]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription)
}
return nil
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
}
if _, present := c.m[sd.SQL]; present {
return
}
// The statement may have been invalidated but not yet handled. Do not readd it to the cache.
for _, invalidSD := range c.invalidStmts {
if invalidSD.SQL == sd.SQL {
return
}
}
if c.l.Len() == c.cap {
c.invalidateOldest()
}
el := c.l.PushFront(sd)
c.m[sd.SQL] = el
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *LRUCache) Invalidate(sql string) {
if el, ok := c.m[sql]; ok {
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
c.l.Remove(el)
}
}
// InvalidateAll invalidates all statement descriptions.
func (c *LRUCache) InvalidateAll() {
el := c.l.Front()
for el != nil {
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
el = el.Next()
}
c.m = make(map[string]*list.Element)
c.l = list.New()
}
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
c.invalidStmts = nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRUCache) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRUCache) Cap() int {
return c.cap
}
func (c *LRUCache) invalidateOldest() {
oldest := c.l.Back()
sd := oldest.Value.(*pgconn.StatementDescription)
c.invalidStmts = append(c.invalidStmts, sd)
delete(c.m, sd.SQL)
c.l.Remove(oldest)
}

View File

@ -1,45 +0,0 @@
// Package stmtcache is a cache for statement descriptions.
package stmtcache
import (
"crypto/sha256"
"encoding/hex"
"github.com/jackc/pgx/v5/pgconn"
)
// StatementName returns a statement name that will be stable for sql across multiple connections and program
// executions.
func StatementName(sql string) string {
digest := sha256.Sum256([]byte(sql))
return "stmtcache_" + hex.EncodeToString(digest[0:24])
}
// Cache caches statement descriptions.
type Cache interface {
// Get returns the statement description for sql. Returns nil if not found.
Get(sql string) *pgconn.StatementDescription
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
Put(sd *pgconn.StatementDescription)
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
Invalidate(sql string)
// InvalidateAll invalidates all statement descriptions.
InvalidateAll()
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
GetInvalidated() []*pgconn.StatementDescription
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
RemoveInvalidated()
// Len returns the number of cached prepared statement descriptions.
Len() int
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
}

View File

@ -1,77 +0,0 @@
package stmtcache
import (
"math"
"github.com/jackc/pgx/v5/pgconn"
)
// UnlimitedCache implements Cache with no capacity limit.
type UnlimitedCache struct {
m map[string]*pgconn.StatementDescription
invalidStmts []*pgconn.StatementDescription
}
// NewUnlimitedCache creates a new UnlimitedCache.
func NewUnlimitedCache() *UnlimitedCache {
return &UnlimitedCache{
m: make(map[string]*pgconn.StatementDescription),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
return c.m[sql]
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
}
if _, present := c.m[sd.SQL]; present {
return
}
c.m[sd.SQL] = sd
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *UnlimitedCache) Invalidate(sql string) {
if sd, ok := c.m[sql]; ok {
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, sd)
}
}
// InvalidateAll invalidates all statement descriptions.
func (c *UnlimitedCache) InvalidateAll() {
for _, sd := range c.m {
c.invalidStmts = append(c.invalidStmts, sd)
}
c.m = make(map[string]*pgconn.StatementDescription)
}
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *UnlimitedCache) RemoveInvalidated() {
c.invalidStmts = nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *UnlimitedCache) Len() int {
return len(c.m)
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *UnlimitedCache) Cap() int {
return math.MaxInt
}

View File

@ -1,24 +1,56 @@
package pgx
import (
"context"
"errors"
"io"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/pgtype"
)
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
// was created.
// LargeObjects is a structure used to access the large objects API. It is only
// valid within the transaction where it was created.
//
// For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html
type LargeObjects struct {
tx Tx
// Has64 is true if the server is capable of working with 64-bit numbers
Has64 bool
fp *fastpath
}
const largeObjectFns = `select proname, oid from pg_catalog.pg_proc
where proname in (
'lo_open',
'lo_close',
'lo_create',
'lo_unlink',
'lo_lseek',
'lo_lseek64',
'lo_tell',
'lo_tell64',
'lo_truncate',
'lo_truncate64',
'loread',
'lowrite')
and pronamespace = (select oid from pg_catalog.pg_namespace where nspname = 'pg_catalog')`
// LargeObjects returns a LargeObjects instance for the transaction.
func (tx *Tx) LargeObjects() (*LargeObjects, error) {
if tx.conn.fp == nil {
tx.conn.fp = newFastpath(tx.conn)
}
if _, exists := tx.conn.fp.fns["lo_open"]; !exists {
res, err := tx.Query(largeObjectFns)
if err != nil {
return nil, err
}
if err := tx.conn.fp.addFunctions(res); err != nil {
return nil, err
}
}
lo := &LargeObjects{fp: tx.conn.fp}
_, lo.Has64 = lo.fp.fns["lo_lseek64"]
return lo, nil
}
type LargeObjectMode int32
@ -28,134 +60,90 @@ const (
LargeObjectModeRead LargeObjectMode = 0x40000
)
// Create creates a new large object. If oid is zero, the server assigns an unused OID.
func (o *LargeObjects) Create(ctx context.Context, oid uint32) (uint32, error) {
err := o.tx.QueryRow(ctx, "select lo_create($1)", oid).Scan(&oid)
return oid, err
// Create creates a new large object. If id is zero, the server assigns an
// unused OID.
func (o *LargeObjects) Create(id pgtype.OID) (pgtype.OID, error) {
newOID, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))}))
return pgtype.OID(newOID), err
}
// Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large
// object.
func (o *LargeObjects) Open(ctx context.Context, oid uint32, mode LargeObjectMode) (*LargeObject, error) {
var fd int32
err := o.tx.QueryRow(ctx, "select lo_open($1, $2)", oid, mode).Scan(&fd)
if err != nil {
return nil, err
}
return &LargeObject{fd: fd, tx: o.tx, ctx: ctx}, nil
// Open opens an existing large object with the given mode.
func (o *LargeObjects) Open(oid pgtype.OID, mode LargeObjectMode) (*LargeObject, error) {
fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))}))
return &LargeObject{fd: fd, lo: o}, err
}
// Unlink removes a large object from the database.
func (o *LargeObjects) Unlink(ctx context.Context, oid uint32) error {
var result int32
err := o.tx.QueryRow(ctx, "select lo_unlink($1)", oid).Scan(&result)
if err != nil {
return err
}
if result != 1 {
return errors.New("failed to remove large object")
}
return nil
func (o *LargeObjects) Unlink(oid pgtype.OID) error {
_, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))})
return err
}
// A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized
// in. It uses the context it was initialized with for all operations. It implements these interfaces:
// A LargeObject is a large object stored on the server. It is only valid within
// the transaction that it was initialized in. It implements these interfaces:
//
// io.Writer
// io.Reader
// io.Seeker
// io.Closer
// io.Writer
// io.Reader
// io.Seeker
// io.Closer
type LargeObject struct {
ctx context.Context
tx Tx
fd int32
fd int32
lo *LargeObjects
}
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
// Write writes p to the large object and returns the number of bytes written
// and an error if not all of p was written.
func (o *LargeObject) Write(p []byte) (int, error) {
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
if err != nil {
return nTotal, err
}
if n < 0 {
return nTotal, errors.New("failed to write to large object")
}
nTotal += n
if n < expected {
return nTotal, errors.New("short write to large object")
} else if n > expected {
return nTotal, errors.New("invalid write to large object")
}
}
return nTotal, nil
n, err := fpInt32(o.lo.fp.CallFn("lowrite", []fpArg{fpIntArg(o.fd), p}))
return int(n), err
}
// Read reads up to len(p) bytes into p returning the number of bytes read.
func (o *LargeObject) Read(p []byte) (int, error) {
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
res := pgtype.PreallocBytes(p[nTotal:])
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
// We compute expected so that it always fits into p, so it should never happen
// that PreallocBytes's ScanBytes had to allocate a new slice.
nTotal += len(res)
if err != nil {
return nTotal, err
}
if len(res) < expected {
return nTotal, io.EOF
} else if len(res) > expected {
return nTotal, errors.New("invalid read of large object")
}
res, err := o.lo.fp.CallFn("loread", []fpArg{fpIntArg(o.fd), fpIntArg(int32(len(p)))})
if len(res) < len(p) {
err = io.EOF
}
return nTotal, nil
return copy(p, res), err
}
// Seek moves the current location pointer to the new location specified by offset.
func (o *LargeObject) Seek(offset int64, whence int) (n int64, err error) {
err = o.tx.QueryRow(o.ctx, "select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n)
return n, err
if o.lo.Has64 {
n, err = fpInt64(o.lo.fp.CallFn("lo_lseek64", []fpArg{fpIntArg(o.fd), fpInt64Arg(offset), fpIntArg(int32(whence))}))
} else {
var n32 int32
n32, err = fpInt32(o.lo.fp.CallFn("lo_lseek", []fpArg{fpIntArg(o.fd), fpIntArg(int32(offset)), fpIntArg(int32(whence))}))
n = int64(n32)
}
return
}
// Tell returns the current read or write location of the large object descriptor.
// Tell returns the current read or write location of the large object
// descriptor.
func (o *LargeObject) Tell() (n int64, err error) {
err = o.tx.QueryRow(o.ctx, "select lo_tell64($1)", o.fd).Scan(&n)
return n, err
if o.lo.Has64 {
n, err = fpInt64(o.lo.fp.CallFn("lo_tell64", []fpArg{fpIntArg(o.fd)}))
} else {
var n32 int32
n32, err = fpInt32(o.lo.fp.CallFn("lo_tell", []fpArg{fpIntArg(o.fd)}))
n = int64(n32)
}
return
}
// Truncate the large object to size.
// Trunctes the large object to size.
func (o *LargeObject) Truncate(size int64) (err error) {
_, err = o.tx.Exec(o.ctx, "select lo_truncate64($1, $2)", o.fd, size)
return err
if o.lo.Has64 {
_, err = o.lo.fp.CallFn("lo_truncate64", []fpArg{fpIntArg(o.fd), fpInt64Arg(size)})
} else {
_, err = o.lo.fp.CallFn("lo_truncate", []fpArg{fpIntArg(o.fd), fpIntArg(int32(size))})
}
return
}
// Close the large object descriptor.
// Close closees the large object descriptor.
func (o *LargeObject) Close() error {
_, err := o.tx.Exec(o.ctx, "select lo_close($1)", o.fd)
_, err := o.lo.fp.CallFn("lo_close", []fpArg{fpIntArg(o.fd)})
return err
}

View File

@ -1,20 +0,0 @@
package pgx
import (
"testing"
)
// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable
// to the given length for the duration of the test.
//
// Tests using this helper should not use t.Parallel().
func SetMaxLargeObjectMessageLength(t *testing.T, length int) {
t.Helper()
original := maxLargeObjectMessageLength
t.Cleanup(func() {
maxLargeObjectMessageLength = original
})
maxLargeObjectMessageLength = length
}

View File

@ -1,77 +1,36 @@
package pgx_test
import (
"context"
"io"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/jackc/pgx"
)
func TestLargeObjects(t *testing.T) {
// We use a very short limit to test chunking logic.
pgx.SetMaxLargeObjectMessageLength(t, 2)
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
conn, err := pgx.Connect(*defaultConnConfig)
if err != nil {
t.Fatal(err)
}
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
tx, err := conn.Begin(ctx)
tx, err := conn.Begin()
if err != nil {
t.Fatal(err)
}
testLargeObjects(t, ctx, tx)
}
func TestLargeObjectsSimpleProtocol(t *testing.T) {
// We use a very short limit to test chunking logic.
pgx.SetMaxLargeObjectMessageLength(t, 2)
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
lo, err := tx.LargeObjects()
if err != nil {
t.Fatal(err)
}
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
conn, err := pgx.ConnectConfig(ctx, config)
id, err := lo.Create(0)
if err != nil {
t.Fatal(err)
}
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
tx, err := conn.Begin(ctx)
if err != nil {
t.Fatal(err)
}
testLargeObjects(t, ctx, tx)
}
func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) {
lo := tx.LargeObjects()
id, err := lo.Create(ctx, 0)
if err != nil {
t.Fatal(err)
}
obj, err := lo.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite)
obj, err := lo.Open(id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite)
if err != nil {
t.Fatal(err)
}
@ -150,44 +109,41 @@ func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) {
t.Fatal(err)
}
err = lo.Unlink(ctx, id)
err = lo.Unlink(id)
if err != nil {
t.Fatal(err)
}
_, err = lo.Open(ctx, id, pgx.LargeObjectModeRead)
if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" {
_, err = lo.Open(id, pgx.LargeObjectModeRead)
if e, ok := err.(pgx.PgError); !ok || e.Code != "42704" {
t.Errorf("Expected undefined_object error (42704), got %#v", err)
}
}
func TestLargeObjectsMultipleTransactions(t *testing.T) {
// We use a very short limit to test chunking logic.
pgx.SetMaxLargeObjectMessageLength(t, 2)
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
conn, err := pgx.Connect(*defaultConnConfig)
if err != nil {
t.Fatal(err)
}
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
tx, err := conn.Begin(ctx)
tx, err := conn.Begin()
if err != nil {
t.Fatal(err)
}
lo := tx.LargeObjects()
id, err := lo.Create(ctx, 0)
lo, err := tx.LargeObjects()
if err != nil {
t.Fatal(err)
}
obj, err := lo.Open(ctx, id, pgx.LargeObjectModeWrite)
id, err := lo.Create(0)
if err != nil {
t.Fatal(err)
}
obj, err := lo.Open(id, pgx.LargeObjectModeWrite)
if err != nil {
t.Fatal(err)
}
@ -201,29 +157,32 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) {
}
// Commit the first transaction
err = tx.Commit(ctx)
err = tx.Commit()
if err != nil {
t.Fatal(err)
}
// IMPORTANT: Use the same connection for another query
query := `select n from generate_series(1,10) n`
rows, err := conn.Query(ctx, query)
rows, err := conn.Query(query)
if err != nil {
t.Fatal(err)
}
rows.Close()
// Start a new transaction
tx2, err := conn.Begin(ctx)
tx2, err := conn.Begin()
if err != nil {
t.Fatal(err)
}
lo2 := tx2.LargeObjects()
lo2, err := tx2.LargeObjects()
if err != nil {
t.Fatal(err)
}
// Reopen the large object in the new transaction
obj2, err := lo2.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite)
obj2, err := lo2.Open(id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite)
if err != nil {
t.Fatal(err)
}
@ -294,13 +253,13 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) {
t.Fatal(err)
}
err = lo2.Unlink(ctx, id)
err = lo2.Unlink(id)
if err != nil {
t.Fatal(err)
}
_, err = lo2.Open(ctx, id, pgx.LargeObjectModeRead)
if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" {
_, err = lo2.Open(id, pgx.LargeObjectModeRead)
if e, ok := err.(pgx.PgError); !ok || e.Code != "42704" {
t.Errorf("Expected undefined_object error (42704), got %#v", err)
}
}

View File

@ -0,0 +1,47 @@
// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger
// log.
package log15adapter
import (
"github.com/jackc/pgx"
)
// Log15Logger interface defines the subset of
// github.com/inconshreveable/log15.Logger that this adapter uses.
type Log15Logger interface {
Debug(msg string, ctx ...interface{})
Info(msg string, ctx ...interface{})
Warn(msg string, ctx ...interface{})
Error(msg string, ctx ...interface{})
Crit(msg string, ctx ...interface{})
}
type Logger struct {
l Log15Logger
}
func NewLogger(l Log15Logger) *Logger {
return &Logger{l: l}
}
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
logArgs := make([]interface{}, 0, len(data))
for k, v := range data {
logArgs = append(logArgs, k, v)
}
switch level {
case pgx.LogLevelTrace:
l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...)
case pgx.LogLevelDebug:
l.l.Debug(msg, logArgs...)
case pgx.LogLevelInfo:
l.l.Info(msg, logArgs...)
case pgx.LogLevelWarn:
l.l.Warn(msg, logArgs...)
case pgx.LogLevelError:
l.l.Error(msg, logArgs...)
default:
l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...)
}
}

View File

@ -0,0 +1,40 @@
// Package logrusadapter provides a logger that writes to a github.com/sirupsen/logrus.Logger
// log.
package logrusadapter
import (
"github.com/jackc/pgx"
"github.com/sirupsen/logrus"
)
type Logger struct {
l logrus.FieldLogger
}
func NewLogger(l logrus.FieldLogger) *Logger {
return &Logger{l: l}
}
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
var logger logrus.FieldLogger
if data != nil {
logger = l.l.WithFields(data)
} else {
logger = l.l
}
switch level {
case pgx.LogLevelTrace:
logger.WithField("PGX_LOG_LEVEL", level).Debug(msg)
case pgx.LogLevelDebug:
logger.Debug(msg)
case pgx.LogLevelInfo:
logger.Info(msg)
case pgx.LogLevelWarn:
logger.Warn(msg)
case pgx.LogLevelError:
logger.Error(msg)
default:
logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg)
}
}

View File

@ -3,16 +3,15 @@
package testingadapter
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/tracelog"
"github.com/jackc/pgx"
)
// TestingLogger interface defines the subset of testing.TB methods used by this
// adapter.
type TestingLogger interface {
Log(args ...any)
Log(args ...interface{})
}
type Logger struct {
@ -23,8 +22,8 @@ func NewLogger(l TestingLogger) *Logger {
return &Logger{l: l}
}
func (l *Logger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) {
logArgs := make([]any, 0, 2+len(data))
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
logArgs := make([]interface{}, 0, 2+len(data))
logArgs = append(logArgs, level, msg)
for k, v := range data {
logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v))

40
log/zapadapter/adapter.go Normal file
View File

@ -0,0 +1,40 @@
// Package zapadapter provides a logger that writes to a go.uber.org/zap.Logger.
package zapadapter
import (
"github.com/jackc/pgx"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type Logger struct {
logger *zap.Logger
}
func NewLogger(logger *zap.Logger) *Logger {
return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))}
}
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
fields := make([]zapcore.Field, len(data))
i := 0
for k, v := range data {
fields[i] = zap.Reflect(k, v)
i++
}
switch level {
case pgx.LogLevelTrace:
pl.logger.Debug(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...)
case pgx.LogLevelDebug:
pl.logger.Debug(msg, fields...)
case pgx.LogLevelInfo:
pl.logger.Info(msg, fields...)
case pgx.LogLevelWarn:
pl.logger.Warn(msg, fields...)
case pgx.LogLevelError:
pl.logger.Error(msg, fields...)
default:
pl.logger.Error(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...)
}
}

View File

@ -0,0 +1,40 @@
// Package zerologadapter provides a logger that writes to a github.com/rs/zerolog.
package zerologadapter
import (
"github.com/jackc/pgx"
"github.com/rs/zerolog"
)
type Logger struct {
logger zerolog.Logger
}
// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx
// logging fascade as output.
func NewLogger(logger zerolog.Logger) *Logger {
return &Logger{
logger: logger.With().Str("module", "pgx").Logger(),
}
}
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
var zlevel zerolog.Level
switch level {
case pgx.LogLevelNone:
zlevel = zerolog.NoLevel
case pgx.LogLevelError:
zlevel = zerolog.ErrorLevel
case pgx.LogLevelWarn:
zlevel = zerolog.WarnLevel
case pgx.LogLevelInfo:
zlevel = zerolog.InfoLevel
case pgx.LogLevelDebug:
zlevel = zerolog.DebugLevel
default:
zlevel = zerolog.DebugLevel
}
pgxlog := pl.logger.With().Fields(data).Logger()
pgxlog.WithLevel(zlevel).Msg(msg)
}

98
logger.go Normal file
View File

@ -0,0 +1,98 @@
package pgx
import (
"encoding/hex"
"fmt"
"github.com/pkg/errors"
)
// The values for log levels are chosen such that the zero value means that no
// log level was specified.
const (
LogLevelTrace = 6
LogLevelDebug = 5
LogLevelInfo = 4
LogLevelWarn = 3
LogLevelError = 2
LogLevelNone = 1
)
// LogLevel represents the pgx logging level. See LogLevel* constants for
// possible values.
type LogLevel int
func (ll LogLevel) String() string {
switch ll {
case LogLevelTrace:
return "trace"
case LogLevelDebug:
return "debug"
case LogLevelInfo:
return "info"
case LogLevelWarn:
return "warn"
case LogLevelError:
return "error"
case LogLevelNone:
return "none"
default:
return fmt.Sprintf("invalid level %d", ll)
}
}
// Logger is the interface used to get logging from pgx internals.
type Logger interface {
// Log a message at the given level with data key/value pairs. data may be nil.
Log(level LogLevel, msg string, data map[string]interface{})
}
// LogLevelFromString converts log level string to constant
//
// Valid levels:
// trace
// debug
// info
// warn
// error
// none
func LogLevelFromString(s string) (LogLevel, error) {
switch s {
case "trace":
return LogLevelTrace, nil
case "debug":
return LogLevelDebug, nil
case "info":
return LogLevelInfo, nil
case "warn":
return LogLevelWarn, nil
case "error":
return LogLevelError, nil
case "none":
return LogLevelNone, nil
default:
return 0, errors.New("invalid log level")
}
}
func logQueryArgs(args []interface{}) []interface{} {
logArgs := make([]interface{}, 0, len(args))
for _, a := range args {
switch v := a.(type) {
case []byte:
if len(v) < 64 {
a = hex.EncodeToString(v)
} else {
a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64)
}
case string:
if len(v) > 64 {
a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64)
}
}
logArgs = append(logArgs, a)
}
return logArgs
}

240
messages.go Normal file
View File

@ -0,0 +1,240 @@
package pgx
import (
"database/sql/driver"
"math"
"reflect"
"time"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgtype"
)
const (
copyData = 'd'
copyFail = 'f'
copyDone = 'c'
varHeaderSize = 4
)
type FieldDescription struct {
Name string
Table pgtype.OID
AttributeNumber uint16
DataType pgtype.OID
DataTypeSize int16
DataTypeName string
Modifier int32
FormatCode int16
}
func (fd FieldDescription) Length() (int64, bool) {
switch fd.DataType {
case pgtype.TextOID, pgtype.ByteaOID:
return math.MaxInt64, true
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
return int64(fd.Modifier - varHeaderSize), true
default:
return 0, false
}
}
func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) {
switch fd.DataType {
case pgtype.NumericOID:
mod := fd.Modifier - varHeaderSize
precision = int64((mod >> 16) & 0xffff)
scale = int64(mod & 0xffff)
return precision, scale, true
default:
return 0, 0, false
}
}
func (fd FieldDescription) Type() reflect.Type {
switch fd.DataType {
case pgtype.Float8OID:
return reflect.TypeOf(float64(0))
case pgtype.Float4OID:
return reflect.TypeOf(float32(0))
case pgtype.Int8OID:
return reflect.TypeOf(int64(0))
case pgtype.Int4OID:
return reflect.TypeOf(int32(0))
case pgtype.Int2OID:
return reflect.TypeOf(int16(0))
case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID:
return reflect.TypeOf("")
case pgtype.BoolOID:
return reflect.TypeOf(false)
case pgtype.NumericOID:
return reflect.TypeOf(float64(0))
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
return reflect.TypeOf(time.Time{})
case pgtype.ByteaOID:
return reflect.TypeOf([]byte(nil))
default:
return reflect.TypeOf(new(interface{})).Elem()
}
}
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// Notice represents a notice response message reported by the PostgreSQL
// server. Be aware that this is distinct from LISTEN/NOTIFY notification.
type Notice PgError
// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it.
func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.OID) []byte {
buf = append(buf, 'P')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, name...)
buf = append(buf, 0)
buf = append(buf, query...)
buf = append(buf, 0)
buf = pgio.AppendInt16(buf, int16(len(parameterOIDs)))
for _, oid := range parameterOIDs {
buf = pgio.AppendUint32(buf, uint32(oid))
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it.
func appendDescribe(buf []byte, objectType byte, name string) []byte {
buf = append(buf, 'D')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, objectType)
buf = append(buf, name...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it.
func appendSync(buf []byte) []byte {
buf = append(buf, 'S')
buf = pgio.AppendInt32(buf, 4)
return buf
}
// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it.
func appendBind(
buf []byte,
destinationPortal,
preparedStatement string,
connInfo *pgtype.ConnInfo,
parameterOIDs []pgtype.OID,
arguments []interface{},
resultFormatCodes []int16,
) ([]byte, error) {
buf = append(buf, 'B')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, destinationPortal...)
buf = append(buf, 0)
buf = append(buf, preparedStatement...)
buf = append(buf, 0)
var err error
arguments, err = convertDriverValuers(arguments)
if err != nil {
return nil, err
}
buf = pgio.AppendInt16(buf, int16(len(parameterOIDs)))
for i, oid := range parameterOIDs {
buf = pgio.AppendInt16(buf, chooseParameterFormatCode(connInfo, oid, arguments[i]))
}
buf = pgio.AppendInt16(buf, int16(len(arguments)))
for i, oid := range parameterOIDs {
var err error
buf, err = encodePreparedStatementArgument(connInfo, buf, oid, arguments[i])
if err != nil {
return nil, err
}
}
buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes)))
for _, fc := range resultFormatCodes {
buf = pgio.AppendInt16(buf, fc)
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf, nil
}
func convertDriverValuers(args []interface{}) ([]interface{}, error) {
for i, arg := range args {
switch arg := arg.(type) {
case pgtype.BinaryEncoder:
case pgtype.TextEncoder:
case driver.Valuer:
v, err := callValuerValue(arg)
if err != nil {
return nil, err
}
args[i] = v
}
}
return args, nil
}
// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it.
func appendExecute(buf []byte, portal string, maxRows uint32) []byte {
buf = append(buf, 'E')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, portal...)
buf = append(buf, 0)
buf = pgio.AppendUint32(buf, maxRows)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it.
func appendQuery(buf []byte, query string) []byte {
buf = append(buf, 'Q')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, query...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}

View File

@ -1,152 +0,0 @@
// Package multitracer provides a Tracer that can combine several tracers into one.
package multitracer
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
// Tracer can combine several tracers into one.
// You can use New to automatically split tracers by interface.
type Tracer struct {
QueryTracers []pgx.QueryTracer
BatchTracers []pgx.BatchTracer
CopyFromTracers []pgx.CopyFromTracer
PrepareTracers []pgx.PrepareTracer
ConnectTracers []pgx.ConnectTracer
PoolAcquireTracers []pgxpool.AcquireTracer
PoolReleaseTracers []pgxpool.ReleaseTracer
}
// New returns new Tracer from tracers with automatically split tracers by interface.
func New(tracers ...pgx.QueryTracer) *Tracer {
var t Tracer
for _, tracer := range tracers {
t.QueryTracers = append(t.QueryTracers, tracer)
if batchTracer, ok := tracer.(pgx.BatchTracer); ok {
t.BatchTracers = append(t.BatchTracers, batchTracer)
}
if copyFromTracer, ok := tracer.(pgx.CopyFromTracer); ok {
t.CopyFromTracers = append(t.CopyFromTracers, copyFromTracer)
}
if prepareTracer, ok := tracer.(pgx.PrepareTracer); ok {
t.PrepareTracers = append(t.PrepareTracers, prepareTracer)
}
if connectTracer, ok := tracer.(pgx.ConnectTracer); ok {
t.ConnectTracers = append(t.ConnectTracers, connectTracer)
}
if poolAcquireTracer, ok := tracer.(pgxpool.AcquireTracer); ok {
t.PoolAcquireTracers = append(t.PoolAcquireTracers, poolAcquireTracer)
}
if poolReleaseTracer, ok := tracer.(pgxpool.ReleaseTracer); ok {
t.PoolReleaseTracers = append(t.PoolReleaseTracers, poolReleaseTracer)
}
}
return &t
}
func (t *Tracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
for _, tracer := range t.QueryTracers {
ctx = tracer.TraceQueryStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
for _, tracer := range t.QueryTracers {
tracer.TraceQueryEnd(ctx, conn, data)
}
}
func (t *Tracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
for _, tracer := range t.BatchTracers {
ctx = tracer.TraceBatchStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
for _, tracer := range t.BatchTracers {
tracer.TraceBatchQuery(ctx, conn, data)
}
}
func (t *Tracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
for _, tracer := range t.BatchTracers {
tracer.TraceBatchEnd(ctx, conn, data)
}
}
func (t *Tracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
for _, tracer := range t.CopyFromTracers {
ctx = tracer.TraceCopyFromStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
for _, tracer := range t.CopyFromTracers {
tracer.TraceCopyFromEnd(ctx, conn, data)
}
}
func (t *Tracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
for _, tracer := range t.PrepareTracers {
ctx = tracer.TracePrepareStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
for _, tracer := range t.PrepareTracers {
tracer.TracePrepareEnd(ctx, conn, data)
}
}
func (t *Tracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
for _, tracer := range t.ConnectTracers {
ctx = tracer.TraceConnectStart(ctx, data)
}
return ctx
}
func (t *Tracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
for _, tracer := range t.ConnectTracers {
tracer.TraceConnectEnd(ctx, data)
}
}
func (t *Tracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
for _, tracer := range t.PoolAcquireTracers {
ctx = tracer.TraceAcquireStart(ctx, pool, data)
}
return ctx
}
func (t *Tracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
for _, tracer := range t.PoolAcquireTracers {
tracer.TraceAcquireEnd(ctx, pool, data)
}
}
func (t *Tracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
for _, tracer := range t.PoolReleaseTracers {
tracer.TraceRelease(pool, data)
}
}

View File

@ -1,115 +0,0 @@
package multitracer_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/multitracer"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"
)
type testFullTracer struct{}
func (tt *testFullTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
}
func (tt *testFullTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
}
func (tt *testFullTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
}
func (tt *testFullTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
}
func (tt *testFullTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
}
func (tt *testFullTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
}
func (tt *testFullTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
}
func (tt *testFullTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
}
type testCopyTracer struct{}
func (tt *testCopyTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return ctx
}
func (tt *testCopyTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
}
func (tt *testCopyTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
return ctx
}
func (tt *testCopyTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
}
func TestNew(t *testing.T) {
t.Parallel()
fullTracer := &testFullTracer{}
copyTracer := &testCopyTracer{}
mt := multitracer.New(fullTracer, copyTracer)
require.Equal(
t,
&multitracer.Tracer{
QueryTracers: []pgx.QueryTracer{
fullTracer,
copyTracer,
},
BatchTracers: []pgx.BatchTracer{
fullTracer,
},
CopyFromTracers: []pgx.CopyFromTracer{
fullTracer,
copyTracer,
},
PrepareTracers: []pgx.PrepareTracer{
fullTracer,
},
ConnectTracers: []pgx.ConnectTracer{
fullTracer,
},
PoolAcquireTracers: []pgxpool.AcquireTracer{
fullTracer,
},
PoolReleaseTracers: []pgxpool.ReleaseTracer{
fullTracer,
},
},
mt,
)
}

View File

@ -1,295 +0,0 @@
package pgx
import (
"context"
"fmt"
"strconv"
"strings"
"unicode/utf8"
)
// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
// ordinal placeholder and construct the appropriate arguments.
//
// For example, the following two queries are equivalent:
//
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
//
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
// letters, numbers, or underscores.
type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(na, sql, false)
}
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
// named arguments that the sql query uses, and no extra arguments.
type StrictNamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(sna, sql, true)
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
nameToOrdinal: make(map[namedArg]int, len(na)),
}
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
sb := strings.Builder{}
for _, p := range l.parts {
switch p := p.(type) {
case string:
sb.WriteString(p)
case namedArg:
sb.WriteRune('$')
sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
}
}
newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal {
var found bool
newArgs[ordinal-1], found = na[string(name)]
if isStrict && !found {
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
}
}
if isStrict {
for name := range na {
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
}
}
}
return sb.String(), newArgs, nil
}
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) || nextRune == '_' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return namedArgState
}
case '-':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '-' {
l.pos += width
return oneLineCommentState
}
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
return multilineCommentState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func isLetter(r rune) bool {
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
}
func namedArgState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if r == utf8.RuneError {
if l.pos-l.start > 0 {
na := namedArg(l.src[l.start:l.pos])
if _, found := l.nameToOrdinal[na]; !found {
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
}
l.parts = append(l.parts, na)
l.start = l.pos
}
return nil
} else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') {
l.pos -= width
na := namedArg(l.src[l.start:l.pos])
if _, found := l.nameToOrdinal[na]; !found {
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
}
l.parts = append(l.parts, namedArg(na))
l.start = l.pos
return rawState
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func oneLineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\n', '\r':
return rawState
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func multilineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
l.nested++
}
case '*':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '/' {
continue
}
l.pos += width
if l.nested == 0 {
return rawState
}
l.nested--
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}

View File

@ -1,162 +0,0 @@
package pgx_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNamedArgsRewriteQuery(t *testing.T) {
t.Parallel()
for i, tt := range []struct {
sql string
args []any
namedArgs pgx.NamedArgs
expectedSQL string
expectedArgs []any
}{
{
sql: "select * from users where id = @id",
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: "select * from users where id = $1",
expectedArgs: []any{int32(42)},
},
{
sql: "select * from t where foo < @abc and baz = @def and bar < @abc",
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
expectedArgs: []any{int32(42), int32(1)},
},
{
sql: "select @a::int, @b::text",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "select $1::int, $2::text",
expectedArgs: []any{int32(42), "foo"},
},
{
sql: "select @Abc::int, @b_4::text, @_c::int",
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)},
expectedSQL: "select $1::int, $2::text, $3::int",
expectedArgs: []any{int32(42), "foo", int32(1)},
},
{
sql: "at end @",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "at end @",
expectedArgs: []any{},
},
{
sql: "ignores without valid character after @ foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "ignores without valid character after @ foo bar",
expectedArgs: []any{},
},
{
sql: "name cannot start with number @1 foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "name cannot start with number @1 foo bar",
expectedArgs: []any{},
},
{
sql: `select *, '@foo' as "@bar" from users where id = @id`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`,
expectedArgs: []any{int32(42)},
},
{
sql: `select * -- @foo
from users -- @single line comments
where id = @id;`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select * -- @foo
from users -- @single line comments
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: `select * /* @multi line
@comment
*/
/* /* with @nesting */ */
from users
where id = @id;`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select * /* @multi line
@comment
*/
/* /* with @nesting */ */
from users
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: "extra provided argument",
namedArgs: pgx.NamedArgs{"extra": int32(1)},
expectedSQL: "extra provided argument",
expectedArgs: []any{},
},
{
sql: "@missing argument",
namedArgs: pgx.NamedArgs{},
expectedSQL: "$1 argument",
expectedArgs: []any{nil},
},
// test comments and quotes
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
require.NoError(t, err)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}
func TestStrictNamedArgsRewriteQuery(t *testing.T) {
t.Parallel()
for i, tt := range []struct {
sql string
namedArgs pgx.StrictNamedArgs
expectedSQL string
expectedArgs []any
isExpectedError bool
}{
{
sql: "no arguments",
namedArgs: pgx.StrictNamedArgs{},
expectedSQL: "no arguments",
expectedArgs: []any{},
isExpectedError: false,
},
{
sql: "@all @matches",
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
expectedSQL: "$1 $2",
expectedArgs: []any{int32(1), int32(2)},
isExpectedError: false,
},
{
sql: "extra provided argument",
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
isExpectedError: true,
},
{
sql: "@missing argument",
namedArgs: pgx.StrictNamedArgs{},
isExpectedError: true,
},
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
if tt.isExpectedError {
assert.Errorf(t, err, "%d", i)
} else {
require.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}
}

View File

@ -1,75 +0,0 @@
package pgx_test
import (
"context"
"os"
"testing"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPgbouncerStatementCacheDescribe(t *testing.T) {
connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING")
}
config := mustParseConfig(t, connString)
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
config.DescriptionCacheCapacity = 1024
testPgbouncer(t, config, 10, 100)
}
func TestPgbouncerSimpleProtocol(t *testing.T) {
connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING")
}
config := mustParseConfig(t, connString)
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
testPgbouncer(t, config, 10, 100)
}
func testPgbouncer(t *testing.T, config *pgx.ConnConfig, workers, iterations int) {
doneChan := make(chan struct{})
for i := 0; i < workers; i++ {
go func() {
defer func() { doneChan <- struct{}{} }()
conn, err := pgx.ConnectConfig(context.Background(), config)
require.Nil(t, err)
defer closeConn(t, conn)
for i := 0; i < iterations; i++ {
var i32 int32
var i64 int64
var f32 float32
var s string
var s2 string
err = conn.QueryRow(context.Background(), "select 1::int4, 2::int8, 3::float4, 'hi'::text").Scan(&i32, &i64, &f32, &s)
require.NoError(t, err)
assert.Equal(t, int32(1), i32)
assert.Equal(t, int64(2), i64)
assert.Equal(t, float32(3), f32)
assert.Equal(t, "hi", s)
err = conn.QueryRow(context.Background(), "select 1::int8, 2::float4, 'bye'::text, 4::int4, 'whatever'::text").Scan(&i64, &f32, &s, &i32, &s2)
require.NoError(t, err)
assert.Equal(t, int64(1), i64)
assert.Equal(t, float32(2), f32)
assert.Equal(t, "bye", s)
assert.Equal(t, int32(4), i32)
assert.Equal(t, "whatever", s2)
}
}()
}
for i := 0; i < workers; i++ {
<-doneChan
}
}

View File

@ -1,29 +0,0 @@
# pgconn
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
low-level access to PostgreSQL functionality.
## Example Usage
```go
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatalln("pgconn failed to connect:", err)
}
defer pgConn.Close(context.Background())
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
for result.NextRow() {
fmt.Println("User 123 has email:", string(result.Values()[0]))
}
_, err = result.Close()
if err != nil {
log.Fatalln("failed reading result:", err)
}
```
## Testing
See CONTRIBUTING.md for setup instructions.

View File

@ -1,73 +0,0 @@
package pgconn
import (
"strings"
"testing"
)
func BenchmarkCommandTagRowsAffected(b *testing.B) {
benchmarks := []struct {
commandTag string
rowsAffected int64
}{
{"UPDATE 1", 1},
{"UPDATE 123456789", 123456789},
{"INSERT 0 1", 1},
{"INSERT 0 123456789", 123456789},
}
for _, bm := range benchmarks {
ct := CommandTag{s: bm.commandTag}
b.Run(bm.commandTag, func(b *testing.B) {
var n int64
for i := 0; i < b.N; i++ {
n = ct.RowsAffected()
}
if n != bm.rowsAffected {
b.Errorf("expected %d got %d", bm.rowsAffected, n)
}
})
}
}
func BenchmarkCommandTagTypeFromString(b *testing.B) {
ct := CommandTag{s: "UPDATE 1"}
var update bool
for i := 0; i < b.N; i++ {
update = strings.HasPrefix(ct.String(), "UPDATE")
}
if !update {
b.Error("expected update")
}
}
func BenchmarkCommandTagInsert(b *testing.B) {
benchmarks := []struct {
commandTag string
is bool
}{
{"INSERT 1", true},
{"INSERT 1234567890", true},
{"UPDATE 1", false},
{"UPDATE 1234567890", false},
{"DELETE 1", false},
{"DELETE 1234567890", false},
{"SELECT 1", false},
{"SELECT 1234567890", false},
{"UNKNOWN 1234567890", false},
}
for _, bm := range benchmarks {
ct := CommandTag{s: bm.commandTag}
b.Run(bm.commandTag, func(b *testing.B) {
var is bool
for i := 0; i < b.N; i++ {
is = ct.Insert()
}
if is != bm.is {
b.Errorf("expected %v got %v", bm.is, is)
}
})
}
}

View File

@ -1,250 +0,0 @@
package pgconn_test
import (
"bytes"
"context"
"os"
"testing"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
)
func BenchmarkConnect(b *testing.B) {
benchmarks := []struct {
name string
env string
}{
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
}
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
connString := os.Getenv(bm.env)
if connString == "" {
b.Skipf("Skipping due to missing environment variable %v", bm.env)
}
for i := 0; i < b.N; i++ {
conn, err := pgconn.Connect(context.Background(), connString)
require.Nil(b, err)
err = conn.Close(context.Background())
require.Nil(b, err)
}
})
}
}
func BenchmarkExec(b *testing.B) {
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
benchmarks := []struct {
name string
ctx context.Context
}{
// Using an empty context other than context.Background() to compare
// performance
{"background context", context.Background()},
{"empty context", context.TODO()},
}
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
require.Nil(b, err)
defer closeConn(b, conn)
b.ResetTimer()
for i := 0; i < b.N; i++ {
mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
for mrr.NextResult() {
rr := mrr.ResultReader()
rowCount := 0
for rr.NextRow() {
rowCount++
if len(rr.Values()) != len(expectedValues) {
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
}
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
}
}
}
_, err = rr.Close()
if err != nil {
b.Fatal(err)
}
if rowCount != 1 {
b.Fatalf("unexpected rowCount: %d", rowCount)
}
}
err := mrr.Close()
if err != nil {
b.Fatal(err)
}
}
})
}
}
func BenchmarkExecPossibleToCancel(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(b, err)
defer closeConn(b, conn)
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
b.ResetTimer()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for i := 0; i < b.N; i++ {
mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
for mrr.NextResult() {
rr := mrr.ResultReader()
rowCount := 0
for rr.NextRow() {
rowCount++
if len(rr.Values()) != len(expectedValues) {
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
}
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
}
}
}
_, err = rr.Close()
if err != nil {
b.Fatal(err)
}
if rowCount != 1 {
b.Fatalf("unexpected rowCount: %d", rowCount)
}
}
err := mrr.Close()
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkExecPrepared(b *testing.B) {
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
benchmarks := []struct {
name string
ctx context.Context
}{
// Using an empty context other than context.Background() to compare
// performance
{"background context", context.Background()},
{"empty context", context.TODO()},
}
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
require.Nil(b, err)
defer closeConn(b, conn)
_, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
require.Nil(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil)
rowCount := 0
for rr.NextRow() {
rowCount++
if len(rr.Values()) != len(expectedValues) {
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
}
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
}
}
}
_, err = rr.Close()
if err != nil {
b.Fatal(err)
}
if rowCount != 1 {
b.Fatalf("unexpected rowCount: %d", rowCount)
}
}
})
}
}
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(b, err)
defer closeConn(b, conn)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
require.Nil(b, err)
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil)
rowCount := 0
for rr.NextRow() {
rowCount += 1
if len(rr.Values()) != len(expectedValues) {
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
}
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
}
}
}
_, err = rr.Close()
if err != nil {
b.Fatal(err)
}
if rowCount != 1 {
b.Fatalf("unexpected rowCount: %d", rowCount)
}
}
}
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
// require.Nil(b, err)
// defer closeConn(b, conn)
// ctx, cancel := context.WithCancel(context.Background())
// defer cancel()
// b.ResetTimer()
// for i := 0; i < b.N; i++ {
// conn.ChanToSetDeadline().Watch(ctx)
// conn.ChanToSetDeadline().Ignore()
// }
// }

View File

@ -1,953 +0,0 @@
package pgconn
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"math"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/jackc/pgpassfile"
"github.com/jackc/pgservicefile"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)
type (
AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
GetSSLPasswordFunc func(ctx context.Context) string
)
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
// manually initialized Config will cause ConnectConfig to panic.
type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
ConnectTimeout time.Duration
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
KerberosSrvName string
KerberosSpn string
Fallbacks []*FallbackConfig
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
ValidateConnect ValidateConnectFunc
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
// or prepare statements). If this returns an error the connection attempt fails.
AfterConnect AfterConnectFunc
// OnNotice is a callback function called when a notice response is received.
OnNotice NoticeHandler
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
type ParseConfigOptions struct {
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function
// PQsetSSLKeyPassHook_OpenSSL.
GetSSLPassword GetSSLPasswordFunc
}
// Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the TLSConfig field:
// according to the tls.Config docs it must not be modified after creation.
func (c *Config) Copy() *Config {
newConf := new(Config)
*newConf = *c
if newConf.TLSConfig != nil {
newConf.TLSConfig = c.TLSConfig.Clone()
}
if newConf.RuntimeParams != nil {
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
for k, v := range c.RuntimeParams {
newConf.RuntimeParams[k] = v
}
}
if newConf.Fallbacks != nil {
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
for i, fallback := range c.Fallbacks {
newFallback := new(FallbackConfig)
*newFallback = *fallback
if newFallback.TLSConfig != nil {
newFallback.TLSConfig = fallback.TLSConfig.Clone()
}
newConf.Fallbacks[i] = newFallback
}
}
return newConf
}
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
type FallbackConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16
TLSConfig *tls.Config // nil disables TLS
}
// connectOneConfig is the configuration for a single attempt to connect to a single host.
type connectOneConfig struct {
network string
address string
originalHostname string // original hostname before resolving
tlsConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
func isAbsolutePath(path string) bool {
isWindowsPath := func(p string) bool {
if len(p) < 3 {
return false
}
drive := p[0]
colon := p[1]
backslash := p[2]
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
return true
}
return false
}
return strings.HasPrefix(path, "/") || isWindowsPath(path)
}
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
// net.Dial.
func NetworkAddress(host string, port uint16) (network, address string) {
if isAbsolutePath(host) {
network = "unix"
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
} else {
network = "tcp"
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
}
return network, address
}
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty
// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
//
// # Example Keyword/Value
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
//
// # Example URL
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
//
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
// not be modified individually. They should all be modified or all left unchanged.
//
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
// values that will be tried in order. This can be used as part of a high availability system. See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
//
// # Example URL
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
//
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or keyword/value:
//
// PGHOST
// PGPORT
// PGDATABASE
// PGUSER
// PGPASSWORD
// PGPASSFILE
// PGSERVICE
// PGSERVICEFILE
// PGSSLMODE
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGSSLPASSWORD
// PGOPTIONS
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
// PGTZ
//
// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables.
//
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
// usually but not always the environment variable name downcased and without the "PG" prefix.
//
// Important Security Notes:
//
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
// not set.
//
// See http://www.postgresql.org/docs/current/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
// security each sslmode provides.
//
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
// TLSConfig.
//
// Other known differences with libpq:
//
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
// does not.
//
// In addition, ParseConfig accepts the following options:
//
// - servicefile.
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
func ParseConfig(connString string) (*Config, error) {
var parseConfigOptions ParseConfigOptions
return ParseConfigWithOptions(connString, parseConfigOptions)
}
// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
// get the SSL password.
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
defaultSettings := defaultSettings()
envSettings := parseEnvSettings()
connStringSettings := make(map[string]string)
if connString != "" {
var err error
// connString may be a database URL or in PostgreSQL keyword/value format
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseKeywordValueSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err}
}
}
}
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
}
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
}
config := &Config{
createdByParseConfig: true,
Database: settings["database"],
User: settings["user"],
Password: settings["password"],
RuntimeParams: make(map[string]string),
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
}
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
} else {
defaultDialer := makeDefaultDialer()
config.DialFunc = defaultDialer.DialContext
}
config.LookupFunc = makeDefaultResolver().LookupHost
notRuntimeParams := map[string]struct{}{
"host": {},
"port": {},
"database": {},
"user": {},
"password": {},
"passfile": {},
"connect_timeout": {},
"sslmode": {},
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"sslnegotiation": {},
"sslpassword": {},
"sslsni": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
"service": {},
"servicefile": {},
}
// Adding kerberos configuration
if _, present := settings["krbsrvname"]; present {
config.KerberosSrvName = settings["krbsrvname"]
}
if _, present := settings["krbspn"]; present {
config.KerberosSpn = settings["krbspn"]
}
for k, v := range settings {
if _, present := notRuntimeParams[k]; present {
continue
}
config.RuntimeParams[k] = v
}
fallbacks := []*FallbackConfig{}
hosts := strings.Split(settings["host"], ",")
ports := strings.Split(settings["port"], ",")
for i, host := range hosts {
var portStr string
if i < len(ports) {
portStr = ports[i]
} else {
portStr = ports[0]
}
port, err := parsePort(portStr)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
// Ignore TLS settings if Unix domain socket like libpq
if network, _ := NetworkAddress(host, port); network == "unix" {
tlsConfigs = append(tlsConfigs, nil)
} else {
var err error
tlsConfigs, err = configTLS(settings, host, options)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
}
}
for _, tlsConfig := range tlsConfigs {
fallbacks = append(fallbacks, &FallbackConfig{
Host: host,
Port: port,
TLSConfig: tlsConfig,
})
}
}
config.Host = fallbacks[0].Host
config.Port = fallbacks[0].Port
config.TLSConfig = fallbacks[0].TLSConfig
config.Fallbacks = fallbacks[1:]
config.SSLNegotiation = settings["sslnegotiation"]
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
if err == nil {
if config.Password == "" {
host := config.Host
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
host = "localhost"
}
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
}
}
switch tsa := settings["target_session_attrs"]; tsa {
case "read-write":
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
case "read-only":
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
case "primary":
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
case "standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
case "prefer-standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
case "any":
// do nothing
default:
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}
return config, nil
}
func mergeSettings(settingSets ...map[string]string) map[string]string {
settings := make(map[string]string)
for _, s2 := range settingSets {
for k, v := range s2 {
settings[k] = v
}
}
return settings
}
func parseEnvSettings() map[string]string {
settings := make(map[string]string)
nameMap := map[string]string{
"PGHOST": "host",
"PGPORT": "port",
"PGDATABASE": "database",
"PGUSER": "user",
"PGPASSWORD": "password",
"PGPASSFILE": "passfile",
"PGAPPNAME": "application_name",
"PGCONNECT_TIMEOUT": "connect_timeout",
"PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLSNI": "sslsni",
"PGSSLROOTCERT": "sslrootcert",
"PGSSLPASSWORD": "sslpassword",
"PGSSLNEGOTIATION": "sslnegotiation",
"PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service",
"PGSERVICEFILE": "servicefile",
"PGTZ": "timezone",
"PGOPTIONS": "options",
}
for envname, realname := range nameMap {
value := os.Getenv(envname)
if value != "" {
settings[realname] = value
}
}
return settings
}
func parseURLSettings(connString string) (map[string]string, error) {
settings := make(map[string]string)
parsedURL, err := url.Parse(connString)
if err != nil {
if urlErr := new(url.Error); errors.As(err, &urlErr) {
return nil, urlErr.Err
}
return nil, err
}
if parsedURL.User != nil {
settings["user"] = parsedURL.User.Username()
if password, present := parsedURL.User.Password(); present {
settings["password"] = password
}
}
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
var hosts []string
var ports []string
for _, host := range strings.Split(parsedURL.Host, ",") {
if host == "" {
continue
}
if isIPOnly(host) {
hosts = append(hosts, strings.Trim(host, "[]"))
continue
}
h, p, err := net.SplitHostPort(host)
if err != nil {
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
}
if h != "" {
hosts = append(hosts, h)
}
if p != "" {
ports = append(ports, p)
}
}
if len(hosts) > 0 {
settings["host"] = strings.Join(hosts, ",")
}
if len(ports) > 0 {
settings["port"] = strings.Join(ports, ",")
}
database := strings.TrimLeft(parsedURL.Path, "/")
if database != "" {
settings["database"] = database
}
nameMap := map[string]string{
"dbname": "database",
}
for k, v := range parsedURL.Query() {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v[0]
}
return settings, nil
}
func isIPOnly(host string) bool {
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
}
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseKeywordValueSettings(s string) (map[string]string, error) {
settings := make(map[string]string)
nameMap := map[string]string{
"dbname": "database",
}
for len(s) > 0 {
var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return nil, errors.New("invalid keyword/value")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
if len(s) == 0 {
} else if s[0] != '\'' {
end := 0
for ; end < len(s); end++ {
if asciiSpace[s[end]] == 1 {
break
}
if s[end] == '\\' {
end++
if end == len(s) {
return nil, errors.New("invalid backslash")
}
}
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
} else { // quoted string
s = s[1:]
end := 0
for ; end < len(s); end++ {
if s[end] == '\'' {
break
}
if s[end] == '\\' {
end++
}
}
if end == len(s) {
return nil, errors.New("unterminated quoted string in connection info string")
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
}
if k, ok := nameMap[key]; ok {
key = k
}
if key == "" {
return nil, errors.New("invalid keyword/value")
}
settings[key] = val
}
return settings, nil
}
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
if err != nil {
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
}
service, err := servicefile.GetService(serviceName)
if err != nil {
return nil, fmt.Errorf("unable to find service: %v", serviceName)
}
nameMap := map[string]string{
"dbname": "database",
}
settings := make(map[string]string, len(service.Settings))
for k, v := range service.Settings {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v
}
return settings, nil
}
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
// necessary to allow returning multiple TLS configs as sslmode "allow" and
// "prefer" allow fallback.
func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
host := thisHost
sslmode := settings["sslmode"]
sslrootcert := settings["sslrootcert"]
sslcert := settings["sslcert"]
sslkey := settings["sslkey"]
sslpassword := settings["sslpassword"]
sslsni := settings["sslsni"]
sslnegotiation := settings["sslnegotiation"]
// Match libpq default behavior
if sslmode == "" {
sslmode = "prefer"
}
if sslsni == "" {
sslsni = "1"
}
tlsConfig := &tls.Config{}
if sslnegotiation == "direct" {
tlsConfig.NextProtos = []string{"postgresql"}
if sslmode == "prefer" {
sslmode = "require"
}
}
if sslrootcert != "" {
var caCertPool *x509.CertPool
if sslrootcert == "system" {
var err error
caCertPool, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("unable to load system certificate pool: %w", err)
}
sslmode = "verify-full"
} else {
caCertPool = x509.NewCertPool()
caPath := sslrootcert
caCert, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
switch sslmode {
case "disable":
return []*tls.Config{nil}, nil
case "allow", "prefer":
tlsConfig.InsecureSkipVerify = true
case "require":
// According to PostgreSQL documentation, if a root CA file exists,
// the behavior of sslmode=require should be the same as that of verify-ca
//
// See https://www.postgresql.org/docs/current/libpq-ssl.html
if sslrootcert != "" {
goto nextCase
}
tlsConfig.InsecureSkipVerify = true
break
nextCase:
fallthrough
case "verify-ca":
// Don't perform the default certificate verification because it
// will verify the hostname. Instead, verify the server's
// certificate chain ourselves in VerifyPeerCertificate and
// ignore the server name. This emulates libpq's verify-ca
// behavior.
//
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
// for more info.
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, len(certificates))
for i, asn1Data := range certificates {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return errors.New("failed to parse certificate from server: " + err.Error())
}
certs[i] = cert
}
// Leave DNSName empty to skip hostname verification.
opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
Intermediates: x509.NewCertPool(),
}
// Skip the first cert because it's the leaf. All others
// are intermediates.
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
return err
}
case "verify-full":
tlsConfig.ServerName = host
default:
return nil, errors.New("sslmode is invalid")
}
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
}
if sslcert != "" && sslkey != "" {
buf, err := os.ReadFile(sslkey)
if err != nil {
return nil, fmt.Errorf("unable to read sslkey: %w", err)
}
block, _ := pem.Decode(buf)
if block == nil {
return nil, errors.New("failed to decode sslkey")
}
var pemKey []byte
var decryptedKey []byte
var decryptedError error
// If PEM is encrypted, attempt to decrypt using pass phrase
if x509.IsEncryptedPEMBlock(block) {
// Attempt decryption with pass phrase
// NOTE: only supports RSA (PKCS#1)
if sslpassword != "" {
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
}
// if sslpassword not provided or has decryption error when use it
// try to find sslpassword with callback function
if sslpassword == "" || decryptedError != nil {
if parseConfigOptions.GetSSLPassword != nil {
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
}
if sslpassword == "" {
return nil, fmt.Errorf("unable to find sslpassword")
}
}
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
// Should we also provide warning for PKCS#1 needed?
if decryptedError != nil {
return nil, fmt.Errorf("unable to decrypt key: %w", err)
}
pemBytes := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: decryptedKey,
}
pemKey = pem.EncodeToMemory(&pemBytes)
} else {
pemKey = pem.EncodeToMemory(block)
}
certfile, err := os.ReadFile(sslcert)
if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err)
}
cert, err := tls.X509KeyPair(certfile, pemKey)
if err != nil {
return nil, fmt.Errorf("unable to load cert: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
// Set Server Name Indication (SNI), if enabled by connection parameters.
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
// or IPv6).
if sslsni == "1" && net.ParseIP(host) == nil {
tlsConfig.ServerName = host
}
switch sslmode {
case "allow":
return []*tls.Config{nil, tlsConfig}, nil
case "prefer":
return []*tls.Config{tlsConfig, nil}, nil
case "require", "verify-ca", "verify-full":
return []*tls.Config{tlsConfig}, nil
default:
panic("BUG: bad sslmode should already have been caught")
}
}
func parsePort(s string) (uint16, error) {
port, err := strconv.ParseUint(s, 10, 16)
if err != nil {
return 0, err
}
if port < 1 || port > math.MaxUint16 {
return 0, errors.New("outside range")
}
return uint16(port), nil
}
func makeDefaultDialer() *net.Dialer {
// rely on GOLANG KeepAlive settings
return &net.Dialer{}
}
func makeDefaultResolver() *net.Resolver {
return net.DefaultResolver
}
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, err
}
if timeout < 0 {
return 0, errors.New("negative timeout")
}
return time.Duration(timeout) * time.Second, nil
}
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
d := makeDefaultDialer()
d.Timeout = timeout
return d.DialContext
}
// ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
if err != nil {
return err
}
if string(result[0].Rows[0][0]) == "on" {
return errors.New("read only connection")
}
return nil
}
// ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-only.
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
if err != nil {
return err
}
if string(result[0].Rows[0][0]) != "on" {
return errors.New("connection is not read only")
}
return nil
}
// ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=standby.
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
if err != nil {
return err
}
if string(result[0].Rows[0][0]) != "t" {
return errors.New("server is not in hot standby mode")
}
return nil
}
// ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=primary.
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
if err != nil {
return err
}
if string(result[0].Rows[0][0]) == "t" {
return errors.New("server is in standby mode")
}
return nil
}
// ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=prefer-standby.
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
if err != nil {
return err
}
if string(result[0].Rows[0][0]) != "t" {
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -1,80 +0,0 @@
package ctxwatch
import (
"context"
"sync"
)
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
handler Handler
unwatchChan chan struct{}
lock sync.Mutex
watchInProgress bool
onCancelWasCalled bool
}
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(handler Handler) *ContextWatcher {
cw := &ContextWatcher{
handler: handler,
unwatchChan: make(chan struct{}),
}
return cw
}
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
func (cw *ContextWatcher) Watch(ctx context.Context) {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
panic("Watch already in progress")
}
cw.onCancelWasCalled = false
if ctx.Done() != nil {
cw.watchInProgress = true
go func() {
select {
case <-ctx.Done():
cw.handler.HandleCancel(ctx)
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
}
}()
} else {
cw.watchInProgress = false
}
}
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
// called then onUnwatchAfterCancel will also be called.
func (cw *ContextWatcher) Unwatch() {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
cw.handler.HandleUnwatchAfterCancel()
}
cw.watchInProgress = false
}
}
type Handler interface {
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
// context that was canceled.
HandleCancel(canceledCtx context.Context)
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
HandleUnwatchAfterCancel()
}

View File

@ -1,185 +0,0 @@
package ctxwatch_test
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/stretchr/testify/require"
)
type testHandler struct {
handleCancel func(context.Context)
handleUnwatchAfterCancel func()
}
func (h *testHandler) HandleCancel(ctx context.Context) {
h.handleCancel(ctx)
}
func (h *testHandler) HandleUnwatchAfterCancel() {
h.handleUnwatchAfterCancel()
}
func TestContextWatcherContextCancelled(t *testing.T) {
canceledChan := make(chan struct{})
cleanupCalled := false
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
canceledChan <- struct{}{}
}, handleUnwatchAfterCancel: func() {
cleanupCalled = true
},
})
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
cancel()
select {
case <-canceledChan:
case <-time.NewTimer(time.Second).C:
t.Fatal("Timed out waiting for cancel func to be called")
}
cw.Unwatch()
require.True(t, cleanupCalled, "Cleanup func was not called")
}
func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
t.Error("cancel func should not have been called")
}, handleUnwatchAfterCancel: func() {
t.Error("cleanup func should not have been called")
},
})
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
cw.Unwatch()
cancel()
}
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cw.Watch(ctx)
defer cw.Unwatch()
ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
}
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
cw.Unwatch() // unwatch when not / never watching
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cw.Watch(ctx)
cw.Unwatch()
cw.Unwatch() // double unwatch
}
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
cw.Watch(ctx)
go cw.Unwatch()
go cw.Unwatch()
<-ctx.Done()
}
func TestContextWatcherStress(t *testing.T) {
var cancelFuncCalls int64
var cleanupFuncCalls int64
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
atomic.AddInt64(&cancelFuncCalls, 1)
}, handleUnwatchAfterCancel: func() {
atomic.AddInt64(&cleanupFuncCalls, 1)
},
})
cycleCount := 100000
for i := 0; i < cycleCount; i++ {
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
if i%2 == 0 {
cancel()
}
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
if i%333 == 0 {
// on Windows Sleep takes more time than expected so we try to get here less frequently to avoid
// the CI takes a long time
time.Sleep(time.Nanosecond)
}
cw.Unwatch()
if i%2 == 1 {
cancel()
}
}
actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)
if actualCancelFuncCalls == 0 {
t.Fatal("actualCancelFuncCalls == 0")
}
maxCancelFuncCalls := int64(cycleCount) / 2
if actualCancelFuncCalls > maxCancelFuncCalls {
t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
}
if actualCancelFuncCalls != actualCleanupFuncCalls {
t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
}
}
func BenchmarkContextWatcherUncancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
for i := 0; i < b.N; i++ {
cw.Watch(context.Background())
cw.Unwatch()
}
}
func BenchmarkContextWatcherCancelled(b *testing.B) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
for i := 0; i < b.N; i++ {
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
cancel()
cw.Unwatch()
}
}
func BenchmarkContextWatcherCancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for i := 0; i < b.N; i++ {
cw.Watch(ctx)
cw.Unwatch()
}
}

View File

@ -1,63 +0,0 @@
//go:build !windows
// +build !windows
package pgconn
import (
"os"
"os/user"
"path/filepath"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
if err == nil {
settings["user"] = user.Username
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
candidatePaths := []string{
"/var/run/postgresql", // Debian
"/private/tmp", // OSX - homebrew
"/tmp", // standard PostgreSQL
}
for _, path := range candidatePaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return "localhost"
}

View File

@ -1,57 +0,0 @@
package pgconn
import (
"os"
"os/user"
"path/filepath"
"strings"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
appData := os.Getenv("APPDATA")
if err == nil {
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
// but the libpq default is just the `user` portion, so we strip off the first part.
username := user.Username
if strings.Contains(username, "\\") {
username = username[strings.LastIndex(username, "\\")+1:]
}
settings["user"] = username
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
return "localhost"
}

View File

@ -1,38 +0,0 @@
// Package pgconn is a low-level PostgreSQL database driver.
/*
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
nearly the same level is the C library libpq.
Establishing a Connection
Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the
environment for libpq style environment variables.
Executing a Query
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
reads all rows into memory.
Executing Multiple Queries in a Single Round Trip
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
result. The ReadAll method reads all query results into memory.
Pipeline Mode
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of
exactly how many and when network round trips occur.
Context Support
All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the
method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can
be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior.
This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is
a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before
interrupting the query in such a way as to close the connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.
*/
package pgconn

View File

@ -1,256 +0,0 @@
package pgconn
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"regexp"
"strings"
)
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool {
var retryableErr interface{ SafeToRetry() bool }
if errors.As(err, &retryableErr) {
return retryableErr.SafeToRetry()
}
return false
}
// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
var timeoutErr *errTimeout
return errors.As(err, &timeoutErr)
}
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/current/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
SeverityUnlocalized string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// SQLState returns the SQLState of the error.
func (pe *PgError) SQLState() string {
return pe.Code
}
// ConnectError is the error returned when a connection attempt fails.
type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
err error
}
func (e *ConnectError) Error() string {
prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database)
details := e.err.Error()
if strings.Contains(details, "\n") {
return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t")
} else {
return prefix + " " + details
}
}
func (e *ConnectError) Unwrap() error {
return e.err
}
type perDialConnectError struct {
address string
originalHostname string
err error
}
func (e *perDialConnectError) Error() string {
return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error())
}
func (e *perDialConnectError) Unwrap() error {
return e.err
}
type connLockError struct {
status string
}
func (e *connLockError) SafeToRetry() bool {
return true // a lock failure by definition happens before the connection is used.
}
func (e *connLockError) Error() string {
return e.status
}
// ParseConfigError is the error returned when a connection string cannot be parsed.
type ParseConfigError struct {
ConnString string // The connection string that could not be parsed.
msg string
err error
}
func NewParseConfigError(conn, msg string, err error) error {
return &ParseConfigError{
ConnString: conn,
msg: msg,
err: err,
}
}
func (e *ParseConfigError) Error() string {
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
// allow access to the original string if desired and Unwrap would allow access to the underlying error.
connString := redactPW(e.ConnString)
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
}
func (e *ParseConfigError) Unwrap() error {
return e.err
}
func normalizeTimeoutError(ctx context.Context, err error) error {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
if ctx.Err() == context.Canceled {
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
return context.Canceled
} else if ctx.Err() == context.DeadlineExceeded {
return &errTimeout{err: ctx.Err()}
} else {
return &errTimeout{err: netErr}
}
}
return err
}
type pgconnError struct {
msg string
err error
safeToRetry bool
}
func (e *pgconnError) Error() string {
if e.msg == "" {
return e.err.Error()
}
if e.err == nil {
return e.msg
}
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
}
func (e *pgconnError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *pgconnError) Unwrap() error {
return e.err
}
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
type errTimeout struct {
err error
}
func (e *errTimeout) Error() string {
return fmt.Sprintf("timeout: %s", e.err.Error())
}
func (e *errTimeout) SafeToRetry() bool {
return SafeToRetry(e.err)
}
func (e *errTimeout) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
func (e *contextAlreadyDoneError) Error() string {
return fmt.Sprintf("context already done: %s", e.err.Error())
}
func (e *contextAlreadyDoneError) SafeToRetry() bool {
return true
}
func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
func newContextAlreadyDoneError(ctx context.Context) (err error) {
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
}
func redactPW(connString string) string {
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
if u, err := url.Parse(connString); err == nil {
return redactURL(u)
}
}
quotedKV := regexp.MustCompile(`password='[^']*'`)
connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx")
plainKV := regexp.MustCompile(`password=[^ ]*`)
connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx")
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
return connString
}
func redactURL(u *url.URL) string {
if u == nil {
return ""
}
if _, pwSet := u.User.Password(); pwSet {
u.User = url.UserPassword(u.User.Username(), "xxxxx")
}
return u.String()
}
type NotPreferredError struct {
err error
safeToRetry bool
}
func (e *NotPreferredError) Error() string {
return fmt.Sprintf("standby server not found: %s", e.err.Error())
}
func (e *NotPreferredError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *NotPreferredError) Unwrap() error {
return e.err
}

View File

@ -1,54 +0,0 @@
package pgconn_test
import (
"testing"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
)
func TestConfigError(t *testing.T) {
tests := []struct {
name string
err error
expectedMsg string
}{
{
name: "url with password",
err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil),
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg",
},
{
name: "keyword/value with password unquoted",
err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil),
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
},
{
name: "keyword/value with password quoted",
err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil),
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
},
{
name: "weird url",
err: pgconn.NewParseConfigError("postgresql://foo::password@host:1:", "msg", nil),
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg",
},
{
name: "weird url with slash in password",
err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil),
expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg",
},
{
name: "url without password",
err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil),
expectedMsg: "cannot parse `postgresql://other@host/db`: msg",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.EqualError(t, tt.err, tt.expectedMsg)
})
}
}

View File

@ -1,3 +0,0 @@
// File export_test exports some methods for better testing.
package pgconn

View File

@ -1,36 +0,0 @@
package pgconn_test
import (
"context"
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func closeConn(t testing.TB, conn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
require.NoError(t, conn.Close(ctx))
select {
case <-conn.CleanupDone():
case <-time.After(30 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
// Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
cancel()
require.Nil(t, result.Err)
assert.Equal(t, 3, len(result.Rows))
assert.Equal(t, "1", string(result.Rows[0][0]))
assert.Equal(t, "2", string(result.Rows[1][0]))
assert.Equal(t, "3", string(result.Rows[2][0]))
}

View File

@ -1,139 +0,0 @@
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
package bgreader
import (
"io"
"sync"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
const (
StatusStopped = iota
StatusRunning
StatusStopping
)
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
type BGReader struct {
r io.Reader
cond *sync.Cond
status int32
readResults []readResult
}
type readResult struct {
buf *[]byte
err error
}
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
// reader will stop automatically when the underlying reader returns an error.
func (r *BGReader) Start() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
r.status = StatusRunning
go r.bgRead()
case StatusRunning:
// no-op
case StatusStopping:
r.status = StatusRunning
}
}
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
// background reader is not running.
func (r *BGReader) Stop() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
// no-op
case StatusRunning:
r.status = StatusStopping
case StatusStopping:
// no-op
}
}
// Status returns the current status of the background reader.
func (r *BGReader) Status() int32 {
r.cond.L.Lock()
defer r.cond.L.Unlock()
return r.status
}
func (r *BGReader) bgRead() {
keepReading := true
for keepReading {
buf := iobufpool.Get(8192)
n, err := r.r.Read(*buf)
*buf = (*buf)[:n]
r.cond.L.Lock()
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
if r.status == StatusStopping || err != nil {
r.status = StatusStopped
keepReading = false
}
r.cond.L.Unlock()
r.cond.Broadcast()
}
}
// Read implements the io.Reader interface.
func (r *BGReader) Read(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
if len(r.readResults) > 0 {
return r.readFromReadResults(p)
}
// There are no unread background read results and the background reader is stopped.
if r.status == StatusStopped {
return r.r.Read(p)
}
// Wait for results from the background reader
for len(r.readResults) == 0 {
r.cond.Wait()
}
return r.readFromReadResults(p)
}
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
buf := r.readResults[0].buf
var err error
n := copy(p, *buf)
if n == len(*buf) {
err = r.readResults[0].err
iobufpool.Put(buf)
if len(r.readResults) == 1 {
r.readResults = nil
} else {
r.readResults = r.readResults[1:]
}
} else {
*buf = (*buf)[n:]
r.readResults[0].buf = buf
}
return n, err
}
func New(r io.Reader) *BGReader {
return &BGReader{
r: r,
cond: &sync.Cond{
L: &sync.Mutex{},
},
}
}

View File

@ -1,140 +0,0 @@
package bgreader_test
import (
"bytes"
"errors"
"io"
"math/rand"
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/stretchr/testify/require"
)
func TestBGReaderReadWhenStopped(t *testing.T) {
r := bytes.NewReader([]byte("foo bar baz"))
bgr := bgreader.New(r)
buf, err := io.ReadAll(bgr)
require.NoError(t, err)
require.Equal(t, []byte("foo bar baz"), buf)
}
func TestBGReaderReadWhenStarted(t *testing.T) {
r := bytes.NewReader([]byte("foo bar baz"))
bgr := bgreader.New(r)
bgr.Start()
buf, err := io.ReadAll(bgr)
require.NoError(t, err)
require.Equal(t, []byte("foo bar baz"), buf)
}
type mockReadFunc func(p []byte) (int, error)
type mockReader struct {
readFuncs []mockReadFunc
}
func (r *mockReader) Read(p []byte) (int, error) {
if len(r.readFuncs) == 0 {
return 0, io.EOF
}
fn := r.readFuncs[0]
r.readFuncs = r.readFuncs[1:]
return fn(p)
}
func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), nil },
},
}
bgr := bgreader.New(rr)
bgr.Start()
buf := make([]byte, 3)
n, err := bgr.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 3, n)
require.Equal(t, []byte("foo"), buf)
}
func TestBGReaderErrorWhenStarted(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
},
}
bgr := bgreader.New(rr)
bgr.Start()
buf, err := io.ReadAll(bgr)
require.Equal(t, []byte("foobarbaz"), buf)
require.EqualError(t, err, "oops")
}
func TestBGReaderErrorWhenStopped(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
},
}
bgr := bgreader.New(rr)
buf, err := io.ReadAll(bgr)
require.Equal(t, []byte("foobarbaz"), buf)
require.EqualError(t, err, "oops")
}
type numberReader struct {
v uint8
rng *rand.Rand
}
func (nr *numberReader) Read(p []byte) (int, error) {
n := nr.rng.Intn(len(p))
for i := 0; i < n; i++ {
p[i] = nr.v
nr.v++
}
return n, nil
}
// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and
// stopping the background worker from other goroutines.
func TestBGReaderStress(t *testing.T) {
nr := &numberReader{rng: rand.New(rand.NewSource(0))}
bgr := bgreader.New(nr)
bytesRead := 0
var expected uint8
buf := make([]byte, 10_000)
rng := rand.New(rand.NewSource(0))
for bytesRead < 1_000_000 {
randomNumber := rng.Intn(100)
switch {
case randomNumber < 10:
go bgr.Start()
case randomNumber < 20:
go bgr.Stop()
default:
n, err := bgr.Read(buf)
require.NoError(t, err)
for i := 0; i < n; i++ {
require.Equal(t, expected, buf[i])
expected++
}
bytesRead += n
}
}
}

View File

@ -1,100 +0,0 @@
package pgconn
import (
"errors"
"fmt"
"github.com/jackc/pgx/v5/pgproto3"
)
// NewGSSFunc creates a GSS authentication provider, for use with
// RegisterGSSProvider.
type NewGSSFunc func() (GSS, error)
var newGSS NewGSSFunc
// RegisterGSSProvider registers a GSS authentication provider. For example, if
// you need to use Kerberos to authenticate with your server, add this to your
// main package:
//
// import "github.com/otan/gopgkrb5"
//
// func init() {
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
// }
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
newGSS = newGSSArg
}
// GSS provides GSSAPI authentication (e.g., Kerberos).
type GSS interface {
GetInitToken(host, service string) ([]byte, error)
GetInitTokenFromSPN(spn string) ([]byte, error)
Continue(inToken []byte) (done bool, outToken []byte, err error)
}
func (c *PgConn) gssAuth() error {
if newGSS == nil {
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
}
cli, err := newGSS()
if err != nil {
return err
}
var nextData []byte
if c.config.KerberosSpn != "" {
// Use the supplied SPN if provided.
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
} else {
// Allow the kerberos service name to be overridden
service := "postgres"
if c.config.KerberosSrvName != "" {
service = c.config.KerberosSrvName
}
nextData, err = cli.GetInitToken(c.config.Host, service)
}
if err != nil {
return err
}
for {
gssResponse := &pgproto3.GSSResponse{
Data: nextData,
}
c.frontend.Send(gssResponse)
err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
resp, err := c.rxGSSContinue()
if err != nil {
return err
}
var done bool
done, nextData, err = cli.Continue(resp.Data)
if err != nil {
return err
}
if done {
break
}
}
return nil
}
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationGSSContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
}

File diff suppressed because it is too large Load Diff

View File

@ -1,41 +0,0 @@
package pgconn
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCommandTag(t *testing.T) {
t.Parallel()
tests := []struct {
commandTag CommandTag
rowsAffected int64
isInsert bool
isUpdate bool
isDelete bool
isSelect bool
}{
{commandTag: CommandTag{s: "INSERT 0 5"}, rowsAffected: 5, isInsert: true},
{commandTag: CommandTag{s: "UPDATE 0"}, rowsAffected: 0, isUpdate: true},
{commandTag: CommandTag{s: "UPDATE 1"}, rowsAffected: 1, isUpdate: true},
{commandTag: CommandTag{s: "DELETE 0"}, rowsAffected: 0, isDelete: true},
{commandTag: CommandTag{s: "DELETE 1"}, rowsAffected: 1, isDelete: true},
{commandTag: CommandTag{s: "DELETE 1234567890"}, rowsAffected: 1234567890, isDelete: true},
{commandTag: CommandTag{s: "SELECT 1"}, rowsAffected: 1, isSelect: true},
{commandTag: CommandTag{s: "SELECT 99999999999"}, rowsAffected: 99999999999, isSelect: true},
{commandTag: CommandTag{s: "CREATE TABLE"}, rowsAffected: 0},
{commandTag: CommandTag{s: "ALTER TABLE"}, rowsAffected: 0},
{commandTag: CommandTag{s: "DROP TABLE"}, rowsAffected: 0},
}
for i, tt := range tests {
ct := tt.commandTag
assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
}
}

View File

@ -1,90 +0,0 @@
package pgconn_test
import (
"context"
"math/rand"
"os"
"runtime"
"strconv"
"testing"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
)
func TestConnStress(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
actionCount := 10000
if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" {
stressFactor, err := strconv.ParseInt(s, 10, 64)
require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR")
actionCount *= int(stressFactor)
}
setupStressDB(t, pgConn)
actions := []struct {
name string
fn func(*pgconn.PgConn) error
}{
{"Exec Select", stressExecSelect},
{"ExecParams Select", stressExecParamsSelect},
{"Batch", stressBatch},
}
for i := 0; i < actionCount; i++ {
action := actions[rand.Intn(len(actions))]
err := action.fn(pgConn)
require.Nilf(t, err, "%d: %s", i, action.name)
}
// Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled.
numGoroutine := runtime.NumGoroutine()
require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine)
}
func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
_, err := pgConn.Exec(context.Background(), `
create temporary table widgets(
id serial primary key,
name varchar not null,
description text,
creation_time timestamptz default now()
);
insert into widgets(name, description) values
('Foo', 'bar'),
('baz', 'Something really long Something really long Something really long Something really long Something really long'),
('a', 'b')`).ReadAll()
require.NoError(t, err)
}
func stressExecSelect(pgConn *pgconn.PgConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := pgConn.Exec(ctx, "select * from widgets").ReadAll()
return err
}
func stressExecParamsSelect(pgConn *pgconn.PgConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
return result.Err
}
func stressBatch(pgConn *pgconn.PgConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
batch := &pgconn.Batch{}
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
return err
}

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More