Compare commits

..

No commits in common. "master" and "v5.0.0-beta.1" have entirely different histories.

234 changed files with 5474 additions and 15649 deletions

View File

@ -23,7 +23,7 @@ import (
"log"
"os"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v4"
)
func main() {
@ -37,8 +37,6 @@ func main() {
}
```
Please run your example with the race detector enabled. For example, `go run -race main.go` or `go test -race`.
**Expected behavior**
A clear and concise description of what you expected to happen.

View File

@ -2,155 +2,87 @@ name: CI
on:
push:
branches: [master]
branches: [ master, v5-dev ]
pull_request:
branches: [master]
branches: [ master ]
jobs:
test:
name: Test
runs-on: ubuntu-22.04
runs-on: ubuntu-20.04
strategy:
matrix:
go-version: ["1.23", "1.24"]
pg-version: [13, 14, 15, 16, 17, cockroachdb]
go-version: [1.18]
pg-version: [10, 11, 12, 13, 14, cockroachdb]
include:
- pg-version: 10
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 11
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 12
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 13
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 14
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 15
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 16
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 17
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: cockroachdb
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Set up Go ${{ matrix.go-version }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- name: Setup database server for testing
run: ci/setup_test.bash
env:
PGVERSION: ${{ matrix.pg-version }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
# - name: Setup upterm session
# uses: lhotari/action-upterm@v1
# with:
# ## limits ssh access and adds the ssh public key for the user which triggered the workflow
# limit-access-to-actor: true
# env:
# PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
# PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
# PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
# PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
# PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
# PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
# PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
# PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
- name: Setup database server for testing
run: ci/setup_test.bash
env:
PGVERSION: ${{ matrix.pg-version }}
- name: Check formatting
run: |
gofmt -l -s -w .
git status
git diff --exit-code
- name: Test
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
run: go test -parallel=1 -race ./...
env:
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
# TestConnectTLS fails. However, it succeeds if I connect to the CI server with upterm and run it. Give up on that test for now.
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
test-windows:
name: Test Windows
runs-on: windows-latest
strategy:
matrix:
go-version: ["1.23", "1.24"]
steps:
- name: Setup PostgreSQL
id: postgres
uses: ikalnytskyi/action-setup-postgres@v4
with:
database: pgx_test
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Set up Go ${{ matrix.go-version }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
- name: Initialize test database
run: |
psql -f testsetup/postgresql_setup.sql pgx_test
env:
PGSERVICE: ${{ steps.postgres.outputs.service-name }}
shell: bash
- name: Test
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
run: go test -parallel=1 -race ./...
env:
PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }}
- name: Test
run: go test -race ./...
env:
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }}
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}

3
.gitignore vendored
View File

@ -22,6 +22,3 @@ _testmain.go
*.exe
.envrc
/.testdb
.DS_Store

9
.travis.yml Normal file
View File

@ -0,0 +1,9 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

View File

@ -1,286 +1,4 @@
# 5.7.5 (May 17, 2025)
* Support sslnegotiation connection option (divyam234)
* Update golang.org/x/crypto to v0.37.0. This placates security scanners that were unable to see that pgx did not use the behavior affected by https://pkg.go.dev/vuln/GO-2025-3487.
* TraceLog now logs Acquire and Release at the debug level (dave sinclair)
* Add support for PGTZ environment variable
* Add support for PGOPTIONS environment variable
* Unpin memory used by Rows quicker
* Remove PlanScan memoization. This resolves a rare issue where scanning could be broken for one type by first scanning another. The problem was in the memoization system and benchmarking revealed that memoization was not providing any meaningful benefit.
# 5.7.4 (March 24, 2025)
* Fix / revert change to scanning JSON `null` (Felix Röhrich)
# 5.7.3 (March 21, 2025)
* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32)
* Improve SQL sanitizer performance (ninedraft)
* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich)
* Fix Values() for xml type always returning nil instead of []byte
* Add ability to send Flush message in pipeline mode (zenkovev)
* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou)
* Better error messages when scanning structs (logicbomb)
* Fix handling of error on batch write (bonnefoa)
* Match libpq's connection fallback behavior more closely (felix-roehrich)
* Add MinIdleConns to pgxpool (djahandarie)
# 5.7.2 (December 21, 2024)
* Fix prepared statement already exists on batch prepare failure
* Add commit query to tx options (Lucas Hild)
* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels)
* Add message body size limits in frontend and backend (zene)
* Add xid8 type
* Ensure planning encodes and scans cannot infinitely recurse
* Implement pgtype.UUID.String() (Konstantin Grachev)
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
* Update golang.org/x/crypto
* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo)
# 5.7.1 (September 10, 2024)
* Fix data race in tracelog.TraceLog
* Update puddle to v2.2.2. This removes the import of nanotime via linkname.
* Update golang.org/x/crypto and golang.org/x/text
# 5.7.0 (September 7, 2024)
* Add support for sslrootcert=system (Yann Soubeyrand)
* Add LoadTypes to load multiple types in a single SQL query (Nick Farrell)
* Add XMLCodec supports encoding + scanning XML column type like json (nickcruess-soda)
* Add MultiTrace (Stepan Rabotkin)
* Add TraceLogConfig with customizable TimeKey (stringintech)
* pgx.ErrNoRows wraps sql.ErrNoRows to aid in database/sql compatibility with native pgx functions (merlin)
* Support scanning binary formatted uint32 into string / TextScanner (jennifersp)
* Fix interval encoding to allow 0s and avoid extra spaces (Carlos Pérez-Aradros Herce)
* Update pgservicefile - fixes panic when parsing invalid file
* Better error message when reading past end of batch
* Don't print url when url.Parse returns an error (Kevin Biju)
* Fix snake case name normalization collision in RowToStructByName with db tag (nolandseigler)
* Fix: Scan and encode types with underlying types of arrays
# 5.6.0 (May 25, 2024)
* Add StrictNamedArgs (Tomas Zahradnicek)
* Add support for macaddr8 type (Carlos Pérez-Aradros Herce)
* Add SeverityUnlocalized field to PgError / Notice
* Performance optimization of RowToStructByPos/Name (Zach Olstein)
* Allow customizing context canceled behavior for pgconn
* Add ScanLocation to pgtype.Timestamp[tz]Codec
* Add custom data to pgconn.PgConn
* Fix ResultReader.Read() to handle nil values
* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce)
* pgconn.SafeToRetry checks for wrapped errors (tjasko)
* Failed connection attempts include all errors
* Optimize LargeObject.Read (Mitar)
* Add tracing for connection acquire and release from pool (ngavinsir)
* Fix encode driver.Valuer not called when nil
* Add support for custom JSON marshal and unmarshal (Mitar)
* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck)
# 5.5.5 (March 9, 2024)
Use spaces instead of parentheses for SQL sanitization.
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
# 5.5.4 (March 4, 2024)
Fix CVE-2024-27304
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
attacker's control.
Thanks to Paul Gerste for reporting this issue.
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
* Fix simple protocol encoding of json.RawMessage
* Fix *Pipeline.getResults should close pipeline on error
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
* Fix deallocation of invalidated cached statements in a transaction
* Handle invalid sslkey file
* Fix scan float4 into sql.Scanner
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
# 5.5.3 (February 3, 2024)
* Fix: prepared statement already exists
* Improve CopyFrom auto-conversion of text-ish values
* Add ltree type support (Florent Viel)
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
* Add AppendRows function (Edoardo Spadolini)
* Optimize convert UUID [16]byte to string (Kirill Malikov)
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
# 5.5.2 (January 13, 2024)
* Allow NamedArgs to start with underscore
* pgproto3: Maximum message body length support (jeremy.spriet)
* Upgrade golang.org/x/crypto to v0.17.0
* Add snake_case support to RowToStructByName (Tikhon Fedulov)
* Fix: update description cache after exec prepare (James Hartig)
* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler)
* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer)
* Add OnPgError for easier centralized error handling (James Hartig)
# 5.5.1 (December 9, 2023)
* Add CopyFromFunc helper function. (robford)
* Add PgConn.Deallocate method that uses PostgreSQL protocol Close message.
* pgx uses new PgConn.Deallocate method. This allows deallocating statements to work in a failed transaction. This fixes a case where the prepared statement map could become invalid.
* Fix: Prefer driver.Valuer over json.Marshaler for json fields. (Jacopo)
* Fix: simple protocol SQL sanitizer previously panicked if an invalid $0 placeholder was used. This now returns an error instead. (maksymnevajdev)
* Add pgtype.Numeric.ScanScientific (Eshton Robateau)
# 5.5.0 (November 4, 2023)
* Add CollectExactlyOneRow. (Julien GOTTELAND)
* Add OpenDBFromPool to create *database/sql.DB from *pgxpool.Pool. (Lev Zakharov)
* Prepare can automatically choose statement name based on sql. This makes it easier to explicitly manage prepared statements.
* Statement cache now uses deterministic, stable statement names.
* database/sql prepared statement names are deterministically generated.
* Fix: SendBatch wasn't respecting context cancellation.
* Fix: Timeout error from pipeline is now normalized.
* Fix: database/sql encoding json.RawMessage to []byte.
* CancelRequest: Wait for the cancel request to be acknowledged by the server. This should improve PgBouncer compatibility. (Anton Levakin)
* stdlib: Use Ping instead of CheckConn in ResetSession
* Add json.Marshaler and json.Unmarshaler for Float4, Float8 (Kirill Mironov)
# 5.4.3 (August 5, 2023)
* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert)
* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb)
* Fix: pgxpool: background health check cannot overflow pool
* Fix: Check for nil in defer when sending batch (recover properly from panic)
* Fix: json scan of non-string pointer to pointer
* Fix: zeronull.Timestamptz should use pgtype.Timestamptz
* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig)
* RowTo(AddrOf)StructByPos ignores fields with "-" db tag
* Optimization: improve text format numeric parsing (horpto)
# 5.4.2 (July 11, 2023)
* Fix: RowScanner errors are fatal to Rows
* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman)
* Hstore text codec internal improvements (Evan Jones)
* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares)
* Fix: Stop background reader as soon as possible.
* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn.
# 5.4.1 (June 18, 2023)
* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov)
* Add TxOptions.BeginQuery to allow overriding the default BEGIN query
# 5.4.0 (June 14, 2023)
* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues.
* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov)
* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic
* CancelRequest: don't try to read the reply (Nicola Murino)
* Fix: correctly handle bool type aliases (Wichert Akkerman)
* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr()
* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones)
* Add BeforeClose to pgxpool.Pool (Evan Cordell)
* Fix: various hstore fixes and optimizations (Evan Jones)
* Fix: RowToStructByPos with embedded unexported struct
* Support different bool string representations (Lev Zakharov)
* Fix: error when using BatchResults.Exec on a select that returns an error after some rows.
* Fix: pipelineBatchResults.Exec() not returning error from ResultReader
* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using
a callback.
* Fix: scanning a table type into a struct
* Fix: scan array of record to pointer to slice of struct
* Fix: handle null for json (Cemre Mengu)
* Batch Query callback is called even when there is an error
* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P)
# 5.3.1 (February 27, 2023)
* Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka)
* Fix: sql.Scanner not being used in certain cases
* Add text format jsonpath support
* Fix: fake non-blocking read adaptive wait time
# 5.3.0 (February 11, 2023)
* Fix: json values work with sql.Scanner
* Fixed / improved error messages (Mark Chambers and Yevgeny Pats)
* Fix: support scan into single dimensional arrays
* Fix: MaxConnLifetimeJitter setting actually jitter (Ben Weintraub)
* Fix: driver.Value representation of bytea should be []byte not string
* Fix: better handling of unregistered OIDs
* CopyFrom can use query cache to avoid extra round trip to get OIDs (Alejandro Do Nascimento Mora)
* Fix: encode to json ignoring driver.Valuer
* Support sql.Scanner on renamed base type
* Fix: pgtype.Numeric text encoding of negative numbers (Mark Chambers)
* Fix: connect with multiple hostnames when one can't be resolved
* Upgrade puddle to remove dependency on uber/atomic and fix alignment issue on 32-bit platform
* Fix: scanning json column into **string
* Multiple reductions in memory allocations
* Fake non-blocking read adapts its max wait time
* Improve CopyFrom performance and reduce memory usage
* Fix: encode []any to array
* Fix: LoadType for composite with dropped attributes (Felix Röhrich)
* Support v4 and v5 stdlib in same program
* Fix: text format array decoding with string of "NULL"
* Prefer binary format for arrays
# 5.2.0 (December 5, 2022)
* `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov)
* Optimize creating begin transaction SQL string (Petr Evdokimov and ksco)
* `Conn.LoadType` supports range and multirange types (Vitalii Solodilov)
* Fix scan `uint` and `uint64` `ScanNumeric`. This resolves a PostgreSQL `numeric` being incorrectly scanned into `uint` and `uint64`.
# 5.1.1 (November 17, 2022)
* Fix simple query sanitizer where query text contains a Unicode replacement character.
* Remove erroneous `name` argument from `DeallocateAll()`. Technically, this is a breaking change, but given that method was only added 5 days ago this change was accepted. (Bodo Kaiser)
# 5.1.0 (November 12, 2022)
* Update puddle to v2.1.2. This resolves a race condition and a deadlock in pgxpool.
* `QueryRewriter.RewriteQuery` now returns an error. Technically, this is a breaking change for any external implementers, but given the minimal likelihood that there are actually any external implementers this change was accepted.
* Expose `GetSSLPassword` support to pgx.
* Fix encode `ErrorResponse` unknown field handling. This would only affect pgproto3 being used directly as a proxy with a non-PostgreSQL server that included additional error fields.
* Fix date text format encoding with 5 digit years.
* Fix date values passed to a `sql.Scanner` as `string` instead of `time.Time`.
* DateCodec.DecodeValue can return `pgtype.InfinityModifier` instead of `string` for infinite values. This now matches the behavior of the timestamp types.
* Add domain type support to `Conn.LoadType()`.
* Add `RowToStructByName` and `RowToAddrOfStructByName`. (Pavlo Golub)
* Add `Conn.DeallocateAll()` to clear all prepared statements including the statement cache. (Bodo Kaiser)
# 5.0.4 (October 24, 2022)
* Fix: CollectOneRow prefers PostgreSQL error over pgx.ErrorNoRows
* Fix: some reflect Kind checks to first check for nil
* Bump golang.org/x/text dependency to placate snyk
* Fix: RowToStructByPos on structs with multiple anonymous sub-structs (Baptiste Fontaine)
* Fix: Exec checks if tx is closed
# 5.0.3 (October 14, 2022)
* Fix `driver.Valuer` handling edge cases that could cause infinite loop or crash
# v5.0.2 (October 8, 2022)
* Fix date encoding in text format to always use 2 digits for month and day
* Prefer driver.Valuer over wrap plans when encoding
* Fix scan to pointer to pointer to renamed type
* Allow scanning NULL even if PG and Go types are incompatible
# v5.0.1 (September 24, 2022)
* Fix 32-bit atomic usage
* Add MarshalJSON for Float8 (yogipristiawan)
* Add `[` and `]` to text encoding of `Lseg`
* Fix sqlScannerWrapper NULL handling
# v5.0.0 (September 17, 2022)
# Unreleased v5
## Merged Packages
@ -304,11 +22,9 @@ pgconn now supports pipeline mode.
`*PgConn.ReceiveResults` removed. Use pipeline mode instead.
`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error.
## pgxpool
`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
`Connect` and `ConnectConfig` have been renamed to `New` and `NewConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
## pgtype
@ -317,10 +33,7 @@ The `pgtype` package has been significantly changed.
### NULL Representation
Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a
`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable.
Previously, a type that implemented `driver.Valuer` would have the `Value` method called even on a nil pointer. All nils
whether typed or untyped now represent `NULL`.
`Valid` `bool` field to harmonize with how `database/sql` represents NULL and to make the zero value useable.
### Codec and Value Split
@ -334,9 +47,9 @@ generally defined by implementing an interface that a particular `Codec` underst
### Array Types
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also
means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional
arrays.
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This
significantly reduced the amount of code and the compiled binary size. This also means that less common array types such
as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional arrays.
### Composite Types
@ -350,7 +63,7 @@ easily be handled. Multirange types are handled similarly with `MultirangeCodec`
### pgxtype
`LoadDataType` moved to `*Conn` as `LoadType`.
load data type moved to conn
### Bytea
@ -384,7 +97,7 @@ This matches the convention set by `database/sql`. In addition, for comparable t
### 3rd Party Type Integrations
* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to
* Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to
https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims
the pgx dependency tree.

View File

@ -1,121 +0,0 @@
# Contributing
## Discuss Significant Changes
Before you invest a significant amount of time on a change, please create a discussion or issue describing your
proposal. This will help to ensure your proposed change has a reasonable chance of being merged.
## Avoid Dependencies
Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change
that adds a dependency is no.
## Development Environment Setup
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE`
environment variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or key-value pairs. In addition,
the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to
simplify environment variable handling.
### Using an Existing PostgreSQL Cluster
If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx
test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different
authentication methods).
Create and setup a test database:
```
export PGDATABASE=pgx_test
createdb
psql -c 'create extension hstore;'
psql -c 'create extension ltree;'
psql -c 'create domain uint64 as numeric(20,0);'
```
Ensure a `postgres` user exists. This happens by default in normal PostgreSQL installs, but some installation methods
such as Homebrew do not.
```
createuser -s postgres
```
Ensure your `PGX_TEST_DATABASE` environment variable points to the database you just created and run the tests.
```
export PGX_TEST_DATABASE="host=/private/tmp database=pgx_test"
go test ./...
```
This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods).
### Creating a New PostgreSQL Cluster Exclusively for Testing
The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is
highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`.
```
export PGPORT=5015
export PGUSER=postgres
export PGDATABASE=pgx_test
export POSTGRESQL_DATA_DIR=postgresql
export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test"
export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test"
export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret"
export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem"
export PGX_SSL_PASSWORD=certpw
export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key"
```
Create a new database cluster.
```
initdb --locale=en_US -E UTF-8 --username=postgres .testdb/$POSTGRESQL_DATA_DIR
echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
cd .testdb
# Generate CA, server, and encrypted client certificates.
go run ../testsetup/generate_certs.go
# Copy certificates to server directory and set permissions.
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
cp localhost.key $POSTGRESQL_DATA_DIR/server.key
chmod 600 $POSTGRESQL_DATA_DIR/server.key
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
cd ..
```
Start the new cluster. This will be necessary whenever you are running pgx tests.
```
postgres -D .testdb/$POSTGRESQL_DATA_DIR
```
Setup the test database in the new cluster.
```
createdb
psql --no-psqlrc -f testsetup/postgresql_setup.sql
```
### PgBouncer
There are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set.
### Optional Tests
pgx supports multiple connection types and means of authentication. These tests are optional. They will only run if the
appropriate environment variables are set. In addition, there may be tests specific to particular PostgreSQL versions,
non-PostgreSQL servers (e.g. CockroachDB), or connection poolers (e.g. PgBouncer). `go test ./... -v | grep SKIP` to see
if any tests are being skipped.

132
README.md
View File

@ -1,12 +1,15 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgx/v5.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5)
[![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgx/actions/workflows/ci.yml)
[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5)
[![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx)
# pgx - PostgreSQL Driver and Toolkit
*This is the v5 development branch. It is still in active development and testing.*
pgx is a pure Go driver and toolkit for PostgreSQL.
The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` /
`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface.
pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for.
The driver component of pgx can be used alongside the standard `database/sql` package.
The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol
and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers,
@ -48,55 +51,91 @@ func main() {
See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information.
## Choosing Between the pgx and database/sql Interfaces
It is recommended to use the pgx interface if:
1. The application only targets PostgreSQL.
2. No other libraries that require `database/sql` are in use.
The pgx interface is faster and exposes more features.
The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`,
`float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the
`database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses.
## Features
pgx supports many features beyond what is available through `database/sql`:
* Support for approximately 70 different PostgreSQL types
* Automatic statement preparation and caching
* Batch queries
* Single-round trip query mode
* Full TLS connection control
* Binary format support for custom types (allows for much quicker encoding/decoding)
* `COPY` protocol support for faster bulk data loads
* Tracing and logging support
* COPY protocol support for faster bulk data loads
* Extendable logging support
* Connection pool with after-connect hook for arbitrary connection setup
* `LISTEN` / `NOTIFY`
* Listen / notify
* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings
* `hstore` support
* `json` and `jsonb` support
* Hstore support
* JSON and JSONB support
* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix`
* Large object support
* NULL mapping to pointer to pointer
* NULL mapping to Null* struct or pointer to pointer
* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types
* Notice response handling
* Simulated nested transactions with savepoints
## Choosing Between the pgx and database/sql Interfaces
## Performance
The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available
through the `database/sql` interface.
There are three areas in particular where pgx can provide a significant performance advantage over the standard
`database/sql` interface and other drivers:
The pgx interface is recommended when:
1. The application only targets PostgreSQL.
2. No other libraries that require `database/sql` are in use.
It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed.
1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format.
2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an
significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can
perform nearly 3x the number of queries per second.
3. Batched queries - Multiple queries can be batched together to minimize network round trips.
## Testing
See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions.
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` environment
variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or DSN. In addition, the standard `PG*` environment
variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable
handling.
## Architecture
### Example Test Environment
See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.com/watch?v=sXMSWhcHCf8) for a description of pgx architecture.
Connect to your PostgreSQL server and run:
```
create database pgx_test;
```
Connect to the newly-created database and run:
```
create domain uint64 as numeric(20,0);
```
Now, you can run the tests:
```
PGX_TEST_DATABASE="host=/var/run/postgresql database=pgx_test" go test ./...
```
In addition, there are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set.
## Supported Go and PostgreSQL Versions
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.23 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
~~pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.17 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).~~
`v5` is targeted at Go 1.18+. The general release of `v5` is not planned until second half of 2022 so it is expected that the policy of supporting the two most recent versions of Go will be maintained or restored soon after its release.
## Version Policy
pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version.
pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version.
## PGX Family Libraries
@ -120,14 +159,8 @@ pgerrcode contains constants for the PostgreSQL error codes.
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
## Adapters for 3rd Party Tracers
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
## Adapters for 3rd Party Loggers
These adapters can be used with the tracelog package.
@ -137,50 +170,13 @@ These adapters can be used with the tracelog package.
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)
* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog)
* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog)
* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog)
## 3rd Party Libraries with PGX Support
### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock)
pgxmock is a mock library implementing pgx interfaces.
pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection.
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
Library for scanning data from a database into Go structs and more.
### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql)
A carefully designed SQL client for making using SQL easier,
more productive, and less error-prone on Golang.
### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
Adds GSSAPI / Kerberos authentication support.
### [github.com/wcamarao/pmx](https://github.com/wcamarao/pmx)
Explicit data mapping and scanning library for Go structs and slices.
### [github.com/stephenafamo/scan](https://github.com/stephenafamo/scan)
Type safe and flexible package for scanning database data into Go types.
Supports, structs, maps, slices and custom mapping functions.
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
Code first migration library for native pgx (no database/sql abstraction).
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.
### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox)
Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver.
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
Simplifies working with the pgx library, providing convenient scanning of nested structures.

View File

@ -2,7 +2,7 @@ require "erb"
rule '.go' => '.go.erb' do |task|
erb = ERB.new(File.read(task.source))
File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
sh "goimports", "-w", task.name
end

112
batch.go
View File

@ -10,9 +10,9 @@ import (
// QueuedQuery is a query that has been queued for execution via a Batch.
type QueuedQuery struct {
SQL string
Arguments []any
Fn batchItemFunc
query string
arguments []any
fn batchItemFunc
sd *pgconn.StatementDescription
}
@ -20,11 +20,14 @@ type batchItemFunc func(br BatchResults) error
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
qq.Fn = func(br BatchResults) error {
rows, _ := br.Query()
qq.fn = func(br BatchResults) error {
rows, err := br.Query()
if err != nil {
return err
}
defer rows.Close()
err := fn(rows)
err = fn(rows)
if err != nil {
return err
}
@ -36,7 +39,7 @@ func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
qq.Fn = func(br BatchResults) error {
qq.fn = func(br BatchResults) error {
row := br.QueryRow()
return fn(row)
}
@ -44,7 +47,7 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
// Exec sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
qq.Fn = func(br BatchResults) error {
qq.fn = func(br BatchResults) error {
ct, err := br.Exec()
if err != nil {
return err
@ -57,28 +60,22 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
// Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips. A Batch must only be sent once.
type Batch struct {
QueuedQueries []*QueuedQuery
queuedQueries []*QueuedQuery
}
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option
// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode.
//
// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should
// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query,
// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that
// include the current query may reference the wrong query.
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
qq := &QueuedQuery{
SQL: query,
Arguments: arguments,
query: query,
arguments: arguments,
}
b.QueuedQueries = append(b.QueuedQueries, qq)
b.queuedQueries = append(b.queuedQueries, qq)
return qq
}
// Len returns number of queries that have been queued so far.
func (b *Batch) Len() int {
return len(b.QueuedQueries)
return len(b.queuedQueries)
}
type BatchResults interface {
@ -132,7 +129,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
if !br.mrr.NextResult() {
err := br.mrr.Close()
if err == nil {
err = errors.New("no more results in batch")
err = errors.New("no result")
}
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
@ -145,10 +142,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
}
commandTag, err := br.mrr.ResultReader().Close()
if err != nil {
br.err = err
br.mrr.Close()
}
br.err = err
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
@ -184,7 +178,7 @@ func (br *batchResults) Query() (Rows, error) {
if !br.mrr.NextResult() {
rows.err = br.mrr.Close()
if rows.err == nil {
rows.err = errors.New("no more results in batch")
rows.err = errors.New("no result")
}
rows.closed = true
@ -231,10 +225,10 @@ func (br *batchResults) Close() error {
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil {
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br)
if err != nil && br.err == nil {
br.err = err
}
} else {
@ -257,10 +251,10 @@ func (br *batchResults) earlyError() error {
}
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.QueuedQueries[br.qqIdx]
query = bi.SQL
args = bi.Arguments
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
bi := br.b.queuedQueries[br.qqIdx]
query = bi.query
args = bi.arguments
ok = true
br.qqIdx++
}
@ -291,15 +285,12 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
return pgconn.CommandTag{}, br.err
}
query, arguments, err := br.nextQueryAndArgs()
if err != nil {
return pgconn.CommandTag{}, err
}
query, arguments, _ := br.nextQueryAndArgs()
results, err := br.pipeline.GetResults()
if err != nil {
br.err = err
return pgconn.CommandTag{}, br.err
return pgconn.CommandTag{}, err
}
var commandTag pgconn.CommandTag
switch results := results.(type) {
@ -318,7 +309,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
})
}
return commandTag, br.err
return commandTag, err
}
// Query reads the results from the next query in the batch as if the query has been sent with Query.
@ -337,9 +328,9 @@ func (br *pipelineBatchResults) Query() (Rows, error) {
return &baseRows{err: br.err, closed: true}, br.err
}
query, arguments, err := br.nextQueryAndArgs()
if err != nil {
return &baseRows{err: err, closed: true}, err
query, arguments, ok := br.nextQueryAndArgs()
if !ok {
query = "batch query"
}
rows := br.conn.getRows(br.ctx, query, arguments)
@ -393,20 +384,24 @@ func (br *pipelineBatchResults) Close() error {
}
}()
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
if br.err != nil {
return br.err
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return br.err
}
if br.closed {
return br.err
return nil
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil {
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br)
if err != nil && br.err == nil {
br.err = err
}
} else {
@ -428,16 +423,13 @@ func (br *pipelineBatchResults) earlyError() error {
return br.err
}
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, err error) {
if br.b == nil {
return "", nil, errors.New("no reference to batch")
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
bi := br.b.queuedQueries[br.qqIdx]
query = bi.query
args = bi.arguments
ok = true
br.qqIdx++
}
if br.qqIdx >= len(br.b.QueuedQueries) {
return "", nil, errors.New("no more results in batch")
}
bi := br.b.QueuedQueries[br.qqIdx]
br.qqIdx++
return bi.SQL, bi.Arguments, nil
return
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
@ -18,10 +17,7 @@ import (
func TestConnSendBatch(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
sql := `create temporary table ledger(
@ -40,7 +36,7 @@ func TestConnSendBatch(t *testing.T) {
batch.Queue("select * from ledger where false")
batch.Queue("select sum(amount) from ledger")
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
ct, err := br.Exec()
if err != nil {
@ -156,10 +152,7 @@ func TestConnSendBatch(t *testing.T) {
func TestConnSendBatchQueuedQuery(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
sql := `create temporary table ledger(
@ -244,7 +237,7 @@ func TestConnSendBatchQueuedQuery(t *testing.T) {
return nil
})
err := conn.SendBatch(ctx, batch).Close()
err := conn.SendBatch(context.Background(), batch).Close()
assert.NoError(t, err)
})
}
@ -252,10 +245,7 @@ func TestConnSendBatchQueuedQuery(t *testing.T) {
func TestConnSendBatchMany(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
@ -272,7 +262,7 @@ func TestConnSendBatchMany(t *testing.T) {
}
batch.Queue("select count(*) from ledger")
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
for i := 0; i < numInserts; i++ {
ct, err := br.Exec()
@ -290,45 +280,6 @@ func TestConnSendBatchMany(t *testing.T) {
})
}
// https://github.com/jackc/pgx/issues/1801#issuecomment-2203784178
func TestConnSendBatchReadResultsWhenNothingQueued(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
br := conn.SendBatch(ctx, batch)
commandTag, err := br.Exec()
require.Equal(t, "", commandTag.String())
require.EqualError(t, err, "no more results in batch")
err = br.Close()
require.NoError(t, err)
})
}
func TestConnSendBatchReadMoreResultsThanQueriesSent(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 1")
br := conn.SendBatch(ctx, batch)
commandTag, err := br.Exec()
require.Equal(t, "SELECT 1", commandTag.String())
require.NoError(t, err)
commandTag, err = br.Exec()
require.Equal(t, "", commandTag.String())
require.EqualError(t, err, "no more results in batch")
err = br.Close()
require.NoError(t, err)
})
}
func TestConnSendBatchWithPreparedStatement(t *testing.T) {
t.Parallel()
@ -339,12 +290,9 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
pgx.QueryExecModeExec,
// Don't test simple mode with prepared statements.
}
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
_, err := conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
if err != nil {
t.Fatal(err)
}
@ -356,7 +304,7 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
batch.Queue("ps1", 5)
}
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
for i := 0; i < queryCount; i++ {
rows, err := br.Query()
@ -389,16 +337,13 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
func TestConnSendBatchWithQueryRewriter(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}})
batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}})
batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}})
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
var n int32
err := br.QueryRow().Scan(&n)
@ -423,9 +368,6 @@ func TestConnSendBatchWithQueryRewriter(t *testing.T) {
func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
@ -438,7 +380,7 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
_, err = conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
_, err = conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
if err != nil {
t.Fatal(err)
}
@ -450,7 +392,7 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
batch.Queue("ps1", 5)
}
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
for i := 0; i < queryCount; i++ {
rows, err := br.Query()
@ -484,16 +426,13 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n")
batch.Queue("select n from generate_series(0,5) n")
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
rows, err := br.Query()
if err != nil {
@ -546,16 +485,13 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
func TestConnSendBatchQueryError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
batch.Queue("select n from generate_series(0,5) n")
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
rows, err := br.Query()
if err != nil {
@ -587,15 +523,12 @@ func TestConnSendBatchQueryError(t *testing.T) {
func TestConnSendBatchQuerySyntaxError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 1 1")
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
var n int32
err := br.QueryRow().Scan(&n)
@ -614,10 +547,7 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {
func TestConnSendBatchQueryRowInsert(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger(
id serial primary key,
@ -630,7 +560,7 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
batch.Queue("select 1")
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
var value int
err := br.QueryRow().Scan(&value)
@ -654,10 +584,7 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger(
id serial primary key,
@ -670,7 +597,7 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
batch.Queue("select 1 union all select 2 union all select 3")
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
rows, err := br.Query()
if err != nil {
@ -694,10 +621,7 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
func TestTxSendBatch(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger1(
id serial primary key,
@ -711,7 +635,7 @@ func TestTxSendBatch(t *testing.T) {
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin(ctx)
tx, _ := conn.Begin(context.Background())
batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
@ -728,7 +652,7 @@ func TestTxSendBatch(t *testing.T) {
batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
batch.Queue("select amount from ledger2 where id = $1", id)
br = tx.SendBatch(ctx, batch)
br = tx.SendBatch(context.Background(), batch)
ct, err := br.Exec()
if err != nil {
@ -745,10 +669,10 @@ func TestTxSendBatch(t *testing.T) {
}
br.Close()
tx.Commit(ctx)
tx.Commit(context.Background())
var count int
conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id).Scan(&count)
conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count)
if count != 1 {
t.Errorf("count => %v, want %v", count, 1)
}
@ -764,10 +688,7 @@ func TestTxSendBatch(t *testing.T) {
func TestTxSendBatchRollback(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger1(
id serial primary key,
@ -775,11 +696,11 @@ func TestTxSendBatchRollback(t *testing.T) {
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin(ctx)
tx, _ := conn.Begin(context.Background())
batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
br := tx.SendBatch(ctx, batch)
br := tx.SendBatch(context.Background(), batch)
var id int
err := br.QueryRow().Scan(&id)
@ -787,9 +708,9 @@ func TestTxSendBatchRollback(t *testing.T) {
t.Error(err)
}
br.Close()
tx.Rollback(ctx)
tx.Rollback(context.Background())
row := conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id)
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)
var count int
row.Scan(&count)
if count != 0 {
@ -799,62 +720,10 @@ func TestTxSendBatchRollback(t *testing.T) {
})
}
// https://github.com/jackc/pgx/issues/1578
func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 4 / $1::int", 0)
batchResult := conn.SendBatch(ctx, batch)
_, execErr := batchResult.Exec()
require.Error(t, execErr)
closeErr := batchResult.Close()
require.Equal(t, execErr, closeErr)
// Try to use the connection.
_, err := conn.Exec(ctx, "select 1")
require.NoError(t, err)
})
}
func TestSendBatchErrorWhileReadingResultsWithExecWhereSomeRowsAreReturned(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 4 / n from generate_series(-2, 2) n")
batchResult := conn.SendBatch(ctx, batch)
_, execErr := batchResult.Exec()
require.Error(t, execErr)
closeErr := batchResult.Close()
require.Equal(t, execErr, closeErr)
// Try to use the connection.
_, err := conn.Exec(ctx, "select 1")
require.NoError(t, err)
})
}
func TestConnBeginBatchDeferredError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
@ -870,7 +739,7 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
batch.Queue(`update t set n=n+1 where id='b' returning *`)
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
rows, err := br.Query()
if err != nil {
@ -899,9 +768,6 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
}
func TestConnSendBatchNoStatementCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
config.StatementCacheCapacity = 0
@ -910,13 +776,10 @@ func TestConnSendBatchNoStatementCache(t *testing.T) {
conn := mustConnect(t, config)
defer closeConn(t, conn)
testConnSendBatch(t, ctx, conn, 3)
testConnSendBatch(t, conn, 3)
}
func TestConnSendBatchPrepareStatementCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
config.StatementCacheCapacity = 32
@ -924,13 +787,10 @@ func TestConnSendBatchPrepareStatementCache(t *testing.T) {
conn := mustConnect(t, config)
defer closeConn(t, conn)
testConnSendBatch(t, ctx, conn, 3)
testConnSendBatch(t, conn, 3)
}
func TestConnSendBatchDescribeStatementCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
config.DescriptionCacheCapacity = 32
@ -938,16 +798,16 @@ func TestConnSendBatchDescribeStatementCache(t *testing.T) {
conn := mustConnect(t, config)
defer closeConn(t, conn)
testConnSendBatch(t, ctx, conn, 3)
testConnSendBatch(t, conn, 3)
}
func testConnSendBatch(t *testing.T, ctx context.Context, conn *pgx.Conn, queryCount int) {
func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) {
batch := &pgx.Batch{}
for j := 0; j < queryCount; j++ {
batch.Queue("select n from generate_series(0,5) n")
}
br := conn.SendBatch(ctx, batch)
br := conn.SendBatch(context.Background(), batch)
for j := 0; j < queryCount; j++ {
rows, err := br.Query()
@ -970,12 +830,12 @@ func testConnSendBatch(t *testing.T, ctx context.Context, conn *pgx.Conn, queryC
func TestSendBatchSimpleProtocol(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
conn := mustConnect(t, config)
defer closeConn(t, conn)
@ -1008,41 +868,8 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
assert.False(t, rows.Next())
}
// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
mustExec(t, conn, `create temporary table foo(col1 text primary key);`)
batch := &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err := conn.SendBatch(ctx, batch).Close()
require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`)
mustExec(t, conn, `create temporary table baz(col1 text primary key);`)
// Since table baz now exists, the batch should succeed.
batch = &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err = conn.SendBatch(ctx, batch).Close()
require.NoError(t, err)
})
}
func ExampleConn_SendBatch() {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
@ -1085,7 +912,7 @@ func ExampleConn_SendBatch() {
return err
})
err = conn.SendBatch(ctx, batch).Close()
err = conn.SendBatch(context.Background(), batch).Close()
if err != nil {
fmt.Printf("SendBatch error: %v", err)
return

View File

@ -13,6 +13,7 @@ import (
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/internal/nbconn"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/require"
@ -151,7 +152,7 @@ func BenchmarkMinimalPgConnPreparedSelect(b *testing.B) {
for rr.NextRow() {
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[0], encodedBytes) {
if bytes.Compare(rr.Values()[0], encodedBytes) != 0 {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes)
}
}
@ -339,9 +340,8 @@ type benchmarkWriteTableCopyFromSrc struct {
}
func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
next := s.idx < s.count
s.idx++
return next
return s.idx < s.count
}
func (s *benchmarkWriteTableCopyFromSrc) Values() ([]any, error) {
@ -407,34 +407,6 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
}
}
func benchmarkWriteNRowsViaBatchInsert(b *testing.B, n int) {
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
defer closeConn(b, conn)
mustExec(b, conn, benchmarkWriteTableCreateSQL)
_, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
src := newBenchmarkWriteTableCopyFromSrc(n)
batch := &pgx.Batch{}
for src.Next() {
values, _ := src.Values()
batch.Queue("insert_t", values...)
}
err = conn.SendBatch(context.Background(), batch).Close()
if err != nil {
b.Fatal(err)
}
}
}
type queryArgs []any
func (qa *queryArgs) Append(v any) string {
@ -512,7 +484,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
}
if err := tx.Commit(context.Background()); err != nil {
return 0, err
return 0, nil
}
return rowCount, nil
@ -588,22 +560,6 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
}
}
func BenchmarkWrite2RowsViaInsert(b *testing.B) {
benchmarkWriteNRowsViaInsert(b, 2)
}
func BenchmarkWrite2RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 2)
}
func BenchmarkWrite2RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 2)
}
func BenchmarkWrite2RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 2)
}
func BenchmarkWrite5RowsViaInsert(b *testing.B) {
benchmarkWriteNRowsViaInsert(b, 5)
}
@ -611,9 +567,6 @@ func BenchmarkWrite5RowsViaInsert(b *testing.B) {
func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 5)
}
func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 5)
}
func BenchmarkWrite5RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 5)
@ -626,9 +579,6 @@ func BenchmarkWrite10RowsViaInsert(b *testing.B) {
func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 10)
}
func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 10)
}
func BenchmarkWrite10RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 10)
@ -641,9 +591,6 @@ func BenchmarkWrite100RowsViaInsert(b *testing.B) {
func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 100)
}
func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 100)
}
func BenchmarkWrite100RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 100)
@ -657,10 +604,6 @@ func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 1000)
}
func BenchmarkWrite1000RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 1000)
}
func BenchmarkWrite1000RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 1000)
}
@ -672,9 +615,6 @@ func BenchmarkWrite10000RowsViaInsert(b *testing.B) {
func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 10000)
}
func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 10000)
}
func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 10000)
@ -944,7 +884,6 @@ type BenchRowSimple struct {
BirthDate time.Time
Weight int32
Height int32
Tags []string
UpdateTime time.Time
}
@ -958,13 +897,13 @@ func BenchmarkSelectRowsScanSimple(b *testing.B) {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
br := &BenchRowSimple{}
for i := 0; i < b.N; i++ {
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
if err != nil {
b.Fatal(err)
}
for rows.Next() {
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
}
if rows.Err() != nil {
@ -983,7 +922,6 @@ type BenchRowStringBytes struct {
BirthDate time.Time
Weight int32
Height int32
Tags []string
UpdateTime time.Time
}
@ -997,13 +935,13 @@ func BenchmarkSelectRowsScanStringBytes(b *testing.B) {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
br := &BenchRowStringBytes{}
for i := 0; i < b.N; i++ {
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
if err != nil {
b.Fatal(err)
}
for rows.Next() {
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
}
if rows.Err() != nil {
@ -1022,7 +960,6 @@ type BenchRowDecoder struct {
BirthDate pgtype.Date
Weight pgtype.Int4
Height pgtype.Int4
Tags pgtype.FlatArray[string]
UpdateTime pgtype.Timestamptz
}
@ -1048,7 +985,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
for i := 0; i < b.N; i++ {
rows, err := conn.Query(
context.Background(),
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
pgx.QueryResultFormats{format.code},
rowCount,
)
@ -1057,7 +994,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
}
for rows.Next() {
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
}
if rows.Err() != nil {
@ -1079,7 +1016,7 @@ func BenchmarkSelectRowsPgConnExecText(b *testing.B) {
for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
for i := 0; i < b.N; i++ {
mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount))
mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount))
for mrr.NextResult() {
rr := mrr.ResultReader()
for rr.NextRow() {
@ -1116,11 +1053,11 @@ func BenchmarkSelectRowsPgConnExecParams(b *testing.B) {
for i := 0; i < b.N; i++ {
rr := conn.PgConn().ExecParams(
context.Background(),
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
nil,
nil,
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
)
for rr.NextRow() {
rr.Values()
@ -1137,107 +1074,13 @@ func BenchmarkSelectRowsPgConnExecParams(b *testing.B) {
}
}
func BenchmarkSelectRowsSimpleCollectRowsRowToStructByPos(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b)
for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
for i := 0; i < b.N; i++ {
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByPos[BenchRowSimple])
if err != nil {
b.Fatal(err)
}
if len(benchRows) != int(rowCount) {
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
}
}
})
}
}
func BenchmarkSelectRowsSimpleAppendRowsRowToStructByPos(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b)
for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
benchRows := make([]BenchRowSimple, 0, rowCount)
for i := 0; i < b.N; i++ {
benchRows = benchRows[:0]
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
var err error
benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple])
if err != nil {
b.Fatal(err)
}
if len(benchRows) != int(rowCount) {
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
}
}
})
}
}
func BenchmarkSelectRowsSimpleCollectRowsRowToStructByName(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b)
for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
for i := 0; i < b.N; i++ {
rows, _ := conn.Query(context.Background(), "select n as id, 'Adam' as first_name, 'Smith ' || n as last_name, 'male' as sex, '1952-06-16'::date as birth_date, 258 as weight, 72 as height, '{foo,bar,baz}'::text[] as tags, '2001-01-28 01:02:03-05'::timestamptz as update_time from generate_series(100001, 100000 + $1) n", rowCount)
benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByName[BenchRowSimple])
if err != nil {
b.Fatal(err)
}
if len(benchRows) != int(rowCount) {
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
}
}
})
}
}
func BenchmarkSelectRowsSimpleAppendRowsRowToStructByName(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b)
for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
benchRows := make([]BenchRowSimple, 0, rowCount)
for i := 0; i < b.N; i++ {
benchRows = benchRows[:0]
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
var err error
benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple])
if err != nil {
b.Fatal(err)
}
if len(benchRows) != int(rowCount) {
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
}
}
})
}
}
func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b)
_, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
_, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
if err != nil {
b.Fatal(err)
}
@ -1259,7 +1102,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
"ps1",
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
nil,
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
)
for rr.NextRow() {
rr.Values()
@ -1277,7 +1120,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
}
type queryRecorder struct {
conn net.Conn
conn nbconn.Conn
writeBuf []byte
readCount int
}
@ -1293,6 +1136,14 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) {
return qr.conn.Write(b)
}
func (qr *queryRecorder) BufferReadUntilBlock() error {
return qr.conn.BufferReadUntilBlock()
}
func (qr *queryRecorder) Flush() error {
return qr.conn.Flush()
}
func (qr *queryRecorder) Close() error {
return qr.conn.Close()
}
@ -1338,7 +1189,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")).PgConn()
defer conn.Close(context.Background())
_, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
_, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
if err != nil {
b.Fatal(err)
}
@ -1361,7 +1212,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) {
"ps1",
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
nil,
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
)
_, err := rr.Close()
require.NoError(b, err)

View File

@ -9,41 +9,40 @@ then
sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list"
sudo apt-get update -qq
sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION
sudo cp testsetup/pg_hba.conf /etc/postgresql/$PGVERSION/main/pg_hba.conf
sudo sh -c "echo \"listen_addresses = '127.0.0.1'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
sudo sh -c "cat testsetup/postgresql_ssl.conf >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
cd testsetup
# Generate CA, server, and encrypted client certificates.
go run generate_certs.go
# Copy certificates to server directory and set permissions.
sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/root.crt
sudo cp localhost.key /var/lib/postgresql/$PGVERSION/main/server.key
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.key
sudo chmod 600 /var/lib/postgresql/$PGVERSION/main/server.key
sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt
cp ca.pem /tmp
cp pgx_sslcert.key /tmp
cp pgx_sslcert.crt /tmp
cd ..
sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf
if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then
echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
fi
sudo /etc/init.d/postgresql restart
createdb -U postgres pgx_test
psql -U postgres -f testsetup/postgresql_setup.sql pgx_test
psql -U postgres -c 'create database pgx_test'
psql -U postgres pgx_test -c 'create extension hstore'
psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)'
psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
psql -U postgres -c "create user `whoami`"
psql -U postgres -c "create user pgx_replication with replication password 'secret'"
# The tricky test user, below, has to actually exist so that it can be used in a test
# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
fi
if [[ "${PGVERSION-}" =~ ^cockroach ]]
then
wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz
sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/
wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz
sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/
cockroach start-single-node --insecure --background --listen-addr=localhost
cockroach sql --insecure -e 'create database pgx_test'
fi

576
conn.go
View File

@ -2,15 +2,13 @@ package pgx
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/sanitize"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn"
@ -37,18 +35,13 @@ type ConnConfig struct {
// DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol
// and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as
// PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
// PGBouncer. In this case it may be preferrable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
// functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument.
DefaultQueryExecMode QueryExecMode
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
type ParseConfigOptions struct {
pgconn.ParseConfigOptions
}
// Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the tls.Config:
// according to the tls.Config docs it must not be modified after creation.
@ -101,33 +94,11 @@ func (ident Identifier) Sanitize() string {
return strings.Join(parts, ".")
}
var (
// ErrNoRows occurs when rows are expected but none are returned.
ErrNoRows = newProxyErr(sql.ErrNoRows, "no rows in result set")
// ErrTooManyRows occurs when more rows than expected are returned.
ErrTooManyRows = errors.New("too many rows in result set")
)
// ErrNoRows occurs when rows are expected but none are returned.
var ErrNoRows = errors.New("no rows in result set")
func newProxyErr(background error, msg string) error {
return &proxyError{
msg: msg,
background: background,
}
}
type proxyError struct {
msg string
background error
}
func (err *proxyError) Error() string { return err.msg }
func (err *proxyError) Unwrap() error { return err.background }
var (
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
)
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
// Connect establishes a connection with a PostgreSQL server with a connection string. See
// pgconn.Connect for details.
@ -139,16 +110,6 @@ func Connect(ctx context.Context, connString string) (*Conn, error) {
return connect(ctx, connConfig)
}
// ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to
// provide a GetSSLPassword function.
func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) {
connConfig, err := ParseConfigWithOptions(connString, options)
if err != nil {
return nil, err
}
return connect(ctx, connConfig)
}
// ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct.
// connConfig must have been created by ParseConfig.
func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
@ -159,10 +120,22 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
return connect(ctx, connConfig)
}
// ParseConfigWithOptions behaves exactly as ParseConfig does with the addition of options. At the present options is
// only used to provide a GetSSLPassword function.
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) {
config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions)
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig
// does. In addition, it accepts the following options:
//
// default_query_exec_mode
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement".
//
// statement_cache_capacity
// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode.
// Default: 512.
//
// description_cache_capacity
// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode.
// Default: 512.
func ParseConfig(connString string) (*ConnConfig, error) {
config, err := pgconn.ParseConfig(connString)
if err != nil {
return nil, err
}
@ -172,7 +145,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
delete(config.RuntimeParams, "statement_cache_capacity")
n, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse statement_cache_capacity", err)
return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err)
}
statementCacheCapacity = int(n)
}
@ -182,7 +155,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
delete(config.RuntimeParams, "description_cache_capacity")
n, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse description_cache_capacity", err)
return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err)
}
descriptionCacheCapacity = int(n)
}
@ -202,7 +175,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "simple_protocol":
defaultQueryExecMode = QueryExecModeSimpleProtocol
default:
return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err)
return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err)
}
}
@ -218,24 +191,6 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
return connConfig, nil
}
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that [pgconn.ParseConfig]
// does. In addition, it accepts the following options:
//
// - default_query_exec_mode.
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement".
//
// - statement_cache_capacity.
// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode.
// Default: 512.
//
// - description_cache_capacity.
// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode.
// Default: 512.
func ParseConfig(connString string) (*ConnConfig, error) {
return ParseConfigWithOptions(connString, ParseConfigOptions{})
}
// connect connects to a database. connect takes ownership of config. The caller must not use or access it again.
func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
if connectTracer, ok := config.Tracer.(ConnectTracer); ok {
@ -293,7 +248,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
return c, nil
}
// Close closes a connection. It is safe to call Close on an already closed
// Close closes a connection. It is safe to call Close on a already closed
// connection.
func (c *Conn) Close(ctx context.Context) error {
if c.IsClosed() {
@ -304,15 +259,12 @@ func (c *Conn) Close(ctx context.Context) error {
return err
}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These
// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and
// Exec to execute the statement. It can also be used with Batch.Queue.
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
//
// The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if
// name == sql.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This
// allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared.
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
if c.prepareTracer != nil {
ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
@ -334,60 +286,22 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem
}()
}
var psName, psKey string
if name == sql {
digest := sha256.Sum256([]byte(sql))
psName = "stmt_" + hex.EncodeToString(digest[0:24])
psKey = sql
} else {
psName = name
psKey = name
}
sd, err = c.pgConn.Prepare(ctx, psName, sql, nil)
sd, err = c.pgConn.Prepare(ctx, name, sql, nil)
if err != nil {
return nil, err
}
if psKey != "" {
c.preparedStatements[psKey] = sd
if name != "" {
c.preparedStatements[name] = sd
}
return sd, nil
}
// Deallocate releases a prepared statement. Calling Deallocate on a non-existent prepared statement will succeed.
// Deallocate released a prepared statement
func (c *Conn) Deallocate(ctx context.Context, name string) error {
var psName string
sd := c.preparedStatements[name]
if sd != nil {
psName = sd.Name
} else {
psName = name
}
err := c.pgConn.Deallocate(ctx, psName)
if err != nil {
return err
}
if sd != nil {
delete(c.preparedStatements, name)
}
return nil
}
// DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache.
func (c *Conn) DeallocateAll(ctx context.Context) error {
c.preparedStatements = map[string]*pgconn.StatementDescription{}
if c.config.StatementCacheCapacity > 0 {
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
}
if c.config.DescriptionCacheCapacity > 0 {
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
}
_, err := c.pgConn.Exec(ctx, "deallocate all").ReadAll()
delete(c.preparedStatements, name)
_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
return err
}
@ -420,7 +334,7 @@ func (c *Conn) IsClosed() bool {
return c.pgConn.IsClosed()
}
func (c *Conn) die() {
func (c *Conn) die(err error) {
if c.IsClosed() {
return
}
@ -434,9 +348,11 @@ func quoteIdentifier(s string) string {
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
}
// Ping delegates to the underlying *pgconn.PgConn.Ping.
// Ping executes an empty sql statement against the *Conn
// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned.
func (c *Conn) Ping(ctx context.Context) error {
return c.pgConn.Ping(ctx)
_, err := c.Exec(ctx, ";")
return err
}
// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
@ -491,10 +407,7 @@ optionLoop:
}
if queryRewriter != nil {
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
if err != nil {
return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", err)
}
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
}
// Always use simple protocol when there are no arguments.
@ -513,7 +426,7 @@ optionLoop:
}
sd := c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
return pgconn.CommandTag{}, err
}
@ -531,7 +444,6 @@ optionLoop:
if err != nil {
return pgconn.CommandTag{}, err
}
c.descriptionCache.Put(sd)
}
return c.execParams(ctx, sd, arguments)
@ -560,7 +472,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []a
mrr := c.pgConn.Exec(ctx, sql)
for mrr.NextResult() {
commandTag, _ = mrr.ResultReader().Close()
commandTag, err = mrr.ResultReader().Close()
}
err = mrr.Close()
return commandTag, err
@ -588,6 +500,14 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
return result.CommandTag, result.Err
}
type unknownArgumentTypeQueryExecModeExecError struct {
arg any
}
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
}
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
err := c.eqb.Build(c.typeMap, nil, args)
if err != nil {
@ -618,57 +538,40 @@ type QueryExecMode int32
const (
_ QueryExecMode = iota
// Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single round
// trip after the statement is cached. This is the default. If the database schema is modified or the search_path is
// changed after a statement is cached then the first execution of a previously cached query may fail. e.g. If the
// number of columns returned by a "SELECT *" changes or the type of a column is changed.
// Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single
// round trip after the statement is cached. This is the default.
QueryExecModeCacheStatement
// Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the extended
// protocol. Queries are executed in a single round trip after the description is cached. If the database schema is
// modified or the search_path is changed after a statement is cached then the first execution of a previously cached
// query may fail. e.g. If the number of columns returned by a "SELECT *" changes or the type of a column is changed.
// Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the
// extended protocol. Queries are executed in a single round trip after the description is cached. If the database
// schema is modified or the search_path is changed this may result in undetected result decoding errors.
QueryExecModeCacheDescribe
// Get the statement description on every execution. This uses the extended protocol. Queries require two round trips
// to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the
// statement description on the first round trip and then uses it to execute the query on the second round trip. This
// may cause problems with connection poolers that switch the underlying connection between round trips. It is safe
// even when the database schema is modified concurrently.
// to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even
// when the the database schema is modified concurrently.
QueryExecModeDescribeExec
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
// with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be
// registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are
// unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
// unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
// the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot.
//
// On rare occasions user defined types may behave differently when encoded in the text format instead of the binary
// format. For example, this could happen if a "type RomanNumeral int32" implements fmt.Stringer to format integers as
// Roman numerals (e.g. 7 is VII). The binary format would properly encode the integer 7 as the binary value for 7.
// But the text format would encode the integer 7 as the string "VII". As QueryExecModeExec uses the text format, it
// is possible that changing query mode from another mode to QueryExecModeExec could change the behavior of the query.
// This should not occur with types pgx supports directly and can be avoided by registering the types with
// pgtype.Map.RegisterDefaultPgType and implementing the appropriate type interfaces. In the cas of RomanNumeral, it
// should implement pgtype.Int64Valuer.
QueryExecModeExec
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is
// especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used
// instead for text type values including json and jsonb. Type mappings can be registered with
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous.
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a
// map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip.
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments.
// Queries are executed in a single round trip. Type mappings can be registered with
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
// a map[string]string directly as an argument. This mode cannot.
//
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes
// the warning regarding differences in text format and binary format encoding with user defined types. There may be
// other minor exceptions such as behavior when multiple result returning queries are erroneously sent in a single
// string.
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor
// exceptions such as behavior when multiple result returning queries are erroneously sent in a single string.
//
// QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer
// QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol should
// only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does not
// support the extended protocol.
// QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol
// should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does
// not support the extended protocol.
QueryExecModeSimpleProtocol
)
@ -697,7 +600,7 @@ type QueryResultFormatsByOID map[uint32]int16
// QueryRewriter rewrites a query when used as the first arguments to a query method.
type QueryRewriter interface {
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error)
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any)
}
// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query
@ -708,9 +611,6 @@ type QueryRewriter interface {
// returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It
// is allowed to ignore the error returned from Query and handle it in Rows.
//
// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not
// return an error.
//
// It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be
// collected before processing rather than processed while receiving each row. This avoids the possibility of the
// application processing rows from a query that the server rejected. The CollectRows function is useful here.
@ -759,16 +659,7 @@ optionLoop:
}
if queryRewriter != nil {
var err error
originalSQL := sql
originalArgs := args
sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args)
if err != nil {
rows := c.getRows(ctx, originalSQL, originalArgs)
err = fmt.Errorf("rewrite query failed: %w", err)
rows.fatal(err)
return rows, err
}
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
}
// Bypass any statement caching.
@ -777,16 +668,50 @@ optionLoop:
}
c.eqb.reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args)
var err error
sd, explicitPreparedStatement := c.preparedStatements[sql]
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
if sd == nil {
sd, err = c.getStatementDescription(ctx, mode, sql)
if err != nil {
rows.fatal(err)
return rows, err
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
err = errDisabledStatementCache
rows.fatal(err)
return rows, err
}
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
err = errDisabledDescriptionCache
rows.fatal(err)
return rows, err
}
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
}
}
@ -856,47 +781,6 @@ optionLoop:
return rows, rows.err
}
// getStatementDescription returns the statement description of the sql query
// according to the given mode.
//
// If the mode is one that doesn't require to know the param and result OIDs
// then nil is returned without error.
func (c *Conn) getStatementDescription(
ctx context.Context,
mode QueryExecMode,
sql string,
) (sd *pgconn.StatementDescription, err error) {
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
return nil, errDisabledStatementCache
}
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
if err != nil {
return nil, err
}
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
return nil, errDisabledDescriptionCache
}
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
return nil, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
return c.Prepare(ctx, "", sql)
}
return sd, err
}
// QueryRow is a convenience wrapper over Query. Any error that occurs while
// querying is deferred until calling Scan on the returned Row. That Row will
// error with ErrNoRows if no rows are returned.
@ -908,9 +792,6 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// is used again.
//
// Depending on the QueryExecMode, all queries may be prepared before any are executed. This means that creating a table
// and using it in a subsequent query in the same batch can fail.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
if c.batchTracer != nil {
ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b})
@ -926,14 +807,15 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
return &batchResults{ctx: ctx, conn: c, err: err}
}
for _, bi := range b.QueuedQueries {
mode := c.config.DefaultQueryExecMode
for _, bi := range b.queuedQueries {
var queryRewriter QueryRewriter
sql := bi.SQL
arguments := bi.Arguments
sql := bi.query
arguments := bi.arguments
optionLoop:
for len(arguments) > 0 {
// Update Batch.Queue function comment when additional options are implemented
switch arg := arguments[0].(type) {
case QueryRewriter:
queryRewriter = arg
@ -944,26 +826,20 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
}
if queryRewriter != nil {
var err error
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %w", err)}
}
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
}
bi.SQL = sql
bi.Arguments = arguments
bi.query = sql
bi.arguments = arguments
}
// TODO: changing mode per batch? Update Batch.Queue function comment when implemented
mode := c.config.DefaultQueryExecMode
if mode == QueryExecModeSimpleProtocol {
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
}
// All other modes use extended protocol and thus can use prepared statements.
for _, bi := range b.QueuedQueries {
if sd, ok := c.preparedStatements[bi.SQL]; ok {
for _, bi := range b.queuedQueries {
if sd, ok := c.preparedStatements[bi.query]; ok {
bi.sd = sd
}
}
@ -984,11 +860,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
var sb strings.Builder
for i, bi := range b.QueuedQueries {
for i, bi := range b.queuedQueries {
if i > 0 {
sb.WriteByte(';')
}
sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
@ -1007,21 +883,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
batch := &pgconn.Batch{}
for _, bi := range b.QueuedQueries {
for _, bi := range b.queuedQueries {
sd := bi.sd
if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
@ -1040,24 +916,24 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.statementCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true}
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
}
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.QueuedQueries {
for _, bi := range b.queuedQueries {
if bi.sd == nil {
sd := c.statementCache.Get(bi.SQL)
sd := c.statementCache.Get(bi.query)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
Name: stmtcache.StatementName(bi.SQL),
SQL: bi.SQL,
Name: stmtcache.NextStatementName(),
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
@ -1072,23 +948,23 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.descriptionCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true}
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
}
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.QueuedQueries {
for _, bi := range b.queuedQueries {
if bi.sd == nil {
sd := c.descriptionCache.Get(bi.SQL)
sd := c.descriptionCache.Get(bi.query)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
SQL: bi.SQL,
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
@ -1105,13 +981,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.QueuedQueries {
for _, bi := range b.queuedQueries {
if bi.sd == nil {
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd := &pgconn.StatementDescription{
SQL: bi.SQL,
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
@ -1124,82 +1000,63 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
}
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
pipeline := c.pgConn.StartPipeline(ctx)
pipeline := c.pgConn.StartPipeline(context.Background())
defer func() {
if pbr != nil && pbr.err != nil {
if pbr.err != nil {
pipeline.Close()
}
}()
// Prepare any needed queries
if len(distinctNewQueries) > 0 {
err := func() (err error) {
for _, sd := range distinctNewQueries {
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
}
for _, sd := range distinctNewQueries {
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
}
// Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will
// clean them up later.
if sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Put(sd)
}
}
// If something goes wrong preparing the statements, we need to invalidate the cache entries we just added.
defer func() {
if err != nil && sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Invalidate(sd.SQL)
}
}
}()
err = pipeline.Sync()
if err != nil {
return err
}
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return err
}
resultSD, ok := results.(*pgconn.StatementDescription)
if !ok {
return fmt.Errorf("expected statement description, got %T", results)
}
// Fill in the previously empty / pending statement descriptions.
sd.ParamOIDs = resultSD.ParamOIDs
sd.Fields = resultSD.Fields
}
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return err
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
_, ok := results.(*pgconn.PipelineSync)
resultSD, ok := results.(*pgconn.StatementDescription)
if !ok {
return fmt.Errorf("expected sync, got %T", results)
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
}
return nil
}()
// Fill in the previously empty / pending statement descriptions.
sd.ParamOIDs = resultSD.ParamOIDs
sd.Fields = resultSD.Fields
}
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
_, ok := results.(*pgconn.PipelineSync)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
}
}
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
if sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Put(sd)
}
}
// Queue the queries.
for _, bi := range b.QueuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
for _, bi := range b.queuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
if err != nil {
// we wrap the error so we the user can understand which query failed inside the batch
err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
if bi.sd.Name == "" {
@ -1211,7 +1068,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
return &pipelineBatchResults{
@ -1243,15 +1100,7 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
return sanitize.SanitizeSQL(sql, valueArgs...)
}
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be
// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular,
// typeName must be one of the following:
// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered.
// - A composite type name where all field types are already registered.
// - A domain type name where the base type is already registered.
// - An enum type name.
// - A range type name where the element type is already registered.
// - A multirange type name where the element type is already registered.
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration.
func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
var oid uint32
@ -1261,9 +1110,8 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
}
var typtype string
var typbasetype uint32
err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype)
err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
if err != nil {
return nil, err
}
@ -1288,39 +1136,8 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
case "d": // domain
dt, ok := c.TypeMap().TypeForOID(typbasetype)
if !ok {
return nil, errors.New("domain base type OID not registered")
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
case "e": // enum
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
case "r": // range
elementOID, err := c.getRangeElementOID(ctx, oid)
if err != nil {
return nil, err
}
dt, ok := c.TypeMap().TypeForOID(elementOID)
if !ok {
return nil, errors.New("range element OID not registered")
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil
case "m": // multirange
elementOID, err := c.getMultiRangeElementOID(ctx, oid)
if err != nil {
return nil, err
}
dt, ok := c.TypeMap().TypeForOID(elementOID)
if !ok {
return nil, errors.New("multirange element OID not registered")
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil
default:
return &pgtype.Type{}, errors.New("unknown typtype")
}
@ -1337,28 +1154,6 @@ func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, erro
return typelem, nil
}
func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
var typelem uint32
err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}
return typelem, nil
}
func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
var typelem uint32
err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}
return typelem, nil
}
func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) {
var typrelid uint32
@ -1373,8 +1168,6 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com
rows, _ := c.Query(ctx, `select attname, atttypid
from pg_attribute
where attrelid=$1
and not attisdropped
and attnum > 0
order by attnum`,
typrelid,
)
@ -1394,17 +1187,17 @@ order by attnum`,
}
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
if c.pgConn.TxStatus() != 'I' {
return nil
}
if c.descriptionCache != nil {
c.descriptionCache.RemoveInvalidated()
c.descriptionCache.HandleInvalidated()
}
var invalidatedStatements []*pgconn.StatementDescription
if c.statementCache != nil {
invalidatedStatements = c.statementCache.GetInvalidated()
invalidatedStatements = c.statementCache.HandleInvalidated()
}
if len(invalidatedStatements) == 0 {
@ -1428,10 +1221,5 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}
c.statementCache.RemoveInvalidated()
for _, sd := range invalidatedStatements {
delete(c.preparedStatements, sd.Name)
}
return nil
}

View File

@ -1,55 +0,0 @@
package pgx
import (
"context"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func mustParseConfig(t testing.TB, connString string) *ConnConfig {
config, err := ParseConfig(connString)
require.Nil(t, err)
return config
}
func mustConnect(t testing.TB, config *ConnConfig) *Conn {
conn, err := ConnectConfig(context.Background(), config)
if err != nil {
t.Fatalf("Unable to establish connection: %v", err)
}
return conn
}
// Ensures the connection limits the size of its cached objects.
// This test examines the internals of *Conn so must be in the same package.
func TestStmtCacheSizeLimit(t *testing.T) {
const cacheLimit = 16
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
connConfig.StatementCacheCapacity = cacheLimit
conn := mustConnect(t, connConfig)
defer func() {
err := conn.Close(context.Background())
if err != nil {
t.Fatal(err)
}
}()
// run a set of unique queries that should overflow the cache
ctx := context.Background()
for i := 0; i < cacheLimit*2; i++ {
uniqueString := fmt.Sprintf("unique %d", i)
uniqueSQL := fmt.Sprintf("select '%s'", uniqueString)
var output string
err := conn.QueryRow(ctx, uniqueSQL).Scan(&output)
require.NoError(t, err)
require.Equal(t, uniqueString, output)
}
// preparedStatements contains cacheLimit+1 because deallocation happens before the query
assert.Len(t, conn.preparedStatements, cacheLimit+1)
assert.Equal(t, cacheLimit, conn.statementCache.Len())
}

View File

@ -3,7 +3,6 @@ package pgx_test
import (
"bytes"
"context"
"database/sql"
"os"
"strings"
"sync"
@ -198,28 +197,10 @@ func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) {
}
}
func TestParseConfigErrors(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
connString string
expectedErrSubstring string
}{
{"default_query_exec_mode=does_not_exist", "does_not_exist"},
} {
config, err := pgx.ParseConfig(tt.connString)
require.Nil(t, config)
require.ErrorContains(t, err, tt.expectedErrSubstring)
}
}
func TestExec(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec")
}
@ -255,17 +236,14 @@ type testQueryRewriter struct {
args []any
}
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return qr.sql, qr.args, nil
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) {
return qr.sql, qr.args
}
func TestExecWithQueryRewriter(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
_, err := conn.Exec(ctx, "should be replaced", &qr)
require.NoError(t, err)
@ -275,10 +253,7 @@ func TestExecWithQueryRewriter(t *testing.T) {
func TestExecFailure(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
if _, err := conn.Exec(context.Background(), "selct;"); err == nil {
t.Fatal("Expected SQL syntax error")
}
@ -294,10 +269,7 @@ func TestExecFailure(t *testing.T) {
func TestExecFailureWithArguments(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(context.Background(), "selct $1;", 1)
if err == nil {
t.Fatal("Expected SQL syntax error")
@ -312,11 +284,8 @@ func TestExecFailureWithArguments(t *testing.T) {
func TestExecContextWithoutCancelation(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
ctx, cancelFunc := context.WithCancel(ctx)
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);")
@ -333,11 +302,8 @@ func TestExecContextWithoutCancelation(t *testing.T) {
func TestExecContextFailureWithoutCancelation(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
ctx, cancelFunc := context.WithCancel(ctx)
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
_, err := conn.Exec(ctx, "selct;")
@ -358,11 +324,8 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
ctx, cancelFunc := context.WithCancel(ctx)
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
_, err := conn.Exec(ctx, "selct $1;", 1)
@ -461,7 +424,7 @@ func TestPrepare(t *testing.T) {
t.Errorf("Prepared statement did not return expected value: %v", s)
}
err = conn.DeallocateAll(context.Background())
err = conn.Deallocate(context.Background(), "test")
if err != nil {
t.Errorf("conn.Deallocate failed: %v", err)
}
@ -483,42 +446,37 @@ func TestPrepareBadSQLFailure(t *testing.T) {
func TestPrepareIdempotency(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
for i := 0; i < 2; i++ {
_, err := conn.Prepare(context.Background(), "test", "select 42::integer")
if err != nil {
t.Fatalf("%d. Unable to prepare statement: %v", i, err)
}
var n int32
err = conn.QueryRow(context.Background(), "test").Scan(&n)
if err != nil {
t.Errorf("%d. Executing prepared statement failed: %v", i, err)
}
if n != int32(42) {
t.Errorf("%d. Prepared statement did not return expected value: %v", i, n)
}
for i := 0; i < 2; i++ {
_, err := conn.Prepare(context.Background(), "test", "select 42::integer")
if err != nil {
t.Fatalf("%d. Unable to prepare statement: %v", i, err)
}
_, err := conn.Prepare(context.Background(), "test", "select 'fail'::varchar")
if err == nil {
t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't")
return
var n int32
err = conn.QueryRow(context.Background(), "test").Scan(&n)
if err != nil {
t.Errorf("%d. Executing prepared statement failed: %v", i, err)
}
})
if n != int32(42) {
t.Errorf("%d. Prepared statement did not return expected value: %v", i, n)
}
}
_, err := conn.Prepare(context.Background(), "test", "select 'fail'::varchar")
if err == nil {
t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't")
return
}
}
func TestPrepareStatementCacheModes(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Prepare(context.Background(), "test", "select $1::text")
require.NoError(t, err)
@ -529,91 +487,6 @@ func TestPrepareStatementCacheModes(t *testing.T) {
})
}
func TestPrepareWithDigestedName(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := "select $1::text"
sd, err := conn.Prepare(ctx, sql, sql)
require.NoError(t, err)
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
var s string
err = conn.QueryRow(ctx, sql, "hello").Scan(&s)
require.NoError(t, err)
require.Equal(t, "hello", s)
err = conn.Deallocate(ctx, sql)
require.NoError(t, err)
})
}
// https://github.com/jackc/pgx/pull/1795
func TestDeallocateInAbortedTransaction(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
tx, err := conn.Begin(ctx)
require.NoError(t, err)
sql := "select $1::text"
sd, err := tx.Prepare(ctx, sql, sql)
require.NoError(t, err)
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
var s string
err = tx.QueryRow(ctx, sql, "hello").Scan(&s)
require.NoError(t, err)
require.Equal(t, "hello", s)
_, err = tx.Exec(ctx, "select 1/0") // abort transaction with divide by zero error
require.Error(t, err)
err = conn.Deallocate(ctx, sql)
require.NoError(t, err)
err = tx.Rollback(ctx)
require.NoError(t, err)
sd, err = conn.Prepare(ctx, sql, sql)
require.NoError(t, err)
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
})
}
func TestDeallocateMissingPreparedStatementStillClearsFromPreparedStatementMap(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Prepare(ctx, "ps", "select $1::text")
require.NoError(t, err)
_, err = conn.Exec(ctx, "deallocate ps")
require.NoError(t, err)
err = conn.Deallocate(ctx, "ps")
require.NoError(t, err)
_, err = conn.Prepare(ctx, "ps", "select $1::text, $2::text")
require.NoError(t, err)
var s1, s2 string
err = conn.QueryRow(ctx, "ps", "hello", "world").Scan(&s1, &s2)
require.NoError(t, err)
require.Equal(t, "hello", s1)
require.Equal(t, "world", s2)
})
}
func TestListenNotify(t *testing.T) {
t.Parallel()
@ -653,7 +526,6 @@ func TestListenNotify(t *testing.T) {
defer cancel()
notification, err = listener.WaitForNotification(ctx)
assert.True(t, pgconn.Timeout(err))
assert.Nil(t, notification)
// listener can listen again after a timeout
mustExec(t, notifier, "notify chat")
@ -673,7 +545,6 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
listenerDone := make(chan bool)
notifierDone := make(chan bool)
listening := make(chan bool)
go func() {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
@ -682,7 +553,6 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
}()
mustExec(t, conn, "listen busysafe")
listening <- true
for i := 0; i < 5000; i++ {
var sum int32
@ -705,7 +575,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
}
if rows.Err() != nil {
t.Errorf("conn.Query failed: %v", rows.Err())
t.Errorf("conn.Query failed: %v", err)
return
}
@ -718,6 +588,8 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
t.Errorf("Wrong number of rows: %v", rowCount)
return
}
time.Sleep(1 * time.Microsecond)
}
}()
@ -728,10 +600,9 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
notifierDone <- true
}()
<-listening
for i := 0; i < 100000; i++ {
mustExec(t, conn, "notify busysafe, 'hello'")
time.Sleep(1 * time.Microsecond)
}
}()
@ -844,10 +715,7 @@ func TestFatalTxError(t *testing.T) {
func TestInsertBoolArray(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec")
}
@ -862,10 +730,7 @@ func TestInsertBoolArray(t *testing.T) {
func TestInsertTimestampArray(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec")
}
@ -947,10 +812,7 @@ func TestConnInitTypeMap(t *testing.T) {
}
func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
var n uint64
@ -966,10 +828,7 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
}
func TestDomainType(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
// Domain type uint64 is a PostgreSQL domain of underlying type numeric.
@ -978,21 +837,24 @@ func TestDomainType(t *testing.T) {
// uint64 but a result OID of the underlying numeric.
var s string
err := conn.QueryRow(ctx, "select $1::uint64", "24").Scan(&s)
err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s)
require.NoError(t, err)
require.Equal(t, "24", s)
// Register type
uint64Type, err := conn.LoadType(ctx, "uint64")
require.NoError(t, err)
conn.TypeMap().RegisterType(uint64Type)
var uint64OID uint32
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID)
if err != nil {
t.Fatalf("did not find uint64 OID, %v", err)
}
conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}})
var n uint64
err = conn.QueryRow(ctx, "select $1::uint64", uint64(24)).Scan(&n)
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
require.NoError(t, err)
// String is still an acceptable argument after registration
err = conn.QueryRow(ctx, "select $1::uint64", "7").Scan(&n)
err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n)
if err != nil {
t.Fatal(err)
}
@ -1003,10 +865,7 @@ func TestDomainType(t *testing.T) {
}
func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)")
tx, err := conn.Begin(ctx)
@ -1047,111 +906,6 @@ create type pgx_b.point as (c text);
})
}
func TestLoadCompositeType(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)")
tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, "create type compositetype as (attr1 int, attr2 int)")
require.NoError(t, err)
_, err = tx.Exec(ctx, "alter type compositetype drop attribute attr1")
require.NoError(t, err)
_, err = conn.LoadType(ctx, "compositetype")
require.NoError(t, err)
})
}
func TestLoadRangeType(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support range types")
tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi)")
require.NoError(t, err)
// Register types
newRangeType, err := conn.LoadType(ctx, "examplefloatrange")
require.NoError(t, err)
conn.TypeMap().RegisterType(newRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")
var inputRangeType = pgtype.Range[float64]{
Lower: 1.0,
Upper: 2.0,
LowerType: pgtype.Inclusive,
UpperType: pgtype.Inclusive,
Valid: true,
}
var outputRangeType pgtype.Range[float64]
err = tx.QueryRow(ctx, "SELECT $1::examplefloatrange", inputRangeType).Scan(&outputRangeType)
require.NoError(t, err)
require.Equal(t, inputRangeType, outputRangeType)
})
}
func TestLoadMultiRangeType(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support range types")
pgxtest.SkipPostgreSQLVersionLessThan(t, conn, 14) // multirange data type was added in 14 postgresql
tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi, multirange_type_name=examplefloatmultirange)")
require.NoError(t, err)
// Register types
newRangeType, err := conn.LoadType(ctx, "examplefloatrange")
require.NoError(t, err)
conn.TypeMap().RegisterType(newRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")
newMultiRangeType, err := conn.LoadType(ctx, "examplefloatmultirange")
require.NoError(t, err)
conn.TypeMap().RegisterType(newMultiRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange")
var inputMultiRangeType = pgtype.Multirange[pgtype.Range[float64]]{
{
Lower: 1.0,
Upper: 2.0,
LowerType: pgtype.Inclusive,
UpperType: pgtype.Inclusive,
Valid: true,
},
{
Lower: 3.0,
Upper: 4.0,
LowerType: pgtype.Exclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
}
var outputMultiRangeType pgtype.Multirange[pgtype.Range[float64]]
err = tx.QueryRow(ctx, "SELECT $1::examplefloatmultirange", inputMultiRangeType).Scan(&outputMultiRangeType)
require.NoError(t, err)
require.Equal(t, inputMultiRangeType, outputMultiRangeType)
})
}
func TestStmtCacheInvalidationConn(t *testing.T) {
ctx := context.Background()
@ -1294,10 +1048,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
}
func TestInsertDurationInterval(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)")
require.NoError(t, err)
@ -1331,7 +1082,7 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
rows.Close()
require.NoError(t, rows.Err())
if !bytes.Equal(original, buf) {
if bytes.Compare(original, buf) != 0 {
return
}
}
@ -1339,82 +1090,3 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
})
}
// https://github.com/jackc/pgx/issues/1847
func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "CockroachDB returns decimal instead of integer for integer division")
var n int32
err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
// Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was
// encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn
// we could call conn.statementCache.InvalidateAll() instead.
err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n)
require.Error(t, err)
ctx2, cancel2 := context.WithCancel(ctx)
cancel2()
err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
})
}
// https://github.com/jackc/pgx/issues/1847
func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *testing.T) {
t.Parallel()
ctx := context.Background()
connString := os.Getenv("PGX_TEST_DATABASE")
config := mustParseConfig(t, connString)
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
config.StatementCacheCapacity = 2
conn, err := pgx.ConnectConfig(ctx, config)
require.NoError(t, err)
tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, "select $1::int + 1", 1)
require.NoError(t, err)
_, err = tx.Exec(ctx, "select $1::int + 2", 1)
require.NoError(t, err)
// This should invalidate the first cached statement.
_, err = tx.Exec(ctx, "select $1::int + 3", 1)
require.NoError(t, err)
batch := &pgx.Batch{}
batch.Queue("select $1::int + 1", 1)
err = tx.SendBatch(ctx, batch).Close()
require.NoError(t, err)
err = tx.Rollback(ctx)
require.NoError(t, err)
ensureConnValid(t, conn)
}
func TestErrNoRows(t *testing.T) {
t.Parallel()
// ensure we preserve old error message
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
}

View File

@ -64,33 +64,6 @@ func (cts *copyFromSlice) Err() error {
return cts.err
}
// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
// or it returns an error. If nxtf returns an error, the copy is aborted.
func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
return &copyFromFunc{next: nxtf}
}
type copyFromFunc struct {
next func() ([]any, error)
valueRow []any
err error
}
func (g *copyFromFunc) Next() bool {
g.valueRow, g.err = g.next()
// only return true if valueRow exists and no error
return g.valueRow != nil && g.err == nil
}
func (g *copyFromFunc) Values() ([]any, error) {
return g.valueRow, g.err
}
func (g *copyFromFunc) Err() error {
return g.err
}
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
type CopyFromSource interface {
// Next returns true if there is another row and makes the next row data
@ -112,7 +85,6 @@ type copyFrom struct {
columnNames []string
rowSrc CopyFromSource
readerErrChan chan error
mode QueryExecMode
}
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
@ -133,29 +105,9 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
}
quotedColumnNames := cbuf.String()
var sd *pgconn.StatementDescription
switch ct.mode {
case QueryExecModeExec, QueryExecModeSimpleProtocol:
// These modes don't support the binary format. Before the inclusion of the
// QueryExecModes, Conn.Prepare was called on every COPY operation to get
// the OIDs. These prepared statements were not cached.
//
// Since that's the same behavior provided by QueryExecModeDescribeExec,
// we'll default to that mode.
ct.mode = QueryExecModeDescribeExec
fallthrough
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
var err error
sd, err = ct.conn.getStatementDescription(
ctx,
ct.mode,
fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
)
if err != nil {
return 0, fmt.Errorf("statement description failed: %w", err)
}
default:
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
if err != nil {
return 0, err
}
r, w := io.Pipe()
@ -215,13 +167,8 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
}
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
const sendBufSize = 65536 - 5 // The packet has a 5-byte header
lastBufLen := 0
largestRowLen := 0
for ct.rowSrc.Next() {
lastBufLen = len(buf)
values, err := ct.rowSrc.Values()
if err != nil {
return false, nil, err
@ -238,15 +185,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
}
}
rowLen := len(buf) - lastBufLen
if rowLen > largestRowLen {
largestRowLen = rowLen
}
// Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of
// io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531
// 13, 65531, 13, 65531, 13.
if len(buf) > sendBufSize-largestRowLen {
if len(buf) > 65536 {
return true, buf, nil
}
}
@ -254,14 +193,12 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
return false, buf, nil
}
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and
// an error.
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
// It returns the number of rows copied and an error.
//
// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
// for the type of each column. Almost all types implemented by pgx support the binary format.
//
// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
// Conn.LoadType and pgtype.Map.RegisterType.
// CopyFrom requires all values use the binary format. Almost all types
// implemented by pgx use the binary format by default. Types implementing
// Encoder can only be used if they encode to the binary format.
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
ct := &copyFrom{
conn: c,
@ -269,7 +206,6 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames [
columnNames: columnNames,
rowSrc: rowSrc,
readerErrChan: make(chan error),
mode: c.config.DefaultQueryExecMode,
}
return ct.run(ctx)

View File

@ -14,137 +14,9 @@ import (
"github.com/stretchr/testify/require"
)
func TestConnCopyWithAllQueryExecModes(t *testing.T) {
for _, mode := range pgxtest.AllQueryExecModes {
t.Run(mode.String(), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
cfg.DefaultQueryExecMode = mode
conn := mustConnect(t, cfg)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d text,
e timestamptz
)`)
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
inputRows := [][]any{
{int16(0), int32(1), int64(2), "abc", tzedTime},
{nil, nil, nil, nil, nil},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
if int(copyCount) != len(inputRows) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]any
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
ensureConnValid(t, conn)
})
}
}
func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) {
for _, mode := range pgxtest.KnownOIDQueryExecModes {
t.Run(mode.String(), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
cfg.DefaultQueryExecMode = mode
conn := mustConnect(t, cfg)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g timestamptz
)`)
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
inputRows := [][]any{
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
{nil, nil, nil, nil, nil, nil, nil},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
if int(copyCount) != len(inputRows) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]any
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
ensureConnValid(t, conn)
})
}
}
func TestConnCopyFromSmall(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
@ -165,7 +37,7 @@ func TestConnCopyFromSmall(t *testing.T) {
{nil, nil, nil, nil, nil, nil, nil},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
@ -173,7 +45,7 @@ func TestConnCopyFromSmall(t *testing.T) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
@ -201,9 +73,6 @@ func TestConnCopyFromSmall(t *testing.T) {
func TestConnCopyFromSliceSmall(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
@ -224,7 +93,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
{nil, nil, nil, nil, nil, nil, nil},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) {
return inputRows[i], nil
}))
@ -235,7 +104,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
@ -263,12 +132,11 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
func TestConnCopyFromLarge(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
pgxtest.SkipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)")
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
@ -288,7 +156,7 @@ func TestConnCopyFromLarge(t *testing.T) {
inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
@ -296,7 +164,7 @@ func TestConnCopyFromLarge(t *testing.T) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
@ -324,12 +192,10 @@ func TestConnCopyFromLarge(t *testing.T) {
func TestConnCopyFromEnum(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
ctx := context.Background()
tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)
@ -354,7 +220,7 @@ func TestConnCopyFromEnum(t *testing.T) {
conn.TypeMap().RegisterType(typ)
}
_, err = tx.Exec(ctx, `create temporary table foo(
_, err = tx.Exec(ctx, `create table foo(
a text,
b color,
c fruit,
@ -369,11 +235,11 @@ func TestConnCopyFromEnum(t *testing.T) {
{nil, nil, nil, nil, nil, nil},
}
copyCount, err := tx.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows))
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows))
require.NoError(t, err)
require.EqualValues(t, len(inputRows), copyCount)
rows, err := tx.Query(ctx, "select * from foo")
rows, err := conn.Query(ctx, "select * from foo")
require.NoError(t, err)
var outputRows [][]any
@ -389,18 +255,12 @@ func TestConnCopyFromEnum(t *testing.T) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
err = tx.Rollback(ctx)
require.NoError(t, err)
ensureConnValid(t, conn)
}
func TestConnCopyFromJSON(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
@ -420,7 +280,7 @@ func TestConnCopyFromJSON(t *testing.T) {
{nil, nil},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
@ -428,7 +288,7 @@ func TestConnCopyFromJSON(t *testing.T) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
@ -478,9 +338,6 @@ func (cfs *clientFailSource) Err() error {
func TestConnCopyFromFailServerSideMidway(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
@ -495,7 +352,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
{int32(3), "def"},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}
@ -506,7 +363,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
@ -557,9 +414,6 @@ func (fs *failSource) Err() error {
func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
@ -571,7 +425,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
startTime := time.Now()
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}
@ -588,7 +442,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime)
}
rows, err := conn.Query(ctx, "select * from foo")
rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
@ -637,9 +491,6 @@ func (fs *slowFailRaceSource) Err() error {
func TestConnCopyFromSlowFailRace(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
@ -648,7 +499,7 @@ func TestConnCopyFromSlowFailRace(t *testing.T) {
b bytea not null
)`)
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{})
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{})
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}
@ -665,9 +516,6 @@ func TestConnCopyFromSlowFailRace(t *testing.T) {
func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
@ -675,7 +523,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
a bytea not null
)`)
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}
@ -683,7 +531,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
@ -728,9 +576,6 @@ func (cfs *clientFinalErrSource) Err() error {
func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
@ -738,7 +583,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
a bytea not null
)`)
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}
@ -746,7 +591,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
rows, err := conn.Query(context.Background(), "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
@ -770,125 +615,3 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
ensureConnValid(t, conn)
}
func TestConnCopyFromAutomaticStringConversion(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int8
)`)
inputRows := [][]interface{}{
{"42"},
{"7"},
{8},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
require.NoError(t, err)
require.EqualValues(t, len(inputRows), copyCount)
rows, _ := conn.Query(ctx, "select * from foo")
nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
require.NoError(t, err)
require.Equal(t, []int64{42, 7, 8}, nums)
ensureConnValid(t, conn)
}
// https://github.com/jackc/pgx/discussions/1891
func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a numeric[]
)`)
inputRows := [][]interface{}{
{[]string{"42"}},
{[]string{"7"}},
{[]string{"8", "9"}},
{[][]string{{"10", "11"}, {"12", "13"}}},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
require.NoError(t, err)
require.EqualValues(t, len(inputRows), copyCount)
// Test reads as int64 and flattened array for simplicity.
rows, _ := conn.Query(ctx, "select * from foo")
nums, err := pgx.CollectRows(rows, pgx.RowTo[[]int64])
require.NoError(t, err)
require.Equal(t, [][]int64{{42}, {7}, {8, 9}, {10, 11, 12, 13}}, nums)
ensureConnValid(t, conn)
}
func TestCopyFromFunc(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int
)`)
dataCh := make(chan int, 1)
const channelItems = 10
go func() {
for i := 0; i < channelItems; i++ {
dataCh <- i
}
close(dataCh)
}()
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
pgx.CopyFromFunc(func() ([]any, error) {
v, ok := <-dataCh
if !ok {
return nil, nil
}
return []any{v}, nil
}))
require.ErrorIs(t, err, nil)
require.EqualValues(t, channelItems, copyCount)
rows, err := conn.Query(context.Background(), "select * from foo order by a")
require.NoError(t, err)
nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
require.NoError(t, err)
require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums)
// simulate a failure
copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
pgx.CopyFromFunc(func() func() ([]any, error) {
x := 9
return func() ([]any, error) {
x++
if x > 100 {
return nil, fmt.Errorf("simulated error")
}
return []any{x}, nil
}
}()))
require.NotErrorIs(t, err, nil)
require.EqualValues(t, 0, copyCount) // no change, due to error
ensureConnValid(t, conn)
}

View File

@ -1,256 +0,0 @@
package pgx
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/jackc/pgx/v5/pgtype"
)
/*
buildLoadDerivedTypesSQL generates the correct query for retrieving type information.
pgVersion: the major version of the PostgreSQL server
typeNames: the names of the types to load. If nil, load all types.
*/
func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string {
supportsMultirange := (pgVersion >= 14)
var typeNamesClause string
if typeNames == nil {
// This should not occur; this will not return any types
typeNamesClause = "= ''"
} else {
typeNamesClause = "= ANY($1)"
}
parts := make([]string, 0, 10)
// Each of the type names provided might be found in pg_class or pg_type.
// Additionally, it may or may not include a schema portion.
parts = append(parts, `
WITH RECURSIVE
-- find the OIDs in pg_class which match one of the provided type names
selected_classes(oid,reltype) AS (
-- this query uses the namespace search path, so will match type names without a schema prefix
SELECT pg_class.oid, pg_class.reltype
FROM pg_catalog.pg_class
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
AND relname `, typeNamesClause, `
UNION ALL
-- this query will only match type names which include the schema prefix
SELECT pg_class.oid, pg_class.reltype
FROM pg_class
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
WHERE nspname || '.' || relname `, typeNamesClause, `
),
selected_types(oid) AS (
-- collect the OIDs from pg_types which correspond to the selected classes
SELECT reltype AS oid
FROM selected_classes
UNION ALL
-- as well as any other type names which match our criteria
SELECT pg_type.oid
FROM pg_type
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
WHERE typname `, typeNamesClause, `
OR nspname || '.' || typname `, typeNamesClause, `
),
-- this builds a parent/child mapping of objects, allowing us to know
-- all the child (ie: dependent) types that a parent (type) requires
-- As can be seen, there are 3 ways this can occur (the last of which
-- is due to being a composite class, where the composite fields are children)
pc(parent, child) AS (
SELECT parent.oid, parent.typelem
FROM pg_type parent
WHERE parent.typtype = 'b' AND parent.typelem != 0
UNION ALL
SELECT parent.oid, parent.typbasetype
FROM pg_type parent
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
UNION ALL
SELECT pg_type.oid, atttypid
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
),
-- Now construct a recursive query which includes a 'depth' element.
-- This is used to ensure that the "youngest" children are registered before
-- their parents.
relationships(parent, child, depth) AS (
SELECT DISTINCT 0::OID, selected_types.oid, 0
FROM selected_types
UNION ALL
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
FROM selected_classes c
inner join pg_type ON (c.reltype = pg_type.oid)
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
UNION ALL
SELECT pc.parent, pc.child, relationships.depth + 1
FROM pc
INNER JOIN relationships ON (pc.parent = relationships.child)
),
-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration
composite AS (
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
GROUP BY pg_type.oid
)
-- Bring together this information, showing all the information which might possibly be required
-- to complete the registration, applying filters to only show the items which relate to the selected
-- types/classes.
SELECT typname,
pg_namespace.nspname,
typtype,
typbasetype,
typelem,
pg_type.oid,`)
if supportsMultirange {
parts = append(parts, `
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
} else {
parts = append(parts, `
0 AS rngtypid,`)
}
parts = append(parts, `
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
attnames, atttypids
FROM relationships
INNER JOIN pg_type ON (pg_type.oid = relationships.child)
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
if supportsMultirange {
parts = append(parts, `
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
}
parts = append(parts, `
LEFT OUTER JOIN composite USING (oid)
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
WHERE NOT (typtype = 'b' AND typelem = 0)`)
parts = append(parts, `
GROUP BY typname, pg_namespace.nspname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
if supportsMultirange {
parts = append(parts, `
multirange.rngtypid,`)
}
parts = append(parts, `
attnames, atttypids
ORDER BY MAX(depth) desc, typname;`)
return strings.Join(parts, "")
}
type derivedTypeInfo struct {
Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32
TypeName, Typtype, NspName string
Attnames []string
Atttypids []uint32
}
// LoadTypes performs a single (complex) query, returning all the required
// information to register the named types, as well as any other types directly
// or indirectly required to complete the registration.
// The result of this call can be passed into RegisterTypes to complete the process.
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
m := c.TypeMap()
if len(typeNames) == 0 {
return nil, fmt.Errorf("No type names were supplied.")
}
// Disregard server version errors. This will result in
// the SQL not support recent structures such as multirange
serverVersion, _ := serverVersion(c)
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
if err != nil {
return nil, fmt.Errorf("While generating load types query: %w", err)
}
defer rows.Close()
result := make([]*pgtype.Type, 0, 100)
for rows.Next() {
ti := derivedTypeInfo{}
err = rows.Scan(&ti.TypeName, &ti.NspName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids)
if err != nil {
return nil, fmt.Errorf("While scanning type information: %w", err)
}
var type_ *pgtype.Type
switch ti.Typtype {
case "b": // array
dt, ok := m.TypeForOID(ti.Typelem)
if !ok {
return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}
case "c": // composite
var fields []pgtype.CompositeCodecField
for i, fieldName := range ti.Attnames {
dt, ok := m.TypeForOID(ti.Atttypids[i])
if !ok {
return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i])
}
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}}
case "d": // domain
dt, ok := m.TypeForOID(ti.Typbasetype)
if !ok {
return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec}
case "e": // enum
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}}
case "r": // range
dt, ok := m.TypeForOID(ti.Rngsubtype)
if !ok {
return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}}
case "m": // multirange
dt, ok := m.TypeForOID(ti.Rngtypid)
if !ok {
return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}
default:
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
}
// the type_ is imposible to be null
m.RegisterType(type_)
if ti.NspName != "" {
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
m.RegisterType(nspType)
result = append(result, nspType)
}
result = append(result, type_)
}
return result, nil
}
// serverVersion returns the postgresql server version.
func serverVersion(c *Conn) (int64, error) {
serverVersionStr := c.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
}
version, err := strconv.ParseInt(serverVersionStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
}
return version, nil
}

View File

@ -1,40 +0,0 @@
package pgx_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/require"
)
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `
drop type if exists dtype_test;
drop domain if exists anotheruint64;
create domain anotheruint64 as numeric(20,0);
create type dtype_test as (
a text,
b int4,
c anotheruint64,
d anotheruint64[]
);`)
require.NoError(t, err)
defer conn.Exec(ctx, "drop type dtype_test")
defer conn.Exec(ctx, "drop domain anotheruint64")
types, err := conn.LoadTypes(ctx, []string{"dtype_test"})
require.NoError(t, err)
require.Len(t, types, 6)
require.Equal(t, types[0].Name, "public.anotheruint64")
require.Equal(t, types[1].Name, "anotheruint64")
require.Equal(t, types[2].Name, "public._anotheruint64")
require.Equal(t, types[3].Name, "_anotheruint64")
require.Equal(t, types[4].Name, "public.dtype_test")
require.Equal(t, types[5].Name, "dtype_test")
})
}

39
doc.go
View File

@ -7,25 +7,24 @@ details.
Establishing a Connection
The primary way of establishing a connection is with [pgx.Connect]:
The primary way of establishing a connection is with `pgx.Connect`.
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be
specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the
connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection
string.
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified
here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with
`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string.
Connection Pool
[*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package
`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
Query Interface
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(),
rows.Scan, and rows.Err().
ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and
rows.Err().
CollectRows can be used collect all returned rows into a slice.
@ -41,7 +40,7 @@ directly.
var sum, n int32
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
_, err := pgx.ForEachRow(rows, []any{&n}, func() error {
_, err := pgx.ForEachRow(rows, []any{&n}, func(pgx.QueryFuncRow) error {
sum += n
return nil
})
@ -70,9 +69,8 @@ Use Exec to execute a query that does not return a result set.
PostgreSQL Data Types
pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types
directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may
require type registration. See that package's documentation for details.
The package pgtype provides extensive and customizable support for converting Go values to and from PostgreSQL values
including array and composite types. See that package's documentation for details.
Transactions
@ -99,8 +97,7 @@ Transactions are started by calling Begin.
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
These are internally implemented with savepoints.
Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of
a pseudo nested transaction.
Use BeginTx to control the transaction mode.
BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
transaction depending on the return value of the function. These can be simpler and less error prone to use.
@ -163,19 +160,17 @@ notification is received or the context is canceled.
_, err := conn.Exec(context.Background(), "listen channelname")
if err != nil {
return err
return nil
}
notification, err := conn.WaitForNotification(context.Background())
if err != nil {
return err
if notification, err := conn.WaitForNotification(context.Background()); err != nil {
// do something with notification
}
// do something with notification
Tracing and Logging
pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer.
pgx supports tracing by setting ConnConfig.Tracer.
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
@ -183,12 +178,12 @@ For debug tracing of the actual PostgreSQL wire protocol messages see github.com
Lower Level PostgreSQL Functionality
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn is
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in
implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer.
PgBouncer
By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
*/
package pgx

View File

@ -2,7 +2,7 @@ package main
import (
"context"
"io"
"io/ioutil"
"log"
"net/http"
"os"
@ -29,7 +29,7 @@ func getUrlHandler(w http.ResponseWriter, req *http.Request) {
func putUrlHandler(w http.ResponseWriter, req *http.Request) {
id := req.URL.Path
var url string
if body, err := io.ReadAll(req.Body); err == nil {
if body, err := ioutil.ReadAll(req.Body); err == nil {
url = string(body)
} else {
http.Error(w, "Internal server error", http.StatusInternalServerError)

View File

@ -3,6 +3,7 @@ package pgx
import (
"fmt"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
@ -21,15 +22,10 @@ type ExtendedQueryBuilder struct {
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
eqb.reset()
anynil.NormalizeSlice(args)
if sd == nil {
for i := range args {
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
return nil
return eqb.appendParamsForQueryExecModeExec(m, args)
}
if len(sd.ParamOIDs) != len(args) {
@ -39,7 +35,7 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri
for i := range args {
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
err = fmt.Errorf("failed to encode args[%d]: %v", i, err)
return err
}
}
@ -55,33 +51,14 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri
// must be an untyped nil.
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
if format == -1 {
preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
if preferredErr == nil {
return nil
}
var otherFormat int16
if preferredFormat == TextFormatCode {
otherFormat = BinaryFormatCode
} else {
otherFormat = TextFormatCode
}
otherErr := eqb.appendParam(m, oid, otherFormat, arg)
if otherErr == nil {
return nil
}
return preferredErr // return the error from the preferred format
format = eqb.chooseParameterFormatCode(m, oid, arg)
}
eqb.ParamFormats = append(eqb.ParamFormats, format)
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
if err != nil {
return err
}
eqb.ParamFormats = append(eqb.ParamFormats, format)
eqb.ParamValues = append(eqb.ParamValues, v)
return nil
@ -116,6 +93,10 @@ func (eqb *ExtendedQueryBuilder) reset() {
}
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
if anynil.Is(arg) {
return nil, nil
}
if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128)
}
@ -144,3 +125,61 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui
return m.FormatCodeForOID(oid)
}
// appendParamsForQueryExecModeExec appends the args to eqb.
//
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
// type conversion it takes the date directly and ignores time zone (i.e. it works).
//
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
// no way to safely use binary or to specify the parameter OIDs.
func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
for _, arg := range args {
if arg == nil {
err := eqb.appendParam(m, 0, TextFormatCode, arg)
if err != nil {
return err
}
} else {
dt, ok := m.TypeForValue(arg)
if !ok {
var tv pgtype.TextValuer
if tv, ok = arg.(pgtype.TextValuer); ok {
t, err := tv.TextValue()
if err != nil {
return err
}
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = t
}
}
}
if !ok {
var str fmt.Stringer
if str, ok = arg.(fmt.Stringer); ok {
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = str.String()
}
}
}
if !ok {
return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
}
err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
if err != nil {
return err
}
}
}
return nil
}

17
go.mod
View File

@ -1,21 +1,20 @@
module github.com/jackc/pgx/v5
go 1.23.0
go 1.18
require (
github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761
github.com/jackc/puddle/v2 v2.2.2
github.com/stretchr/testify v1.8.1
golang.org/x/crypto v0.37.0
golang.org/x/sync v0.13.0
golang.org/x/text v0.24.0
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b
github.com/jackc/puddle/v2 v2.0.0-beta.1
github.com/stretchr/testify v1.8.0
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/text v0.3.7
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

41
go.sum
View File

@ -1,45 +1,36 @@
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
github.com/jackc/puddle/v2 v2.0.0-beta.1 h1:Y4Ao+kFWANtDhWUkdw1JcbH+x84/aq6WUfhVQ1wdib8=
github.com/jackc/puddle/v2 v2.0.0-beta.1/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4=
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -79,7 +79,7 @@ func ensureConnValid(t testing.TB, conn *pgx.Conn) {
}
if rows.Err() != nil {
t.Fatalf("conn.Query failed: %v", rows.Err())
t.Fatalf("conn.Query failed: %v", err)
}
if rowCount != 10 {

36
internal/anynil/anynil.go Normal file
View File

@ -0,0 +1,36 @@
package anynil
import "reflect"
// Is returns true if value is any type of nil. e.g. nil or []byte(nil).
func Is(value any) bool {
if value == nil {
return true
}
refVal := reflect.ValueOf(value)
switch refVal.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return refVal.IsNil()
default:
return false
}
}
// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified.
func Normalize(v any) any {
if Is(v) {
return nil
}
return v
}
// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is
// mutated in place.
func NormalizeSlice(s []any) {
for i := range s {
if Is(s[i]) {
s[i] = nil
}
}
}

View File

@ -1,7 +1,4 @@
// Package iobufpool implements a global segregated-fit pool of buffers for IO.
//
// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid
// an allocation is purposely not documented. https://github.com/golang/go/issues/16323
package iobufpool
import "sync"
@ -13,27 +10,17 @@ var pools [18]*sync.Pool
func init() {
for i := range pools {
bufLen := 1 << (minPoolExpOf2 + i)
pools[i] = &sync.Pool{
New: func() any {
buf := make([]byte, bufLen)
return &buf
},
}
pools[i] = &sync.Pool{New: func() any { return make([]byte, bufLen) }}
}
}
// Get gets a []byte of len size with cap <= size*2.
func Get(size int) *[]byte {
func Get(size int) []byte {
i := getPoolIdx(size)
if i >= len(pools) {
buf := make([]byte, size)
return &buf
return make([]byte, size)
}
ptrBuf := (pools[i].Get().(*[]byte))
*ptrBuf = (*ptrBuf)[:size]
return ptrBuf
return pools[i].Get().([]byte)[:size]
}
func getPoolIdx(size int) int {
@ -49,8 +36,8 @@ func getPoolIdx(size int) int {
}
// Put returns buf to the pool.
func Put(buf *[]byte) {
i := putPoolIdx(cap(*buf))
func Put(buf []byte) {
i := putPoolIdx(cap(buf))
if i < 0 {
return
}

View File

@ -30,15 +30,15 @@ func TestGetCap(t *testing.T) {
}
for _, tt := range tests {
buf := iobufpool.Get(tt.requestedLen)
assert.Equalf(t, tt.requestedLen, len(*buf), "bad len for requestedLen: %d", len(*buf), tt.requestedLen)
assert.Equalf(t, tt.expectedCap, cap(*buf), "bad cap for requestedLen: %d", tt.requestedLen)
assert.Equalf(t, tt.requestedLen, len(buf), "bad len for requestedLen: %d", len(buf), tt.requestedLen)
assert.Equalf(t, tt.expectedCap, cap(buf), "bad cap for requestedLen: %d", tt.requestedLen)
}
}
func TestPutHandlesWrongSizedBuffers(t *testing.T) {
for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} {
putBuf := make([]byte, putBufSize)
iobufpool.Put(&putBuf)
iobufpool.Put(putBuf)
tests := []struct {
requestedLen int
@ -62,8 +62,8 @@ func TestPutHandlesWrongSizedBuffers(t *testing.T) {
}
for _, tt := range tests {
getBuf := iobufpool.Get(tt.requestedLen)
assert.Equalf(t, tt.requestedLen, len(*getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
assert.Equalf(t, tt.expectedCap, cap(*getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
assert.Equalf(t, tt.requestedLen, len(getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
assert.Equalf(t, tt.expectedCap, cap(getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
}
}
}
@ -73,10 +73,10 @@ func TestPutGetBufferReuse(t *testing.T) {
// it not to be. So try many times.
for i := 0; i < 100000; i++ {
buf := iobufpool.Get(4)
(*buf)[0] = 1
buf[0] = 1
iobufpool.Put(buf)
buf = iobufpool.Get(4)
if (*buf)[0] == 1 {
if buf[0] == 1 {
return
}
}

View File

@ -0,0 +1,70 @@
package nbconn
import (
"sync"
)
const minBufferQueueLen = 8
type bufferQueue struct {
lock sync.Mutex
queue [][]byte
r, w int
}
func (bq *bufferQueue) pushBack(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
bq.queue[bq.w] = buf
bq.w++
}
func (bq *bufferQueue) pushFront(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
bq.queue[bq.r] = buf
bq.w++
}
func (bq *bufferQueue) popFront() []byte {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.r == bq.w {
return nil
}
buf := bq.queue[bq.r]
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
bq.r++
if bq.r == bq.w {
bq.r = 0
bq.w = 0
if len(bq.queue) > minBufferQueueLen {
bq.queue = make([][]byte, minBufferQueueLen)
}
}
return buf
}
func (bq *bufferQueue) growQueue() {
desiredLen := (len(bq.queue) + 1) * 3 / 2
if desiredLen < minBufferQueueLen {
desiredLen = minBufferQueueLen
}
newQueue := make([][]byte, desiredLen)
copy(newQueue, bq.queue)
bq.queue = newQueue
}

536
internal/nbconn/nbconn.go Normal file
View File

@ -0,0 +1,536 @@
// Package nbconn implements a non-blocking net.Conn wrapper.
//
// It is designed to solve three problems.
//
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
//
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
//
// The third is to efficiently check if a connection has been closed via a non-blocking read.
package nbconn
import (
"crypto/tls"
"errors"
"io"
"net"
"os"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
var errClosed = errors.New("closed")
var ErrWouldBlock = new(wouldBlockError)
const fakeNonblockingWaitDuration = 100 * time.Millisecond
// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read
// mode.
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC)
// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to
// ignore all future calls.
var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC)
// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error.
type wouldBlockError struct{}
func (*wouldBlockError) Error() string {
return "would block"
}
func (*wouldBlockError) Timeout() bool { return true }
func (*wouldBlockError) Temporary() bool { return true }
// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to
// the underlying connection.
type Conn interface {
net.Conn
// Flush flushes any buffered writes.
Flush() error
// BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block.
BufferReadUntilBlock() error
}
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
type NetConn struct {
conn net.Conn
rawConn syscall.RawConn
readQueue bufferQueue
writeQueue bufferQueue
readFlushLock sync.Mutex
// non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the
// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations.
nonblockWriteBuf []byte
nonblockWriteErr error
nonblockWriteN int
readDeadlineLock sync.Mutex
readDeadline time.Time
readNonblocking bool
writeDeadlineLock sync.Mutex
writeDeadline time.Time
// Only access with atomics
closed int64 // 0 = not closed, 1 = closed
}
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
nc := &NetConn{
conn: conn,
}
if !fakeNonBlockingIO {
if sc, ok := conn.(syscall.Conn); ok {
if rawConn, err := sc.SyscallConn(); err == nil {
nc.rawConn = rawConn
}
}
}
return nc
}
// Read implements io.Reader.
func (c *NetConn) Read(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return 0, err
}
for n < len(b) {
buf := c.readQueue.popFront()
if buf == nil {
break
}
copiedN := copy(b[n:], buf)
if copiedN < len(buf) {
buf = buf[copiedN:]
c.readQueue.pushFront(buf)
} else {
iobufpool.Put(buf)
}
n += copiedN
}
// If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to
// Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block.
if n > 0 {
return n, nil
}
var readNonblocking bool
c.readDeadlineLock.Lock()
readNonblocking = c.readNonblocking
c.readDeadlineLock.Unlock()
var readN int
if readNonblocking {
readN, err = c.nonblockingRead(b[n:])
} else {
readN, err = c.conn.Read(b[n:])
}
n += readN
return n, err
}
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
// closed. Call Flush to actually write to the underlying connection.
func (c *NetConn) Write(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
buf := iobufpool.Get(len(b))
copy(buf, b)
c.writeQueue.pushBack(buf)
return len(b), nil
}
func (c *NetConn) Close() (err error) {
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
if !swapped {
return errClosed
}
defer func() {
closeErr := c.conn.Close()
if err == nil {
err = closeErr
}
}()
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return err
}
return nil
}
func (c *NetConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *NetConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
func (c *NetConn) SetDeadline(t time.Time) error {
err := c.SetReadDeadline(t)
if err != nil {
return err
}
return c.SetWriteDeadline(t)
}
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
func (c *NetConn) SetReadDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
if c.readDeadline == disableSetDeadlineDeadline {
return nil
}
if t == disableSetDeadlineDeadline {
c.readDeadline = t
return nil
}
if t == NonBlockingDeadline {
c.readNonblocking = true
t = time.Time{}
} else {
c.readNonblocking = false
}
c.readDeadline = t
return c.conn.SetReadDeadline(t)
}
func (c *NetConn) SetWriteDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
if c.writeDeadline == disableSetDeadlineDeadline {
return nil
}
if t == disableSetDeadlineDeadline {
c.writeDeadline = t
return nil
}
c.writeDeadline = t
return c.conn.SetWriteDeadline(t)
}
func (c *NetConn) Flush() error {
if c.isClosed() {
return errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
return c.flush()
}
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
func (c *NetConn) flush() error {
var stopChan chan struct{}
var errChan chan error
defer func() {
if stopChan != nil {
select {
case stopChan <- struct{}{}:
case <-errChan:
}
}
}()
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
remainingBuf := buf
for len(remainingBuf) > 0 {
n, err := c.nonblockingWrite(remainingBuf)
remainingBuf = remainingBuf[n:]
if err != nil {
if !errors.Is(err, ErrWouldBlock) {
buf = buf[:len(remainingBuf)]
copy(buf, remainingBuf)
c.writeQueue.pushFront(buf)
return err
}
// Writing was blocked. Reading might unblock it.
if stopChan == nil {
stopChan, errChan = c.bufferNonblockingRead()
}
select {
case err := <-errChan:
stopChan = nil
return err
default:
}
}
}
iobufpool.Put(buf)
}
return nil
}
func (c *NetConn) BufferReadUntilBlock() error {
for {
buf := iobufpool.Get(8 * 1024)
n, err := c.nonblockingRead(buf)
if n > 0 {
buf = buf[:n]
c.readQueue.pushBack(buf)
}
if err != nil {
if errors.Is(err, ErrWouldBlock) {
return nil
} else {
return err
}
}
}
}
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
stopChan = make(chan struct{})
errChan = make(chan error, 1)
go func() {
for {
err := c.BufferReadUntilBlock()
if err != nil {
errChan <- err
return
}
select {
case <-stopChan:
return
default:
}
}
}()
return stopChan, errChan
}
func (c *NetConn) isClosed() bool {
closed := atomic.LoadInt64(&c.closed)
return closed == 1
}
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
if c.rawConn == nil {
return c.fakeNonblockingWrite(b)
} else {
return c.realNonblockingWrite(b)
}
}
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
err = c.conn.SetWriteDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.conn.SetWriteDeadline(c.writeDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.conn.Write(b)
}
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
c.nonblockWriteBuf = b
c.nonblockWriteN = 0
c.nonblockWriteErr = nil
err = c.rawConn.Write(func(fd uintptr) (done bool) {
c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf)
return true
})
n = c.nonblockWriteN
if err == nil && c.nonblockWriteErr != nil {
if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = c.nonblockWriteErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
return n, nil
}
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
if c.rawConn == nil {
return c.fakeNonblockingRead(b)
} else {
return c.realNonblockingRead(b)
}
}
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
err = c.conn.SetReadDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.conn.SetReadDeadline(c.readDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.conn.Read(b)
}
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
var funcErr error
err = c.rawConn.Read(func(fd uintptr) (done bool) {
n, funcErr = syscall.Read(int(fd), b)
return true
})
if err == nil && funcErr != nil {
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = funcErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
// syscall read did not return an error and 0 bytes were read means EOF.
if n == 0 {
return 0, io.EOF
}
return n, nil
}
// syscall.Conn is interface
// TLSClient establishes a TLS connection as a client over conn using config.
//
// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby
// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the
// *TLSConn is returned.
func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) {
tc := tls.Client(conn, config)
err := tc.Handshake()
if err != nil {
return nil, err
}
// Ensure last written part of Handshake is actually sent.
err = conn.Flush()
if err != nil {
return nil, err
}
return &TLSConn{
tlsConn: tc,
nbConn: conn,
}, nil
}
// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a
// tls.Conn.
type TLSConn struct {
tlsConn *tls.Conn
nbConn *NetConn
}
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() }
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
func (tc *TLSConn) Close() error {
// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then
// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our
// own 5 second deadline then make all set deadlines no-op.
tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5))
tc.tlsConn.SetDeadline(disableSetDeadlineDeadline)
return tc.tlsConn.Close()
}
func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) }
func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) }
func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) }

View File

@ -0,0 +1,584 @@
package nbconn_test
import (
"crypto/tls"
"errors"
"io"
"net"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5/internal/nbconn"
"github.com/stretchr/testify/require"
)
// Test keys generated with:
//
// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost'
var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE-----
MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls
b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ
BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5
yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT
caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT
0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW
c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v
7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg
Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw
HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g
TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk
D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB
hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y
E7ZYmaKTMOhvkg==
-----END CERTIFICATE-----`)
// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in
// source code.
var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny
k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+
fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px
N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav
IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM
4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX
IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8
TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL
CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ
/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn
lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I
Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9
YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp
RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq
MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd
3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE
Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0
TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA
riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr
IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu
nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk
WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc
Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77
DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD
pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
2qWm8jTPeDC3sq+67s2oojHf+Q==
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
for _, tt := range []struct {
name string
makeConns func(t *testing.T) (local, remote net.Conn)
useTLS bool
fakeNonBlockingIO bool
}{
{
name: "Pipe",
makeConns: makePipeConns,
useTLS: false,
fakeNonBlockingIO: true,
},
{
name: "TCP with Fake Non-blocking IO",
makeConns: makeTCPConns,
useTLS: false,
fakeNonBlockingIO: true,
},
{
name: "TLS over TCP with Fake Non-blocking IO",
makeConns: makeTCPConns,
useTLS: true,
fakeNonBlockingIO: true,
},
{
name: "TCP with Real Non-blocking IO",
makeConns: makeTCPConns,
useTLS: false,
fakeNonBlockingIO: false,
},
{
name: "TLS over TCP with Real Non-blocking IO",
makeConns: makeTCPConns,
useTLS: true,
fakeNonBlockingIO: false,
},
} {
t.Run(tt.name, func(t *testing.T) {
local, remote := tt.makeConns(t)
// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
// uses remote it may be garbage collected leading to the connection being closed.
defer local.Close()
defer remote.Close()
var conn nbconn.Conn
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)
if tt.useTLS {
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
require.NoError(t, err)
tlsServer := tls.Server(remote, &tls.Config{
Certificates: []tls.Certificate{cert},
})
serverTLSHandshakeChan := make(chan error)
go func() {
err := tlsServer.Handshake()
serverTLSHandshakeChan <- err
}()
tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true})
require.NoError(t, err)
conn = tlsConn
err = <-serverTLSHandshakeChan
require.NoError(t, err)
remote = tlsServer
} else {
conn = netConn
}
f(t, conn, remote)
})
}
}
// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is
// useful for testing an exact sequence of reads and writes with the underlying connection blocking.
func makePipeConns(t *testing.T) (local, remote net.Conn) {
local, remote = net.Pipe()
t.Cleanup(func() {
local.Close()
remote.Close()
})
return local, remote
}
// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost.
func makeTCPConns(t *testing.T) (local, remote net.Conn) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
type acceptResultT struct {
conn net.Conn
err error
}
acceptChan := make(chan acceptResultT)
go func() {
conn, err := ln.Accept()
acceptChan <- acceptResultT{conn: conn, err: err}
}()
local, err = net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
acceptResult := <-acceptChan
require.NoError(t, acceptResult.err)
remote = acceptResult.conn
return local, remote
}
func TestWriteIsBuffered(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
// net.Pipe is synchronous so the Write would block if not buffered.
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 1)
go func() {
err := conn.Flush()
errChan <- err
}()
readBuf := make([]byte, len(writeBuf))
_, err = remote.Read(readBuf)
require.NoError(t, err)
require.NoError(t, <-errChan)
})
}
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.SetWriteDeadline(time.Now())
require.NoError(t, err)
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
})
}
func TestReadFlushesWriteBuffer(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 2)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err
_, err = remote.Write([]byte("okay"))
errChan <- err
}()
readBuf := make([]byte, 4)
_, err = conn.Read(readBuf)
require.NoError(t, err)
require.Equal(t, []byte("okay"), readBuf)
require.NoError(t, <-errChan)
require.NoError(t, <-errChan)
})
}
func TestCloseFlushesWriteBuffer(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 1)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err
}()
err = conn.Close()
require.NoError(t, err)
require.NoError(t, <-errChan)
})
}
// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
// large values.
func TestInternalNonBlockingWrite(t *testing.T) {
const deadlockSize = 4 * 1024 * 1024
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := make([]byte, deadlockSize)
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, deadlockSize, n)
errChan := make(chan error, 1)
go func() {
remoteWriteBuf := make([]byte, deadlockSize)
_, err := remote.Write(remoteWriteBuf)
if err != nil {
errChan <- err
return
}
readBuf := make([]byte, deadlockSize)
_, err = io.ReadFull(remote, readBuf)
errChan <- err
}()
readBuf := make([]byte, deadlockSize)
_, err = io.ReadFull(conn, readBuf)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
require.NoError(t, <-errChan)
})
}
func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
const deadlockSize = 4 * 1024 * 1024
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := make([]byte, deadlockSize)
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, deadlockSize, n)
err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
require.NoError(t, err)
err = conn.Flush()
require.Error(t, err)
require.Contains(t, err.Error(), "i/o timeout")
})
}
func TestNonBlockingRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
require.NoError(t, err)
buf := make([]byte, 4)
n, err := conn.Read(buf)
require.ErrorIs(t, err, nbconn.ErrWouldBlock)
require.EqualValues(t, 0, n)
errChan := make(chan error, 1)
go func() {
_, err := remote.Write([]byte("okay"))
errChan <- err
}()
err = conn.SetReadDeadline(time.Time{})
require.NoError(t, err)
n, err = conn.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
})
}
func TestBufferNonBlockingRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.BufferReadUntilBlock()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
_, err := remote.Write([]byte("okay"))
errChan <- err
}()
for i := 0; i < 1000; i++ {
err = conn.BufferReadUntilBlock()
if !errors.Is(err, nbconn.ErrWouldBlock) {
break
}
time.Sleep(time.Millisecond)
}
require.NoError(t, err)
buf := make([]byte, 4)
n, err := conn.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
require.Equal(t, []byte("okay"), buf)
})
}
func TestReadPreviouslyBuffered(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 5)
n, err := conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 5, n)
require.Equal(t, []byte("alpha"), readBuf)
})
}
func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 10)
n, err := conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 5, n)
require.Equal(t, []byte("alpha"), readBuf[:n])
})
}
func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 2)
n, err := conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 2, n)
require.Equal(t, []byte("al"), readBuf)
readBuf = make([]byte, 3)
n, err = conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 3, n)
require.Equal(t, []byte("pha"), readBuf)
})
}
func TestReadMultiplePreviouslyBuffered(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
_, err = remote.Write([]byte("beta"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 9)
n, err := io.ReadFull(conn, readBuf)
require.NoError(t, err)
require.EqualValues(t, 9, n)
require.Equal(t, []byte("alphabeta"), readBuf)
})
}
func TestReadPreviouslyBufferedAndReadMore(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
flushCompleteChan := make(chan struct{})
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
<-flushCompleteChan
_, err = remote.Write([]byte("beta"))
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
close(flushCompleteChan)
readBuf := make([]byte, 9)
n, err := io.ReadFull(conn, readBuf)
require.NoError(t, err)
require.EqualValues(t, 9, n)
require.Equal(t, []byte("alphabeta"), readBuf)
err = <-errChan
require.NoError(t, err)
})
}

View File

@ -23,7 +23,7 @@ func TestScript(t *testing.T) {
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
{
pgproto3.FieldDescription{
Name: []byte("?column?"),
TableOID: 0,
TableAttributeNumber: 0,
@ -69,7 +69,9 @@ func TestScript(t *testing.T) {
}
}()
host, port, _ := strings.Cut(ln.Addr().String(), ":")
parts := strings.Split(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)

View File

@ -1,60 +0,0 @@
#!/usr/bin/env bash
current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi
restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT
# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi
# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi
commits=("$@")
benchmarks_dir=benchmarks
if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi
# Benchmark results
bench_files=()
# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}
# Sanitized commmit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi
bench_files+=("$bench_file")
done
# go install golang.org/x/perf/cmd/benchstat[@latest]
benchstat "${bench_files[@]}"

View File

@ -4,10 +4,8 @@ import (
"bytes"
"encoding/hex"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
@ -20,81 +18,44 @@ type Query struct {
Parts []Part
}
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
// character. utf8.RuneError is not an error if it is also width 3.
//
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3
const maxBufSize = 16384 // 16 Ki
var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}
var null = []byte("null")
func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args))
buf := bufPool.get()
defer bufPool.put(buf)
buf := &bytes.Buffer{}
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
buf.WriteString(part)
str = part
case int:
argIdx := part - 1
var p []byte
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}
if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
p = null
str = "null"
case int64:
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
str = strconv.FormatInt(arg, 10)
case float64:
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
str = strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
str = strconv.FormatBool(arg)
case []byte:
p = QuoteBytes(buf.AvailableBuffer(), arg)
str = QuoteBytes(arg)
case string:
p = QuoteString(buf.AvailableBuffer(), arg)
str = QuoteString(arg)
case time.Time:
p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
buf.Write(p)
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
@ -106,99 +67,26 @@ func (q *Query) Sanitize(args ...any) (string, error) {
}
func NewQuery(sql string) (*Query, error) {
query := &Query{}
query.init(sql)
return query, nil
}
var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}
func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
l := &sqlLexer{
src: sql,
stateFn: rawState,
}
l := sqlLexerPool.get()
defer sqlLexerPool.put(l)
l.src = sql
l.stateFn = rawState
l.parts = parts
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
q.Parts = l.parts
query := &Query{Parts: l.parts}
return query, nil
}
func QuoteString(dst []byte, str string) []byte {
const quote = '\''
// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)
// Add opening quote
dst = append(dst, quote)
// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}
// Add closing quote
dst = append(dst, quote)
return dst
func QuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}
func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}
// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}
// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]
// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'
// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)
// Add suffix
dst[len(dst)-1] = '\''
return dst
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}
type sqlLexer struct {
@ -250,13 +138,11 @@ func rawState(l *sqlLexer) stateFn {
return multilineCommentState
}
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
@ -274,13 +160,11 @@ func singleQuoteState(l *sqlLexer) stateFn {
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
@ -298,13 +182,11 @@ func doubleQuoteState(l *sqlLexer) stateFn {
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
@ -346,13 +228,11 @@ func escapeStringState(l *sqlLexer) stateFn {
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
@ -369,13 +249,11 @@ func oneLineCommentState(l *sqlLexer) stateFn {
case '\n', '\r':
return rawState
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
@ -405,56 +283,22 @@ func multilineCommentState(l *sqlLexer) stateFn {
l.nested--
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...any) (string, error) {
query := queryPool.get()
query.init(sql)
defer queryPool.put(query)
query, err := NewQuery(sql)
if err != nil {
return "", err
}
return query.Sanitize(args...)
}
type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}
func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}
return v
}
func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}

View File

@ -1,62 +0,0 @@
// sanitize_benchmark_test.go
package sanitize_test
import (
"testing"
"time"
"github.com/jackc/pgx/v5/internal/sanitize"
)
var benchmarkSanitizeResult string
const benchmarkQuery = "" +
`SELECT *
FROM "water_containers"
WHERE NOT "id" = $1 -- int64
AND "tags" NOT IN $2 -- nil
AND "volume" > $3 -- float64
AND "transportable" = $4 -- bool
AND position($5 IN "sign") -- bytes
AND "label" LIKE $6 -- string
AND "created_at" > $7; -- time.Time`
var benchmarkArgs = []any{
int64(12345),
nil,
float64(500),
true,
[]byte("8BADF00D"),
"kombucha's han'dy awokowa",
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
}
func BenchmarkSanitize(b *testing.B) {
query, err := sanitize.NewQuery(benchmarkQuery)
if err != nil {
b.Fatalf("failed to create query: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize query: %v", err)
}
}
}
var benchmarkNewSQLResult string
func BenchmarkSanitizeSQL(b *testing.B) {
b.ReportAllocs()
var err error
for i := 0; i < b.N; i++ {
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize SQL: %v", err)
}
}
}

View File

@ -1,55 +0,0 @@
package sanitize_test
import (
"strings"
"testing"
"github.com/jackc/pgx/v5/internal/sanitize"
)
func FuzzQuoteString(f *testing.F) {
const prefix = "prefix"
f.Add("new\nline")
f.Add("sample text")
f.Add("sample q'u'o't'e's")
f.Add("select 'quoted $42', $1")
f.Fuzz(func(t *testing.T, input string) {
got := string(sanitize.QuoteString([]byte(prefix), input))
want := oldQuoteString(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}
func FuzzQuoteBytes(f *testing.F) {
const prefix = "prefix"
f.Add([]byte(nil))
f.Add([]byte("\n"))
f.Add([]byte("sample text"))
f.Add([]byte("sample q'u'o't'e's"))
f.Add([]byte("select 'quoted $42', $1"))
f.Fuzz(func(t *testing.T, input []byte) {
got := string(sanitize.QuoteBytes([]byte(prefix), input))
want := oldQuoteBytes(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}

View File

@ -1,8 +1,6 @@
package sanitize_test
import (
"encoding/hex"
"strings"
"testing"
"time"
@ -90,16 +88,6 @@ func TestNewQuery(t *testing.T) {
sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}},
},
{
// https://github.com/jackc/pgx/issues/1380
sql: "select 'hello w<>rld'",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello w<>rld'"}},
},
{
// Unterminated quoted string
sql: "select 'hello world",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}},
},
}
for i, tt := range successTests {
@ -134,57 +122,47 @@ func TestQuerySanitize(t *testing.T) {
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{int64(42)},
expected: `select 42 `,
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{float64(1.23)},
expected: `select 1.23 `,
expected: `select 1.23`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{true},
expected: `select true `,
expected: `select true`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{[]byte{0, 1, 2, 3, 255}},
expected: `select '\x00010203ff' `,
expected: `select '\x00010203ff'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{nil},
expected: `select null `,
expected: `select null`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{"foobar"},
expected: `select 'foobar' `,
expected: `select 'foobar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{"foo'bar"},
expected: `select 'foo''bar' `,
expected: `select 'foo''bar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []any{`foo\'bar`},
expected: `select 'foo\''bar' `,
expected: `select 'foo\''bar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
expected: `insert '2020-03-01 23:59:59.999999Z' `,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
args: []any{int64(-1)},
expected: `select 1- -1 `,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
args: []any{float64(-1)},
expected: `select 1- -1 `,
expected: `insert '2020-03-01 23:59:59.999999Z'`,
},
}
@ -229,55 +207,3 @@ func TestQuerySanitize(t *testing.T) {
}
}
}
func TestQuoteString(t *testing.T) {
tc := func(name, input string) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteString(nil, input))
want := oldQuoteString(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("empty", "")
tc("text", "abcd")
tc("with quotes", `one's hat is always a cat`)
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}
func TestQuoteBytes(t *testing.T) {
tc := func(name string, input []byte) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteBytes(nil, input))
want := oldQuoteBytes(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("nil", nil)
tc("empty", []byte{})
tc("text", []byte("abcd"))
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}

View File

@ -34,8 +34,7 @@ func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
@ -45,13 +44,6 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
return
}
// The statement may have been invalidated but not yet handled. Do not readd it to the cache.
for _, invalidSD := range c.invalidStmts {
if invalidSD.SQL == sd.SQL {
return
}
}
if c.l.Len() == c.cap {
c.invalidateOldest()
}
@ -81,16 +73,10 @@ func (c *LRUCache) InvalidateAll() {
c.l = list.New()
}
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.

View File

@ -2,17 +2,18 @@
package stmtcache
import (
"crypto/sha256"
"encoding/hex"
"strconv"
"sync/atomic"
"github.com/jackc/pgx/v5/pgconn"
)
// StatementName returns a statement name that will be stable for sql across multiple connections and program
// executions.
func StatementName(sql string) string {
digest := sha256.Sum256([]byte(sql))
return "stmtcache_" + hex.EncodeToString(digest[0:24])
var stmtCounter int64
// NextStatementName returns a statement name that will be unique for the lifetime of the program.
func NextStatementName() string {
n := atomic.AddInt64(&stmtCounter, 1)
return "stmtcache_" + strconv.FormatInt(n, 10)
}
// Cache caches statement descriptions.
@ -29,13 +30,8 @@ type Cache interface {
// InvalidateAll invalidates all statement descriptions.
InvalidateAll()
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
GetInvalidated() []*pgconn.StatementDescription
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
RemoveInvalidated()
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
HandleInvalidated() []*pgconn.StatementDescription
// Len returns the number of cached prepared statement descriptions.
Len() int
@ -43,3 +39,19 @@ type Cache interface {
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
}
func IsStatementInvalid(err error) bool {
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return false
}
// https://github.com/jackc/pgx/issues/1162
//
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
return possibleInvalidCachedPlanError
}

View File

@ -54,16 +54,10 @@ func (c *UnlimitedCache) InvalidateAll() {
c.m = make(map[string]*pgconn.StatementDescription)
}
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *UnlimitedCache) RemoveInvalidated() {
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.

View File

@ -4,15 +4,8 @@ import (
"context"
"errors"
"io"
"github.com/jackc/pgx/v5/pgtype"
)
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
// was created.
//
@ -75,65 +68,32 @@ type LargeObject struct {
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
func (o *LargeObject) Write(p []byte) (int, error) {
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
if err != nil {
return nTotal, err
}
if n < 0 {
return nTotal, errors.New("failed to write to large object")
}
nTotal += n
if n < expected {
return nTotal, errors.New("short write to large object")
} else if n > expected {
return nTotal, errors.New("invalid write to large object")
}
var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n)
if err != nil {
return n, err
}
return nTotal, nil
if n < 0 {
return 0, errors.New("failed to write to large object")
}
return n, nil
}
// Read reads up to len(p) bytes into p returning the number of bytes read.
func (o *LargeObject) Read(p []byte) (int, error) {
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
res := pgtype.PreallocBytes(p[nTotal:])
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
// We compute expected so that it always fits into p, so it should never happen
// that PreallocBytes's ScanBytes had to allocate a new slice.
nTotal += len(res)
if err != nil {
return nTotal, err
}
if len(res) < expected {
return nTotal, io.EOF
} else if len(res) > expected {
return nTotal, errors.New("invalid read of large object")
}
var res []byte
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res)
copy(p, res)
if err != nil {
return len(res), err
}
return nTotal, nil
if len(res) < len(p) {
err = io.EOF
}
return len(res), err
}
// Seek moves the current location pointer to the new location specified by offset.

View File

@ -1,20 +0,0 @@
package pgx
import (
"testing"
)
// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable
// to the given length for the duration of the test.
//
// Tests using this helper should not use t.Parallel().
func SetMaxLargeObjectMessageLength(t *testing.T, length int) {
t.Helper()
original := maxLargeObjectMessageLength
t.Cleanup(func() {
maxLargeObjectMessageLength = original
})
maxLargeObjectMessageLength = length
}

View File

@ -13,10 +13,9 @@ import (
)
func TestLargeObjects(t *testing.T) {
// We use a very short limit to test chunking logic.
pgx.SetMaxLargeObjectMessageLength(t, 2)
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
@ -35,10 +34,9 @@ func TestLargeObjects(t *testing.T) {
}
func TestLargeObjectsSimpleProtocol(t *testing.T) {
// We use a very short limit to test chunking logic.
pgx.SetMaxLargeObjectMessageLength(t, 2)
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
@ -162,10 +160,9 @@ func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) {
}
func TestLargeObjectsMultipleTransactions(t *testing.T) {
// We use a very short limit to test chunking logic.
pgx.SetMaxLargeObjectMessageLength(t, 2)
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))

View File

@ -1,152 +0,0 @@
// Package multitracer provides a Tracer that can combine several tracers into one.
package multitracer
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
// Tracer can combine several tracers into one.
// You can use New to automatically split tracers by interface.
type Tracer struct {
QueryTracers []pgx.QueryTracer
BatchTracers []pgx.BatchTracer
CopyFromTracers []pgx.CopyFromTracer
PrepareTracers []pgx.PrepareTracer
ConnectTracers []pgx.ConnectTracer
PoolAcquireTracers []pgxpool.AcquireTracer
PoolReleaseTracers []pgxpool.ReleaseTracer
}
// New returns new Tracer from tracers with automatically split tracers by interface.
func New(tracers ...pgx.QueryTracer) *Tracer {
var t Tracer
for _, tracer := range tracers {
t.QueryTracers = append(t.QueryTracers, tracer)
if batchTracer, ok := tracer.(pgx.BatchTracer); ok {
t.BatchTracers = append(t.BatchTracers, batchTracer)
}
if copyFromTracer, ok := tracer.(pgx.CopyFromTracer); ok {
t.CopyFromTracers = append(t.CopyFromTracers, copyFromTracer)
}
if prepareTracer, ok := tracer.(pgx.PrepareTracer); ok {
t.PrepareTracers = append(t.PrepareTracers, prepareTracer)
}
if connectTracer, ok := tracer.(pgx.ConnectTracer); ok {
t.ConnectTracers = append(t.ConnectTracers, connectTracer)
}
if poolAcquireTracer, ok := tracer.(pgxpool.AcquireTracer); ok {
t.PoolAcquireTracers = append(t.PoolAcquireTracers, poolAcquireTracer)
}
if poolReleaseTracer, ok := tracer.(pgxpool.ReleaseTracer); ok {
t.PoolReleaseTracers = append(t.PoolReleaseTracers, poolReleaseTracer)
}
}
return &t
}
func (t *Tracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
for _, tracer := range t.QueryTracers {
ctx = tracer.TraceQueryStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
for _, tracer := range t.QueryTracers {
tracer.TraceQueryEnd(ctx, conn, data)
}
}
func (t *Tracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
for _, tracer := range t.BatchTracers {
ctx = tracer.TraceBatchStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
for _, tracer := range t.BatchTracers {
tracer.TraceBatchQuery(ctx, conn, data)
}
}
func (t *Tracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
for _, tracer := range t.BatchTracers {
tracer.TraceBatchEnd(ctx, conn, data)
}
}
func (t *Tracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
for _, tracer := range t.CopyFromTracers {
ctx = tracer.TraceCopyFromStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
for _, tracer := range t.CopyFromTracers {
tracer.TraceCopyFromEnd(ctx, conn, data)
}
}
func (t *Tracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
for _, tracer := range t.PrepareTracers {
ctx = tracer.TracePrepareStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
for _, tracer := range t.PrepareTracers {
tracer.TracePrepareEnd(ctx, conn, data)
}
}
func (t *Tracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
for _, tracer := range t.ConnectTracers {
ctx = tracer.TraceConnectStart(ctx, data)
}
return ctx
}
func (t *Tracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
for _, tracer := range t.ConnectTracers {
tracer.TraceConnectEnd(ctx, data)
}
}
func (t *Tracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
for _, tracer := range t.PoolAcquireTracers {
ctx = tracer.TraceAcquireStart(ctx, pool, data)
}
return ctx
}
func (t *Tracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
for _, tracer := range t.PoolAcquireTracers {
tracer.TraceAcquireEnd(ctx, pool, data)
}
}
func (t *Tracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
for _, tracer := range t.PoolReleaseTracers {
tracer.TraceRelease(pool, data)
}
}

View File

@ -1,115 +0,0 @@
package multitracer_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/multitracer"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"
)
type testFullTracer struct{}
func (tt *testFullTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
}
func (tt *testFullTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
}
func (tt *testFullTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
}
func (tt *testFullTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
}
func (tt *testFullTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
}
func (tt *testFullTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
}
func (tt *testFullTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
}
func (tt *testFullTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
}
type testCopyTracer struct{}
func (tt *testCopyTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return ctx
}
func (tt *testCopyTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
}
func (tt *testCopyTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
return ctx
}
func (tt *testCopyTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
}
func TestNew(t *testing.T) {
t.Parallel()
fullTracer := &testFullTracer{}
copyTracer := &testCopyTracer{}
mt := multitracer.New(fullTracer, copyTracer)
require.Equal(
t,
&multitracer.Tracer{
QueryTracers: []pgx.QueryTracer{
fullTracer,
copyTracer,
},
BatchTracers: []pgx.BatchTracer{
fullTracer,
},
CopyFromTracers: []pgx.CopyFromTracer{
fullTracer,
copyTracer,
},
PrepareTracers: []pgx.PrepareTracer{
fullTracer,
},
ConnectTracers: []pgx.ConnectTracer{
fullTracer,
},
PoolAcquireTracers: []pgxpool.AcquireTracer{
fullTracer,
},
PoolReleaseTracers: []pgxpool.ReleaseTracer{
fullTracer,
},
},
mt,
)
}

View File

@ -2,7 +2,6 @@ package pgx
import (
"context"
"fmt"
"strconv"
"strings"
"unicode/utf8"
@ -13,43 +12,12 @@ import (
//
// For example, the following two queries are equivalent:
//
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
//
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
// letters, numbers, or underscores.
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}))
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2}))
type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(na, sql, false)
}
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
// named arguments that the sql query uses, and no extra arguments.
type StrictNamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(sna, sql, true)
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
@ -73,24 +41,27 @@ func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string,
newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal {
var found bool
newArgs[ordinal-1], found = na[string(name)]
if isStrict && !found {
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
}
newArgs[ordinal-1] = na[string(name)]
}
if isStrict {
for name := range na {
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
}
}
}
return sb.String(), newArgs, nil
return sb.String(), newArgs
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
@ -109,7 +80,7 @@ func rawState(l *sqlLexer) stateFn {
return doubleQuoteState
case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) || nextRune == '_' {
if isLetter(nextRune) {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}

View File

@ -6,7 +6,6 @@ import (
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNamedArgsRewriteQuery(t *testing.T) {
@ -38,10 +37,10 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
expectedArgs: []any{int32(42), "foo"},
},
{
sql: "select @Abc::int, @b_4::text, @_c::int",
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)},
expectedSQL: "select $1::int, $2::text, $3::int",
expectedArgs: []any{int32(42), "foo", int32(1)},
sql: "select @Abc::int, @b_4::text",
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo"},
expectedSQL: "select $1::int, $2::text",
expectedArgs: []any{int32(42), "foo"},
},
{
sql: "at end @",
@ -50,15 +49,15 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
expectedArgs: []any{},
},
{
sql: "ignores without valid character after @ foo bar",
sql: "ignores without letter after @ foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "ignores without valid character after @ foo bar",
expectedSQL: "ignores without letter after @ foo bar",
expectedArgs: []any{},
},
{
sql: "name cannot start with number @1 foo bar",
sql: "name must start with letter @1 foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "name cannot start with number @1 foo bar",
expectedSQL: "name must start with letter @1 foo bar",
expectedArgs: []any{},
},
{
@ -93,70 +92,11 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: "extra provided argument",
namedArgs: pgx.NamedArgs{"extra": int32(1)},
expectedSQL: "extra provided argument",
expectedArgs: []any{},
},
{
sql: "@missing argument",
namedArgs: pgx.NamedArgs{},
expectedSQL: "$1 argument",
expectedArgs: []any{nil},
},
// test comments and quotes
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
require.NoError(t, err)
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}
func TestStrictNamedArgsRewriteQuery(t *testing.T) {
t.Parallel()
for i, tt := range []struct {
sql string
namedArgs pgx.StrictNamedArgs
expectedSQL string
expectedArgs []any
isExpectedError bool
}{
{
sql: "no arguments",
namedArgs: pgx.StrictNamedArgs{},
expectedSQL: "no arguments",
expectedArgs: []any{},
isExpectedError: false,
},
{
sql: "@all @matches",
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
expectedSQL: "$1 $2",
expectedArgs: []any{int32(1), int32(2)},
isExpectedError: false,
},
{
sql: "extra provided argument",
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
isExpectedError: true,
},
{
sql: "@missing argument",
namedArgs: pgx.StrictNamedArgs{},
isExpectedError: true,
},
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
if tt.isExpectedError {
assert.Errorf(t, err, "%d", i)
} else {
require.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}
}

View File

@ -26,4 +26,28 @@ if err != nil {
## Testing
See CONTRIBUTING.md for setup instructions.
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
environment variable handling.
### Example Test Environment
Connect to your PostgreSQL server and run:
```
create database pgx_test;
```
Now you can run the tests:
```bash
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
```
### Connection and Authentication Tests
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
authentication code.

View File

@ -42,12 +42,12 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: sc.clientFirstMessage(),
}
c.frontend.Send(saslInitialResponse)
err = c.flushWithPotentialWriteReadDeadlock()
err = c.frontend.Flush()
if err != nil {
return err
}
// Receive server-first-message payload in an AuthenticationSASLContinue.
// Receive server-first-message payload in a AuthenticationSASLContinue.
saslContinue, err := c.rxSASLContinue()
if err != nil {
return err
@ -62,12 +62,12 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: []byte(sc.clientFinalMessage()),
}
c.frontend.Send(saslResponse)
err = c.flushWithPotentialWriteReadDeadlock()
err = c.frontend.Flush()
if err != nil {
return err
}
// Receive server-final-message payload in an AuthenticationSASLFinal.
// Receive server-final-message payload in a AuthenticationSASLFinal.
saslFinal, err := c.rxSASLFinal()
if err != nil {
return err

View File

@ -53,7 +53,7 @@ func BenchmarkExec(b *testing.B) {
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err)
defer closeConn(b, conn)
@ -97,7 +97,7 @@ func BenchmarkExec(b *testing.B) {
}
func BenchmarkExecPossibleToCancel(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err)
defer closeConn(b, conn)
@ -159,7 +159,7 @@ func BenchmarkExecPrepared(b *testing.B) {
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err)
defer closeConn(b, conn)
@ -197,7 +197,7 @@ func BenchmarkExecPrepared(b *testing.B) {
}
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err)
defer closeConn(b, conn)
@ -238,7 +238,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
}
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
// require.Nil(b, err)
// defer closeConn(b, conn)

View File

@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"net/url"
@ -19,7 +20,6 @@ import (
"github.com/jackc/pgpassfile"
"github.com/jackc/pgservicefile"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)
@ -27,7 +27,7 @@ type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
type GetSSLPasswordFunc func(ctx context.Context) string
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
// manually initialized Config will cause ConnectConfig to panic.
type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
@ -40,19 +40,12 @@ type Config struct {
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
KerberosSrvName string
KerberosSpn string
Fallbacks []*FallbackConfig
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
@ -68,17 +61,12 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
type ParseConfigOptions struct {
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
// PQsetSSLKeyPassHook_OpenSSL.
GetSSLPassword GetSSLPasswordFunc
}
@ -120,14 +108,6 @@ type FallbackConfig struct {
TLSConfig *tls.Config // nil disables TLS
}
// connectOneConfig is the configuration for a single attempt to connect to a single host.
type connectOneConfig struct {
network string
address string
originalHostname string // original hostname before resolving
tlsConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
@ -162,11 +142,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty
// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
//
// # Example Keyword/Value
// # Example DSN
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
//
// # Example URL
@ -185,7 +165,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
//
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or keyword/value:
// via database URL or DSN:
//
// PGHOST
// PGPORT
@ -200,11 +180,9 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// PGSSLKEY
// PGSSLROOTCERT
// PGSSLPASSWORD
// PGOPTIONS
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
// PGTZ
//
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
//
@ -233,9 +211,9 @@ func NetworkAddress(host string, port uint16) (network, address string) {
//
// In addition, ParseConfig accepts the following options:
//
// - servicefile.
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
// servicefile
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
func ParseConfig(connString string) (*Config, error) {
var parseConfigOptions ParseConfigOptions
return ParseConfigWithOptions(connString, parseConfigOptions)
@ -251,16 +229,16 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
connStringSettings := make(map[string]string)
if connString != "" {
var err error
// connString may be a database URL or in PostgreSQL keyword/value format
// connString may be a database URL or a DSN
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseKeywordValueSettings(connString)
connStringSettings, err = parseDSNSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err}
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
@ -269,7 +247,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
}
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
@ -284,22 +262,12 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
}
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
@ -322,9 +290,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"sslnegotiation": {},
"sslpassword": {},
"sslsni": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
@ -362,7 +328,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
port, err := parsePort(portStr)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
@ -374,7 +340,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
var err error
tlsConfigs, err = configTLS(settings, host, options)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
}
}
@ -391,7 +357,6 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
config.Port = fallbacks[0].Port
config.TLSConfig = fallbacks[0].TLSConfig
config.Fallbacks = fallbacks[1:]
config.SSLNegotiation = settings["sslnegotiation"]
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
if err == nil {
@ -419,7 +384,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "any":
// do nothing
default:
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}
return config, nil
@ -452,15 +417,11 @@ func parseEnvSettings() map[string]string {
"PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLSNI": "sslsni",
"PGSSLROOTCERT": "sslrootcert",
"PGSSLPASSWORD": "sslpassword",
"PGSSLNEGOTIATION": "sslnegotiation",
"PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service",
"PGSERVICEFILE": "servicefile",
"PGTZ": "timezone",
"PGOPTIONS": "options",
}
for envname, realname := range nameMap {
@ -476,17 +437,14 @@ func parseEnvSettings() map[string]string {
func parseURLSettings(connString string) (map[string]string, error) {
settings := make(map[string]string)
parsedURL, err := url.Parse(connString)
url, err := url.Parse(connString)
if err != nil {
if urlErr := new(url.Error); errors.As(err, &urlErr) {
return nil, urlErr.Err
}
return nil, err
}
if parsedURL.User != nil {
settings["user"] = parsedURL.User.Username()
if password, present := parsedURL.User.Password(); present {
if url.User != nil {
settings["user"] = url.User.Username()
if password, present := url.User.Password(); present {
settings["password"] = password
}
}
@ -494,7 +452,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
var hosts []string
var ports []string
for _, host := range strings.Split(parsedURL.Host, ",") {
for _, host := range strings.Split(url.Host, ",") {
if host == "" {
continue
}
@ -520,7 +478,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
settings["port"] = strings.Join(ports, ",")
}
database := strings.TrimLeft(parsedURL.Path, "/")
database := strings.TrimLeft(url.Path, "/")
if database != "" {
settings["database"] = database
}
@ -529,7 +487,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
"dbname": "database",
}
for k, v := range parsedURL.Query() {
for k, v := range url.Query() {
if k2, present := nameMap[k]; present {
k = k2
}
@ -546,7 +504,7 @@ func isIPOnly(host string) bool {
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseKeywordValueSettings(s string) (map[string]string, error) {
func parseDSNSettings(s string) (map[string]string, error) {
settings := make(map[string]string)
nameMap := map[string]string{
@ -557,7 +515,7 @@ func parseKeywordValueSettings(s string) (map[string]string, error) {
var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return nil, errors.New("invalid keyword/value")
return nil, errors.New("invalid dsn")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
@ -609,7 +567,7 @@ func parseKeywordValueSettings(s string) (map[string]string, error) {
}
if key == "" {
return nil, errors.New("invalid keyword/value")
return nil, errors.New("invalid dsn")
}
settings[key] = val
@ -654,56 +612,14 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
sslcert := settings["sslcert"]
sslkey := settings["sslkey"]
sslpassword := settings["sslpassword"]
sslsni := settings["sslsni"]
sslnegotiation := settings["sslnegotiation"]
// Match libpq default behavior
if sslmode == "" {
sslmode = "prefer"
}
if sslsni == "" {
sslsni = "1"
}
tlsConfig := &tls.Config{}
if sslnegotiation == "direct" {
tlsConfig.NextProtos = []string{"postgresql"}
if sslmode == "prefer" {
sslmode = "require"
}
}
if sslrootcert != "" {
var caCertPool *x509.CertPool
if sslrootcert == "system" {
var err error
caCertPool, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("unable to load system certificate pool: %w", err)
}
sslmode = "verify-full"
} else {
caCertPool = x509.NewCertPool()
caPath := sslrootcert
caCert, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
switch sslmode {
case "disable":
return []*tls.Config{nil}, nil
@ -761,19 +677,33 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
return nil, errors.New("sslmode is invalid")
}
if sslrootcert != "" {
caCertPool := x509.NewCertPool()
caPath := sslrootcert
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
}
if sslcert != "" && sslkey != "" {
buf, err := os.ReadFile(sslkey)
buf, err := ioutil.ReadFile(sslkey)
if err != nil {
return nil, fmt.Errorf("unable to read sslkey: %w", err)
}
block, _ := pem.Decode(buf)
if block == nil {
return nil, errors.New("failed to decode sslkey")
}
var pemKey []byte
var decryptedKey []byte
var decryptedError error
@ -808,7 +738,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
} else {
pemKey = pem.EncodeToMemory(block)
}
certfile, err := os.ReadFile(sslcert)
certfile, err := ioutil.ReadFile(sslcert)
if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err)
}
@ -819,13 +749,6 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
tlsConfig.Certificates = []tls.Certificate{cert}
}
// Set Server Name Indication (SNI), if enabled by connection parameters.
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
// or IPv6).
if sslsni == "1" && net.ParseIP(host) == nil {
tlsConfig.ServerName = host
}
switch sslmode {
case "allow":
return []*tls.Config{nil, tlsConfig}, nil
@ -850,8 +773,7 @@ func parsePort(s string) (uint16, error) {
}
func makeDefaultDialer() *net.Dialer {
// rely on GOLANG KeepAlive settings
return &net.Dialer{}
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func makeDefaultResolver() *net.Resolver {
@ -875,75 +797,75 @@ func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
return d.DialContext
}
// ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
if err != nil {
return err
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result[0].Rows[0][0]) == "on" {
if string(result.Rows[0][0]) == "on" {
return errors.New("read only connection")
}
return nil
}
// ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-only.
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
if err != nil {
return err
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result[0].Rows[0][0]) != "on" {
if string(result.Rows[0][0]) != "on" {
return errors.New("connection is not read only")
}
return nil
}
// ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=standby.
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
if err != nil {
return err
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result[0].Rows[0][0]) != "t" {
if string(result.Rows[0][0]) != "t" {
return errors.New("server is not in hot standby mode")
}
return nil
}
// ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=primary.
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
if err != nil {
return err
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result[0].Rows[0][0]) == "t" {
if string(result.Rows[0][0]) == "t" {
return errors.New("server is in standby mode")
}
return nil
}
// ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=prefer-standby.
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
if err != nil {
return err
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result[0].Rows[0][0]) != "t" {
if string(result.Rows[0][0]) != "t" {
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
}

View File

@ -4,11 +4,10 @@ import (
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"os"
"os/user"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
"time"
@ -18,25 +17,8 @@ import (
"github.com/stretchr/testify/require"
)
func skipOnWindows(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("FIXME: skipping on Windows, investigate why this test fails in CI environment")
}
}
func getDefaultPort(t *testing.T) uint16 {
if envPGPORT := os.Getenv("PGPORT"); envPGPORT != "" {
p, err := strconv.ParseUint(envPGPORT, 10, 16)
require.NoError(t, err)
return uint16(p)
}
return 5432
}
func getDefaultUser(t *testing.T) string {
if pguser := os.Getenv("PGUSER"); pguser != "" {
return pguser
}
func TestParseConfig(t *testing.T) {
t.Parallel()
var osUserName string
osUser, err := user.Current()
@ -50,20 +32,10 @@ func getDefaultUser(t *testing.T) string {
}
}
return osUserName
}
func TestParseConfig(t *testing.T) {
skipOnWindows(t)
t.Parallel()
config, err := pgconn.ParseConfig("")
require.NoError(t, err)
defaultHost := config.Host
defaultUser := getDefaultUser(t)
defaultPort := getDefaultPort(t)
tests := []struct {
name string
connString string
@ -81,11 +53,10 @@ func TestParseConfig(t *testing.T) {
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "localhost",
Port: 5432,
TLSConfig: nil,
@ -118,12 +89,11 @@ func TestParseConfig(t *testing.T) {
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "localhost",
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
},
},
@ -141,11 +111,10 @@ func TestParseConfig(t *testing.T) {
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "localhost",
Port: 5432,
TLSConfig: nil,
@ -164,7 +133,6 @@ func TestParseConfig(t *testing.T) {
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
RuntimeParams: map[string]string{},
},
@ -180,7 +148,6 @@ func TestParseConfig(t *testing.T) {
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
RuntimeParams: map[string]string{},
},
@ -231,7 +198,7 @@ func TestParseConfig(t *testing.T) {
name: "database url missing user and password",
connString: "postgres://localhost:5432/mydb?sslmode=disable",
config: &pgconn.Config{
User: defaultUser,
User: osUserName,
Host: "localhost",
Port: 5432,
Database: "mydb",
@ -256,9 +223,9 @@ func TestParseConfig(t *testing.T) {
name: "database url unix domain socket host",
connString: "postgres:///foo?host=/tmp",
config: &pgconn.Config{
User: defaultUser,
User: osUserName,
Host: "/tmp",
Port: defaultPort,
Port: 5432,
Database: "foo",
TLSConfig: nil,
RuntimeParams: map[string]string{},
@ -268,9 +235,9 @@ func TestParseConfig(t *testing.T) {
name: "database url unix domain socket host on windows",
connString: "postgres:///foo?host=C:\\tmp",
config: &pgconn.Config{
User: defaultUser,
User: osUserName,
Host: "C:\\tmp",
Port: defaultPort,
Port: 5432,
Database: "foo",
TLSConfig: nil,
RuntimeParams: map[string]string{},
@ -280,9 +247,9 @@ func TestParseConfig(t *testing.T) {
name: "database url dbname",
connString: "postgres://localhost/?dbname=foo&sslmode=disable",
config: &pgconn.Config{
User: defaultUser,
User: osUserName,
Host: "localhost",
Port: defaultPort,
Port: 5432,
Database: "foo",
TLSConfig: nil,
RuntimeParams: map[string]string{},
@ -330,14 +297,14 @@ func TestParseConfig(t *testing.T) {
config: &pgconn.Config{
User: "jack",
Host: "2001:db8::1",
Port: defaultPort,
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "Key/value everything",
name: "DSN everything",
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
config: &pgconn.Config{
User: "jack",
@ -354,7 +321,7 @@ func TestParseConfig(t *testing.T) {
},
},
{
name: "Key/value with escaped single quote",
name: "DSN with escaped single quote",
connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
config: &pgconn.Config{
User: "jack's",
@ -367,7 +334,7 @@ func TestParseConfig(t *testing.T) {
},
},
{
name: "Key/value with escaped backslash",
name: "DSN with escaped backslash",
connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable",
config: &pgconn.Config{
User: "jack",
@ -380,48 +347,48 @@ func TestParseConfig(t *testing.T) {
},
},
{
name: "Key/value with single quoted values",
name: "DSN with single quoted values",
connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: defaultPort,
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "Key/value with single quoted value with escaped single quote",
name: "DSN with single quoted value with escaped single quote",
connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'",
config: &pgconn.Config{
User: "jack's",
Host: "localhost",
Port: defaultPort,
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "Key/value with empty single quoted value",
name: "DSN with empty single quoted value",
connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: defaultPort,
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "Key/value with space between key and value",
name: "DSN with space between key and value",
connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: defaultPort,
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
@ -434,19 +401,19 @@ func TestParseConfig(t *testing.T) {
User: "jack",
Password: "secret",
Host: "foo",
Port: defaultPort,
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "bar",
Port: defaultPort,
Port: 5432,
TLSConfig: nil,
},
{
&pgconn.FallbackConfig{
Host: "baz",
Port: defaultPort,
Port: 5432,
TLSConfig: nil,
},
},
@ -464,12 +431,12 @@ func TestParseConfig(t *testing.T) {
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "bar",
Port: 2,
TLSConfig: nil,
},
{
&pgconn.FallbackConfig{
Host: "baz",
Port: 3,
TLSConfig: nil,
@ -492,7 +459,7 @@ func TestParseConfig(t *testing.T) {
},
},
{
name: "Key/value multiple hosts one port",
name: "DSN multiple hosts one port",
connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable",
config: &pgconn.Config{
User: "jack",
@ -503,12 +470,12 @@ func TestParseConfig(t *testing.T) {
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "bar",
Port: 5432,
TLSConfig: nil,
},
{
&pgconn.FallbackConfig{
Host: "baz",
Port: 5432,
TLSConfig: nil,
@ -517,7 +484,7 @@ func TestParseConfig(t *testing.T) {
},
},
{
name: "Key/value multiple hosts multiple ports",
name: "DSN multiple hosts multiple ports",
connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable",
config: &pgconn.Config{
User: "jack",
@ -528,12 +495,12 @@ func TestParseConfig(t *testing.T) {
TLSConfig: nil,
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "bar",
Port: 2,
TLSConfig: nil,
},
{
&pgconn.FallbackConfig{
Host: "baz",
Port: 3,
TLSConfig: nil,
@ -542,47 +509,44 @@ func TestParseConfig(t *testing.T) {
},
},
{
name: "multiple hosts and fallback tls",
name: "multiple hosts and fallback tsl",
connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "foo",
Port: defaultPort,
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "foo",
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "foo",
Port: defaultPort,
Port: 5432,
TLSConfig: nil,
},
{
&pgconn.FallbackConfig{
Host: "bar",
Port: defaultPort,
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "bar",
}},
{
&pgconn.FallbackConfig{
Host: "bar",
Port: defaultPort,
Port: 5432,
TLSConfig: nil,
},
{
&pgconn.FallbackConfig{
Host: "baz",
Port: defaultPort,
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "baz",
}},
{
&pgconn.FallbackConfig{
Host: "baz",
Port: defaultPort,
Port: 5432,
TLSConfig: nil,
},
},
@ -684,82 +648,6 @@ func TestParseConfig(t *testing.T) {
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is set by default",
connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "sni.test",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "sni.test",
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set for IPv4",
connString: "postgres://jack:secret@1.1.1.1:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "1.1.1.1",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set for IPv6",
connString: "postgres://jack:secret@[::1]:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "::1",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set when disabled (URL-style)",
connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require&sslsni=0",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "sni.test",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set when disabled (key/value style)",
connString: "user=jack password=secret host=sni.test dbname=mydb sslmode=require sslsni=0",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "sni.test",
Port: defaultPort,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
}
for i, tt := range tests {
@ -773,18 +661,18 @@ func TestParseConfig(t *testing.T) {
}
// https://github.com/jackc/pgconn/issues/47
func TestParseConfigKVWithTrailingEmptyEqualDoesNotPanic(t *testing.T) {
func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) {
_, err := pgconn.ParseConfig("host= user= password= port= database=")
require.NoError(t, err)
}
func TestParseConfigKVLeadingEqual(t *testing.T) {
func TestParseConfigDSNLeadingEqual(t *testing.T) {
_, err := pgconn.ParseConfig("= user=jack")
require.Error(t, err)
}
// https://github.com/jackc/pgconn/issues/49
func TestParseConfigKVTrailingBackslash(t *testing.T) {
func TestParseConfigDSNTrailingBackslash(t *testing.T) {
_, err := pgconn.ParseConfig(`x=x\`)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid backslash")
@ -817,7 +705,7 @@ func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) {
}
func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
connString := os.Getenv("PGX_TEST_DATABASE")
connString := os.Getenv("PGX_TEST_CONN_STRING")
original, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
@ -932,7 +820,20 @@ func TestParseConfigEnvLibpq(t *testing.T) {
}
}
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI", "PGTZ", "PGOPTIONS"}
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
savedEnv := make(map[string]string)
for _, n := range pgEnvvars {
savedEnv[n] = os.Getenv(n)
}
defer func() {
for k, v := range savedEnv {
err := os.Setenv(k, v)
if err != nil {
t.Fatalf("Unable to restore environment: %v", err)
}
}
}()
tests := []struct {
name string
@ -952,7 +853,7 @@ func TestParseConfigEnvLibpq(t *testing.T) {
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "123.123.123.123",
Port: 5432,
TLSConfig: nil,
@ -971,8 +872,6 @@ func TestParseConfigEnvLibpq(t *testing.T) {
"PGCONNECT_TIMEOUT": "10",
"PGSSLMODE": "disable",
"PGAPPNAME": "pgxtest",
"PGTZ": "America/New_York",
"PGOPTIONS": "-c search_path=myschema",
},
config: &pgconn.Config{
Host: "123.123.123.123",
@ -982,31 +881,20 @@ func TestParseConfigEnvLibpq(t *testing.T) {
Password: "baz",
ConnectTimeout: 10 * time.Second,
TLSConfig: nil,
RuntimeParams: map[string]string{"application_name": "pgxtest", "timezone": "America/New_York", "options": "-c search_path=myschema"},
},
},
{
name: "SNI can be disabled via environment variable",
envvars: map[string]string{
"PGHOST": "test.foo",
"PGSSLMODE": "require",
"PGSSLSNI": "0",
},
config: &pgconn.Config{
User: osUserName,
Host: "test.foo",
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
RuntimeParams: map[string]string{"application_name": "pgxtest"},
},
},
}
for i, tt := range tests {
for _, env := range pgEnvvars {
t.Setenv(env, tt.envvars[env])
for _, n := range pgEnvvars {
err := os.Unsetenv(n)
require.NoError(t, err)
}
for k, v := range tt.envvars {
err := os.Setenv(k, v)
require.NoError(t, err)
}
config, err := pgconn.ParseConfig("")
@ -1019,14 +907,18 @@ func TestParseConfigEnvLibpq(t *testing.T) {
}
func TestParseConfigReadsPgPassfile(t *testing.T) {
skipOnWindows(t)
t.Parallel()
tfName := filepath.Join(t.TempDir(), "config")
err := os.WriteFile(tfName, []byte("test1:5432:curlydb:curly:nyuknyuknyuk"), 0600)
tf, err := ioutil.TempFile("", "")
require.NoError(t, err)
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tfName)
defer tf.Close()
defer os.Remove(tf.Name())
_, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk"))
require.NoError(t, err)
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name())
expected := &pgconn.Config{
User: "curly",
Password: "nyuknyuknyuk",
@ -1044,12 +936,15 @@ func TestParseConfigReadsPgPassfile(t *testing.T) {
}
func TestParseConfigReadsPgServiceFile(t *testing.T) {
skipOnWindows(t)
t.Parallel()
tfName := filepath.Join(t.TempDir(), "config")
tf, err := ioutil.TempFile("", "")
require.NoError(t, err)
err := os.WriteFile(tfName, []byte(`
defer tf.Close()
defer os.Remove(tf.Name())
_, err = tf.Write([]byte(`
[abc]
host=abc.example.com
port=9999
@ -1061,11 +956,9 @@ host = def.example.com
dbname = defdb
user = defuser
application_name = spaced string
`), 0600)
`))
require.NoError(t, err)
defaultPort := getDefaultPort(t)
tests := []struct {
name string
connString string
@ -1073,7 +966,7 @@ application_name = spaced string
}{
{
name: "abc",
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "abc"),
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"),
config: &pgconn.Config{
Host: "abc.example.com",
Database: "abcdb",
@ -1081,11 +974,10 @@ application_name = spaced string
Port: 9999,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "abc.example.com",
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "abc.example.com",
Port: 9999,
TLSConfig: nil,
@ -1095,21 +987,20 @@ application_name = spaced string
},
{
name: "def",
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "def"),
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"),
config: &pgconn.Config{
Host: "def.example.com",
Port: defaultPort,
Port: 5432,
Database: "defdb",
User: "defuser",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "def.example.com",
},
RuntimeParams: map[string]string{"application_name": "spaced string"},
Fallbacks: []*pgconn.FallbackConfig{
{
&pgconn.FallbackConfig{
Host: "def.example.com",
Port: defaultPort,
Port: 5432,
TLSConfig: nil,
},
},
@ -1117,7 +1008,7 @@ application_name = spaced string
},
{
name: "conn string has precedence",
connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tfName, "abc"),
connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"),
config: &pgconn.Config{
Host: "other.example.com",
Database: "abcdb",

View File

@ -5,8 +5,8 @@ nearly the same level is the C library libpq.
Establishing a Connection
Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the
environment for libpq style environment variables.
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
libpq style environment variables.
Executing a Query
@ -20,17 +20,13 @@ result. The ReadAll method reads all query results into memory.
Pipeline Mode
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of
exactly how many and when network round trips occur.
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows
control of exactly how many and when network round trips occur.
Context Support
All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the
method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can
be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior.
This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is
a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before
interrupting the query in such a way as to close the connection.
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
method immediately returns. In most circumstances, this will close the underlying connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.

View File

@ -12,15 +12,14 @@ import (
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool {
var retryableErr interface{ SafeToRetry() bool }
if errors.As(err, &retryableErr) {
return retryableErr.SafeToRetry()
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
return e.SafeToRetry()
}
return false
}
// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
var timeoutErr *errTimeout
return errors.As(err, &timeoutErr)
@ -30,24 +29,23 @@ func Timeout(err error) bool {
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
SeverityUnlocalized string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
@ -59,37 +57,22 @@ func (pe *PgError) SQLState() string {
return pe.Code
}
// ConnectError is the error returned when a connection attempt fails.
type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
type connectError struct {
config *Config
msg string
err error
}
func (e *ConnectError) Error() string {
prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database)
details := e.err.Error()
if strings.Contains(details, "\n") {
return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t")
} else {
return prefix + " " + details
func (e *connectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
return sb.String()
}
func (e *ConnectError) Unwrap() error {
return e.err
}
type perDialConnectError struct {
address string
originalHostname string
err error
}
func (e *perDialConnectError) Error() string {
return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error())
}
func (e *perDialConnectError) Unwrap() error {
func (e *connectError) Unwrap() error {
return e.err
}
@ -105,47 +88,29 @@ func (e *connLockError) Error() string {
return e.status
}
// ParseConfigError is the error returned when a connection string cannot be parsed.
type ParseConfigError struct {
ConnString string // The connection string that could not be parsed.
type parseConfigError struct {
connString string
msg string
err error
}
func NewParseConfigError(conn, msg string, err error) error {
return &ParseConfigError{
ConnString: conn,
msg: msg,
err: err,
}
}
func (e *ParseConfigError) Error() string {
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
// allow access to the original string if desired and Unwrap would allow access to the underlying error.
connString := redactPW(e.ConnString)
func (e *parseConfigError) Error() string {
connString := redactPW(e.connString)
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
}
func (e *ParseConfigError) Unwrap() error {
func (e *parseConfigError) Unwrap() error {
return e.err
}
func normalizeTimeoutError(ctx context.Context, err error) error {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
if ctx.Err() == context.Canceled {
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
return context.Canceled
} else if ctx.Err() == context.DeadlineExceeded {
return &errTimeout{err: ctx.Err()}
} else {
return &errTimeout{err: netErr}
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return &errTimeout{err: ctx.Err()}
}
return err
}
@ -219,10 +184,10 @@ func redactPW(connString string) string {
return redactURL(u)
}
}
quotedKV := regexp.MustCompile(`password='[^']*'`)
connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx")
plainKV := regexp.MustCompile(`password=[^ ]*`)
connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx")
quotedDSN := regexp.MustCompile(`password='[^']*'`)
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
plainDSN := regexp.MustCompile(`password=[^ ]*`)
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
return connString

View File

@ -19,18 +19,18 @@ func TestConfigError(t *testing.T) {
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg",
},
{
name: "keyword/value with password unquoted",
name: "dsn with password unquoted",
err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil),
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
},
{
name: "keyword/value with password quoted",
name: "dsn with password quoted",
err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil),
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
},
{
name: "weird url",
err: pgconn.NewParseConfigError("postgresql://foo::password@host:1:", "msg", nil),
err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil),
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg",
},
{

View File

@ -1,3 +1,11 @@
// File export_test exports some methods for better testing.
package pgconn
func NewParseConfigError(conn, msg string, err error) error {
return &parseConfigError{
connString: conn,
msg: msg,
err: err,
}
}

View File

@ -12,19 +12,19 @@ import (
)
func closeConn(t testing.TB, conn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
require.NoError(t, conn.Close(ctx))
select {
case <-conn.CleanupDone():
case <-time.After(30 * time.Second):
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
// Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
cancel()

View File

@ -1,139 +0,0 @@
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
package bgreader
import (
"io"
"sync"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
const (
StatusStopped = iota
StatusRunning
StatusStopping
)
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
type BGReader struct {
r io.Reader
cond *sync.Cond
status int32
readResults []readResult
}
type readResult struct {
buf *[]byte
err error
}
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
// reader will stop automatically when the underlying reader returns an error.
func (r *BGReader) Start() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
r.status = StatusRunning
go r.bgRead()
case StatusRunning:
// no-op
case StatusStopping:
r.status = StatusRunning
}
}
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
// background reader is not running.
func (r *BGReader) Stop() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
// no-op
case StatusRunning:
r.status = StatusStopping
case StatusStopping:
// no-op
}
}
// Status returns the current status of the background reader.
func (r *BGReader) Status() int32 {
r.cond.L.Lock()
defer r.cond.L.Unlock()
return r.status
}
func (r *BGReader) bgRead() {
keepReading := true
for keepReading {
buf := iobufpool.Get(8192)
n, err := r.r.Read(*buf)
*buf = (*buf)[:n]
r.cond.L.Lock()
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
if r.status == StatusStopping || err != nil {
r.status = StatusStopped
keepReading = false
}
r.cond.L.Unlock()
r.cond.Broadcast()
}
}
// Read implements the io.Reader interface.
func (r *BGReader) Read(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
if len(r.readResults) > 0 {
return r.readFromReadResults(p)
}
// There are no unread background read results and the background reader is stopped.
if r.status == StatusStopped {
return r.r.Read(p)
}
// Wait for results from the background reader
for len(r.readResults) == 0 {
r.cond.Wait()
}
return r.readFromReadResults(p)
}
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
buf := r.readResults[0].buf
var err error
n := copy(p, *buf)
if n == len(*buf) {
err = r.readResults[0].err
iobufpool.Put(buf)
if len(r.readResults) == 1 {
r.readResults = nil
} else {
r.readResults = r.readResults[1:]
}
} else {
*buf = (*buf)[n:]
r.readResults[0].buf = buf
}
return n, err
}
func New(r io.Reader) *BGReader {
return &BGReader{
r: r,
cond: &sync.Cond{
L: &sync.Mutex{},
},
}
}

View File

@ -1,140 +0,0 @@
package bgreader_test
import (
"bytes"
"errors"
"io"
"math/rand"
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/stretchr/testify/require"
)
func TestBGReaderReadWhenStopped(t *testing.T) {
r := bytes.NewReader([]byte("foo bar baz"))
bgr := bgreader.New(r)
buf, err := io.ReadAll(bgr)
require.NoError(t, err)
require.Equal(t, []byte("foo bar baz"), buf)
}
func TestBGReaderReadWhenStarted(t *testing.T) {
r := bytes.NewReader([]byte("foo bar baz"))
bgr := bgreader.New(r)
bgr.Start()
buf, err := io.ReadAll(bgr)
require.NoError(t, err)
require.Equal(t, []byte("foo bar baz"), buf)
}
type mockReadFunc func(p []byte) (int, error)
type mockReader struct {
readFuncs []mockReadFunc
}
func (r *mockReader) Read(p []byte) (int, error) {
if len(r.readFuncs) == 0 {
return 0, io.EOF
}
fn := r.readFuncs[0]
r.readFuncs = r.readFuncs[1:]
return fn(p)
}
func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), nil },
},
}
bgr := bgreader.New(rr)
bgr.Start()
buf := make([]byte, 3)
n, err := bgr.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 3, n)
require.Equal(t, []byte("foo"), buf)
}
func TestBGReaderErrorWhenStarted(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
},
}
bgr := bgreader.New(rr)
bgr.Start()
buf, err := io.ReadAll(bgr)
require.Equal(t, []byte("foobarbaz"), buf)
require.EqualError(t, err, "oops")
}
func TestBGReaderErrorWhenStopped(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
},
}
bgr := bgreader.New(rr)
buf, err := io.ReadAll(bgr)
require.Equal(t, []byte("foobarbaz"), buf)
require.EqualError(t, err, "oops")
}
type numberReader struct {
v uint8
rng *rand.Rand
}
func (nr *numberReader) Read(p []byte) (int, error) {
n := nr.rng.Intn(len(p))
for i := 0; i < n; i++ {
p[i] = nr.v
nr.v++
}
return n, nil
}
// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and
// stopping the background worker from other goroutines.
func TestBGReaderStress(t *testing.T) {
nr := &numberReader{rng: rand.New(rand.NewSource(0))}
bgr := bgreader.New(nr)
bytesRead := 0
var expected uint8
buf := make([]byte, 10_000)
rng := rand.New(rand.NewSource(0))
for bytesRead < 1_000_000 {
randomNumber := rng.Intn(100)
switch {
case randomNumber < 10:
go bgr.Start()
case randomNumber < 20:
go bgr.Stop()
default:
n, err := bgr.Read(buf)
require.NoError(t, err)
for i := 0; i < n; i++ {
require.Equal(t, expected, buf[i])
expected++
}
bytesRead += n
}
}
}

View File

@ -8,8 +8,9 @@ import (
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
handler Handler
unwatchChan chan struct{}
onCancel func()
onUnwatchAfterCancel func()
unwatchChan chan struct{}
lock sync.Mutex
watchInProgress bool
@ -19,10 +20,11 @@ type ContextWatcher struct {
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(handler Handler) *ContextWatcher {
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
cw := &ContextWatcher{
handler: handler,
unwatchChan: make(chan struct{}),
onCancel: onCancel,
onUnwatchAfterCancel: onUnwatchAfterCancel,
unwatchChan: make(chan struct{}),
}
return cw
@ -44,7 +46,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
go func() {
select {
case <-ctx.Done():
cw.handler.HandleCancel(ctx)
cw.onCancel()
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
@ -64,17 +66,8 @@ func (cw *ContextWatcher) Unwatch() {
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
cw.handler.HandleUnwatchAfterCancel()
cw.onUnwatchAfterCancel()
}
cw.watchInProgress = false
}
}
type Handler interface {
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
// context that was canceled.
HandleCancel(canceledCtx context.Context)
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
HandleUnwatchAfterCancel()
}

View File

@ -6,32 +6,17 @@ import (
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/stretchr/testify/require"
)
type testHandler struct {
handleCancel func(context.Context)
handleUnwatchAfterCancel func()
}
func (h *testHandler) HandleCancel(ctx context.Context) {
h.handleCancel(ctx)
}
func (h *testHandler) HandleUnwatchAfterCancel() {
h.handleUnwatchAfterCancel()
}
func TestContextWatcherContextCancelled(t *testing.T) {
canceledChan := make(chan struct{})
cleanupCalled := false
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
canceledChan <- struct{}{}
}, handleUnwatchAfterCancel: func() {
cleanupCalled = true
},
cw := ctxwatch.NewContextWatcher(func() {
canceledChan <- struct{}{}
}, func() {
cleanupCalled = true
})
ctx, cancel := context.WithCancel(context.Background())
@ -49,13 +34,11 @@ func TestContextWatcherContextCancelled(t *testing.T) {
require.True(t, cleanupCalled, "Cleanup func was not called")
}
func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
t.Error("cancel func should not have been called")
}, handleUnwatchAfterCancel: func() {
t.Error("cleanup func should not have been called")
},
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {
t.Error("cancel func should not have been called")
}, func() {
t.Error("cleanup func should not have been called")
})
ctx, cancel := context.WithCancel(context.Background())
@ -65,12 +48,11 @@ func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
}
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cw.Watch(ctx)
defer cw.Unwatch()
ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
@ -78,7 +60,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) {
}
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
cw.Unwatch() // unwatch when not / never watching
ctx, cancel := context.WithCancel(context.Background())
@ -89,7 +71,7 @@ func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
}
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
@ -105,12 +87,10 @@ func TestContextWatcherStress(t *testing.T) {
var cancelFuncCalls int64
var cleanupFuncCalls int64
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
atomic.AddInt64(&cancelFuncCalls, 1)
}, handleUnwatchAfterCancel: func() {
atomic.AddInt64(&cleanupFuncCalls, 1)
},
cw := ctxwatch.NewContextWatcher(func() {
atomic.AddInt64(&cancelFuncCalls, 1)
}, func() {
atomic.AddInt64(&cleanupFuncCalls, 1)
})
cycleCount := 100000
@ -123,9 +103,7 @@ func TestContextWatcherStress(t *testing.T) {
}
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
if i%333 == 0 {
// on Windows Sleep takes more time than expected so we try to get here less frequently to avoid
// the CI takes a long time
if i%3 == 0 {
time.Sleep(time.Nanosecond)
}
@ -153,7 +131,7 @@ func TestContextWatcherStress(t *testing.T) {
}
func BenchmarkContextWatcherUncancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
for i := 0; i < b.N; i++ {
cw.Watch(context.Background())
@ -162,7 +140,7 @@ func BenchmarkContextWatcherUncancellable(b *testing.B) {
}
func BenchmarkContextWatcherCancelled(b *testing.B) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
for i := 0; i < b.N; i++ {
ctx, cancel := context.WithCancel(context.Background())
@ -173,7 +151,7 @@ func BenchmarkContextWatcherCancelled(b *testing.B) {
}
func BenchmarkContextWatcherCancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

View File

@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {
Data: nextData,
}
c.frontend.Send(gssResponse)
err = c.flushWithPotentialWriteReadDeadlock()
err = c.frontend.Flush()
if err != nil {
return err
}

File diff suppressed because it is too large Load Diff

View File

@ -14,7 +14,7 @@ import (
)
func TestConnStress(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
# pgproto3
Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.

View File

@ -35,10 +35,11 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return finishMessage(dst, sp)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -27,10 +27,11 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
return nil
}
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return finishMessage(dst, sp)
return dst
}
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {

View File

@ -31,11 +31,12 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
return nil
}
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return finishMessage(dst, sp)
return dst
}
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {

View File

@ -38,11 +38,12 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...)
return finishMessage(dst, sp)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -35,10 +35,11 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeOk)
return finishMessage(dst, sp)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -47,8 +47,10 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms {
@ -57,7 +59,9 @@ func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
}
dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -38,11 +38,17 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = append(dst, src.Data...)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -38,11 +38,17 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = append(dst, src.Data...)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Unmarshaler.

View File

@ -16,8 +16,7 @@ type Backend struct {
// before it is actually transmitted (i.e. before Flush).
tracer *tracer
wbuf []byte
encodeError error
wbuf []byte
// Frontend message flyweights
bind Bind
@ -39,7 +38,6 @@ type Backend struct {
terminate Terminate
bodyLen int
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
msgType byte
partialMsg bool
authType uint32
@ -56,21 +54,11 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
}
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
// encountered will be returned from Flush.
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
// called.
func (b *Backend) Send(msg BackendMessage) {
if b.encodeError != nil {
return
}
prevLen := len(b.wbuf)
newBuf, err := msg.Encode(b.wbuf)
if err != nil {
b.encodeError = err
return
}
b.wbuf = newBuf
b.wbuf = msg.Encode(b.wbuf)
if b.tracer != nil {
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
}
@ -78,12 +66,6 @@ func (b *Backend) Send(msg BackendMessage) {
// Flush writes any pending messages to the frontend (i.e. the client).
func (b *Backend) Flush() error {
if err := b.encodeError; err != nil {
b.encodeError = nil
b.wbuf = b.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
n, err := b.w.Write(b.wbuf)
const maxLen = 1024
@ -175,16 +157,7 @@ func (b *Backend) Receive() (FrontendMessage, error) {
}
b.msgType = header[0]
msgLength := int(binary.BigEndian.Uint32(header[1:]))
if msgLength < 4 {
return nil, fmt.Errorf("invalid message length: %d", msgLength)
}
b.bodyLen = msgLength - 4
if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
}
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
}
@ -223,7 +196,7 @@ func (b *Backend) Receive() (FrontendMessage, error) {
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:
// to maintain backwards compatibility
// to maintain backwards compatability
msg = &PasswordMessage{}
}
case 'Q':
@ -260,11 +233,11 @@ func (b *Backend) Receive() (FrontendMessage, error) {
// contextual identification of FrontendMessages. For example, in the
// PG message flow documentation for PasswordMessage:
//
// Byte1('p')
// Byte1('p')
//
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
//
// Since the Frontend does not know about the state of a backend, it is important
// to call SetAuthType() after an authentication request is received by the Frontend.
@ -287,13 +260,3 @@ func (b *Backend) SetAuthType(authType uint32) error {
return nil
}
// SetMaxBodyLen sets the maximum length of a message body in octets.
// If a message body exceeds this length, Receive will return an error.
// This is useful for protecting against malicious clients that send
// large messages with the intent of causing memory exhaustion.
// The default value is 0.
// If maxBodyLen is 0, then no maximum is enforced.
func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
b.maxBodyLen = maxBodyLen
}

View File

@ -29,11 +29,12 @@ func (dst *BackendKeyData) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'K')
func (src *BackendKeyData) Encode(dst []byte) []byte {
dst = append(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return finishMessage(dst, sp)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
"username": "tester",
},
}
dst, err := want.Encode([]byte{})
require.NoError(t, err)
dst := []byte{}
dst = want.Encode(dst)
server := &interruptReader{}
server.push(dst)
@ -120,21 +120,3 @@ func TestStartupMessage(t *testing.T) {
}
})
}
func TestBackendReceiveExceededMaxBodyLen(t *testing.T) {
t.Parallel()
server := &interruptReader{}
server.push([]byte{'Q', 0, 0, 10, 10})
backend := pgproto3.NewBackend(server, nil)
// Set max body len to 5
backend.SetMaxBodyLen(5)
// Receive regular msg
msg, err := backend.Receive()
assert.Nil(t, msg)
var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr
assert.ErrorAs(t, err, &invalidBodyLenErr)
}

View File

@ -5,9 +5,7 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -110,25 +108,21 @@ func (dst *Bind) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'B')
func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)
if len(src.ParameterFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many parameter format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
if len(src.Parameters) > math.MaxUint16 {
return nil, errors.New("too many parameters")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
@ -140,15 +134,14 @@ func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst = append(dst, p...)
}
if len(src.ResultFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many result format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '2', 0, 0, 0, 4), nil
func (src *BindComplete) Encode(dst []byte) []byte {
return append(dst, '2', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -1,20 +0,0 @@
package pgproto3_test
import (
"testing"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/require"
)
func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
t.Parallel()
// Maximum allowed size.
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
require.NoError(t, err)
// 1 byte too big
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
require.Error(t, err)
}

View File

@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
func (src *CancelRequest) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst, nil
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -14,7 +14,7 @@ import (
type chunkReader struct {
r io.Reader
buf *[]byte
buf []byte
rp, wp int // buf read position and write position
minBufSize int
@ -45,7 +45,7 @@ func newChunkReader(r io.Reader, minBufSize int) *chunkReader {
func (r *chunkReader) Next(n int) (buf []byte, err error) {
// Reset the buffer if it is empty
if r.rp == r.wp {
if len(*r.buf) != r.minBufSize {
if len(r.buf) != r.minBufSize {
iobufpool.Put(r.buf)
r.buf = iobufpool.Get(r.minBufSize)
}
@ -55,15 +55,15 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) {
// n bytes already in buf
if (r.wp - r.rp) >= n {
buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
buf = r.buf[r.rp : r.rp+n : r.rp+n]
r.rp += n
return buf, err
}
// buf is smaller than requested number of bytes
if len(*r.buf) < n {
if len(r.buf) < n {
bigBuf := iobufpool.Get(n)
r.wp = copy((*bigBuf), (*r.buf)[r.rp:r.wp])
r.wp = copy(bigBuf, r.buf[r.rp:r.wp])
r.rp = 0
iobufpool.Put(r.buf)
r.buf = bigBuf
@ -71,20 +71,20 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) {
// buf is large enough, but need to shift filled area to start to make enough contiguous space
minReadCount := n - (r.wp - r.rp)
if (len(*r.buf) - r.wp) < minReadCount {
r.wp = copy((*r.buf), (*r.buf)[r.rp:r.wp])
if (len(r.buf) - r.wp) < minReadCount {
r.wp = copy(r.buf, r.buf[r.rp:r.wp])
r.rp = 0
}
// Read at least the required number of bytes from the underlying io.Reader
readBytesCount, err := io.ReadAtLeast(r.r, (*r.buf)[r.wp:], minReadCount)
readBytesCount, err := io.ReadAtLeast(r.r, r.buf[r.wp:], minReadCount)
r.wp += readBytesCount
// fmt.Println("read", n)
if err != nil {
return nil, err
}
buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
buf = r.buf[r.rp : r.rp+n : r.rp+n]
r.rp += n
return buf, nil
}

View File

@ -17,7 +17,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(n1, src[0:2]) {
if bytes.Compare(n1, src[0:2]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
}
@ -25,11 +25,11 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(n2, src[2:4]) {
if bytes.Compare(n2, src[2:4]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
}
if !bytes.Equal((*r.buf)[:len(src)], src) {
if bytes.Compare(r.buf[:len(src)], src) != 0 {
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
}

View File

@ -4,6 +4,8 @@ import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Close struct {
@ -35,12 +37,18 @@ func (dst *Close) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Close) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
func (src *Close) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '3', 0, 0, 0, 4), nil
func (src *CloseComplete) Encode(dst []byte) []byte {
return append(dst, '3', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,6 +3,8 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CommandComplete struct {
@ -29,11 +31,17 @@ func (dst *CommandComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
func (src *CommandComplete) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.CommandTag...)
dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,7 +5,6 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -45,18 +44,19 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'W')
func (src *CopyBothResponse) Encode(dst []byte) []byte {
dst = append(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,7 +5,6 @@ import (
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEncodeDecode(t *testing.T) {
@ -14,7 +13,6 @@ func TestEncodeDecode(t *testing.T) {
err := dstResp.Decode(srcBytes[5:])
assert.NoError(t, err, "No errors on decode")
dstBytes := []byte{}
dstBytes, err = dstResp.Encode(dstBytes)
require.NoError(t, err)
dstBytes = dstResp.Encode(dstBytes)
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
}

View File

@ -3,6 +3,8 @@ package pgproto3
import (
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CopyData struct {
@ -23,10 +25,11 @@ func (dst *CopyData) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'd')
func (src *CopyData) Encode(dst []byte) []byte {
dst = append(dst, 'd')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
dst = append(dst, src.Data...)
return finishMessage(dst, sp)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
return append(dst, 'c', 0, 0, 0, 4), nil
func (src *CopyDone) Encode(dst []byte) []byte {
return append(dst, 'c', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,6 +3,8 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CopyFail struct {
@ -26,11 +28,17 @@ func (dst *CopyFail) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'f')
func (src *CopyFail) Encode(dst []byte) []byte {
dst = append(dst, 'f')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Message...)
dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,7 +5,6 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -45,19 +44,20 @@ func (dst *CopyInResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'G')
func (src *CopyInResponse) Encode(dst []byte) []byte {
dst = append(dst, 'G')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,7 +5,6 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -44,20 +43,21 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'H')
func (src *CopyOutResponse) Encode(dst []byte) []byte {
dst = append(dst, 'H')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,8 +4,6 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -65,12 +63,11 @@ func (dst *DataRow) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')
func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
if len(src.Values) > math.MaxUint16 {
return nil, errors.New("too many values")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {
@ -82,7 +79,9 @@ func (src *DataRow) Encode(dst []byte) ([]byte, error) {
dst = append(dst, v...)
}
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,6 +4,8 @@ import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Describe struct {
@ -35,12 +37,18 @@ func (dst *Describe) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Describe) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')
func (src *Describe) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -1,7 +1,7 @@
// Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
//
// The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are
// sent with Send (or a specialized Send variant). Messages are automatically buffered to minimize small writes. Call
// sent with Send (or a specialized Send variant). Messages are automatically bufferred to minimize small writes. Call
// Flush to ensure a message has actually been sent.
//
// The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a

View File

@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
return append(dst, 'I', 0, 0, 0, 4), nil
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
return append(dst, 'I', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -2,6 +2,7 @@ package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"strconv"
)
@ -110,113 +111,120 @@ func (dst *ErrorResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'E')
dst = src.appendFields(dst)
return finishMessage(dst, sp)
func (src *ErrorResponse) Encode(dst []byte) []byte {
return append(dst, src.marshalBinary('E')...)
}
func (src *ErrorResponse) appendFields(dst []byte) []byte {
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
if src.Severity != "" {
dst = append(dst, 'S')
dst = append(dst, src.Severity...)
dst = append(dst, 0)
buf.WriteByte('S')
buf.WriteString(src.Severity)
buf.WriteByte(0)
}
if src.SeverityUnlocalized != "" {
dst = append(dst, 'V')
dst = append(dst, src.SeverityUnlocalized...)
dst = append(dst, 0)
buf.WriteByte('V')
buf.WriteString(src.SeverityUnlocalized)
buf.WriteByte(0)
}
if src.Code != "" {
dst = append(dst, 'C')
dst = append(dst, src.Code...)
dst = append(dst, 0)
buf.WriteByte('C')
buf.WriteString(src.Code)
buf.WriteByte(0)
}
if src.Message != "" {
dst = append(dst, 'M')
dst = append(dst, src.Message...)
dst = append(dst, 0)
buf.WriteByte('M')
buf.WriteString(src.Message)
buf.WriteByte(0)
}
if src.Detail != "" {
dst = append(dst, 'D')
dst = append(dst, src.Detail...)
dst = append(dst, 0)
buf.WriteByte('D')
buf.WriteString(src.Detail)
buf.WriteByte(0)
}
if src.Hint != "" {
dst = append(dst, 'H')
dst = append(dst, src.Hint...)
dst = append(dst, 0)
buf.WriteByte('H')
buf.WriteString(src.Hint)
buf.WriteByte(0)
}
if src.Position != 0 {
dst = append(dst, 'P')
dst = append(dst, strconv.Itoa(int(src.Position))...)
dst = append(dst, 0)
buf.WriteByte('P')
buf.WriteString(strconv.Itoa(int(src.Position)))
buf.WriteByte(0)
}
if src.InternalPosition != 0 {
dst = append(dst, 'p')
dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
dst = append(dst, 0)
buf.WriteByte('p')
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
buf.WriteByte(0)
}
if src.InternalQuery != "" {
dst = append(dst, 'q')
dst = append(dst, src.InternalQuery...)
dst = append(dst, 0)
buf.WriteByte('q')
buf.WriteString(src.InternalQuery)
buf.WriteByte(0)
}
if src.Where != "" {
dst = append(dst, 'W')
dst = append(dst, src.Where...)
dst = append(dst, 0)
buf.WriteByte('W')
buf.WriteString(src.Where)
buf.WriteByte(0)
}
if src.SchemaName != "" {
dst = append(dst, 's')
dst = append(dst, src.SchemaName...)
dst = append(dst, 0)
buf.WriteByte('s')
buf.WriteString(src.SchemaName)
buf.WriteByte(0)
}
if src.TableName != "" {
dst = append(dst, 't')
dst = append(dst, src.TableName...)
dst = append(dst, 0)
buf.WriteByte('t')
buf.WriteString(src.TableName)
buf.WriteByte(0)
}
if src.ColumnName != "" {
dst = append(dst, 'c')
dst = append(dst, src.ColumnName...)
dst = append(dst, 0)
buf.WriteByte('c')
buf.WriteString(src.ColumnName)
buf.WriteByte(0)
}
if src.DataTypeName != "" {
dst = append(dst, 'd')
dst = append(dst, src.DataTypeName...)
dst = append(dst, 0)
buf.WriteByte('d')
buf.WriteString(src.DataTypeName)
buf.WriteByte(0)
}
if src.ConstraintName != "" {
dst = append(dst, 'n')
dst = append(dst, src.ConstraintName...)
dst = append(dst, 0)
buf.WriteByte('n')
buf.WriteString(src.ConstraintName)
buf.WriteByte(0)
}
if src.File != "" {
dst = append(dst, 'F')
dst = append(dst, src.File...)
dst = append(dst, 0)
buf.WriteByte('F')
buf.WriteString(src.File)
buf.WriteByte(0)
}
if src.Line != 0 {
dst = append(dst, 'L')
dst = append(dst, strconv.Itoa(int(src.Line))...)
dst = append(dst, 0)
buf.WriteByte('L')
buf.WriteString(strconv.Itoa(int(src.Line)))
buf.WriteByte(0)
}
if src.Routine != "" {
dst = append(dst, 'R')
dst = append(dst, src.Routine...)
dst = append(dst, 0)
buf.WriteByte('R')
buf.WriteString(src.Routine)
buf.WriteByte(0)
}
for k, v := range src.UnknownFields {
dst = append(dst, k)
dst = append(dst, v...)
dst = append(dst, 0)
buf.WriteByte(k)
buf.WriteByte(0)
buf.WriteString(v)
buf.WriteByte(0)
}
dst = append(dst, 0)
buf.WriteByte(0)
return dst
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes()
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error {
return fmt.Errorf("error generating query response: %w", err)
}
buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
{
Name: []byte("fortune"),
TableOID: 0,
@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error {
TypeModifier: -1,
Format: 0,
},
}}).Encode(nil))
buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf))
buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf))
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
}}).Encode(nil)
buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)
buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
_, err = p.conn.Write(buf)
if err != nil {
return fmt.Errorf("error writing query response: %w", err)
@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error {
switch startupMessage.(type) {
case *pgproto3.StartupMessage:
buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
buf := (&pgproto3.AuthenticationOk{}).Encode(nil)
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
_, err = p.conn.Write(buf)
if err != nil {
return fmt.Errorf("error sending ready for query: %w", err)
@ -102,10 +102,3 @@ func (p *PgFortuneBackend) handleStartup() error {
func (p *PgFortuneBackend) Close() error {
return p.conn.Close()
}
func mustEncode(buf []byte, err error) []byte {
if err != nil {
panic(err)
}
return buf
}

View File

@ -36,12 +36,19 @@ func (dst *Execute) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Execute) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'E')
func (src *Execute) Encode(dst []byte) []byte {
dst = append(dst, 'E')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Portal...)
dst = append(dst, 0)
dst = pgio.AppendUint32(dst, src.MaxRows)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

Some files were not shown because too many files have changed in this diff Show More