mirror of
https://github.com/jackc/pgx.git
synced 2025-09-04 19:37:10 +00:00
Compare commits
No commits in common. "master" and "v3.4.0" have entirely different histories.
54
.github/ISSUE_TEMPLATE/bug_report.md
vendored
54
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -1,54 +0,0 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
If possible, please provide runnable example such as:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
func main() {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
// Your code here...
|
||||
}
|
||||
```
|
||||
|
||||
Please run your example with the race detector enabled. For example, `go run -race main.go` or `go test -race`.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Actual behavior**
|
||||
A clear and concise description of what actually happened.
|
||||
|
||||
**Version**
|
||||
- Go: `$ go version` -> [e.g. go version go1.18.3 darwin/amd64]
|
||||
- PostgreSQL: `$ psql --no-psqlrc --tuples-only -c 'select version()'` -> [e.g. PostgreSQL 14.4 on x86_64-apple-darwin21.5.0, compiled by Apple clang version 13.1.6 (clang-1316.0.21.2.5), 64-bit]
|
||||
- pgx: `$ grep 'github.com/jackc/pgx/v[0-9]' go.mod` -> [e.g. v4.16.1]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@ -1,20 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
10
.github/ISSUE_TEMPLATE/other-issues.md
vendored
10
.github/ISSUE_TEMPLATE/other-issues.md
vendored
@ -1,10 +0,0 @@
|
||||
---
|
||||
name: Other issues
|
||||
about: Any issue that is not a bug or a feature request
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
Please describe the issue in detail. If this is a question about how to use pgx please use discussions instead.
|
156
.github/workflows/ci.yml
vendored
156
.github/workflows/ci.yml
vendored
@ -1,156 +0,0 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master]
|
||||
pull_request:
|
||||
branches: [master]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ["1.23", "1.24"]
|
||||
pg-version: [13, 14, 15, 16, 17, cockroachdb]
|
||||
include:
|
||||
- pg-version: 13
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 14
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 15
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 16
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 17
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: cockroachdb
|
||||
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
|
||||
- name: Setup database server for testing
|
||||
run: ci/setup_test.bash
|
||||
env:
|
||||
PGVERSION: ${{ matrix.pg-version }}
|
||||
|
||||
# - name: Setup upterm session
|
||||
# uses: lhotari/action-upterm@v1
|
||||
# with:
|
||||
# ## limits ssh access and adds the ssh public key for the user which triggered the workflow
|
||||
# limit-access-to-actor: true
|
||||
# env:
|
||||
# PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
||||
# PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
||||
# PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
|
||||
# PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
|
||||
# PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
|
||||
# PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
|
||||
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
|
||||
# PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
|
||||
# PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
|
||||
|
||||
- name: Check formatting
|
||||
run: |
|
||||
gofmt -l -s -w .
|
||||
git status
|
||||
git diff --exit-code
|
||||
|
||||
- name: Test
|
||||
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
|
||||
run: go test -parallel=1 -race ./...
|
||||
env:
|
||||
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
||||
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
||||
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
|
||||
PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
|
||||
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
|
||||
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
|
||||
# TestConnectTLS fails. However, it succeeds if I connect to the CI server with upterm and run it. Give up on that test for now.
|
||||
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
|
||||
PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
|
||||
PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
|
||||
|
||||
test-windows:
|
||||
name: Test Windows
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ["1.23", "1.24"]
|
||||
|
||||
steps:
|
||||
- name: Setup PostgreSQL
|
||||
id: postgres
|
||||
uses: ikalnytskyi/action-setup-postgres@v4
|
||||
with:
|
||||
database: pgx_test
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
|
||||
- name: Initialize test database
|
||||
run: |
|
||||
psql -f testsetup/postgresql_setup.sql pgx_test
|
||||
env:
|
||||
PGSERVICE: ${{ steps.postgres.outputs.service-name }}
|
||||
shell: bash
|
||||
|
||||
- name: Test
|
||||
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
|
||||
run: go test -parallel=1 -race ./...
|
||||
env:
|
||||
PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }}
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -21,7 +21,5 @@ _testmain.go
|
||||
|
||||
*.exe
|
||||
|
||||
conn_config_test.go
|
||||
.envrc
|
||||
/.testdb
|
||||
|
||||
.DS_Store
|
||||
|
@ -1,21 +0,0 @@
|
||||
# See for configurations: https://golangci-lint.run/usage/configuration/
|
||||
version: 2
|
||||
|
||||
# See: https://golangci-lint.run/usage/formatters/
|
||||
formatters:
|
||||
default: none
|
||||
enable:
|
||||
- gofmt # https://pkg.go.dev/cmd/gofmt
|
||||
- gofumpt # https://github.com/mvdan/gofumpt
|
||||
|
||||
settings:
|
||||
gofmt:
|
||||
simplify: true # Simplify code: gofmt with `-s` option.
|
||||
|
||||
gofumpt:
|
||||
# Module path which contains the source code being formatted.
|
||||
# Default: ""
|
||||
module-path: github.com/jackc/pgx/v5 # Should match with module in go.mod
|
||||
# Choose whether to use the extra rules.
|
||||
# Default: false
|
||||
extra-rules: true
|
33
.travis.yml
Normal file
33
.travis.yml
Normal file
@ -0,0 +1,33 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.x
|
||||
- tip
|
||||
|
||||
# Derived from https://github.com/lib/pq/blob/master/.travis.yml
|
||||
before_install:
|
||||
- ./travis/before_install.bash
|
||||
|
||||
env:
|
||||
global:
|
||||
- PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
matrix:
|
||||
- CRATEVERSION=2.1
|
||||
- PGVERSION=10
|
||||
- PGVERSION=9.6
|
||||
- PGVERSION=9.5
|
||||
- PGVERSION=9.4
|
||||
- PGVERSION=9.3
|
||||
|
||||
before_script:
|
||||
- ./travis/before_script.bash
|
||||
|
||||
install:
|
||||
- ./travis/install.bash
|
||||
|
||||
script:
|
||||
- ./travis/script.bash
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
744
CHANGELOG.md
744
CHANGELOG.md
@ -1,462 +1,384 @@
|
||||
# 5.7.5 (May 17, 2025)
|
||||
|
||||
* Support sslnegotiation connection option (divyam234)
|
||||
* Update golang.org/x/crypto to v0.37.0. This placates security scanners that were unable to see that pgx did not use the behavior affected by https://pkg.go.dev/vuln/GO-2025-3487.
|
||||
* TraceLog now logs Acquire and Release at the debug level (dave sinclair)
|
||||
* Add support for PGTZ environment variable
|
||||
* Add support for PGOPTIONS environment variable
|
||||
* Unpin memory used by Rows quicker
|
||||
* Remove PlanScan memoization. This resolves a rare issue where scanning could be broken for one type by first scanning another. The problem was in the memoization system and benchmarking revealed that memoization was not providing any meaningful benefit.
|
||||
|
||||
# 5.7.4 (March 24, 2025)
|
||||
|
||||
* Fix / revert change to scanning JSON `null` (Felix Röhrich)
|
||||
|
||||
# 5.7.3 (March 21, 2025)
|
||||
|
||||
* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32)
|
||||
* Improve SQL sanitizer performance (ninedraft)
|
||||
* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich)
|
||||
* Fix Values() for xml type always returning nil instead of []byte
|
||||
* Add ability to send Flush message in pipeline mode (zenkovev)
|
||||
* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou)
|
||||
* Better error messages when scanning structs (logicbomb)
|
||||
* Fix handling of error on batch write (bonnefoa)
|
||||
* Match libpq's connection fallback behavior more closely (felix-roehrich)
|
||||
* Add MinIdleConns to pgxpool (djahandarie)
|
||||
|
||||
# 5.7.2 (December 21, 2024)
|
||||
|
||||
* Fix prepared statement already exists on batch prepare failure
|
||||
* Add commit query to tx options (Lucas Hild)
|
||||
* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels)
|
||||
* Add message body size limits in frontend and backend (zene)
|
||||
* Add xid8 type
|
||||
* Ensure planning encodes and scans cannot infinitely recurse
|
||||
* Implement pgtype.UUID.String() (Konstantin Grachev)
|
||||
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
|
||||
* Update golang.org/x/crypto
|
||||
* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo)
|
||||
|
||||
# 5.7.1 (September 10, 2024)
|
||||
|
||||
* Fix data race in tracelog.TraceLog
|
||||
* Update puddle to v2.2.2. This removes the import of nanotime via linkname.
|
||||
* Update golang.org/x/crypto and golang.org/x/text
|
||||
|
||||
# 5.7.0 (September 7, 2024)
|
||||
|
||||
* Add support for sslrootcert=system (Yann Soubeyrand)
|
||||
* Add LoadTypes to load multiple types in a single SQL query (Nick Farrell)
|
||||
* Add XMLCodec supports encoding + scanning XML column type like json (nickcruess-soda)
|
||||
* Add MultiTrace (Stepan Rabotkin)
|
||||
* Add TraceLogConfig with customizable TimeKey (stringintech)
|
||||
* pgx.ErrNoRows wraps sql.ErrNoRows to aid in database/sql compatibility with native pgx functions (merlin)
|
||||
* Support scanning binary formatted uint32 into string / TextScanner (jennifersp)
|
||||
* Fix interval encoding to allow 0s and avoid extra spaces (Carlos Pérez-Aradros Herce)
|
||||
* Update pgservicefile - fixes panic when parsing invalid file
|
||||
* Better error message when reading past end of batch
|
||||
* Don't print url when url.Parse returns an error (Kevin Biju)
|
||||
* Fix snake case name normalization collision in RowToStructByName with db tag (nolandseigler)
|
||||
* Fix: Scan and encode types with underlying types of arrays
|
||||
|
||||
# 5.6.0 (May 25, 2024)
|
||||
|
||||
* Add StrictNamedArgs (Tomas Zahradnicek)
|
||||
* Add support for macaddr8 type (Carlos Pérez-Aradros Herce)
|
||||
* Add SeverityUnlocalized field to PgError / Notice
|
||||
* Performance optimization of RowToStructByPos/Name (Zach Olstein)
|
||||
* Allow customizing context canceled behavior for pgconn
|
||||
* Add ScanLocation to pgtype.Timestamp[tz]Codec
|
||||
* Add custom data to pgconn.PgConn
|
||||
* Fix ResultReader.Read() to handle nil values
|
||||
* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce)
|
||||
* pgconn.SafeToRetry checks for wrapped errors (tjasko)
|
||||
* Failed connection attempts include all errors
|
||||
* Optimize LargeObject.Read (Mitar)
|
||||
* Add tracing for connection acquire and release from pool (ngavinsir)
|
||||
* Fix encode driver.Valuer not called when nil
|
||||
* Add support for custom JSON marshal and unmarshal (Mitar)
|
||||
* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck)
|
||||
|
||||
# 5.5.5 (March 9, 2024)
|
||||
|
||||
Use spaces instead of parentheses for SQL sanitization.
|
||||
|
||||
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
|
||||
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
|
||||
|
||||
# 5.5.4 (March 4, 2024)
|
||||
|
||||
Fix CVE-2024-27304
|
||||
|
||||
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
|
||||
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
|
||||
attacker's control.
|
||||
|
||||
Thanks to Paul Gerste for reporting this issue.
|
||||
|
||||
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
|
||||
* Fix simple protocol encoding of json.RawMessage
|
||||
* Fix *Pipeline.getResults should close pipeline on error
|
||||
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
|
||||
* Fix deallocation of invalidated cached statements in a transaction
|
||||
* Handle invalid sslkey file
|
||||
* Fix scan float4 into sql.Scanner
|
||||
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
|
||||
|
||||
# 5.5.3 (February 3, 2024)
|
||||
|
||||
* Fix: prepared statement already exists
|
||||
* Improve CopyFrom auto-conversion of text-ish values
|
||||
* Add ltree type support (Florent Viel)
|
||||
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
|
||||
* Add AppendRows function (Edoardo Spadolini)
|
||||
* Optimize convert UUID [16]byte to string (Kirill Malikov)
|
||||
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
|
||||
|
||||
# 5.5.2 (January 13, 2024)
|
||||
|
||||
* Allow NamedArgs to start with underscore
|
||||
* pgproto3: Maximum message body length support (jeremy.spriet)
|
||||
* Upgrade golang.org/x/crypto to v0.17.0
|
||||
* Add snake_case support to RowToStructByName (Tikhon Fedulov)
|
||||
* Fix: update description cache after exec prepare (James Hartig)
|
||||
* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler)
|
||||
* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer)
|
||||
* Add OnPgError for easier centralized error handling (James Hartig)
|
||||
|
||||
# 5.5.1 (December 9, 2023)
|
||||
|
||||
* Add CopyFromFunc helper function. (robford)
|
||||
* Add PgConn.Deallocate method that uses PostgreSQL protocol Close message.
|
||||
* pgx uses new PgConn.Deallocate method. This allows deallocating statements to work in a failed transaction. This fixes a case where the prepared statement map could become invalid.
|
||||
* Fix: Prefer driver.Valuer over json.Marshaler for json fields. (Jacopo)
|
||||
* Fix: simple protocol SQL sanitizer previously panicked if an invalid $0 placeholder was used. This now returns an error instead. (maksymnevajdev)
|
||||
* Add pgtype.Numeric.ScanScientific (Eshton Robateau)
|
||||
|
||||
# 5.5.0 (November 4, 2023)
|
||||
|
||||
* Add CollectExactlyOneRow. (Julien GOTTELAND)
|
||||
* Add OpenDBFromPool to create *database/sql.DB from *pgxpool.Pool. (Lev Zakharov)
|
||||
* Prepare can automatically choose statement name based on sql. This makes it easier to explicitly manage prepared statements.
|
||||
* Statement cache now uses deterministic, stable statement names.
|
||||
* database/sql prepared statement names are deterministically generated.
|
||||
* Fix: SendBatch wasn't respecting context cancellation.
|
||||
* Fix: Timeout error from pipeline is now normalized.
|
||||
* Fix: database/sql encoding json.RawMessage to []byte.
|
||||
* CancelRequest: Wait for the cancel request to be acknowledged by the server. This should improve PgBouncer compatibility. (Anton Levakin)
|
||||
* stdlib: Use Ping instead of CheckConn in ResetSession
|
||||
* Add json.Marshaler and json.Unmarshaler for Float4, Float8 (Kirill Mironov)
|
||||
|
||||
# 5.4.3 (August 5, 2023)
|
||||
|
||||
* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert)
|
||||
* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb)
|
||||
* Fix: pgxpool: background health check cannot overflow pool
|
||||
* Fix: Check for nil in defer when sending batch (recover properly from panic)
|
||||
* Fix: json scan of non-string pointer to pointer
|
||||
* Fix: zeronull.Timestamptz should use pgtype.Timestamptz
|
||||
* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig)
|
||||
* RowTo(AddrOf)StructByPos ignores fields with "-" db tag
|
||||
* Optimization: improve text format numeric parsing (horpto)
|
||||
|
||||
# 5.4.2 (July 11, 2023)
|
||||
|
||||
* Fix: RowScanner errors are fatal to Rows
|
||||
* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman)
|
||||
* Hstore text codec internal improvements (Evan Jones)
|
||||
* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares)
|
||||
* Fix: Stop background reader as soon as possible.
|
||||
* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn.
|
||||
|
||||
# 5.4.1 (June 18, 2023)
|
||||
|
||||
* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov)
|
||||
* Add TxOptions.BeginQuery to allow overriding the default BEGIN query
|
||||
|
||||
# 5.4.0 (June 14, 2023)
|
||||
|
||||
* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues.
|
||||
* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov)
|
||||
* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic
|
||||
* CancelRequest: don't try to read the reply (Nicola Murino)
|
||||
* Fix: correctly handle bool type aliases (Wichert Akkerman)
|
||||
* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr()
|
||||
* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones)
|
||||
* Add BeforeClose to pgxpool.Pool (Evan Cordell)
|
||||
* Fix: various hstore fixes and optimizations (Evan Jones)
|
||||
* Fix: RowToStructByPos with embedded unexported struct
|
||||
* Support different bool string representations (Lev Zakharov)
|
||||
* Fix: error when using BatchResults.Exec on a select that returns an error after some rows.
|
||||
* Fix: pipelineBatchResults.Exec() not returning error from ResultReader
|
||||
* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using
|
||||
a callback.
|
||||
* Fix: scanning a table type into a struct
|
||||
* Fix: scan array of record to pointer to slice of struct
|
||||
* Fix: handle null for json (Cemre Mengu)
|
||||
* Batch Query callback is called even when there is an error
|
||||
* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P)
|
||||
|
||||
# 5.3.1 (February 27, 2023)
|
||||
|
||||
* Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka)
|
||||
* Fix: sql.Scanner not being used in certain cases
|
||||
* Add text format jsonpath support
|
||||
* Fix: fake non-blocking read adaptive wait time
|
||||
|
||||
# 5.3.0 (February 11, 2023)
|
||||
|
||||
* Fix: json values work with sql.Scanner
|
||||
* Fixed / improved error messages (Mark Chambers and Yevgeny Pats)
|
||||
* Fix: support scan into single dimensional arrays
|
||||
* Fix: MaxConnLifetimeJitter setting actually jitter (Ben Weintraub)
|
||||
* Fix: driver.Value representation of bytea should be []byte not string
|
||||
* Fix: better handling of unregistered OIDs
|
||||
* CopyFrom can use query cache to avoid extra round trip to get OIDs (Alejandro Do Nascimento Mora)
|
||||
* Fix: encode to json ignoring driver.Valuer
|
||||
* Support sql.Scanner on renamed base type
|
||||
* Fix: pgtype.Numeric text encoding of negative numbers (Mark Chambers)
|
||||
* Fix: connect with multiple hostnames when one can't be resolved
|
||||
* Upgrade puddle to remove dependency on uber/atomic and fix alignment issue on 32-bit platform
|
||||
* Fix: scanning json column into **string
|
||||
* Multiple reductions in memory allocations
|
||||
* Fake non-blocking read adapts its max wait time
|
||||
* Improve CopyFrom performance and reduce memory usage
|
||||
* Fix: encode []any to array
|
||||
* Fix: LoadType for composite with dropped attributes (Felix Röhrich)
|
||||
* Support v4 and v5 stdlib in same program
|
||||
* Fix: text format array decoding with string of "NULL"
|
||||
* Prefer binary format for arrays
|
||||
|
||||
# 5.2.0 (December 5, 2022)
|
||||
|
||||
* `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov)
|
||||
* Optimize creating begin transaction SQL string (Petr Evdokimov and ksco)
|
||||
* `Conn.LoadType` supports range and multirange types (Vitalii Solodilov)
|
||||
* Fix scan `uint` and `uint64` `ScanNumeric`. This resolves a PostgreSQL `numeric` being incorrectly scanned into `uint` and `uint64`.
|
||||
# 3.4.0 (May 3, 2019)
|
||||
|
||||
## Features
|
||||
|
||||
* Improved .pgpass handling (Dmitry Smal)
|
||||
* Adds RowsAffected for CopyToWriter and CopyFromReader (Nikolay Vorobev)
|
||||
* Support binding of []int type to array integer (David Bariod)
|
||||
* Expose registered driver instance to aid integration with other libraries (PLATEL Kévin)
|
||||
* Allow normal queries on replication connections (Jan Vcelak)
|
||||
* Add support for creating a DB from pgx.Pool (fzerorubigd)
|
||||
* SCRAM authentication
|
||||
* pgtype.Date JSON marshal/unmarshal (Andrey Kuzmin)
|
||||
|
||||
## Fixes
|
||||
|
||||
* Fix encoding of ErrorResponse (Josh Leverette)
|
||||
* Use more detailed error output of unknown field (Ilya Sivanev)
|
||||
* "Temporary" Write errors no longer silently break connections.
|
||||
* Fix PreferSimpleProtocol overwrite (Ilya Sinelnikov)
|
||||
* Fix enum handling (Robert Lin)
|
||||
* Copy protocol fixes (Andrey)
|
||||
|
||||
## Changes
|
||||
|
||||
* Do not attempt recovery from any Write error.
|
||||
* Use LogLevel type instead of int for conn config
|
||||
|
||||
# 3.3.0 (December 1, 2018)
|
||||
|
||||
## Features
|
||||
|
||||
* Add CopyFromReader and CopyToWriter (Murat Kabilov)
|
||||
* Add MacaddrArray (Anthony Regeda)
|
||||
* Add float types to FieldDescription.Type (David Yamnitsky)
|
||||
* Add CheckedOutConnections helper method (MOZGIII)
|
||||
* Add host query parameter to support Unix sockets (Jörg Thalheim)
|
||||
* Custom cancelation hook for use with PostgreSQL-like databases (James Hartig)
|
||||
* Added LastStmtSent for safe retry logic (James Hartig)
|
||||
|
||||
## Fixes
|
||||
|
||||
* Do not silently ignore assign NULL to \*string
|
||||
* Fix issue with JSON and driver.Valuer conversion
|
||||
* Fix race with stdlib Driver.configs Open (Greg Curtis)
|
||||
|
||||
## Changes
|
||||
|
||||
* Connection pool uses connections in queue order instead of stack. This
|
||||
minimized the time any connection is idle vs. any other connection.
|
||||
(Anthony Regeda)
|
||||
* FieldDescription.Modifier is int32 instead of uint32
|
||||
* tls: stop sending ssl_renegotiation_limit in startup message (Tejas Manohar)
|
||||
|
||||
# 3.2.0 (August 7, 2018)
|
||||
|
||||
## Features
|
||||
|
||||
* Support sslkey, sslcert, and sslrootcert URI params (Sean Chittenden)
|
||||
* Allow any scheme in ParseURI (for convenience with cockroachdb) (Sean Chittenden)
|
||||
* Add support for domain types
|
||||
* Add zerolog logging adaptor (Justin Reagor)
|
||||
* Add driver.Connector support / Go 1.10 support (James Lawrence)
|
||||
* Allow nested database/sql/driver.Drivers (Jackson Owens)
|
||||
* Support int64 and uint64 numeric array (Anthony Regeda)
|
||||
* Add nul support to pgtype.Bool (Tarik Demirci)
|
||||
* Add types to decode error messages (Damir Vandic)
|
||||
|
||||
|
||||
## Fixes
|
||||
|
||||
* Fix Rows.Values returning same value for multiple columns of same complex type
|
||||
* Fix StartReplication() syntax (steampunkcoder)
|
||||
* Fix precision loss for test format geometric types
|
||||
* Allows scanning jsonb column into `*json.RawMessage`
|
||||
* Allow recovery to savepoint in failed transaction
|
||||
* Fix deadlock when CopyFromSource panics
|
||||
* Include PreferSimpleProtocol in config Merge (Murat Kabilov)
|
||||
|
||||
## Changes
|
||||
|
||||
* pgtype.JSON(B).Value now returns []byte instead of string. This allows
|
||||
database/sql to scan json(b) into \*json.RawMessage. This is a tiny behavior
|
||||
change, but database/sql Scan should automatically convert []byte to string, so
|
||||
there shouldn't be any incompatibility.
|
||||
|
||||
# 3.1.0 (January 15, 2018)
|
||||
|
||||
## Features
|
||||
|
||||
* Add QueryEx, QueryRowEx, ExecEx, and RollbackEx to Tx
|
||||
* Add more ColumnType support (Timothée Peignier)
|
||||
* Add UUIDArray type (Kelsey Francis)
|
||||
* Add zap log adapter (Kelsey Francis)
|
||||
* Add CreateReplicationSlotEx that consistent_point and snapshot_name (Mark Fletcher)
|
||||
* Add BeginBatch to Tx (Gaspard Douady)
|
||||
* Support CrateDB (Felix Geisendörfer)
|
||||
* Allow use of logrus logger with fields configured (André Bierlein)
|
||||
* Add array of enum support
|
||||
* Add support for bit type
|
||||
* Handle timeout parameters (Timothée Peignier)
|
||||
* Allow overriding connection info (James Lawrence)
|
||||
* Add support for bpchar type (Iurii Krasnoshchok)
|
||||
* Add ConnConfig.PreferSimpleProtocol
|
||||
|
||||
## Fixes
|
||||
|
||||
* Fix numeric EncodeBinary bug (Wei Congrui)
|
||||
* Fix logrus updated package name (Damir Vandic)
|
||||
* Fix some invalid one round trip execs failing to return non-nil error. (Kelsey Francis)
|
||||
* Return ErrClosedPool when Acquire() with closed pool (Mike Graf)
|
||||
* Fix decoding row with same type values
|
||||
* Always return non-nil \*Rows from Query to fix QueryRow (Kelsey Francis)
|
||||
* Fix pgtype types that can Set database/sql/driver.driver.Valuer
|
||||
* Prefix types in namespaces other than pg_catalog or public (Kelsey Francis)
|
||||
* Fix incomplete selects during batch (Gaspard Douady and Jack Christensen)
|
||||
* Support nil pointers to value implementing driver.Valuer
|
||||
* Fix time logging for QueryEx
|
||||
* Fix ranges with text format where end is unbounded
|
||||
* Detect erroneous JSON(B) encoding
|
||||
* Fix missing interval mapping
|
||||
* ConnPool begin should not retry if ctx is done (Gaspard Douady)
|
||||
* Fix reading interrupted messages could break connection
|
||||
* Return error on unknown oid while decoding record instead of panic (Iurii Krasnoshchok)
|
||||
|
||||
# 5.1.1 (November 17, 2022)
|
||||
## Changes
|
||||
|
||||
* Fix simple query sanitizer where query text contains a Unicode replacement character.
|
||||
* Remove erroneous `name` argument from `DeallocateAll()`. Technically, this is a breaking change, but given that method was only added 5 days ago this change was accepted. (Bodo Kaiser)
|
||||
* Align sslmode "require" more closely to libpq (Johan Brandhorst)
|
||||
|
||||
# 5.1.0 (November 12, 2022)
|
||||
# 3.0.1 (August 12, 2017)
|
||||
|
||||
* Update puddle to v2.1.2. This resolves a race condition and a deadlock in pgxpool.
|
||||
* `QueryRewriter.RewriteQuery` now returns an error. Technically, this is a breaking change for any external implementers, but given the minimal likelihood that there are actually any external implementers this change was accepted.
|
||||
* Expose `GetSSLPassword` support to pgx.
|
||||
* Fix encode `ErrorResponse` unknown field handling. This would only affect pgproto3 being used directly as a proxy with a non-PostgreSQL server that included additional error fields.
|
||||
* Fix date text format encoding with 5 digit years.
|
||||
* Fix date values passed to a `sql.Scanner` as `string` instead of `time.Time`.
|
||||
* DateCodec.DecodeValue can return `pgtype.InfinityModifier` instead of `string` for infinite values. This now matches the behavior of the timestamp types.
|
||||
* Add domain type support to `Conn.LoadType()`.
|
||||
* Add `RowToStructByName` and `RowToAddrOfStructByName`. (Pavlo Golub)
|
||||
* Add `Conn.DeallocateAll()` to clear all prepared statements including the statement cache. (Bodo Kaiser)
|
||||
## Fixes
|
||||
|
||||
* Fix compilation on 32-bit platform
|
||||
* Fix invalid MarshalJSON of types with status Undefined
|
||||
* Fix pid logging
|
||||
|
||||
# 5.0.4 (October 24, 2022)
|
||||
# 3.0.0 (July 24, 2017)
|
||||
|
||||
* Fix: CollectOneRow prefers PostgreSQL error over pgx.ErrorNoRows
|
||||
* Fix: some reflect Kind checks to first check for nil
|
||||
* Bump golang.org/x/text dependency to placate snyk
|
||||
* Fix: RowToStructByPos on structs with multiple anonymous sub-structs (Baptiste Fontaine)
|
||||
* Fix: Exec checks if tx is closed
|
||||
## Changes
|
||||
|
||||
# 5.0.3 (October 14, 2022)
|
||||
* Pid to PID in accordance with Go naming conventions.
|
||||
* Conn.Pid changed to accessor method Conn.PID()
|
||||
* Conn.SecretKey removed
|
||||
* Remove Conn.TxStatus
|
||||
* Logger interface reduced to single Log method.
|
||||
* Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode.
|
||||
* Transaction isolation level constants are now typed strings instead of bare strings.
|
||||
* Conn.WaitForNotification now takes context.Context instead of time.Duration for cancellation support.
|
||||
* Conn.WaitForNotification no longer automatically pings internally every 15 seconds.
|
||||
* ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support.
|
||||
* Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228
|
||||
* No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed.
|
||||
* Remove CopyTo (functionality is now in CopyFrom)
|
||||
* OID constants moved from pgx to pgtype package
|
||||
* Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system
|
||||
* Removed ValueReader
|
||||
* ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset.
|
||||
* Removed Rows.Fatal(error)
|
||||
* Removed Rows.AfterClose()
|
||||
* Removed Rows.Conn()
|
||||
* Removed Tx.AfterClose()
|
||||
* Removed Tx.Conn()
|
||||
* Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR
|
||||
* Replaced stdlib.OpenFromConnPool with DriverConfig system
|
||||
|
||||
* Fix `driver.Valuer` handling edge cases that could cause infinite loop or crash
|
||||
## Features
|
||||
|
||||
# v5.0.2 (October 8, 2022)
|
||||
* Entirely revamped pluggable type system that supports approximately 60 PostgreSQL types.
|
||||
* Types support database/sql interfaces and therefore can be used with other drivers
|
||||
* Added context methods supporting cancellation where appropriate
|
||||
* Added simple query protocol support
|
||||
* Added single round-trip query mode
|
||||
* Added batch query operations
|
||||
* Added OnNotice
|
||||
* github.com/pkg/errors used where possible for errors
|
||||
* Added stdlib.DriverConfig which directly allows full configuration of underlying pgx connections without needing to use a pgx.ConnPool
|
||||
* Added AcquireConn and ReleaseConn to stdlib to allow acquiring a connection from a database/sql connection.
|
||||
|
||||
* Fix date encoding in text format to always use 2 digits for month and day
|
||||
* Prefer driver.Valuer over wrap plans when encoding
|
||||
* Fix scan to pointer to pointer to renamed type
|
||||
* Allow scanning NULL even if PG and Go types are incompatible
|
||||
# 2.11.0 (June 5, 2017)
|
||||
|
||||
# v5.0.1 (September 24, 2022)
|
||||
## Fixes
|
||||
|
||||
* Fix 32-bit atomic usage
|
||||
* Add MarshalJSON for Float8 (yogipristiawan)
|
||||
* Add `[` and `]` to text encoding of `Lseg`
|
||||
* Fix sqlScannerWrapper NULL handling
|
||||
* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock)
|
||||
|
||||
# v5.0.0 (September 17, 2022)
|
||||
## Features
|
||||
|
||||
## Merged Packages
|
||||
* .pgpass support (j7b)
|
||||
* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen)
|
||||
* Add ParseConnectionString (James Lawrence)
|
||||
|
||||
`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main
|
||||
`github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional
|
||||
release work due to releasing multiple packages, and less clear changelogs.
|
||||
## Performance
|
||||
|
||||
## pgconn
|
||||
* Optimize HStore encoding (René Kroon)
|
||||
|
||||
`CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`.
|
||||
# 2.10.0 (March 17, 2017)
|
||||
|
||||
The return value `ResultReader.Values()` is no longer safe to retain a reference to after a subsequent call to `NextRow()` or `Close()`.
|
||||
## Fixes
|
||||
|
||||
`Trace()` method adds low level message tracing similar to the `PQtrace` function in `libpq`.
|
||||
* database/sql driver created through stdlib.OpenFromConnPool closes connections when requested by database/sql rather than release to underlying connection pool.
|
||||
|
||||
pgconn now uses non-blocking IO. This is a significant internal restructuring, but it should not cause any visible changes on its own. However, it is important in implementing other new features.
|
||||
# 2.11.0 (June 5, 2017)
|
||||
|
||||
`CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping.
|
||||
## Fixes
|
||||
|
||||
pgconn now supports pipeline mode.
|
||||
* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock)
|
||||
|
||||
`*PgConn.ReceiveResults` removed. Use pipeline mode instead.
|
||||
## Features
|
||||
|
||||
`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error.
|
||||
* .pgpass support (j7b)
|
||||
* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen)
|
||||
* Add ParseConnectionString (James Lawrence)
|
||||
|
||||
## pgxpool
|
||||
## Performance
|
||||
|
||||
`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
|
||||
* Optimize HStore encoding (René Kroon)
|
||||
|
||||
## pgtype
|
||||
# 2.10.0 (March 17, 2017)
|
||||
|
||||
The `pgtype` package has been significantly changed.
|
||||
## Fixes
|
||||
|
||||
### NULL Representation
|
||||
* Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood)
|
||||
* Explicitly close checked-in connections on ConnPool.Reset, previously they were closed by GC
|
||||
|
||||
Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a
|
||||
`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable.
|
||||
## Features
|
||||
|
||||
Previously, a type that implemented `driver.Valuer` would have the `Value` method called even on a nil pointer. All nils
|
||||
whether typed or untyped now represent `NULL`.
|
||||
* Add xid type support (Manni Wood)
|
||||
* Add cid type support (Manni Wood)
|
||||
* Add tid type support (Manni Wood)
|
||||
* Add "char" type support (Manni Wood)
|
||||
* Add NullOid type (Manni Wood)
|
||||
* Add json/jsonb binary support to allow use with CopyTo
|
||||
* Add named error ErrAcquireTimeout (Alexander Staubo)
|
||||
* Add logical replication decoding (Kris Wehner)
|
||||
* Add PgxScanner interface to allow types to simultaneously support database/sql and pgx (Jack Christensen)
|
||||
* Add CopyFrom with schema support (Jack Christensen)
|
||||
|
||||
### Codec and Value Split
|
||||
## Compatibility
|
||||
|
||||
Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled
|
||||
encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when
|
||||
there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a
|
||||
PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This
|
||||
concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are
|
||||
generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and
|
||||
`PointValuer` for the PostgreSQL `point` type).
|
||||
* jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work.
|
||||
* CopyTo is now deprecated but will continue to work.
|
||||
|
||||
### Array Types
|
||||
# 2.9.0 (August 26, 2016)
|
||||
|
||||
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also
|
||||
means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional
|
||||
arrays.
|
||||
## Fixes
|
||||
|
||||
### Composite Types
|
||||
* Fix *ConnPool.Deallocate() not deleting prepared statement from map
|
||||
* Fix stdlib not logging unprepared query SQL (Krzysztof Dryś)
|
||||
* Fix Rows.Values() with varchar binary format
|
||||
* Concurrent ConnPool.Acquire calls with Dialer timeouts now timeout in the expected amount of time (Konstantin Dzreev)
|
||||
|
||||
Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite
|
||||
values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite.
|
||||
## Features
|
||||
|
||||
### Range Types
|
||||
* Add CopyTo
|
||||
* Add PrepareEx
|
||||
* Add basic record to []interface{} decoding
|
||||
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
|
||||
* Decode inet/cidr to net.IP
|
||||
* Encode/decode [][]byte to/from bytea[]
|
||||
* Encode/decode named types whose underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64
|
||||
|
||||
Range types are now handled with types `RangeCodec` and `Range[T]`. This allows additional user defined range types to
|
||||
easily be handled. Multirange types are handled similarly with `MultirangeCodec` and `Multirange[T]`.
|
||||
## Performance
|
||||
|
||||
### pgxtype
|
||||
* Substantial reduction in memory allocations
|
||||
|
||||
`LoadDataType` moved to `*Conn` as `LoadType`.
|
||||
# 2.8.1 (March 24, 2016)
|
||||
|
||||
### Bytea
|
||||
## Features
|
||||
|
||||
The `Bytea` and `GenericBinary` types have been replaced. Use the following instead:
|
||||
* Scan accepts nil argument to ignore a column
|
||||
|
||||
* `[]byte` - For normal usage directly use `[]byte`.
|
||||
* `DriverBytes` - Uses driver memory only available until next database method call. Avoids a copy and an allocation.
|
||||
* `PreallocBytes` - Uses preallocated byte slice to avoid an allocation.
|
||||
* `UndecodedBytes` - Avoids any decoding. Allows working with raw bytes.
|
||||
## Fixes
|
||||
|
||||
### Dropped lib/pq Support
|
||||
* Fix compilation on 32-bit architecture
|
||||
* Fix Tx.status not being set on error on Commit
|
||||
* Fix Listen/Unlisten with special characters
|
||||
|
||||
`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work
|
||||
in most cases this is no longer supported.
|
||||
# 2.8.0 (March 18, 2016)
|
||||
|
||||
### database/sql Scan
|
||||
## Fixes
|
||||
|
||||
Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now
|
||||
only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by
|
||||
considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with
|
||||
`pgx`. The previous behavior was only necessary for `lib/pq` compatibility.
|
||||
* Fix unrecognized commit failure
|
||||
* Fix msgReader.rxMsg bug when msgReader already has error
|
||||
* Go float64 can no longer be encoded to a PostgreSQL float4
|
||||
* Fix connection corruption when query with error is closed early
|
||||
|
||||
Added `*Map.SQLScanner` to create a `sql.Scanner` for types such as `[]int32` and `Range[T]` that do not implement
|
||||
`sql.Scanner` directly.
|
||||
## Features
|
||||
|
||||
### Number Type Fields Include Bit size
|
||||
This release adds multiple extension points helpful when wrapping pgx with
|
||||
custom application behavior. pgx can now use custom types designed for the
|
||||
standard database/sql package such as
|
||||
[github.com/shopspring/decimal](https://github.com/shopspring/decimal).
|
||||
|
||||
`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`.
|
||||
This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and
|
||||
`sql.NullInt64` the structures are identical. This means they can be directly converted one to another.
|
||||
|
||||
### 3rd Party Type Integrations
|
||||
|
||||
* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to
|
||||
https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims
|
||||
the pgx dependency tree.
|
||||
|
||||
### Other Changes
|
||||
|
||||
* `Bit` and `Varbit` are both replaced by the `Bits` type.
|
||||
* `CID`, `OID`, `OIDValue`, and `XID` are replaced by the `Uint32` type.
|
||||
* `Hstore` is now defined as `map[string]*string`.
|
||||
* `JSON` and `JSONB` types removed. Use `[]byte` or `string` directly.
|
||||
* `QChar` type removed. Use `rune` or `byte` directly.
|
||||
* `Inet` and `Cidr` types removed. Use `netip.Addr` and `netip.Prefix` directly. These types are more memory efficient than the previous `net.IPNet`.
|
||||
* `Macaddr` type removed. Use `net.HardwareAddr` directly.
|
||||
* Renamed `pgtype.ConnInfo` to `pgtype.Map`.
|
||||
* Renamed `pgtype.DataType` to `pgtype.Type`.
|
||||
* Renamed `pgtype.None` to `pgtype.Finite`.
|
||||
* `RegisterType` now accepts a `*Type` instead of `Type`.
|
||||
* Assorted array helper methods and types made private.
|
||||
|
||||
## stdlib
|
||||
|
||||
* Removed `AcquireConn` and `ReleaseConn` as that functionality has been built in since Go 1.13.
|
||||
|
||||
## Reduced Memory Usage by Reusing Read Buffers
|
||||
|
||||
Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed
|
||||
transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy.
|
||||
However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large
|
||||
chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now
|
||||
ownership remains with the read buffer and anything needing to retain a value must make a copy.
|
||||
|
||||
## Query Execution Modes
|
||||
|
||||
Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode.
|
||||
See documentation for `QueryExecMode`.
|
||||
|
||||
## QueryRewriter Interface and NamedArgs
|
||||
|
||||
pgx now supports named arguments with the `NamedArgs` type. This is implemented via the new `QueryRewriter` interface which
|
||||
allows arbitrary rewriting of query SQL and arguments.
|
||||
|
||||
## RowScanner Interface
|
||||
|
||||
The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row.
|
||||
|
||||
## Rows Result Helpers
|
||||
|
||||
* `CollectRows` and `RowTo*` functions simplify collecting results into a slice.
|
||||
* `CollectOneRow` collects one row using `RowTo*` functions.
|
||||
* `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`.
|
||||
|
||||
## Tx Helpers
|
||||
|
||||
Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and
|
||||
`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`.
|
||||
|
||||
## Improved Batch Query Ergonomics
|
||||
|
||||
Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the
|
||||
results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code
|
||||
to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can
|
||||
be used to register a callback function that will handle the result. Callback functions are called automatically when
|
||||
`BatchResults.Close` is called.
|
||||
|
||||
## SendBatch Uses Pipeline Mode When Appropriate
|
||||
|
||||
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
|
||||
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
|
||||
in a single network round trip. So it would only take 2 round trips.
|
||||
|
||||
## Tracing and Logging
|
||||
|
||||
Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer.
|
||||
|
||||
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
|
||||
tree.
|
||||
* Add *Tx.AfterClose() hook
|
||||
* Add *Tx.Conn()
|
||||
* Add *Tx.Status()
|
||||
* Add *Tx.Err()
|
||||
* Add *Rows.AfterClose() hook
|
||||
* Add *Rows.Conn()
|
||||
* Add *Conn.SetLogger() to allow changing logger
|
||||
* Add *Conn.SetLogLevel() to allow changing log level
|
||||
* Add ConnPool.Reset method
|
||||
* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces
|
||||
* Rows.Scan errors now include which argument caused error
|
||||
* Add Encode() to allow custom Encoders to reuse internal encoding functionality
|
||||
* Add Decode() to allow customer Decoders to reuse internal decoding functionality
|
||||
* Add ConnPool.Prepare method
|
||||
* Add ConnPool.Deallocate method
|
||||
* Add Scan to uint32 and uint64 (utrack)
|
||||
* Add encode and decode to []uint16, []uint32, and []uint64 (Max Musatov)
|
||||
|
||||
## Performance
|
||||
|
||||
* []byte skips encoding/decoding
|
||||
|
||||
# 2.7.1 (October 26, 2015)
|
||||
|
||||
* Disable SSL renegotiation
|
||||
|
||||
# 2.7.0 (October 16, 2015)
|
||||
|
||||
* Add RuntimeParams to ConnConfig
|
||||
* ParseURI extracts RuntimeParams
|
||||
* ParseDSN extracts RuntimeParams
|
||||
* ParseEnvLibpq extracts PGAPPNAME
|
||||
* Prepare is now idempotent
|
||||
* Rows.Values now supports oid type
|
||||
* ConnPool.Release automatically unlistens connections (Joseph Glanville)
|
||||
* Add trace log level
|
||||
* Add more efficient log leveling
|
||||
* Retry automatically on ConnPool.Begin (Joseph Glanville)
|
||||
* Encode from net.IP to inet and cidr
|
||||
* Generalize encoding pointer to string to any PostgreSQL type
|
||||
* Add UUID encoding from pointer to string (Joseph Glanville)
|
||||
* Add null mapping to pointer to pointer (Jonathan Rudenberg)
|
||||
* Add JSON and JSONB type support (Joseph Glanville)
|
||||
|
||||
# 2.6.0 (September 3, 2015)
|
||||
|
||||
* Add inet and cidr type support
|
||||
* Add binary decoding to TimestampOid in stdlib driver (Samuel Stauffer)
|
||||
* Add support for specifying sslmode in connection strings (Rick Snyder)
|
||||
* Allow ConnPool to have MaxConnections of 1
|
||||
* Add basic PGSSLMODE to support to ParseEnvLibpq
|
||||
* Add fallback TLS config
|
||||
* Expose specific error for TSL refused
|
||||
* More error details exposed in PgError
|
||||
* Support custom dialer (Lewis Marshall)
|
||||
|
||||
# 2.5.0 (April 15, 2015)
|
||||
|
||||
* Fix stdlib nil support (Blaž Hrastnik)
|
||||
* Support custom Scanner not reading entire value
|
||||
* Fix empty array scanning (Laurent Debacker)
|
||||
* Add ParseDSN (deoxxa)
|
||||
* Add timestamp support to NullTime
|
||||
* Remove unused text format scanners
|
||||
* Return error when too many parameters on Prepare
|
||||
* Add Travis CI integration (Jonathan Rudenberg)
|
||||
* Large object support (Jonathan Rudenberg)
|
||||
* Fix reading null byte arrays (Karl Seguin)
|
||||
* Add timestamptz[] support
|
||||
* Add timestamp[] support (Karl Seguin)
|
||||
* Add bool[] support (Karl Seguin)
|
||||
* Allow writing []byte into text and varchar columns without type conversion (Hari Bhaskaran)
|
||||
* Fix ConnPool Close panic
|
||||
* Add Listen / notify example
|
||||
* Reduce memory allocations (Karl Seguin)
|
||||
|
||||
# 2.4.0 (October 3, 2014)
|
||||
|
||||
* Add per connection oid to name map
|
||||
* Add Hstore support (Andy Walker)
|
||||
* Move introductory docs to godoc from readme
|
||||
* Fix documentation references to TextEncoder and BinaryEncoder
|
||||
* Add keep-alive to TCP connections (Andy Walker)
|
||||
* Add support for EmptyQueryResponse / Allow no-op Exec (Andy Walker)
|
||||
* Allow reading any type into []byte
|
||||
* WaitForNotification detects lost connections quicker
|
||||
|
||||
# 2.3.0 (September 16, 2014)
|
||||
|
||||
* Truncate logged strings and byte slices
|
||||
* Extract more error information from PostgreSQL
|
||||
* Fix data race with Rows and ConnPool
|
||||
|
121
CONTRIBUTING.md
121
CONTRIBUTING.md
@ -1,121 +0,0 @@
|
||||
# Contributing
|
||||
|
||||
## Discuss Significant Changes
|
||||
|
||||
Before you invest a significant amount of time on a change, please create a discussion or issue describing your
|
||||
proposal. This will help to ensure your proposed change has a reasonable chance of being merged.
|
||||
|
||||
## Avoid Dependencies
|
||||
|
||||
Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change
|
||||
that adds a dependency is no.
|
||||
|
||||
## Development Environment Setup
|
||||
|
||||
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE`
|
||||
environment variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or key-value pairs. In addition,
|
||||
the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to
|
||||
simplify environment variable handling.
|
||||
|
||||
### Using an Existing PostgreSQL Cluster
|
||||
|
||||
If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx
|
||||
test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different
|
||||
authentication methods).
|
||||
|
||||
Create and setup a test database:
|
||||
|
||||
```
|
||||
export PGDATABASE=pgx_test
|
||||
createdb
|
||||
psql -c 'create extension hstore;'
|
||||
psql -c 'create extension ltree;'
|
||||
psql -c 'create domain uint64 as numeric(20,0);'
|
||||
```
|
||||
|
||||
Ensure a `postgres` user exists. This happens by default in normal PostgreSQL installs, but some installation methods
|
||||
such as Homebrew do not.
|
||||
|
||||
```
|
||||
createuser -s postgres
|
||||
```
|
||||
|
||||
Ensure your `PGX_TEST_DATABASE` environment variable points to the database you just created and run the tests.
|
||||
|
||||
```
|
||||
export PGX_TEST_DATABASE="host=/private/tmp database=pgx_test"
|
||||
go test ./...
|
||||
```
|
||||
|
||||
This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods).
|
||||
|
||||
### Creating a New PostgreSQL Cluster Exclusively for Testing
|
||||
|
||||
The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is
|
||||
highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`.
|
||||
|
||||
```
|
||||
export PGPORT=5015
|
||||
export PGUSER=postgres
|
||||
export PGDATABASE=pgx_test
|
||||
export POSTGRESQL_DATA_DIR=postgresql
|
||||
|
||||
export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
|
||||
export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test"
|
||||
export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
|
||||
export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test"
|
||||
export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
|
||||
export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret"
|
||||
export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem"
|
||||
export PGX_SSL_PASSWORD=certpw
|
||||
export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key"
|
||||
```
|
||||
|
||||
Create a new database cluster.
|
||||
|
||||
```
|
||||
initdb --locale=en_US -E UTF-8 --username=postgres .testdb/$POSTGRESQL_DATA_DIR
|
||||
|
||||
echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
|
||||
|
||||
cd .testdb
|
||||
|
||||
# Generate CA, server, and encrypted client certificates.
|
||||
go run ../testsetup/generate_certs.go
|
||||
|
||||
# Copy certificates to server directory and set permissions.
|
||||
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
|
||||
cp localhost.key $POSTGRESQL_DATA_DIR/server.key
|
||||
chmod 600 $POSTGRESQL_DATA_DIR/server.key
|
||||
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
|
||||
|
||||
cd ..
|
||||
```
|
||||
|
||||
|
||||
Start the new cluster. This will be necessary whenever you are running pgx tests.
|
||||
|
||||
```
|
||||
postgres -D .testdb/$POSTGRESQL_DATA_DIR
|
||||
```
|
||||
|
||||
Setup the test database in the new cluster.
|
||||
|
||||
```
|
||||
createdb
|
||||
psql --no-psqlrc -f testsetup/postgresql_setup.sql
|
||||
```
|
||||
|
||||
### PgBouncer
|
||||
|
||||
There are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set.
|
||||
|
||||
### Optional Tests
|
||||
|
||||
pgx supports multiple connection types and means of authentication. These tests are optional. They will only run if the
|
||||
appropriate environment variables are set. In addition, there may be tests specific to particular PostgreSQL versions,
|
||||
non-PostgreSQL servers (e.g. CockroachDB), or connection poolers (e.g. PgBouncer). `go test ./... -v | grep SKIP` to see
|
||||
if any tests are being skipped.
|
4
LICENSE
4
LICENSE
@ -1,4 +1,4 @@
|
||||
Copyright (c) 2013-2021 Jack Christensen
|
||||
Copyright (c) 2013 Jack Christensen
|
||||
|
||||
MIT License
|
||||
|
||||
@ -19,4 +19,4 @@ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
223
README.md
223
README.md
@ -1,191 +1,152 @@
|
||||
[](https://pkg.go.dev/github.com/jackc/pgx/v5)
|
||||
[](https://github.com/jackc/pgx/actions/workflows/ci.yml)
|
||||
[](https://godoc.org/github.com/jackc/pgx)
|
||||
[](https://travis-ci.org/jackc/pgx)
|
||||
|
||||
# pgx - PostgreSQL Driver and Toolkit
|
||||
|
||||
pgx is a pure Go driver and toolkit for PostgreSQL.
|
||||
pgx is a pure Go driver and toolkit for PostgreSQL. pgx is different from other drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a database/sql compatible driver, pgx is also usable directly. It offers a native interface similar to database/sql that offers better performance and more features.
|
||||
|
||||
The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` /
|
||||
`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface.
|
||||
|
||||
The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol
|
||||
and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers,
|
||||
proxies, load balancers, logical replication clients, etc.
|
||||
|
||||
## Example Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// urlExample := "postgres://username:password@localhost:5432/database_name"
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
var name string
|
||||
var weight int64
|
||||
err = conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "QueryRow failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Println(name, weight)
|
||||
var name string
|
||||
var weight int64
|
||||
err := conn.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information.
|
||||
|
||||
## Features
|
||||
|
||||
* Support for approximately 70 different PostgreSQL types
|
||||
* Automatic statement preparation and caching
|
||||
pgx supports many additional features beyond what is available through database/sql.
|
||||
|
||||
* Support for approximately 60 different PostgreSQL types
|
||||
* Batch queries
|
||||
* Single-round trip query mode
|
||||
* Full TLS connection control
|
||||
* Binary format support for custom types (allows for much quicker encoding/decoding)
|
||||
* `COPY` protocol support for faster bulk data loads
|
||||
* Tracing and logging support
|
||||
* Connection pool with after-connect hook for arbitrary connection setup
|
||||
* `LISTEN` / `NOTIFY`
|
||||
* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings
|
||||
* `hstore` support
|
||||
* `json` and `jsonb` support
|
||||
* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix`
|
||||
* Binary format support for custom types (can be much faster)
|
||||
* Copy protocol support for faster bulk data loads
|
||||
* Extendable logging support including built-in support for log15 and logrus
|
||||
* Connection pool with after connect hook to do arbitrary connection setup
|
||||
* Listen / notify
|
||||
* PostgreSQL array to Go slice mapping for integers, floats, and strings
|
||||
* Hstore support
|
||||
* JSON and JSONB support
|
||||
* Maps inet and cidr PostgreSQL types to net.IPNet and net.IP
|
||||
* Large object support
|
||||
* NULL mapping to pointer to pointer
|
||||
* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types
|
||||
* Notice response handling
|
||||
* Simulated nested transactions with savepoints
|
||||
* NULL mapping to Null* struct or pointer to pointer.
|
||||
* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types
|
||||
* Logical replication connections, including receiving WAL and sending standby status updates
|
||||
* Notice response handling (this is different than listen / notify)
|
||||
|
||||
## Choosing Between the pgx and database/sql Interfaces
|
||||
## Performance
|
||||
|
||||
The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available
|
||||
through the `database/sql` interface.
|
||||
pgx performs roughly equivalent to [go-pg](https://github.com/go-pg/pg) and is almost always faster than [pq](http://godoc.org/github.com/lib/pq). When parsing large result sets the percentage difference can be significant (16483 queries/sec for pgx vs. 10106 queries/sec for pq -- 63% faster).
|
||||
|
||||
The pgx interface is recommended when:
|
||||
In many use cases a significant cause of latency is network round trips between the application and the server. pgx supports query batching to bundle multiple queries into a single round trip. Even in the case of a connection with the lowest possible latency, a local Unix domain socket, batching as few as three queries together can yield an improvement of 57%. With a typical network connection the results can be even more substantial.
|
||||
|
||||
1. The application only targets PostgreSQL.
|
||||
2. No other libraries that require `database/sql` are in use.
|
||||
See this [gist](https://gist.github.com/jackc/4996e8648a0c59839bff644f49d6e434) for the underlying benchmark results or checkout [go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself.
|
||||
|
||||
It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed.
|
||||
In addition to the native driver, pgx also includes a number of packages that provide additional functionality.
|
||||
|
||||
## Testing
|
||||
## github.com/jackc/pgx/stdlib
|
||||
|
||||
See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions.
|
||||
database/sql compatibility layer for pgx. pgx can be used as a normal database/sql driver, but at any time the native interface may be acquired for more performance or PostgreSQL specific functionality.
|
||||
|
||||
## Architecture
|
||||
## github.com/jackc/pgx/pgtype
|
||||
|
||||
See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.com/watch?v=sXMSWhcHCf8) for a description of pgx architecture.
|
||||
Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver.
|
||||
|
||||
## Supported Go and PostgreSQL Versions
|
||||
## github.com/jackc/pgx/pgproto3
|
||||
|
||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.23 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||
pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling.
|
||||
|
||||
## Version Policy
|
||||
|
||||
pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version.
|
||||
|
||||
## PGX Family Libraries
|
||||
|
||||
### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl)
|
||||
|
||||
pglogrepl provides functionality to act as a client for PostgreSQL logical replication.
|
||||
|
||||
### [github.com/jackc/pgmock](https://github.com/jackc/pgmock)
|
||||
## github.com/jackc/pgx/pgmock
|
||||
|
||||
pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler).
|
||||
|
||||
### [github.com/jackc/tern](https://github.com/jackc/tern)
|
||||
## Documentation
|
||||
|
||||
tern is a stand-alone SQL migration system.
|
||||
pgx includes extensive documentation in the godoc format. It is viewable online at [godoc.org](https://godoc.org/github.com/jackc/pgx).
|
||||
|
||||
### [github.com/jackc/pgerrcode](https://github.com/jackc/pgerrcode)
|
||||
## Testing
|
||||
|
||||
pgerrcode contains constants for the PostgreSQL error codes.
|
||||
pgx supports multiple connection and authentication types. Setting up a test
|
||||
environment that can test all of them can be cumbersome. In particular,
|
||||
Windows cannot test Unix domain socket connections. Because of this pgx will
|
||||
skip tests for connection types that are not configured.
|
||||
|
||||
## Adapters for 3rd Party Types
|
||||
### Normal Test Environment
|
||||
|
||||
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
|
||||
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
|
||||
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
|
||||
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
|
||||
To setup the normal test environment, first install these dependencies:
|
||||
|
||||
go get github.com/cockroachdb/apd
|
||||
go get github.com/hashicorp/go-version
|
||||
go get github.com/jackc/fake
|
||||
go get github.com/lib/pq
|
||||
go get github.com/pkg/errors
|
||||
go get github.com/satori/go.uuid
|
||||
go get github.com/shopspring/decimal
|
||||
go get github.com/sirupsen/logrus
|
||||
go get go.uber.org/zap
|
||||
go get gopkg.in/inconshreveable/log15.v2
|
||||
|
||||
## Adapters for 3rd Party Tracers
|
||||
Then run the following SQL:
|
||||
|
||||
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
|
||||
* [github.com/exaring/otelpgx](https://github.com/exaring/otelpgx)
|
||||
create user pgx_md5 password 'secret';
|
||||
create user " tricky, ' } "" \ test user " password 'secret';
|
||||
create database pgx_test;
|
||||
create user pgx_replication with replication password 'secret';
|
||||
|
||||
## Adapters for 3rd Party Loggers
|
||||
Connect to database pgx_test and run:
|
||||
|
||||
These adapters can be used with the tracelog package.
|
||||
create extension hstore;
|
||||
create domain uint64 as numeric(20,0);
|
||||
|
||||
* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log)
|
||||
* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15)
|
||||
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
|
||||
* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)
|
||||
* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog)
|
||||
* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog)
|
||||
* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog)
|
||||
Next open conn_config_test.go.example and make a copy without the
|
||||
.example. If your PostgreSQL server is accepting connections on 127.0.0.1,
|
||||
then you are done.
|
||||
|
||||
## 3rd Party Libraries with PGX Support
|
||||
### Connection and Authentication Test Environment
|
||||
|
||||
### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock)
|
||||
Complete the normal test environment setup and also do the following.
|
||||
|
||||
pgxmock is a mock library implementing pgx interfaces.
|
||||
pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection.
|
||||
Run the following SQL:
|
||||
|
||||
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
|
||||
create user pgx_none;
|
||||
create user pgx_pw password 'secret';
|
||||
|
||||
Library for scanning data from a database into Go structs and more.
|
||||
Add the following to your pg_hba.conf:
|
||||
|
||||
### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql)
|
||||
If you are developing on Unix with domain socket connections:
|
||||
|
||||
A carefully designed SQL client for making using SQL easier,
|
||||
more productive, and less error-prone on Golang.
|
||||
local pgx_test pgx_none trust
|
||||
local pgx_test pgx_pw password
|
||||
local pgx_test pgx_md5 md5
|
||||
|
||||
### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
|
||||
If you are developing on Windows with TCP connections:
|
||||
|
||||
Adds GSSAPI / Kerberos authentication support.
|
||||
host pgx_test pgx_none 127.0.0.1/32 trust
|
||||
host pgx_test pgx_pw 127.0.0.1/32 password
|
||||
host pgx_test pgx_md5 127.0.0.1/32 md5
|
||||
|
||||
### [github.com/wcamarao/pmx](https://github.com/wcamarao/pmx)
|
||||
### Replication Test Environment
|
||||
|
||||
Explicit data mapping and scanning library for Go structs and slices.
|
||||
Add a replication user:
|
||||
|
||||
### [github.com/stephenafamo/scan](https://github.com/stephenafamo/scan)
|
||||
create user pgx_replication with replication password 'secret';
|
||||
|
||||
Type safe and flexible package for scanning database data into Go types.
|
||||
Supports, structs, maps, slices and custom mapping functions.
|
||||
Add a replication line to your pg_hba.conf:
|
||||
|
||||
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
|
||||
host replication pgx_replication 127.0.0.1/32 md5
|
||||
|
||||
Code first migration library for native pgx (no database/sql abstraction).
|
||||
Change the following settings in your postgresql.conf:
|
||||
|
||||
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
|
||||
wal_level=logical
|
||||
max_wal_senders=5
|
||||
max_replication_slots=5
|
||||
|
||||
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.
|
||||
Set `replicationConnConfig` appropriately in `conn_config_test.go`.
|
||||
|
||||
### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox)
|
||||
## Version Policy
|
||||
|
||||
Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver.
|
||||
|
||||
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
|
||||
|
||||
Simplifies working with the pgx library, providing convenient scanning of nested structures.
|
||||
|
||||
## [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter)
|
||||
|
||||
Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters.
|
||||
pgx follows semantic versioning for the documented public API on stable releases. Branch `v3` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v3` branch (in practice, this occurs very rarely). `v2` is the previous stable release.
|
||||
|
18
Rakefile
18
Rakefile
@ -1,18 +0,0 @@
|
||||
require "erb"
|
||||
|
||||
rule '.go' => '.go.erb' do |task|
|
||||
erb = ERB.new(File.read(task.source))
|
||||
File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
|
||||
sh "goimports", "-w", task.name
|
||||
end
|
||||
|
||||
generated_code_files = [
|
||||
"pgtype/int.go",
|
||||
"pgtype/int_test.go",
|
||||
"pgtype/integration_benchmark_test.go",
|
||||
"pgtype/zeronull/int.go",
|
||||
"pgtype/zeronull/int_test.go"
|
||||
]
|
||||
|
||||
desc "Generate code"
|
||||
task generate: generated_code_files
|
@ -10,7 +10,7 @@
|
||||
// https://github.com/lib/pq/pull/788
|
||||
// https://github.com/lib/pq/pull/833
|
||||
|
||||
package pgconn
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -22,7 +22,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/text/secure/precis"
|
||||
)
|
||||
@ -30,7 +30,7 @@ import (
|
||||
const clientNonceLen = 18
|
||||
|
||||
// Perform SCRAM authentication.
|
||||
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
func (c *Conn) scramAuth(serverAuthMechanisms []string) error {
|
||||
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -41,18 +41,17 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
AuthMechanism: "SCRAM-SHA-256",
|
||||
Data: sc.clientFirstMessage(),
|
||||
}
|
||||
c.frontend.Send(saslInitialResponse)
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-first-message payload in an AuthenticationSASLContinue.
|
||||
saslContinue, err := c.rxSASLContinue()
|
||||
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||
authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = sc.recvServerFirstMessage(saslContinue.Data)
|
||||
err = sc.recvServerFirstMessage(authMsg.SASLData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -61,48 +60,33 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
saslResponse := &pgproto3.SASLResponse{
|
||||
Data: []byte(sc.clientFinalMessage()),
|
||||
}
|
||||
c.frontend.Send(saslResponse)
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
_, err = c.conn.Write(saslResponse.Encode(nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-final-message payload in an AuthenticationSASLFinal.
|
||||
saslFinal, err := c.rxSASLFinal()
|
||||
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||
authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sc.recvServerFinalMessage(saslFinal.Data)
|
||||
return sc.recvServerFinalMessage(authMsg.SASLData)
|
||||
}
|
||||
|
||||
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
func (c *Conn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationSASLContinue:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
authMsg, ok := msg.(*pgproto3.Authentication)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected message type")
|
||||
}
|
||||
if authMsg.Type != typ {
|
||||
return nil, errors.New("unexpected auth type")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
|
||||
}
|
||||
|
||||
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationSASLFinal:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
|
||||
return authMsg, nil
|
||||
}
|
||||
|
||||
type scramClient struct {
|
||||
@ -198,12 +182,12 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
|
||||
var err error
|
||||
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
|
||||
return fmt.Errorf("invalid SCRAM salt received from server: %v", err)
|
||||
}
|
||||
|
||||
sc.iterations, err = strconv.Atoi(string(iterationsStr))
|
||||
if err != nil || sc.iterations <= 0 {
|
||||
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
|
||||
return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr)
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
|
||||
@ -263,9 +247,9 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
||||
return buf
|
||||
}
|
||||
|
||||
func computeServerSignature(saltedPassword, authMessage []byte) []byte {
|
||||
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||
serverSignature := computeHMAC(serverKey, authMessage)
|
||||
serverSignature := computeHMAC(serverKey[:], authMessage)
|
||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||
base64.StdEncoding.Encode(buf, serverSignature)
|
||||
return buf
|
660
batch.go
660
batch.go
@ -2,468 +2,310 @@ package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
// QueuedQuery is a query that has been queued for execution via a Batch.
|
||||
type QueuedQuery struct {
|
||||
SQL string
|
||||
Arguments []any
|
||||
Fn batchItemFunc
|
||||
sd *pgconn.StatementDescription
|
||||
}
|
||||
|
||||
type batchItemFunc func(br BatchResults) error
|
||||
|
||||
// Query sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
|
||||
qq.Fn = func(br BatchResults) error {
|
||||
rows, _ := br.Query()
|
||||
defer rows.Close()
|
||||
|
||||
err := fn(rows)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Query sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
|
||||
qq.Fn = func(br BatchResults) error {
|
||||
row := br.QueryRow()
|
||||
return fn(row)
|
||||
}
|
||||
}
|
||||
|
||||
// Exec sets fn to be called when the response to qq is received.
|
||||
//
|
||||
// Note: for simple batch insert uses where it is not required to handle
|
||||
// each potential error individually, it's sufficient to not set any callbacks,
|
||||
// and just handle the return value of BatchResults.Close.
|
||||
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
||||
qq.Fn = func(br BatchResults) error {
|
||||
ct, err := br.Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fn(ct)
|
||||
}
|
||||
type batchItem struct {
|
||||
query string
|
||||
arguments []interface{}
|
||||
parameterOIDs []pgtype.OID
|
||||
resultFormatCodes []int16
|
||||
}
|
||||
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips. A Batch must only be sent once.
|
||||
// unnecessary network round trips.
|
||||
type Batch struct {
|
||||
QueuedQueries []*QueuedQuery
|
||||
conn *Conn
|
||||
connPool *ConnPool
|
||||
items []*batchItem
|
||||
resultsRead int
|
||||
pendingCommandComplete bool
|
||||
ctx context.Context
|
||||
err error
|
||||
inTx bool
|
||||
}
|
||||
|
||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option
|
||||
// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode.
|
||||
// BeginBatch returns a *Batch query for c.
|
||||
func (c *Conn) BeginBatch() *Batch {
|
||||
return &Batch{conn: c}
|
||||
}
|
||||
|
||||
// BeginBatch returns a *Batch query for tx. Since this *Batch is already part
|
||||
// of a transaction it will not automatically be wrapped in a transaction.
|
||||
func (tx *Tx) BeginBatch() *Batch {
|
||||
return &Batch{conn: tx.conn, inTx: true}
|
||||
}
|
||||
|
||||
// Conn returns the underlying connection that b will or was performed on.
|
||||
func (b *Batch) Conn() *Conn {
|
||||
return b.conn
|
||||
}
|
||||
|
||||
// Queue queues a query to batch b. parameterOIDs are required if there are
|
||||
// parameters and query is not the name of a prepared statement.
|
||||
// resultFormatCodes are required if there is a result.
|
||||
func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgtype.OID, resultFormatCodes []int16) {
|
||||
b.items = append(b.items, &batchItem{
|
||||
query: query,
|
||||
arguments: arguments,
|
||||
parameterOIDs: parameterOIDs,
|
||||
resultFormatCodes: resultFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// Send sends all queued queries to the server at once.
|
||||
// If the batch is created from a conn Object then All queries are wrapped
|
||||
// in a transaction. The transaction can optionally be configured with
|
||||
// txOptions. The context is in effect until the Batch is closed.
|
||||
//
|
||||
// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should
|
||||
// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query,
|
||||
// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that
|
||||
// include the current query may reference the wrong query.
|
||||
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
|
||||
qq := &QueuedQuery{
|
||||
SQL: query,
|
||||
Arguments: arguments,
|
||||
}
|
||||
b.QueuedQueries = append(b.QueuedQueries, qq)
|
||||
return qq
|
||||
}
|
||||
|
||||
// Len returns number of queries that have been queued so far.
|
||||
func (b *Batch) Len() int {
|
||||
return len(b.QueuedQueries)
|
||||
}
|
||||
|
||||
type BatchResults interface {
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
|
||||
// calling Exec on the QueuedQuery, or just calling Close.
|
||||
Exec() (pgconn.CommandTag, error)
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
|
||||
// calling Query on the QueuedQuery.
|
||||
Query() (Rows, error)
|
||||
|
||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
|
||||
// Prefer calling QueryRow on the QueuedQuery.
|
||||
QueryRow() Row
|
||||
|
||||
// Close closes the batch operation. All unread results are read and any callback functions registered with
|
||||
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
|
||||
// error or the batch encounters an error subsequent callback functions will not be called.
|
||||
//
|
||||
// For simple batch inserts inside a transaction or similar queries, it's sufficient to not set any callbacks,
|
||||
// and just handle the return value of Close.
|
||||
//
|
||||
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch
|
||||
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
|
||||
// connection will have been closed.
|
||||
//
|
||||
// Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback
|
||||
// functions will not be rerun.
|
||||
Close() error
|
||||
}
|
||||
|
||||
type batchResults struct {
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
mrr *pgconn.MultiResultReader
|
||||
err error
|
||||
b *Batch
|
||||
qqIdx int
|
||||
closed bool
|
||||
endTraced bool
|
||||
}
|
||||
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||
func (br *batchResults) Exec() (pgconn.CommandTag, error) {
|
||||
if br.err != nil {
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
if br.closed {
|
||||
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
|
||||
// Warning: Send writes all queued queries before reading any results. This can
|
||||
// cause a deadlock if an excessive number of queries are queued. It is highly
|
||||
// advisable to use a timeout context to protect against this possibility.
|
||||
// Unfortunately, this excessive number can vary based on operating system,
|
||||
// connection type (TCP or Unix domain socket), and type of query. Unix domain
|
||||
// sockets seem to be much more susceptible to this issue than TCP connections.
|
||||
// However, it usually is at least several thousand.
|
||||
//
|
||||
// The deadlock occurs when the batched queries to be sent are so large that the
|
||||
// PostgreSQL server cannot receive it all at once. PostgreSQL received some of
|
||||
// the queued queries and starts executing them. As PostgreSQL executes the
|
||||
// queries it sends responses back. pgx will not read any of these responses
|
||||
// until it has finished sending. Therefore, if all network buffers are full pgx
|
||||
// will not be able to finish sending the queries and PostgreSQL will not be
|
||||
// able to finish sending the responses.
|
||||
//
|
||||
// See https://github.com/jackc/pgx/issues/374.
|
||||
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
|
||||
query, arguments, _ := br.nextQueryAndArgs()
|
||||
b.ctx = ctx
|
||||
|
||||
if !br.mrr.NextResult() {
|
||||
err := br.mrr.Close()
|
||||
if err == nil {
|
||||
err = errors.New("no more results in batch")
|
||||
}
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
commandTag, err := br.mrr.ResultReader().Close()
|
||||
err := b.conn.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
br.err = err
|
||||
br.mrr.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
CommandTag: commandTag,
|
||||
Err: br.err,
|
||||
})
|
||||
if err := b.conn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return commandTag, br.err
|
||||
buf := b.conn.wbuf
|
||||
if !b.inTx {
|
||||
buf = appendQuery(buf, txOptions.beginSQL())
|
||||
}
|
||||
|
||||
err = b.conn.initContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, bi := range b.items {
|
||||
var psName string
|
||||
var psParameterOIDs []pgtype.OID
|
||||
|
||||
if ps, ok := b.conn.preparedStatements[bi.query]; ok {
|
||||
psName = ps.Name
|
||||
psParameterOIDs = ps.ParameterOIDs
|
||||
} else {
|
||||
psParameterOIDs = bi.parameterOIDs
|
||||
buf = appendParse(buf, "", bi.query, psParameterOIDs)
|
||||
}
|
||||
|
||||
var err error
|
||||
buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf = appendDescribe(buf, 'P', "")
|
||||
buf = appendExecute(buf, "", 0)
|
||||
}
|
||||
|
||||
buf = appendSync(buf)
|
||||
b.conn.pendingReadyForQueryCount++
|
||||
|
||||
if !b.inTx {
|
||||
buf = appendQuery(buf, "commit")
|
||||
b.conn.pendingReadyForQueryCount++
|
||||
}
|
||||
|
||||
_, err = b.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
b.conn.die(err)
|
||||
return err
|
||||
}
|
||||
|
||||
for !b.inTx {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
return nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
func (br *batchResults) Query() (Rows, error) {
|
||||
query, arguments, ok := br.nextQueryAndArgs()
|
||||
if !ok {
|
||||
query = "batch query"
|
||||
// ExecResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Exec.
|
||||
func (b *Batch) ExecResults() (CommandTag, error) {
|
||||
if b.err != nil {
|
||||
return "", b.err
|
||||
}
|
||||
|
||||
if br.err != nil {
|
||||
return &baseRows{err: br.err, closed: true}, br.err
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.die(b.ctx.Err())
|
||||
return "", b.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
alreadyClosedErr := fmt.Errorf("batch already closed")
|
||||
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||
rows.batchTracer = br.conn.batchTracer
|
||||
b.resultsRead++
|
||||
|
||||
if !br.mrr.NextResult() {
|
||||
rows.err = br.mrr.Close()
|
||||
if rows.err == nil {
|
||||
rows.err = errors.New("no more results in batch")
|
||||
}
|
||||
rows.closed = true
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
Err: rows.err,
|
||||
})
|
||||
for {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return rows, rows.err
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
b.pendingCommandComplete = false
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// QueryResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Query.
|
||||
func (b *Batch) QueryResults() (*Rows, error) {
|
||||
rows := b.conn.getRows("batch query", nil)
|
||||
|
||||
if b.err != nil {
|
||||
rows.fatal(b.err)
|
||||
return rows, b.err
|
||||
}
|
||||
|
||||
rows.resultReader = br.mrr.ResultReader()
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.die(b.ctx.Err())
|
||||
rows.fatal(b.err)
|
||||
return rows, b.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
fieldDescriptions, err := b.conn.readUntilRowDescription()
|
||||
if err != nil {
|
||||
b.die(err)
|
||||
rows.fatal(b.err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
rows.batch = b
|
||||
rows.fields = fieldDescriptions
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||
func (br *batchResults) QueryRow() Row {
|
||||
rows, _ := br.Query()
|
||||
return (*connRow)(rows.(*baseRows))
|
||||
// QueryRowResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with QueryRow.
|
||||
func (b *Batch) QueryRowResults() *Row {
|
||||
rows, _ := b.QueryResults()
|
||||
return (*Row)(rows)
|
||||
|
||||
}
|
||||
|
||||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||
func (br *batchResults) Close() error {
|
||||
defer func() {
|
||||
if !br.endTraced {
|
||||
if br.conn != nil && br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
|
||||
}
|
||||
br.endTraced = true
|
||||
}
|
||||
// Close closes the batch operation. Any error that occured during a batch
|
||||
// operation may have made it impossible to resyncronize the connection with the
|
||||
// server. In this case the underlying connection will have been closed.
|
||||
func (b *Batch) Close() (err error) {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
|
||||
invalidateCachesOnBatchResultsError(br.conn, br.b, br.err)
|
||||
defer func() {
|
||||
err = b.conn.termContext(err)
|
||||
if b.conn != nil && b.connPool != nil {
|
||||
b.connPool.Release(b.conn)
|
||||
}
|
||||
}()
|
||||
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
|
||||
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
|
||||
if err != nil {
|
||||
br.err = err
|
||||
}
|
||||
} else {
|
||||
br.Exec()
|
||||
for i := b.resultsRead; i < len(b.items); i++ {
|
||||
if _, err = b.ExecResults(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
br.closed = true
|
||||
|
||||
err := br.mrr.Close()
|
||||
if br.err == nil {
|
||||
br.err = err
|
||||
if err = b.conn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return br.err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (br *batchResults) earlyError() error {
|
||||
return br.err
|
||||
func (b *Batch) die(err error) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.err = err
|
||||
b.conn.die(err)
|
||||
|
||||
if b.conn != nil && b.connPool != nil {
|
||||
b.connPool.Release(b.conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
bi := br.b.QueuedQueries[br.qqIdx]
|
||||
query = bi.SQL
|
||||
args = bi.Arguments
|
||||
ok = true
|
||||
br.qqIdx++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type pipelineBatchResults struct {
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
pipeline *pgconn.Pipeline
|
||||
lastRows *baseRows
|
||||
err error
|
||||
b *Batch
|
||||
qqIdx int
|
||||
closed bool
|
||||
endTraced bool
|
||||
}
|
||||
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||
func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
|
||||
if br.err != nil {
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
if br.closed {
|
||||
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
|
||||
}
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
|
||||
query, arguments, err := br.nextQueryAndArgs()
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
results, err := br.pipeline.GetResults()
|
||||
if err != nil {
|
||||
br.err = err
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
var commandTag pgconn.CommandTag
|
||||
switch results := results.(type) {
|
||||
case *pgconn.ResultReader:
|
||||
commandTag, br.err = results.Close()
|
||||
default:
|
||||
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
|
||||
}
|
||||
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
CommandTag: commandTag,
|
||||
Err: br.err,
|
||||
})
|
||||
}
|
||||
|
||||
return commandTag, br.err
|
||||
}
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
func (br *pipelineBatchResults) Query() (Rows, error) {
|
||||
if br.err != nil {
|
||||
return &baseRows{err: br.err, closed: true}, br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
alreadyClosedErr := fmt.Errorf("batch already closed")
|
||||
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
||||
}
|
||||
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
br.err = br.lastRows.err
|
||||
return &baseRows{err: br.err, closed: true}, br.err
|
||||
}
|
||||
|
||||
query, arguments, err := br.nextQueryAndArgs()
|
||||
if err != nil {
|
||||
return &baseRows{err: err, closed: true}, err
|
||||
}
|
||||
|
||||
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||
rows.batchTracer = br.conn.batchTracer
|
||||
br.lastRows = rows
|
||||
|
||||
results, err := br.pipeline.GetResults()
|
||||
if err != nil {
|
||||
br.err = err
|
||||
rows.err = err
|
||||
rows.closed = true
|
||||
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
Err: err,
|
||||
})
|
||||
func (b *Batch) ensureCommandComplete() error {
|
||||
for b.pendingCommandComplete {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
switch results := results.(type) {
|
||||
case *pgconn.ResultReader:
|
||||
rows.resultReader = results
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
b.pendingCommandComplete = false
|
||||
return nil
|
||||
default:
|
||||
err = fmt.Errorf("unexpected pipeline result: %T", results)
|
||||
br.err = err
|
||||
rows.err = err
|
||||
rows.closed = true
|
||||
}
|
||||
}
|
||||
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||
func (br *pipelineBatchResults) QueryRow() Row {
|
||||
rows, _ := br.Query()
|
||||
return (*connRow)(rows.(*baseRows))
|
||||
}
|
||||
|
||||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||
func (br *pipelineBatchResults) Close() error {
|
||||
defer func() {
|
||||
if !br.endTraced {
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
|
||||
}
|
||||
br.endTraced = true
|
||||
}
|
||||
|
||||
invalidateCachesOnBatchResultsError(br.conn, br.b, br.err)
|
||||
}()
|
||||
|
||||
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
|
||||
br.err = br.lastRows.err
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
return br.err
|
||||
}
|
||||
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
|
||||
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
|
||||
err = b.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
br.err = err
|
||||
}
|
||||
} else {
|
||||
br.Exec()
|
||||
}
|
||||
}
|
||||
|
||||
br.closed = true
|
||||
|
||||
err := br.pipeline.Close()
|
||||
if br.err == nil {
|
||||
br.err = err
|
||||
}
|
||||
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) earlyError() error {
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, err error) {
|
||||
if br.b == nil {
|
||||
return "", nil, errors.New("no reference to batch")
|
||||
}
|
||||
|
||||
if br.qqIdx >= len(br.b.QueuedQueries) {
|
||||
return "", nil, errors.New("no more results in batch")
|
||||
}
|
||||
|
||||
bi := br.b.QueuedQueries[br.qqIdx]
|
||||
br.qqIdx++
|
||||
return bi.SQL, bi.Arguments, nil
|
||||
}
|
||||
|
||||
// invalidates statement and description caches on batch results error
|
||||
func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) {
|
||||
if err != nil && conn != nil && b != nil {
|
||||
if sc := conn.statementCache; sc != nil {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
sc.Invalidate(bi.SQL)
|
||||
}
|
||||
}
|
||||
|
||||
if sc := conn.descriptionCache; sc != nil {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
sc.Invalidate(bi.SQL)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
1667
batch_test.go
1667
batch_test.go
File diff suppressed because it is too large
Load Diff
55
bench-tmp_test.go
Normal file
55
bench-tmp_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkPgtypeInt4ParseBinary(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err := conn.Prepare("selectBinary", "select n::int4 from generate_series(1, 100) n")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n int32
|
||||
|
||||
rows, err := conn.Query("selectBinary")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&n)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
b.Fatal(rows.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPgtypeInt4EncodeBinary(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err := conn.Prepare("encodeBinary", "select $1::int4, $2::int4, $3::int4, $4::int4, $5::int4, $6::int4, $7::int4")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := conn.Query("encodeBinary", int32(i), int32(i), int32(i), int32(i), int32(i), int32(i), int32(i))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
}
|
1169
bench_test.go
1169
bench_test.go
File diff suppressed because it is too large
Load Diff
89
chunkreader/chunkreader.go
Normal file
89
chunkreader/chunkreader.go
Normal file
@ -0,0 +1,89 @@
|
||||
package chunkreader
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type ChunkReader struct {
|
||||
r io.Reader
|
||||
|
||||
buf []byte
|
||||
rp, wp int // buf read position and write position
|
||||
|
||||
options Options
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
MinBufLen int // Minimum buffer length
|
||||
}
|
||||
|
||||
func NewChunkReader(r io.Reader) *ChunkReader {
|
||||
cr, err := NewChunkReaderEx(r, Options{})
|
||||
if err != nil {
|
||||
panic("default options can't be bad")
|
||||
}
|
||||
|
||||
return cr
|
||||
}
|
||||
|
||||
func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) {
|
||||
if options.MinBufLen == 0 {
|
||||
options.MinBufLen = 4096
|
||||
}
|
||||
|
||||
return &ChunkReader{
|
||||
r: r,
|
||||
buf: make([]byte, options.MinBufLen),
|
||||
options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Next returns buf filled with the next n bytes. If an error occurs, buf will
|
||||
// be nil.
|
||||
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
|
||||
// n bytes already in buf
|
||||
if (r.wp - r.rp) >= n {
|
||||
buf = r.buf[r.rp : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, err
|
||||
}
|
||||
|
||||
// available space in buf is less than n
|
||||
if len(r.buf) < n {
|
||||
r.copyBufContents(r.newBuf(n))
|
||||
}
|
||||
|
||||
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
||||
minReadCount := n - (r.wp - r.rp)
|
||||
if (len(r.buf) - r.wp) < minReadCount {
|
||||
newBuf := r.newBuf(n)
|
||||
r.copyBufContents(newBuf)
|
||||
}
|
||||
|
||||
if err := r.appendAtLeast(minReadCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = r.buf[r.rp : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (r *ChunkReader) appendAtLeast(fillLen int) error {
|
||||
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
|
||||
r.wp += n
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ChunkReader) newBuf(size int) []byte {
|
||||
if size < r.options.MinBufLen {
|
||||
size = r.options.MinBufLen
|
||||
}
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
||||
func (r *ChunkReader) copyBufContents(dest []byte) {
|
||||
r.wp = copy(dest, r.buf[r.rp:r.wp])
|
||||
r.rp = 0
|
||||
r.buf = dest
|
||||
}
|
96
chunkreader/chunkreader_test.go
Normal file
96
chunkreader/chunkreader_test.go
Normal file
@ -0,0 +1,96 @@
|
||||
package chunkreader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:2]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
|
||||
}
|
||||
|
||||
n2, err := r.Next(2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n2, src[2:4]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(r.buf, src) != 0 {
|
||||
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
|
||||
}
|
||||
if r.rp != 4 {
|
||||
t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp)
|
||||
}
|
||||
if r.wp != 4 {
|
||||
t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:5]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1)
|
||||
}
|
||||
if len(r.buf) != 5 {
|
||||
t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkReaderDoesNotReuseBuf(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1)
|
||||
}
|
||||
|
||||
n2, err := r.Next(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n2, src[4:8]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
||||
t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1)
|
||||
}
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -eux
|
||||
|
||||
if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]]
|
||||
then
|
||||
sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common
|
||||
sudo rm -rf /var/lib/postgresql
|
||||
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
|
||||
sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list"
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION
|
||||
|
||||
sudo cp testsetup/pg_hba.conf /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
sudo sh -c "echo \"listen_addresses = '127.0.0.1'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
|
||||
sudo sh -c "cat testsetup/postgresql_ssl.conf >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
|
||||
|
||||
cd testsetup
|
||||
|
||||
# Generate CA, server, and encrypted client certificates.
|
||||
go run generate_certs.go
|
||||
|
||||
# Copy certificates to server directory and set permissions.
|
||||
sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt
|
||||
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/root.crt
|
||||
sudo cp localhost.key /var/lib/postgresql/$PGVERSION/main/server.key
|
||||
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.key
|
||||
sudo chmod 600 /var/lib/postgresql/$PGVERSION/main/server.key
|
||||
sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt
|
||||
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt
|
||||
|
||||
cp ca.pem /tmp
|
||||
cp pgx_sslcert.key /tmp
|
||||
cp pgx_sslcert.crt /tmp
|
||||
|
||||
cd ..
|
||||
|
||||
sudo /etc/init.d/postgresql restart
|
||||
|
||||
createdb -U postgres pgx_test
|
||||
psql -U postgres -f testsetup/postgresql_setup.sql pgx_test
|
||||
fi
|
||||
|
||||
if [[ "${PGVERSION-}" =~ ^cockroach ]]
|
||||
then
|
||||
wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz
|
||||
sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/
|
||||
cockroach start-single-node --insecure --background --listen-addr=localhost
|
||||
cockroach sql --insecure -e 'create database pgx_test'
|
||||
fi
|
||||
|
||||
if [ "${CRATEVERSION-}" != "" ]
|
||||
then
|
||||
docker run \
|
||||
-p "6543:5432" \
|
||||
-d \
|
||||
crate:"$CRATEVERSION" \
|
||||
crate \
|
||||
-Cnetwork.host=0.0.0.0 \
|
||||
-Ctransport.host=localhost \
|
||||
-Clicense.enterprise=false
|
||||
fi
|
79
conn_config_test.go.example
Normal file
79
conn_config_test.go.example
Normal file
@ -0,0 +1,79 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
// "crypto/tls"
|
||||
// "crypto/x509"
|
||||
// "fmt"
|
||||
// "go/build"
|
||||
// "io/ioutil"
|
||||
// "path"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
|
||||
// To skip tests for specific connection / authentication types set that connection param to nil
|
||||
var tcpConnConfig *pgx.ConnConfig = nil
|
||||
var unixSocketConnConfig *pgx.ConnConfig = nil
|
||||
var md5ConnConfig *pgx.ConnConfig = nil
|
||||
var plainPasswordConnConfig *pgx.ConnConfig = nil
|
||||
var invalidUserConnConfig *pgx.ConnConfig = nil
|
||||
var tlsConnConfig *pgx.ConnConfig = nil
|
||||
var customDialerConnConfig *pgx.ConnConfig = nil
|
||||
var replicationConnConfig *pgx.ConnConfig = nil
|
||||
var cratedbConnConfig *pgx.ConnConfig = nil
|
||||
|
||||
// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
|
||||
// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
|
||||
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
||||
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
// var replicationConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"}
|
||||
|
||||
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
//
|
||||
//// or to test client certs:
|
||||
//
|
||||
// var tlsConnConfig *pgx.ConnConfig
|
||||
//
|
||||
// func init() {
|
||||
// homeDir := build.Default.GOPATH
|
||||
// tlsConnConfig = &pgx.ConnConfig{
|
||||
// Host: "127.0.0.1",
|
||||
// User: "pgx_md5",
|
||||
// Password: "secret",
|
||||
// Database: "pgx_test",
|
||||
// TLSConfig: &tls.Config{
|
||||
// InsecureSkipVerify: true,
|
||||
// },
|
||||
// }
|
||||
// caCertPool := x509.NewCertPool()
|
||||
//
|
||||
// caPath := path.Join(homeDir, "/src/github.com/jackc/pgx/rootCA.pem")
|
||||
// caCert, err := ioutil.ReadFile(caPath)
|
||||
// if err != nil {
|
||||
// panic(fmt.Sprintf("unable to read CA file: %v", err))
|
||||
// }
|
||||
//
|
||||
// if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
// panic("unable to add CA to cert pool")
|
||||
// }
|
||||
//
|
||||
// tlsConnConfig.TLSConfig.RootCAs = caCertPool
|
||||
// tlsConnConfig.TLSConfig.ClientCAs = caCertPool
|
||||
//
|
||||
// sslCert := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.crt")
|
||||
// sslKey := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.key")
|
||||
// if (sslCert != "" && sslKey == "") || (sslCert == "" && sslKey != "") {
|
||||
// panic(`both "sslcert" and "sslkey" are required`)
|
||||
// }
|
||||
//
|
||||
// cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
|
||||
// if err != nil {
|
||||
// panic(fmt.Sprintf("unable to read cert: %v", err))
|
||||
// }
|
||||
//
|
||||
// tlsConnConfig.TLSConfig.Certificates = []tls.Certificate{cert}
|
||||
// }
|
36
conn_config_test.go.travis
Normal file
36
conn_config_test.go.travis
Normal file
@ -0,0 +1,36 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/jackc/pgx"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"}
|
||||
var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
|
||||
var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
||||
var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
var replicationConnConfig *pgx.ConnConfig = nil
|
||||
var cratedbConnConfig *pgx.ConnConfig = nil
|
||||
|
||||
func init() {
|
||||
pgVersion := os.Getenv("PGVERSION")
|
||||
|
||||
if len(pgVersion) > 0 {
|
||||
v, err := strconv.ParseFloat(pgVersion, 64)
|
||||
if err == nil && v >= 9.6 {
|
||||
replicationConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"}
|
||||
}
|
||||
}
|
||||
|
||||
crateVersion := os.Getenv("CRATEVERSION")
|
||||
if crateVersion != "" {
|
||||
cratedbConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", Port: 6543, User: "pgx", Password: "", Database: "pgx_test"}
|
||||
}
|
||||
}
|
||||
|
@ -1,55 +0,0 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func mustParseConfig(t testing.TB, connString string) *ConnConfig {
|
||||
config, err := ParseConfig(connString)
|
||||
require.Nil(t, err)
|
||||
return config
|
||||
}
|
||||
|
||||
func mustConnect(t testing.TB, config *ConnConfig) *Conn {
|
||||
conn, err := ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to establish connection: %v", err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
// Ensures the connection limits the size of its cached objects.
|
||||
// This test examines the internals of *Conn so must be in the same package.
|
||||
func TestStmtCacheSizeLimit(t *testing.T) {
|
||||
const cacheLimit = 16
|
||||
|
||||
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
connConfig.StatementCacheCapacity = cacheLimit
|
||||
conn := mustConnect(t, connConfig)
|
||||
defer func() {
|
||||
err := conn.Close(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// run a set of unique queries that should overflow the cache
|
||||
ctx := context.Background()
|
||||
for i := 0; i < cacheLimit*2; i++ {
|
||||
uniqueString := fmt.Sprintf("unique %d", i)
|
||||
uniqueSQL := fmt.Sprintf("select '%s'", uniqueString)
|
||||
var output string
|
||||
err := conn.QueryRow(ctx, uniqueSQL).Scan(&output)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uniqueString, output)
|
||||
}
|
||||
// preparedStatements contains cacheLimit+1 because deallocation happens before the query
|
||||
assert.Len(t, conn.preparedStatements, cacheLimit+1)
|
||||
assert.Equal(t, cacheLimit, conn.statementCache.Len())
|
||||
}
|
582
conn_pool.go
Normal file
582
conn_pool.go
Normal file
@ -0,0 +1,582 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
type ConnPoolConfig struct {
|
||||
ConnConfig
|
||||
MaxConnections int // max simultaneous connections to use, default 5, must be at least 2
|
||||
AfterConnect func(*Conn) error // function to call on every new connection
|
||||
AcquireTimeout time.Duration // max wait time when all connections are busy (0 means no timeout)
|
||||
}
|
||||
|
||||
type ConnPool struct {
|
||||
allConnections []*Conn
|
||||
availableConnections []*Conn
|
||||
cond *sync.Cond
|
||||
config ConnConfig // config used when establishing connection
|
||||
inProgressConnects int
|
||||
maxConnections int
|
||||
resetCount int
|
||||
afterConnect func(*Conn) error
|
||||
logger Logger
|
||||
logLevel LogLevel
|
||||
closed bool
|
||||
preparedStatements map[string]*PreparedStatement
|
||||
acquireTimeout time.Duration
|
||||
connInfo *pgtype.ConnInfo
|
||||
}
|
||||
|
||||
type ConnPoolStat struct {
|
||||
MaxConnections int // max simultaneous connections to use
|
||||
CurrentConnections int // current live connections
|
||||
AvailableConnections int // unused live connections
|
||||
}
|
||||
|
||||
// CheckedOutConnections returns the amount of connections that are currently
|
||||
// checked out from the pool.
|
||||
func (stat *ConnPoolStat) CheckedOutConnections() int {
|
||||
return stat.CurrentConnections - stat.AvailableConnections
|
||||
}
|
||||
|
||||
// ErrAcquireTimeout occurs when an attempt to acquire a connection times out.
|
||||
var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool")
|
||||
|
||||
// ErrClosedPool occurs on an attempt to acquire a connection from a closed pool.
|
||||
var ErrClosedPool = errors.New("cannot acquire from closed pool")
|
||||
|
||||
// NewConnPool creates a new ConnPool. config.ConnConfig is passed through to
|
||||
// Connect directly.
|
||||
func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
|
||||
p = new(ConnPool)
|
||||
p.config = config.ConnConfig
|
||||
p.connInfo = minimalConnInfo
|
||||
p.maxConnections = config.MaxConnections
|
||||
if p.maxConnections == 0 {
|
||||
p.maxConnections = 5
|
||||
}
|
||||
if p.maxConnections < 1 {
|
||||
return nil, errors.New("MaxConnections must be at least 1")
|
||||
}
|
||||
p.acquireTimeout = config.AcquireTimeout
|
||||
if p.acquireTimeout < 0 {
|
||||
return nil, errors.New("AcquireTimeout must be equal to or greater than 0")
|
||||
}
|
||||
|
||||
p.afterConnect = config.AfterConnect
|
||||
|
||||
if config.LogLevel != 0 {
|
||||
p.logLevel = config.LogLevel
|
||||
} else {
|
||||
// Preserve pre-LogLevel behavior by defaulting to LogLevelDebug
|
||||
p.logLevel = LogLevelDebug
|
||||
}
|
||||
p.logger = config.Logger
|
||||
if p.logger == nil {
|
||||
p.logLevel = LogLevelNone
|
||||
}
|
||||
|
||||
p.allConnections = make([]*Conn, 0, p.maxConnections)
|
||||
p.availableConnections = make([]*Conn, 0, p.maxConnections)
|
||||
p.preparedStatements = make(map[string]*PreparedStatement)
|
||||
p.cond = sync.NewCond(new(sync.Mutex))
|
||||
|
||||
// Initially establish one connection
|
||||
var c *Conn
|
||||
c, err = p.createConnection()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p.allConnections = append(p.allConnections, c)
|
||||
p.availableConnections = append(p.availableConnections, c)
|
||||
p.connInfo = c.ConnInfo.DeepCopy()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Acquire takes exclusive use of a connection until it is released.
|
||||
func (p *ConnPool) Acquire() (*Conn, error) {
|
||||
p.cond.L.Lock()
|
||||
c, err := p.acquire(nil)
|
||||
p.cond.L.Unlock()
|
||||
return c, err
|
||||
}
|
||||
|
||||
// deadlinePassed returns true if the given deadline has passed.
|
||||
func (p *ConnPool) deadlinePassed(deadline *time.Time) bool {
|
||||
return deadline != nil && time.Now().After(*deadline)
|
||||
}
|
||||
|
||||
// acquire performs acquision assuming pool is already locked
|
||||
func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
|
||||
if p.closed {
|
||||
return nil, ErrClosedPool
|
||||
}
|
||||
|
||||
// A connection is available
|
||||
// The pool works like a queue. Available connection will be returned
|
||||
// from the head. A new connection will be added to the tail.
|
||||
numAvailable := len(p.availableConnections)
|
||||
if numAvailable > 0 {
|
||||
c := p.availableConnections[0]
|
||||
c.poolResetCount = p.resetCount
|
||||
copy(p.availableConnections, p.availableConnections[1:])
|
||||
p.availableConnections = p.availableConnections[:numAvailable-1]
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Set initial timeout/deadline value. If the method (acquire) happens to
|
||||
// recursively call itself the deadline should retain its value.
|
||||
if deadline == nil && p.acquireTimeout > 0 {
|
||||
tmp := time.Now().Add(p.acquireTimeout)
|
||||
deadline = &tmp
|
||||
}
|
||||
|
||||
// Make sure the deadline (if it is) has not passed yet
|
||||
if p.deadlinePassed(deadline) {
|
||||
return nil, ErrAcquireTimeout
|
||||
}
|
||||
|
||||
// If there is a deadline then start a timeout timer
|
||||
var timer *time.Timer
|
||||
if deadline != nil {
|
||||
timer = time.AfterFunc(deadline.Sub(time.Now()), func() {
|
||||
p.cond.Broadcast()
|
||||
})
|
||||
defer timer.Stop()
|
||||
}
|
||||
|
||||
// No connections are available, but we can create more
|
||||
if len(p.allConnections)+p.inProgressConnects < p.maxConnections {
|
||||
// Create a new connection.
|
||||
// Careful here: createConnectionUnlocked() removes the current lock,
|
||||
// creates a connection and then locks it back.
|
||||
c, err := p.createConnectionUnlocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.poolResetCount = p.resetCount
|
||||
p.allConnections = append(p.allConnections, c)
|
||||
return c, nil
|
||||
}
|
||||
// All connections are in use and we cannot create more
|
||||
if p.logLevel >= LogLevelWarn {
|
||||
p.logger.Log(LogLevelWarn, "waiting for available connection", nil)
|
||||
}
|
||||
|
||||
// Wait until there is an available connection OR room to create a new connection
|
||||
for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections {
|
||||
if p.deadlinePassed(deadline) {
|
||||
return nil, ErrAcquireTimeout
|
||||
}
|
||||
p.cond.Wait()
|
||||
}
|
||||
|
||||
// Stop the timer so that we do not spawn it on every acquire call.
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
return p.acquire(deadline)
|
||||
}
|
||||
|
||||
// Release gives up use of a connection.
|
||||
func (p *ConnPool) Release(conn *Conn) {
|
||||
if conn.ctxInProgress {
|
||||
panic("should never release when context is in progress")
|
||||
}
|
||||
|
||||
if conn.txStatus != 'I' {
|
||||
conn.Exec("rollback")
|
||||
}
|
||||
|
||||
if len(conn.channels) > 0 {
|
||||
if err := conn.Unlisten("*"); err != nil {
|
||||
conn.die(err)
|
||||
}
|
||||
conn.channels = make(map[string]struct{})
|
||||
}
|
||||
conn.notifications = nil
|
||||
|
||||
p.cond.L.Lock()
|
||||
|
||||
if conn.poolResetCount != p.resetCount {
|
||||
conn.Close()
|
||||
p.cond.L.Unlock()
|
||||
p.cond.Signal()
|
||||
return
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
p.availableConnections = append(p.availableConnections, conn)
|
||||
} else {
|
||||
p.removeFromAllConnections(conn)
|
||||
}
|
||||
p.cond.L.Unlock()
|
||||
p.cond.Signal()
|
||||
}
|
||||
|
||||
// removeFromAllConnections Removes the given connection from the list.
|
||||
// It returns true if the connection was found and removed or false otherwise.
|
||||
func (p *ConnPool) removeFromAllConnections(conn *Conn) bool {
|
||||
for i, c := range p.allConnections {
|
||||
if conn == c {
|
||||
p.allConnections = append(p.allConnections[:i], p.allConnections[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Close ends the use of a connection pool. It prevents any new connections from
|
||||
// being acquired and closes available underlying connections. Any acquired
|
||||
// connections will be closed when they are released.
|
||||
func (p *ConnPool) Close() {
|
||||
p.cond.L.Lock()
|
||||
defer p.cond.L.Unlock()
|
||||
|
||||
p.closed = true
|
||||
|
||||
for _, c := range p.availableConnections {
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
// This will cause any checked out connections to be closed on release
|
||||
p.resetCount++
|
||||
}
|
||||
|
||||
// Reset closes all open connections, but leaves the pool open. It is intended
|
||||
// for use when an error is detected that would disrupt all connections (such as
|
||||
// a network interruption or a server state change).
|
||||
//
|
||||
// It is safe to reset a pool while connections are checked out. Those
|
||||
// connections will be closed when they are returned to the pool.
|
||||
func (p *ConnPool) Reset() {
|
||||
p.cond.L.Lock()
|
||||
defer p.cond.L.Unlock()
|
||||
|
||||
p.resetCount++
|
||||
p.allConnections = p.allConnections[0:0]
|
||||
|
||||
for _, conn := range p.availableConnections {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
p.availableConnections = p.availableConnections[0:0]
|
||||
}
|
||||
|
||||
// invalidateAcquired causes all acquired connections to be closed when released.
|
||||
// The pool must already be locked.
|
||||
func (p *ConnPool) invalidateAcquired() {
|
||||
p.resetCount++
|
||||
|
||||
for _, c := range p.availableConnections {
|
||||
c.poolResetCount = p.resetCount
|
||||
}
|
||||
|
||||
p.allConnections = p.allConnections[:len(p.availableConnections)]
|
||||
copy(p.allConnections, p.availableConnections)
|
||||
}
|
||||
|
||||
// Stat returns connection pool statistics
|
||||
func (p *ConnPool) Stat() (s ConnPoolStat) {
|
||||
p.cond.L.Lock()
|
||||
defer p.cond.L.Unlock()
|
||||
|
||||
s.MaxConnections = p.maxConnections
|
||||
s.CurrentConnections = len(p.allConnections)
|
||||
s.AvailableConnections = len(p.availableConnections)
|
||||
return
|
||||
}
|
||||
|
||||
func (p *ConnPool) createConnection() (*Conn, error) {
|
||||
c, err := connect(p.config, p.connInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.afterConnectionCreated(c)
|
||||
}
|
||||
|
||||
// createConnectionUnlocked Removes the current lock, creates a new connection, and
|
||||
// then locks it back.
|
||||
// Here is the point: lets say our pool dialer's OpenTimeout is set to 3 seconds.
|
||||
// And we have a pool with 20 connections in it, and we try to acquire them all at
|
||||
// startup.
|
||||
// If it happens that the remote server is not accessible, then the first connection
|
||||
// in the pool blocks all the others for 3 secs, before it gets the timeout. Then
|
||||
// connection #2 holds the lock and locks everything for the next 3 secs until it
|
||||
// gets OpenTimeout err, etc. And the very last 20th connection will fail only after
|
||||
// 3 * 20 = 60 secs.
|
||||
// To avoid this we put Connect(p.config) outside of the lock (it is thread safe)
|
||||
// what would allow us to make all the 20 connection in parallel (more or less).
|
||||
func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
|
||||
p.inProgressConnects++
|
||||
p.cond.L.Unlock()
|
||||
c, err := Connect(p.config)
|
||||
p.cond.L.Lock()
|
||||
p.inProgressConnects--
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.afterConnectionCreated(c)
|
||||
}
|
||||
|
||||
// afterConnectionCreated executes (if it is) afterConnect() callback and prepares
|
||||
// all the known statements for the new connection.
|
||||
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
||||
if p.afterConnect != nil {
|
||||
err := p.afterConnect(c)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, ps := range p.preparedStatements {
|
||||
if _, err := c.Prepare(ps.Name, ps.SQL); err != nil {
|
||||
c.die(err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Exec acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
var c *Conn
|
||||
if c, err = p.Acquire(); err != nil {
|
||||
return
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.Exec(sql, arguments...)
|
||||
}
|
||||
|
||||
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
var c *Conn
|
||||
if c, err = p.Acquire(); err != nil {
|
||||
return
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.ExecEx(ctx, sql, options, arguments...)
|
||||
}
|
||||
|
||||
// Query acquires a connection and delegates the call to that connection. When
|
||||
// *Rows are closed, the connection is released automatically.
|
||||
func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||
return &Rows{closed: true, err: err}, err
|
||||
}
|
||||
|
||||
rows, err := c.Query(sql, args...)
|
||||
if err != nil {
|
||||
p.Release(c)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
rows.connPool = p
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||
return &Rows{closed: true, err: err}, err
|
||||
}
|
||||
|
||||
rows, err := c.QueryEx(ctx, sql, options, args...)
|
||||
if err != nil {
|
||||
p.Release(c)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
rows.connPool = p
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// QueryRow acquires a connection and delegates the call to that connection. The
|
||||
// connection is released automatically after Scan is called on the returned
|
||||
// *Row.
|
||||
func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
|
||||
rows, _ := p.Query(sql, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
|
||||
rows, _ := p.QueryEx(ctx, sql, options, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
// Begin acquires a connection and begins a transaction on it. When the
|
||||
// transaction is closed the connection will be automatically released.
|
||||
func (p *ConnPool) Begin() (*Tx, error) {
|
||||
return p.BeginEx(context.Background(), nil)
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement on a connection in the pool to test the
|
||||
// statement is valid. If it succeeds all connections accessed through the pool
|
||||
// will have the statement available.
|
||||
//
|
||||
// Prepare creates a prepared statement with name and sql. sql can contain
|
||||
// placeholders for bound parameters. These placeholders are referenced
|
||||
// positional as $1, $2, etc.
|
||||
//
|
||||
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with
|
||||
// the same name and sql arguments. This allows a code path to Prepare and
|
||||
// Query/Exec/PrepareEx without concern for if the statement has already been prepared.
|
||||
func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) {
|
||||
return p.PrepareEx(context.Background(), name, sql, nil)
|
||||
}
|
||||
|
||||
// PrepareEx creates a prepared statement on a connection in the pool to test the
|
||||
// statement is valid. If it succeeds all connections accessed through the pool
|
||||
// will have the statement available.
|
||||
//
|
||||
// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders
|
||||
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
|
||||
// It differs from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct
|
||||
//
|
||||
// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
|
||||
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without
|
||||
// concern for if the statement has already been prepared.
|
||||
func (p *ConnPool) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
|
||||
p.cond.L.Lock()
|
||||
defer p.cond.L.Unlock()
|
||||
|
||||
if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql {
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
c, err := p.acquire(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.availableConnections = append(p.availableConnections, c)
|
||||
|
||||
// Double check that the statement was not prepared by someone else
|
||||
// while we were acquiring the connection (since acquire is not fully
|
||||
// blocking now, see createConnectionUnlocked())
|
||||
if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql {
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
ps, err := c.PrepareEx(ctx, name, sql, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, c := range p.availableConnections {
|
||||
_, err := c.PrepareEx(ctx, name, sql, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
p.invalidateAcquired()
|
||||
p.preparedStatements[name] = ps
|
||||
|
||||
return ps, err
|
||||
}
|
||||
|
||||
// Deallocate releases a prepared statement from all connections in the pool.
|
||||
func (p *ConnPool) Deallocate(name string) (err error) {
|
||||
p.cond.L.Lock()
|
||||
defer p.cond.L.Unlock()
|
||||
|
||||
for _, c := range p.availableConnections {
|
||||
if err := c.Deallocate(name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p.invalidateAcquired()
|
||||
delete(p.preparedStatements, name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginEx acquires a connection and starts a transaction with txOptions
|
||||
// determining the transaction mode. When the transaction is closed the
|
||||
// connection will be automatically released.
|
||||
func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
|
||||
for {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx, err := c.BeginEx(ctx, txOptions)
|
||||
if err != nil {
|
||||
alive := c.IsAlive()
|
||||
p.Release(c)
|
||||
|
||||
// If connection is still alive then the error is not something trying
|
||||
// again on a new connection would fix, so just return the error. But
|
||||
// if the connection is dead try to acquire a new connection and try
|
||||
// again.
|
||||
if alive || ctx.Err() != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
tx.connPool = p
|
||||
return tx, nil
|
||||
}
|
||||
}
|
||||
|
||||
// CopyFrom acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.CopyFrom(tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
||||
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.CopyFromReader(r, sql)
|
||||
}
|
||||
|
||||
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.CopyToWriter(w, sql, args...)
|
||||
}
|
||||
|
||||
// BeginBatch acquires a connection and begins a batch on that connection. When
|
||||
// *Batch is finished, the connection is released automatically.
|
||||
func (p *ConnPool) BeginBatch() *Batch {
|
||||
c, err := p.Acquire()
|
||||
return &Batch{conn: c, connPool: p, err: err}
|
||||
}
|
44
conn_pool_private_test.go
Normal file
44
conn_pool_private_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func compareConnSlices(slice1, slice2 []*Conn) bool {
|
||||
if len(slice1) != len(slice2) {
|
||||
return false
|
||||
}
|
||||
for i, c := range slice1 {
|
||||
if c != slice2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestConnPoolRemoveFromAllConnections(t *testing.T) {
|
||||
t.Parallel()
|
||||
pool := ConnPool{}
|
||||
conn1 := &Conn{}
|
||||
conn2 := &Conn{}
|
||||
conn3 := &Conn{}
|
||||
|
||||
// First element
|
||||
pool.allConnections = []*Conn{conn1, conn2, conn3}
|
||||
pool.removeFromAllConnections(conn1)
|
||||
if !compareConnSlices(pool.allConnections, []*Conn{conn2, conn3}) {
|
||||
t.Fatal("First element test failed")
|
||||
}
|
||||
// Element somewhere in the middle
|
||||
pool.allConnections = []*Conn{conn1, conn2, conn3}
|
||||
pool.removeFromAllConnections(conn2)
|
||||
if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn3}) {
|
||||
t.Fatal("Middle element test failed")
|
||||
}
|
||||
// Last element
|
||||
pool.allConnections = []*Conn{conn1, conn2, conn3}
|
||||
pool.removeFromAllConnections(conn3)
|
||||
if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn2}) {
|
||||
t.Fatal("Last element test failed")
|
||||
}
|
||||
}
|
1083
conn_pool_test.go
Normal file
1083
conn_pool_test.go
Normal file
File diff suppressed because it is too large
Load Diff
2988
conn_test.go
2988
conn_test.go
File diff suppressed because it is too large
Load Diff
400
copy_from.go
400
copy_from.go
@ -2,22 +2,22 @@ package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
||||
// making it usable by *Conn.CopyFrom.
|
||||
func CopyFromRows(rows [][]any) CopyFromSource {
|
||||
func CopyFromRows(rows [][]interface{}) CopyFromSource {
|
||||
return ©FromRows{rows: rows, idx: -1}
|
||||
}
|
||||
|
||||
type copyFromRows struct {
|
||||
rows [][]any
|
||||
rows [][]interface{}
|
||||
idx int
|
||||
}
|
||||
|
||||
@ -26,7 +26,7 @@ func (ctr *copyFromRows) Next() bool {
|
||||
return ctr.idx < len(ctr.rows)
|
||||
}
|
||||
|
||||
func (ctr *copyFromRows) Values() ([]any, error) {
|
||||
func (ctr *copyFromRows) Values() ([]interface{}, error) {
|
||||
return ctr.rows[ctr.idx], nil
|
||||
}
|
||||
|
||||
@ -34,63 +34,6 @@ func (ctr *copyFromRows) Err() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CopyFromSlice returns a CopyFromSource interface over a dynamic func
|
||||
// making it usable by *Conn.CopyFrom.
|
||||
func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource {
|
||||
return ©FromSlice{next: next, idx: -1, len: length}
|
||||
}
|
||||
|
||||
type copyFromSlice struct {
|
||||
next func(int) ([]any, error)
|
||||
idx int
|
||||
len int
|
||||
err error
|
||||
}
|
||||
|
||||
func (cts *copyFromSlice) Next() bool {
|
||||
cts.idx++
|
||||
return cts.idx < cts.len
|
||||
}
|
||||
|
||||
func (cts *copyFromSlice) Values() ([]any, error) {
|
||||
values, err := cts.next(cts.idx)
|
||||
if err != nil {
|
||||
cts.err = err
|
||||
}
|
||||
return values, err
|
||||
}
|
||||
|
||||
func (cts *copyFromSlice) Err() error {
|
||||
return cts.err
|
||||
}
|
||||
|
||||
// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
|
||||
// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
|
||||
// or it returns an error. If nxtf returns an error, the copy is aborted.
|
||||
func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
|
||||
return ©FromFunc{next: nxtf}
|
||||
}
|
||||
|
||||
type copyFromFunc struct {
|
||||
next func() ([]any, error)
|
||||
valueRow []any
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *copyFromFunc) Next() bool {
|
||||
g.valueRow, g.err = g.next()
|
||||
// only return true if valueRow exists and no error
|
||||
return g.valueRow != nil && g.err == nil
|
||||
}
|
||||
|
||||
func (g *copyFromFunc) Values() ([]any, error) {
|
||||
return g.valueRow, g.err
|
||||
}
|
||||
|
||||
func (g *copyFromFunc) Err() error {
|
||||
return g.err
|
||||
}
|
||||
|
||||
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
|
||||
type CopyFromSource interface {
|
||||
// Next returns true if there is another row and makes the next row data
|
||||
@ -99,7 +42,7 @@ type CopyFromSource interface {
|
||||
Next() bool
|
||||
|
||||
// Values returns the values for the current row.
|
||||
Values() ([]any, error)
|
||||
Values() ([]interface{}, error)
|
||||
|
||||
// Err returns any error that has been encountered by the CopyFromSource. If
|
||||
// this is not nil *Conn.CopyFrom will abort the copy.
|
||||
@ -112,17 +55,42 @@ type copyFrom struct {
|
||||
columnNames []string
|
||||
rowSrc CopyFromSource
|
||||
readerErrChan chan error
|
||||
mode QueryExecMode
|
||||
}
|
||||
|
||||
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
if ct.conn.copyFromTracer != nil {
|
||||
ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
|
||||
TableName: ct.tableName,
|
||||
ColumnNames: ct.columnNames,
|
||||
})
|
||||
}
|
||||
func (ct *copyFrom) readUntilReadyForQuery() {
|
||||
for {
|
||||
msg, err := ct.conn.rxMsg()
|
||||
if err != nil {
|
||||
ct.readerErrChan <- err
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
ct.conn.rxReadyForQuery(msg)
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
case *pgproto3.CommandComplete:
|
||||
case *pgproto3.ErrorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
|
||||
default:
|
||||
err = ct.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *copyFrom) waitForReaderDone() error {
|
||||
var err error
|
||||
for err = range ct.readerErrChan {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ct *copyFrom) run() (int, error) {
|
||||
quotedTableName := ct.tableName.Sanitize()
|
||||
cbuf := &bytes.Buffer{}
|
||||
for i, cn := range ct.columnNames {
|
||||
@ -133,144 +101,238 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
}
|
||||
quotedColumnNames := cbuf.String()
|
||||
|
||||
var sd *pgconn.StatementDescription
|
||||
switch ct.mode {
|
||||
case QueryExecModeExec, QueryExecModeSimpleProtocol:
|
||||
// These modes don't support the binary format. Before the inclusion of the
|
||||
// QueryExecModes, Conn.Prepare was called on every COPY operation to get
|
||||
// the OIDs. These prepared statements were not cached.
|
||||
//
|
||||
// Since that's the same behavior provided by QueryExecModeDescribeExec,
|
||||
// we'll default to that mode.
|
||||
ct.mode = QueryExecModeDescribeExec
|
||||
fallthrough
|
||||
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
|
||||
var err error
|
||||
sd, err = ct.conn.getStatementDescription(
|
||||
ctx,
|
||||
ct.mode,
|
||||
fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("statement description failed: %w", err)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
|
||||
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r, w := io.Pipe()
|
||||
doneChan := make(chan struct{})
|
||||
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err = ct.conn.readUntilCopyInResponse()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
|
||||
buf := ct.conn.wbuf
|
||||
panicked := true
|
||||
|
||||
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
moreRows := true
|
||||
for moreRows {
|
||||
var err error
|
||||
moreRows, buf, err = ct.buildCopyBuf(buf, sd)
|
||||
if err != nil {
|
||||
w.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if ct.rowSrc.Err() != nil {
|
||||
w.CloseWithError(ct.rowSrc.Err())
|
||||
return
|
||||
}
|
||||
|
||||
if len(buf) > 0 {
|
||||
_, err = w.Write(buf)
|
||||
if err != nil {
|
||||
w.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
buf = buf[:0]
|
||||
go ct.readUntilReadyForQuery()
|
||||
defer ct.waitForReaderDone()
|
||||
defer func() {
|
||||
if panicked {
|
||||
ct.conn.die(errors.New("panic while in copy from"))
|
||||
}
|
||||
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
r.Close()
|
||||
<-doneChan
|
||||
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
var sentCount int
|
||||
|
||||
moreRows := true
|
||||
for moreRows {
|
||||
select {
|
||||
case err = <-ct.readerErrChan:
|
||||
panicked = false
|
||||
return 0, err
|
||||
default:
|
||||
}
|
||||
|
||||
var addedRows int
|
||||
var err error
|
||||
moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps)
|
||||
if err != nil {
|
||||
panicked = false
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
}
|
||||
sentCount += addedRows
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err = ct.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
panicked = false
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Directly manipulate wbuf to reset to reuse the same buffer
|
||||
buf = buf[0:5]
|
||||
|
||||
if ct.conn.copyFromTracer != nil {
|
||||
ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
|
||||
CommandTag: commandTag,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
|
||||
return commandTag.RowsAffected(), err
|
||||
if ct.rowSrc.Err() != nil {
|
||||
panicked = false
|
||||
ct.cancelCopyIn()
|
||||
return 0, ct.rowSrc.Err()
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
buf = append(buf, copyDone)
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
_, err = ct.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
panicked = false
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.waitForReaderDone()
|
||||
if err != nil {
|
||||
panicked = false
|
||||
return 0, err
|
||||
}
|
||||
|
||||
panicked = false
|
||||
return sentCount, nil
|
||||
}
|
||||
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
|
||||
const sendBufSize = 65536 - 5 // The packet has a 5-byte header
|
||||
lastBufLen := 0
|
||||
largestRowLen := 0
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) {
|
||||
var rowCount int
|
||||
|
||||
for ct.rowSrc.Next() {
|
||||
lastBufLen = len(buf)
|
||||
|
||||
values, err := ct.rowSrc.Values()
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
return false, nil, 0, err
|
||||
}
|
||||
if len(values) != len(ct.columnNames) {
|
||||
return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
|
||||
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
return false, nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
rowLen := len(buf) - lastBufLen
|
||||
if rowLen > largestRowLen {
|
||||
largestRowLen = rowLen
|
||||
}
|
||||
rowCount++
|
||||
|
||||
// Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of
|
||||
// io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531
|
||||
// 13, 65531, 13, 65531, 13.
|
||||
if len(buf) > sendBufSize-largestRowLen {
|
||||
return true, buf, nil
|
||||
if len(buf) > 65536 {
|
||||
return true, buf, rowCount, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, buf, nil
|
||||
return false, buf, rowCount, nil
|
||||
}
|
||||
|
||||
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and
|
||||
// an error.
|
||||
func (c *Conn) readUntilCopyInResponse() error {
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyInResponse:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *copyFrom) cancelCopyIn() error {
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyFail)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, "client error: abort"...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err := ct.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
|
||||
// It returns the number of rows copied and an error.
|
||||
//
|
||||
// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
|
||||
// for the type of each column. Almost all types implemented by pgx support the binary format.
|
||||
//
|
||||
// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
|
||||
// Conn.LoadType and pgtype.Map.RegisterType.
|
||||
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
|
||||
// CopyFrom requires all values use the binary format. Almost all types
|
||||
// implemented by pgx use the binary format by default. Types implementing
|
||||
// Encoder can only be used if they encode to the binary format.
|
||||
func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
|
||||
ct := ©From{
|
||||
conn: c,
|
||||
tableName: tableName,
|
||||
columnNames: columnNames,
|
||||
rowSrc: rowSrc,
|
||||
readerErrChan: make(chan error),
|
||||
mode: c.config.DefaultQueryExecMode,
|
||||
}
|
||||
|
||||
return ct.run(ctx)
|
||||
return ct.run()
|
||||
}
|
||||
|
||||
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
|
||||
func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
|
||||
if err := c.sendSimpleQuery(sql); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.readUntilCopyInResponse(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf := c.wbuf
|
||||
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
for {
|
||||
n, err := r.Read(buf[5:cap(buf)])
|
||||
if err == io.EOF && n == 0 {
|
||||
break
|
||||
}
|
||||
buf = buf[0 : n+5]
|
||||
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||
|
||||
if _, err := c.conn.Write(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
buf = buf[:0]
|
||||
buf = append(buf, copyDone)
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
if _, err := c.conn.Write(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return "", err
|
||||
case *pgproto3.CommandComplete:
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return "", c.rxErrorResponse(msg)
|
||||
default:
|
||||
return "", c.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
63
copy_to.go
Normal file
63
copy_to.go
Normal file
@ -0,0 +1,63 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
func (c *Conn) readUntilCopyOutResponse() error {
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyOutResponse:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
|
||||
if err := c.sendSimpleQuery(sql, args...); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.readUntilCopyOutResponse(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyDone:
|
||||
break
|
||||
case *pgproto3.CopyData:
|
||||
_, err := w.Write(msg.Data)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return "", err
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return "", nil
|
||||
case *pgproto3.CommandComplete:
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return "", c.rxErrorResponse(msg)
|
||||
default:
|
||||
return "", c.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
115
copy_to_test.go
Normal file
115
copy_to_test.go
Normal file
@ -0,0 +1,115 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
func TestConnCopyToWriterSmall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g json
|
||||
)`)
|
||||
mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`)
|
||||
mustExec(t, conn, `insert into foo values (null, null, null, null, null, null, null)`)
|
||||
|
||||
inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
|
||||
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
|
||||
|
||||
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
||||
|
||||
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyToWriter: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 2 {
|
||||
t.Errorf("Expected CopyToWriter to return 2 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
|
||||
t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes()))
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToWriterLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g json,
|
||||
h bytea
|
||||
)`)
|
||||
inputBytes := make([]byte, 0)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`)
|
||||
inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
|
||||
}
|
||||
|
||||
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
||||
|
||||
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 1000 {
|
||||
t.Errorf("Expected CopyToWriter to return 1 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
|
||||
t.Errorf("Input rows and output rows do not equal")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToWriterQueryError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
outputWriter := bytes.NewBuffer(make([]byte, 0))
|
||||
|
||||
res, err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyToWriter return error, but it did not")
|
||||
}
|
||||
|
||||
if _, ok := err.(pgx.PgError); !ok {
|
||||
t.Errorf("Expected CopyToWriter return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyToWriter to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
256
derived_types.go
256
derived_types.go
@ -1,256 +0,0 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
/*
|
||||
buildLoadDerivedTypesSQL generates the correct query for retrieving type information.
|
||||
|
||||
pgVersion: the major version of the PostgreSQL server
|
||||
typeNames: the names of the types to load. If nil, load all types.
|
||||
*/
|
||||
func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string {
|
||||
supportsMultirange := (pgVersion >= 14)
|
||||
var typeNamesClause string
|
||||
|
||||
if typeNames == nil {
|
||||
// This should not occur; this will not return any types
|
||||
typeNamesClause = "= ''"
|
||||
} else {
|
||||
typeNamesClause = "= ANY($1)"
|
||||
}
|
||||
parts := make([]string, 0, 10)
|
||||
|
||||
// Each of the type names provided might be found in pg_class or pg_type.
|
||||
// Additionally, it may or may not include a schema portion.
|
||||
parts = append(parts, `
|
||||
WITH RECURSIVE
|
||||
-- find the OIDs in pg_class which match one of the provided type names
|
||||
selected_classes(oid,reltype) AS (
|
||||
-- this query uses the namespace search path, so will match type names without a schema prefix
|
||||
SELECT pg_class.oid, pg_class.reltype
|
||||
FROM pg_catalog.pg_class
|
||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
|
||||
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
|
||||
AND relname `, typeNamesClause, `
|
||||
UNION ALL
|
||||
-- this query will only match type names which include the schema prefix
|
||||
SELECT pg_class.oid, pg_class.reltype
|
||||
FROM pg_class
|
||||
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
|
||||
WHERE nspname || '.' || relname `, typeNamesClause, `
|
||||
),
|
||||
selected_types(oid) AS (
|
||||
-- collect the OIDs from pg_types which correspond to the selected classes
|
||||
SELECT reltype AS oid
|
||||
FROM selected_classes
|
||||
UNION ALL
|
||||
-- as well as any other type names which match our criteria
|
||||
SELECT pg_type.oid
|
||||
FROM pg_type
|
||||
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
|
||||
WHERE typname `, typeNamesClause, `
|
||||
OR nspname || '.' || typname `, typeNamesClause, `
|
||||
),
|
||||
-- this builds a parent/child mapping of objects, allowing us to know
|
||||
-- all the child (ie: dependent) types that a parent (type) requires
|
||||
-- As can be seen, there are 3 ways this can occur (the last of which
|
||||
-- is due to being a composite class, where the composite fields are children)
|
||||
pc(parent, child) AS (
|
||||
SELECT parent.oid, parent.typelem
|
||||
FROM pg_type parent
|
||||
WHERE parent.typtype = 'b' AND parent.typelem != 0
|
||||
UNION ALL
|
||||
SELECT parent.oid, parent.typbasetype
|
||||
FROM pg_type parent
|
||||
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
|
||||
UNION ALL
|
||||
SELECT pg_type.oid, atttypid
|
||||
FROM pg_attribute
|
||||
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
|
||||
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
|
||||
WHERE NOT attisdropped
|
||||
AND attnum > 0
|
||||
),
|
||||
-- Now construct a recursive query which includes a 'depth' element.
|
||||
-- This is used to ensure that the "youngest" children are registered before
|
||||
-- their parents.
|
||||
relationships(parent, child, depth) AS (
|
||||
SELECT DISTINCT 0::OID, selected_types.oid, 0
|
||||
FROM selected_types
|
||||
UNION ALL
|
||||
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
|
||||
FROM selected_classes c
|
||||
inner join pg_type ON (c.reltype = pg_type.oid)
|
||||
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
|
||||
UNION ALL
|
||||
SELECT pc.parent, pc.child, relationships.depth + 1
|
||||
FROM pc
|
||||
INNER JOIN relationships ON (pc.parent = relationships.child)
|
||||
),
|
||||
-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration
|
||||
composite AS (
|
||||
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
|
||||
FROM pg_attribute
|
||||
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
|
||||
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
|
||||
WHERE NOT attisdropped
|
||||
AND attnum > 0
|
||||
GROUP BY pg_type.oid
|
||||
)
|
||||
-- Bring together this information, showing all the information which might possibly be required
|
||||
-- to complete the registration, applying filters to only show the items which relate to the selected
|
||||
-- types/classes.
|
||||
SELECT typname,
|
||||
pg_namespace.nspname,
|
||||
typtype,
|
||||
typbasetype,
|
||||
typelem,
|
||||
pg_type.oid,`)
|
||||
if supportsMultirange {
|
||||
parts = append(parts, `
|
||||
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
|
||||
} else {
|
||||
parts = append(parts, `
|
||||
0 AS rngtypid,`)
|
||||
}
|
||||
parts = append(parts, `
|
||||
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
|
||||
attnames, atttypids
|
||||
FROM relationships
|
||||
INNER JOIN pg_type ON (pg_type.oid = relationships.child)
|
||||
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
|
||||
if supportsMultirange {
|
||||
parts = append(parts, `
|
||||
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
|
||||
}
|
||||
|
||||
parts = append(parts, `
|
||||
LEFT OUTER JOIN composite USING (oid)
|
||||
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
|
||||
WHERE NOT (typtype = 'b' AND typelem = 0)`)
|
||||
parts = append(parts, `
|
||||
GROUP BY typname, pg_namespace.nspname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
|
||||
if supportsMultirange {
|
||||
parts = append(parts, `
|
||||
multirange.rngtypid,`)
|
||||
}
|
||||
parts = append(parts, `
|
||||
attnames, atttypids
|
||||
ORDER BY MAX(depth) desc, typname;`)
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
type derivedTypeInfo struct {
|
||||
Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32
|
||||
TypeName, Typtype, NspName string
|
||||
Attnames []string
|
||||
Atttypids []uint32
|
||||
}
|
||||
|
||||
// LoadTypes performs a single (complex) query, returning all the required
|
||||
// information to register the named types, as well as any other types directly
|
||||
// or indirectly required to complete the registration.
|
||||
// The result of this call can be passed into RegisterTypes to complete the process.
|
||||
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
|
||||
m := c.TypeMap()
|
||||
if len(typeNames) == 0 {
|
||||
return nil, fmt.Errorf("No type names were supplied.")
|
||||
}
|
||||
|
||||
// Disregard server version errors. This will result in
|
||||
// the SQL not support recent structures such as multirange
|
||||
serverVersion, _ := serverVersion(c)
|
||||
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
|
||||
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("While generating load types query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
result := make([]*pgtype.Type, 0, 100)
|
||||
for rows.Next() {
|
||||
ti := derivedTypeInfo{}
|
||||
err = rows.Scan(&ti.TypeName, &ti.NspName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("While scanning type information: %w", err)
|
||||
}
|
||||
var type_ *pgtype.Type
|
||||
switch ti.Typtype {
|
||||
case "b": // array
|
||||
dt, ok := m.TypeForOID(ti.Typelem)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName)
|
||||
}
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}
|
||||
case "c": // composite
|
||||
var fields []pgtype.CompositeCodecField
|
||||
for i, fieldName := range ti.Attnames {
|
||||
dt, ok := m.TypeForOID(ti.Atttypids[i])
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i])
|
||||
}
|
||||
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}}
|
||||
case "d": // domain
|
||||
dt, ok := m.TypeForOID(ti.Typbasetype)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName)
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec}
|
||||
case "e": // enum
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}}
|
||||
case "r": // range
|
||||
dt, ok := m.TypeForOID(ti.Rngsubtype)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName)
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}}
|
||||
case "m": // multirange
|
||||
dt, ok := m.TypeForOID(ti.Rngtypid)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName)
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
|
||||
}
|
||||
|
||||
// the type_ is imposible to be null
|
||||
m.RegisterType(type_)
|
||||
if ti.NspName != "" {
|
||||
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
|
||||
m.RegisterType(nspType)
|
||||
result = append(result, nspType)
|
||||
}
|
||||
result = append(result, type_)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// serverVersion returns the postgresql server version.
|
||||
func serverVersion(c *Conn) (int64, error) {
|
||||
serverVersionStr := c.PgConn().ParameterStatus("server_version")
|
||||
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
|
||||
// if not PostgreSQL do nothing
|
||||
if serverVersionStr == "" {
|
||||
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
|
||||
}
|
||||
|
||||
version, err := strconv.ParseInt(serverVersionStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
|
||||
}
|
||||
return version, nil
|
||||
}
|
@ -1,40 +0,0 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(ctx, `
|
||||
drop type if exists dtype_test;
|
||||
drop domain if exists anotheruint64;
|
||||
|
||||
create domain anotheruint64 as numeric(20,0);
|
||||
create type dtype_test as (
|
||||
a text,
|
||||
b int4,
|
||||
c anotheruint64,
|
||||
d anotheruint64[]
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(ctx, "drop type dtype_test")
|
||||
defer conn.Exec(ctx, "drop domain anotheruint64")
|
||||
|
||||
types, err := conn.LoadTypes(ctx, []string{"dtype_test"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, types, 6)
|
||||
require.Equal(t, types[0].Name, "public.anotheruint64")
|
||||
require.Equal(t, types[1].Name, "anotheruint64")
|
||||
require.Equal(t, types[2].Name, "public._anotheruint64")
|
||||
require.Equal(t, types[3].Name, "_anotheruint64")
|
||||
require.Equal(t, types[4].Name, "public.dtype_test")
|
||||
require.Equal(t, types[5].Name, "dtype_test")
|
||||
})
|
||||
}
|
291
doc.go
291
doc.go
@ -1,66 +1,58 @@
|
||||
// Package pgx is a PostgreSQL database driver.
|
||||
/*
|
||||
pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar
|
||||
to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use
|
||||
github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for
|
||||
details.
|
||||
|
||||
Establishing a Connection
|
||||
|
||||
The primary way of establishing a connection is with [pgx.Connect]:
|
||||
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
|
||||
The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be
|
||||
specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the
|
||||
connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection
|
||||
string.
|
||||
|
||||
Connection Pool
|
||||
|
||||
[*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package
|
||||
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
|
||||
pgx provides lower level access to PostgreSQL than the standard database/sql.
|
||||
It remains as similar to the database/sql interface as possible while
|
||||
providing better speed and access to PostgreSQL specific features. Import
|
||||
github.com/jackc/pgx/stdlib to use pgx as a database/sql compatible driver.
|
||||
|
||||
Query Interface
|
||||
|
||||
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
|
||||
ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(),
|
||||
rows.Scan, and rows.Err().
|
||||
pgx implements Query and Scan in the familiar database/sql style.
|
||||
|
||||
CollectRows can be used collect all returned rows into a slice.
|
||||
var sum int32
|
||||
|
||||
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5)
|
||||
numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32])
|
||||
// Send the query to the server. The returned rows MUST be closed
|
||||
// before conn can be used again.
|
||||
rows, err := conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
return err
|
||||
return err
|
||||
}
|
||||
// numbers => [1 2 3 4 5]
|
||||
|
||||
ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows
|
||||
directly.
|
||||
// rows.Close is called by rows.Next when all rows are read
|
||||
// or an error occurs in Next or Scan. So it may optionally be
|
||||
// omitted if nothing in the rows.Next loop can panic. It is
|
||||
// safe to close rows multiple times.
|
||||
defer rows.Close()
|
||||
|
||||
var sum, n int32
|
||||
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
||||
_, err := pgx.ForEachRow(rows, []any{&n}, func() error {
|
||||
sum += n
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
// Iterate through the result set
|
||||
for rows.Next() {
|
||||
var n int32
|
||||
err = rows.Scan(&n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sum += n
|
||||
}
|
||||
|
||||
// Any errors encountered by rows.Next or rows.Scan will be returned here
|
||||
if rows.Err() != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// No errors found - do something with sum
|
||||
|
||||
pgx also implements QueryRow in the same style as database/sql.
|
||||
|
||||
var name string
|
||||
var weight int64
|
||||
err := conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
|
||||
err := conn.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Use Exec to execute a query that does not return a result set.
|
||||
|
||||
commandTag, err := conn.Exec(context.Background(), "delete from widgets where id=$1", 42)
|
||||
commandTag, err := conn.Exec("delete from widgets where id=$1", 42)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -68,127 +60,194 @@ Use Exec to execute a query that does not return a result set.
|
||||
return errors.New("No row found to delete")
|
||||
}
|
||||
|
||||
PostgreSQL Data Types
|
||||
Connection Pool
|
||||
|
||||
pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types
|
||||
directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may
|
||||
require type registration. See that package's documentation for details.
|
||||
Connection pool usage is explicit and configurable. In pgx, a connection can be
|
||||
created and managed directly, or a connection pool with a configurable maximum
|
||||
connections can be used. The connection pool offers an after connect hook that
|
||||
allows every connection to be automatically setup before being made available in
|
||||
the connection pool.
|
||||
|
||||
It delegates methods such as QueryRow to an automatically checked out and
|
||||
released connection so you can avoid manually acquiring and releasing
|
||||
connections when you do not need that level of control.
|
||||
|
||||
var name string
|
||||
var weight int64
|
||||
err := pool.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Base Type Mapping
|
||||
|
||||
pgx maps between all common base types directly between Go and PostgreSQL. In
|
||||
particular:
|
||||
|
||||
Go PostgreSQL
|
||||
-----------------------
|
||||
string varchar
|
||||
text
|
||||
|
||||
// Integers are automatically be converted to any other integer type if
|
||||
// it can be done without overflow or underflow.
|
||||
int8
|
||||
int16 smallint
|
||||
int32 int
|
||||
int64 bigint
|
||||
int
|
||||
uint8
|
||||
uint16
|
||||
uint32
|
||||
uint64
|
||||
uint
|
||||
|
||||
// Floats are strict and do not automatically convert like integers.
|
||||
float32 float4
|
||||
float64 float8
|
||||
|
||||
time.Time date
|
||||
timestamp
|
||||
timestamptz
|
||||
|
||||
[]byte bytea
|
||||
|
||||
|
||||
Null Mapping
|
||||
|
||||
pgx can map nulls in two ways. The first is package pgtype provides types that
|
||||
have a data field and a status field. They work in a similar fashion to
|
||||
database/sql. The second is to use a pointer to a pointer.
|
||||
|
||||
var foo pgtype.Varchar
|
||||
var bar *string
|
||||
err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Array Mapping
|
||||
|
||||
pgx maps between int16, int32, int64, float32, float64, and string Go slices
|
||||
and the equivalent PostgreSQL array type. Go slices of native types do not
|
||||
support nulls, so if a PostgreSQL array that contains a null is read into a
|
||||
native Go slice an error will occur. The pgtype package includes many more
|
||||
array types for PostgreSQL types that do not directly map to native Go types.
|
||||
|
||||
JSON and JSONB Mapping
|
||||
|
||||
pgx includes built-in support to marshal and unmarshal between Go types and
|
||||
the PostgreSQL JSON and JSONB.
|
||||
|
||||
Inet and CIDR Mapping
|
||||
|
||||
pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In
|
||||
addition, as a convenience pgx will encode from a net.IP; it will assume a /32
|
||||
netmask for IPv4 and a /128 for IPv6.
|
||||
|
||||
Custom Type Support
|
||||
|
||||
pgx includes support for the common data types like integers, floats, strings,
|
||||
dates, and times that have direct mappings between Go and SQL. In addition,
|
||||
pgx uses the github.com/jackc/pgx/pgtype library to support more types. See
|
||||
documention for that library for instructions on how to implement custom
|
||||
types.
|
||||
|
||||
See example_custom_type_test.go for an example of a custom type for the
|
||||
PostgreSQL point type.
|
||||
|
||||
pgx also includes support for custom types implementing the database/sql.Scanner
|
||||
and database/sql/driver.Valuer interfaces.
|
||||
|
||||
If pgx does cannot natively encode a type and that type is a renamed type (e.g.
|
||||
type MyTime time.Time) pgx will attempt to encode the underlying type. While
|
||||
this is usually desired behavior it can produce suprising behavior if one the
|
||||
underlying type and the renamed type each implement database/sql interfaces and
|
||||
the other implements pgx interfaces. It is recommended that this situation be
|
||||
avoided by implementing pgx interfaces on the renamed type.
|
||||
|
||||
Raw Bytes Mapping
|
||||
|
||||
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified
|
||||
to PostgreSQL.
|
||||
|
||||
Transactions
|
||||
|
||||
Transactions are started by calling Begin.
|
||||
Transactions are started by calling Begin or BeginEx. The BeginEx variant
|
||||
can create a transaction with a specified isolation level.
|
||||
|
||||
tx, err := conn.Begin(context.Background())
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Rollback is safe to call even if the tx is already closed, so if
|
||||
// the tx commits successfully, this is a no-op
|
||||
defer tx.Rollback(context.Background())
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||
_, err = tx.Exec("insert into foo(id) values (1)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit(context.Background())
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
|
||||
These are internally implemented with savepoints.
|
||||
|
||||
Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of
|
||||
a pseudo nested transaction.
|
||||
|
||||
BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
|
||||
transaction depending on the return value of the function. These can be simpler and less error prone to use.
|
||||
|
||||
err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
|
||||
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Prepared Statements
|
||||
|
||||
Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx
|
||||
includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are
|
||||
automatically prepared on first execution and the prepared statement is reused on subsequent executions. See ParseConfig
|
||||
for information on how to customize or disable the statement cache.
|
||||
|
||||
Copy Protocol
|
||||
|
||||
Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a
|
||||
CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface.
|
||||
Or implement CopyFromSource to avoid buffering the entire data set in memory.
|
||||
Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL
|
||||
copy protocol. CopyFrom accepts a CopyFromSource interface. If the data is already
|
||||
in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource interface. Or
|
||||
implement CopyFromSource to avoid buffering the entire data set in memory.
|
||||
|
||||
rows := [][]any{
|
||||
rows := [][]interface{}{
|
||||
{"John", "Smith", int32(36)},
|
||||
{"Jane", "Doe", int32(29)},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(
|
||||
context.Background(),
|
||||
pgx.Identifier{"people"},
|
||||
[]string{"first_name", "last_name", "age"},
|
||||
pgx.CopyFromRows(rows),
|
||||
)
|
||||
|
||||
When you already have a typed array using CopyFromSlice can be more convenient.
|
||||
|
||||
rows := []User{
|
||||
{"John", "Smith", 36},
|
||||
{"Jane", "Doe", 29},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(
|
||||
context.Background(),
|
||||
pgx.Identifier{"people"},
|
||||
[]string{"first_name", "last_name", "age"},
|
||||
pgx.CopyFromSlice(len(rows), func(i int) ([]any, error) {
|
||||
return []any{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil
|
||||
}),
|
||||
)
|
||||
|
||||
CopyFrom can be faster than an insert with as few as 5 rows.
|
||||
|
||||
Listen and Notify
|
||||
|
||||
pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a
|
||||
notification is received or the context is canceled.
|
||||
pgx can listen to the PostgreSQL notification system with the
|
||||
WaitForNotification function. It takes a maximum time to wait for a
|
||||
notification.
|
||||
|
||||
_, err := conn.Exec(context.Background(), "listen channelname")
|
||||
err := conn.Listen("channelname")
|
||||
if err != nil {
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
notification, err := conn.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
if notification, err := conn.WaitForNotification(time.Second); err != nil {
|
||||
// do something with notification
|
||||
}
|
||||
// do something with notification
|
||||
|
||||
TLS
|
||||
|
||||
Tracing and Logging
|
||||
The pgx ConnConfig struct has a TLSConfig field. If this field is
|
||||
nil, then TLS will be disabled. If it is present, then it will be used to
|
||||
configure the TLS connection. This allows total configuration of the TLS
|
||||
connection.
|
||||
|
||||
pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer.
|
||||
pgx has never explicitly supported Postgres < 9.6's `ssl_renegotiation` option.
|
||||
As of v3.3.0, it doesn't send `ssl_renegotiation: 0` either to support Redshift
|
||||
(https://github.com/jackc/pgx/pull/476). If you need TLS Renegotiation,
|
||||
consider supplying `ConnConfig.TLSConfig` with a non-zero `Renegotiation`
|
||||
value and if it's not the default on your server, set `ssl_renegotiation`
|
||||
via `ConnConfig.RuntimeParams`.
|
||||
|
||||
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
|
||||
Logging
|
||||
|
||||
For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3.
|
||||
|
||||
Lower Level PostgreSQL Functionality
|
||||
|
||||
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn is
|
||||
implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer.
|
||||
|
||||
PgBouncer
|
||||
|
||||
By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
|
||||
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
|
||||
pgx defines a simple logger interface. Connections optionally accept a logger
|
||||
that satisfies this interface. Set LogLevel to control logging verbosity.
|
||||
Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus, and
|
||||
the testing log are provided in the log directory.
|
||||
*/
|
||||
package pgx
|
||||
|
105
example_custom_type_test.go
Normal file
105
example_custom_type_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`)
|
||||
|
||||
// Point represents a point that may be null.
|
||||
type Point struct {
|
||||
X, Y float64 // Coordinates of point
|
||||
Status pgtype.Status
|
||||
}
|
||||
|
||||
func (dst *Point) Set(src interface{}) error {
|
||||
return errors.Errorf("cannot convert %v to Point", src)
|
||||
}
|
||||
|
||||
func (dst *Point) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case pgtype.Present:
|
||||
return dst
|
||||
case pgtype.Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Point) AssignTo(dst interface{}) error {
|
||||
return errors.Errorf("cannot assign %v to %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Point{Status: pgtype.Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
s := string(src)
|
||||
match := pointRegexp.FindStringSubmatch(s)
|
||||
if match == nil {
|
||||
return errors.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
|
||||
x, err := strconv.ParseFloat(match[1], 64)
|
||||
if err != nil {
|
||||
return errors.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
y, err := strconv.ParseFloat(match[2], 64)
|
||||
if err != nil {
|
||||
return errors.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
|
||||
*dst = Point{X: x, Y: y, Status: pgtype.Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Point) String() string {
|
||||
if src.Status == pgtype.Null {
|
||||
return "null point"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%.1f, %.1f", src.X, src.Y)
|
||||
}
|
||||
|
||||
func Example_CustomType() {
|
||||
conn, err := pgx.Connect(*defaultConnConfig)
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to establish connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Override registered handler for point
|
||||
conn.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||
Value: &Point{},
|
||||
Name: "point",
|
||||
OID: 600,
|
||||
})
|
||||
|
||||
p := &Point{}
|
||||
err = conn.QueryRow("select null::point").Scan(p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(p)
|
||||
|
||||
err = conn.QueryRow("select point(1.5,2.5)").Scan(p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(p)
|
||||
// Output:
|
||||
// null point
|
||||
// 1.5, 2.5
|
||||
}
|
@ -1,15 +1,13 @@
|
||||
package pgtype_test
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
func Example_json() {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
func Example_JSON() {
|
||||
conn, err := pgx.Connect(*defaultConnConfig)
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to establish connection: %v", err)
|
||||
return
|
||||
@ -27,7 +25,7 @@ func Example_json() {
|
||||
|
||||
var output person
|
||||
|
||||
err = conn.QueryRow(context.Background(), "select $1::json", input).Scan(&output)
|
||||
err = conn.QueryRow("select $1::json", input).Scan(&output)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
@ -3,5 +3,5 @@
|
||||
* chat is a command line chat program using listen/notify.
|
||||
* todo is a command line todo list that demonstrates basic CRUD actions.
|
||||
* url_shortener contains a simple example of using pgx in a web context.
|
||||
* [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx.
|
||||
* [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx (uses v1 of pgx).
|
||||
* [The Pithy Reader](https://github.com/jackc/tpr) is a RSS aggregator that uses pgx.
|
||||
|
@ -8,7 +8,12 @@ between them.
|
||||
|
||||
## Connection configuration
|
||||
|
||||
The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.)
|
||||
The database connection is configured via the standard PostgreSQL environment variables.
|
||||
|
||||
* PGHOST - defaults to localhost
|
||||
* PGUSER - defaults to current OS user
|
||||
* PGPASSWORD - defaults to empty string
|
||||
* PGDATABASE - defaults to user name
|
||||
|
||||
You can either export them then run chat:
|
||||
|
||||
|
@ -6,14 +6,19 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
var pool *pgxpool.Pool
|
||||
var pool *pgx.ConnPool
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
pool, err = pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
config, err := pgx.ParseEnvLibpq()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Unable to parse environment:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
pool, err = pgx.NewConnPool(pgx.ConnPoolConfig{ConnConfig: config})
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Unable to connect to database:", err)
|
||||
os.Exit(1)
|
||||
@ -35,7 +40,7 @@ Type "exit" to quit.`)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
_, err = pool.Exec(context.Background(), "select pg_notify('chat', $1)", msg)
|
||||
_, err = pool.Exec("select pg_notify('chat', $1)", msg)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error sending notification:", err)
|
||||
os.Exit(1)
|
||||
@ -48,21 +53,17 @@ Type "exit" to quit.`)
|
||||
}
|
||||
|
||||
func listen() {
|
||||
conn, err := pool.Acquire(context.Background())
|
||||
conn, err := pool.Acquire()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error acquiring connection:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer conn.Release()
|
||||
defer pool.Release(conn)
|
||||
|
||||
_, err = conn.Exec(context.Background(), "listen chat")
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error listening to chat channel:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
conn.Listen("chat")
|
||||
|
||||
for {
|
||||
notification, err := conn.Conn().WaitForNotification(context.Background())
|
||||
notification, err := conn.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error waiting for notification:", err)
|
||||
os.Exit(1)
|
||||
|
@ -19,7 +19,12 @@ Build todo:
|
||||
|
||||
## Connection configuration
|
||||
|
||||
The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.)
|
||||
The database connection is configured via enviroment variables.
|
||||
|
||||
* PGHOST - defaults to localhost
|
||||
* PGUSER - defaults to current OS user
|
||||
* PGPASSWORD - defaults to empty string
|
||||
* PGDATABASE - defaults to user name
|
||||
|
||||
You can either export them then run todo:
|
||||
|
||||
@ -40,7 +45,7 @@ Or you can prefix the todo execution with the environment variables:
|
||||
|
||||
## Update a task
|
||||
|
||||
./todo update 1 'Learn more go'
|
||||
./todo add 1 'Learn more go'
|
||||
|
||||
## Delete a task
|
||||
|
||||
|
@ -1,19 +1,22 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/jackc/pgx"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
var conn *pgx.Conn
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
conn, err = pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
config, err := pgx.ParseEnvLibpq()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Unable to parse environment:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
conn, err = pgx.Connect(config)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Unable to connection to database: %v\n", err)
|
||||
os.Exit(1)
|
||||
@ -71,7 +74,7 @@ func main() {
|
||||
}
|
||||
|
||||
func listTasks() error {
|
||||
rows, _ := conn.Query(context.Background(), "select * from tasks")
|
||||
rows, _ := conn.Query("select * from tasks")
|
||||
|
||||
for rows.Next() {
|
||||
var id int32
|
||||
@ -87,17 +90,17 @@ func listTasks() error {
|
||||
}
|
||||
|
||||
func addTask(description string) error {
|
||||
_, err := conn.Exec(context.Background(), "insert into tasks(description) values($1)", description)
|
||||
_, err := conn.Exec("insert into tasks(description) values($1)", description)
|
||||
return err
|
||||
}
|
||||
|
||||
func updateTask(itemNum int32, description string) error {
|
||||
_, err := conn.Exec(context.Background(), "update tasks set description=$1 where id=$2", description, itemNum)
|
||||
_, err := conn.Exec("update tasks set description=$1 where id=$2", description, itemNum)
|
||||
return err
|
||||
}
|
||||
|
||||
func removeTask(itemNum int32) error {
|
||||
_, err := conn.Exec(context.Background(), "delete from tasks where id=$1", itemNum)
|
||||
_, err := conn.Exec("delete from tasks where id=$1", itemNum)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -6,28 +6,20 @@ This is a sample REST URL shortener service implemented using pgx as the connect
|
||||
|
||||
Create a PostgreSQL database and run structure.sql into it to create the necessary data schema.
|
||||
|
||||
Configure the database connection with `DATABASE_URL` or standard PostgreSQL (`PG*`) environment variables or
|
||||
Edit connectionOptions in main.go with the location and credentials for your database.
|
||||
|
||||
Run main.go:
|
||||
|
||||
```
|
||||
go run main.go
|
||||
```
|
||||
go run main.go
|
||||
|
||||
## Create or Update a Shortened URL
|
||||
|
||||
```
|
||||
curl -X PUT -d 'http://www.google.com' http://localhost:8080/google
|
||||
```
|
||||
curl -X PUT -d 'http://www.google.com' http://localhost:8080/google
|
||||
|
||||
## Get a Shortened URL
|
||||
|
||||
```
|
||||
curl http://localhost:8080/google
|
||||
```
|
||||
curl http://localhost:8080/google
|
||||
|
||||
## Delete a Shortened URL
|
||||
|
||||
```
|
||||
curl -X DELETE http://localhost:8080/google
|
||||
```
|
||||
curl -X DELETE http://localhost:8080/google
|
||||
|
@ -1,21 +1,43 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/log/log15adapter"
|
||||
log "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
var db *pgxpool.Pool
|
||||
var pool *pgx.ConnPool
|
||||
|
||||
// afterConnect creates the prepared statements that this application uses
|
||||
func afterConnect(conn *pgx.Conn) (err error) {
|
||||
_, err = conn.Prepare("getUrl", `
|
||||
select url from shortened_urls where id=$1
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.Prepare("deleteUrl", `
|
||||
delete from shortened_urls where id=$1
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.Prepare("putUrl", `
|
||||
insert into shortened_urls(id, url) values ($1, $2)
|
||||
on conflict (id) do update set url=excluded.url
|
||||
`)
|
||||
return
|
||||
}
|
||||
|
||||
func getUrlHandler(w http.ResponseWriter, req *http.Request) {
|
||||
var url string
|
||||
err := db.QueryRow(context.Background(), "select url from shortened_urls where id=$1", req.URL.Path).Scan(&url)
|
||||
err := pool.QueryRow("getUrl", req.URL.Path).Scan(&url)
|
||||
switch err {
|
||||
case nil:
|
||||
http.Redirect(w, req, url, http.StatusSeeOther)
|
||||
@ -29,15 +51,14 @@ func getUrlHandler(w http.ResponseWriter, req *http.Request) {
|
||||
func putUrlHandler(w http.ResponseWriter, req *http.Request) {
|
||||
id := req.URL.Path
|
||||
var url string
|
||||
if body, err := io.ReadAll(req.Body); err == nil {
|
||||
if body, err := ioutil.ReadAll(req.Body); err == nil {
|
||||
url = string(body)
|
||||
} else {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := db.Exec(context.Background(), `insert into shortened_urls(id, url) values ($1, $2)
|
||||
on conflict (id) do update set url=excluded.url`, id, url); err == nil {
|
||||
if _, err := pool.Exec("putUrl", id, url); err == nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
@ -45,7 +66,7 @@ func putUrlHandler(w http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
func deleteUrlHandler(w http.ResponseWriter, req *http.Request) {
|
||||
if _, err := db.Exec(context.Background(), "delete from shortened_urls where id=$1", req.URL.Path); err == nil {
|
||||
if _, err := pool.Exec("deleteUrl", req.URL.Path); err == nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
@ -70,21 +91,32 @@ func urlHandler(w http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
func main() {
|
||||
poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
log.Fatalln("Unable to parse DATABASE_URL:", err)
|
||||
}
|
||||
logger := log15adapter.NewLogger(log.New("module", "pgx"))
|
||||
|
||||
db, err = pgxpool.NewWithConfig(context.Background(), poolConfig)
|
||||
var err error
|
||||
connPoolConfig := pgx.ConnPoolConfig{
|
||||
ConnConfig: pgx.ConnConfig{
|
||||
Host: "127.0.0.1",
|
||||
User: "jack",
|
||||
Password: "jack",
|
||||
Database: "url_shortener",
|
||||
Logger: logger,
|
||||
},
|
||||
MaxConnections: 5,
|
||||
AfterConnect: afterConnect,
|
||||
}
|
||||
pool, err = pgx.NewConnPool(connPoolConfig)
|
||||
if err != nil {
|
||||
log.Fatalln("Unable to create connection pool:", err)
|
||||
log.Crit("Unable to create connection pool", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
http.HandleFunc("/", urlHandler)
|
||||
|
||||
log.Println("Starting URL shortener on localhost:8080")
|
||||
log.Info("Starting URL shortener on localhost:8080")
|
||||
err = http.ListenAndServe("localhost:8080", nil)
|
||||
if err != nil {
|
||||
log.Fatalln("Unable to start web server:", err)
|
||||
log.Crit("Unable to start web server", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
@ -1,146 +0,0 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result
|
||||
// formats for an extended query.
|
||||
type ExtendedQueryBuilder struct {
|
||||
ParamValues [][]byte
|
||||
paramValueBytes []byte
|
||||
ParamFormats []int16
|
||||
ResultFormats []int16
|
||||
}
|
||||
|
||||
// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If
|
||||
// sd is nil then QueryExecModeExec behavior will be used.
|
||||
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
|
||||
eqb.reset()
|
||||
|
||||
if sd == nil {
|
||||
for i := range args {
|
||||
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(sd.ParamOIDs) != len(args) {
|
||||
return fmt.Errorf("mismatched param and argument count")
|
||||
}
|
||||
|
||||
for i := range args {
|
||||
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for i := range sd.Fields {
|
||||
eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it
|
||||
// must be an untyped nil.
|
||||
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
|
||||
if format == -1 {
|
||||
preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
|
||||
preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
|
||||
if preferredErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var otherFormat int16
|
||||
if preferredFormat == TextFormatCode {
|
||||
otherFormat = BinaryFormatCode
|
||||
} else {
|
||||
otherFormat = TextFormatCode
|
||||
}
|
||||
|
||||
otherErr := eqb.appendParam(m, oid, otherFormat, arg)
|
||||
if otherErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return preferredErr // return the error from the preferred format
|
||||
}
|
||||
|
||||
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eqb.ParamFormats = append(eqb.ParamFormats, format)
|
||||
eqb.ParamValues = append(eqb.ParamValues, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// appendResultFormat appends a result format to the query.
|
||||
func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) {
|
||||
eqb.ResultFormats = append(eqb.ResultFormats, format)
|
||||
}
|
||||
|
||||
// reset readies eqb to build another query.
|
||||
func (eqb *ExtendedQueryBuilder) reset() {
|
||||
eqb.ParamValues = eqb.ParamValues[0:0]
|
||||
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
|
||||
eqb.ParamFormats = eqb.ParamFormats[0:0]
|
||||
eqb.ResultFormats = eqb.ResultFormats[0:0]
|
||||
|
||||
if cap(eqb.ParamValues) > 64 {
|
||||
eqb.ParamValues = make([][]byte, 0, 64)
|
||||
}
|
||||
|
||||
if cap(eqb.paramValueBytes) > 256 {
|
||||
eqb.paramValueBytes = make([]byte, 0, 256)
|
||||
}
|
||||
|
||||
if cap(eqb.ParamFormats) > 64 {
|
||||
eqb.ParamFormats = make([]int16, 0, 64)
|
||||
}
|
||||
if cap(eqb.ResultFormats) > 64 {
|
||||
eqb.ResultFormats = make([]int16, 0, 64)
|
||||
}
|
||||
}
|
||||
|
||||
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
|
||||
if eqb.paramValueBytes == nil {
|
||||
eqb.paramValueBytes = make([]byte, 0, 128)
|
||||
}
|
||||
|
||||
pos := len(eqb.paramValueBytes)
|
||||
|
||||
buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
eqb.paramValueBytes = buf
|
||||
return eqb.paramValueBytes[pos:], nil
|
||||
}
|
||||
|
||||
// chooseParameterFormatCode determines the correct format code for an
|
||||
// argument to a prepared statement. It defaults to TextFormatCode if no
|
||||
// determination can be made.
|
||||
func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 {
|
||||
switch arg.(type) {
|
||||
case string, *string:
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
return m.FormatCodeForOID(oid)
|
||||
}
|
119
fastpath.go
Normal file
119
fastpath.go
Normal file
@ -0,0 +1,119 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func newFastpath(cn *Conn) *fastpath {
|
||||
return &fastpath{cn: cn, fns: make(map[string]pgtype.OID)}
|
||||
}
|
||||
|
||||
type fastpath struct {
|
||||
cn *Conn
|
||||
fns map[string]pgtype.OID
|
||||
}
|
||||
|
||||
func (f *fastpath) functionOID(name string) pgtype.OID {
|
||||
return f.fns[name]
|
||||
}
|
||||
|
||||
func (f *fastpath) addFunction(name string, oid pgtype.OID) {
|
||||
f.fns[name] = oid
|
||||
}
|
||||
|
||||
func (f *fastpath) addFunctions(rows *Rows) error {
|
||||
for rows.Next() {
|
||||
var name string
|
||||
var oid pgtype.OID
|
||||
if err := rows.Scan(&name, &oid); err != nil {
|
||||
return err
|
||||
}
|
||||
f.addFunction(name, oid)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
type fpArg []byte
|
||||
|
||||
func fpIntArg(n int32) fpArg {
|
||||
res := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(res, uint32(n))
|
||||
return res
|
||||
}
|
||||
|
||||
func fpInt64Arg(n int64) fpArg {
|
||||
res := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(res, uint64(n))
|
||||
return res
|
||||
}
|
||||
|
||||
func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) {
|
||||
if err := f.cn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := f.cn.wbuf
|
||||
buf = append(buf, 'F') // function call
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
buf = pgio.AppendInt32(buf, int32(oid)) // function object id
|
||||
buf = pgio.AppendInt16(buf, 1) // # of argument format codes
|
||||
buf = pgio.AppendInt16(buf, 1) // format code: binary
|
||||
buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments
|
||||
for _, arg := range args {
|
||||
buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument
|
||||
buf = append(buf, arg...) // argument value
|
||||
}
|
||||
buf = pgio.AppendInt16(buf, 1) // response format code (binary)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
if _, err := f.cn.conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.cn.pendingReadyForQueryCount++
|
||||
|
||||
for {
|
||||
msg, err := f.cn.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.FunctionCallResponse:
|
||||
res = make([]byte, len(msg.Result))
|
||||
copy(res, msg.Result)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
f.cn.rxReadyForQuery(msg)
|
||||
// done
|
||||
return res, err
|
||||
default:
|
||||
if err := f.cn.processContextFreeMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fastpath) CallFn(fn string, args []fpArg) ([]byte, error) {
|
||||
return f.Call(f.functionOID(fn), args)
|
||||
}
|
||||
|
||||
func fpInt32(data []byte, err error) (int32, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := int32(binary.BigEndian.Uint32(data))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func fpInt64(data []byte, err error) (int64, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(binary.BigEndian.Uint64(data)), nil
|
||||
}
|
21
go.mod
21
go.mod
@ -1,21 +0,0 @@
|
||||
module github.com/jackc/pgx/v5
|
||||
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/jackc/pgpassfile v1.0.0
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761
|
||||
github.com/jackc/puddle/v2 v2.2.2
|
||||
github.com/stretchr/testify v1.8.1
|
||||
golang.org/x/crypto v0.37.0
|
||||
golang.org/x/sync v0.13.0
|
||||
golang.org/x/text v0.24.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/kr/pretty v0.3.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
45
go.sum
45
go.sum
@ -1,45 +0,0 @@
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
61
go_stdlib.go
Normal file
61
go_stdlib.go
Normal file
@ -0,0 +1,61 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// This file contains code copied from the Go standard library due to the
|
||||
// required function not being public.
|
||||
|
||||
// Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// From database/sql/convert.go
|
||||
|
||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
|
||||
// callValuerValue returns vr.Value(), with one exception:
|
||||
// If vr.Value is an auto-generated method on a pointer type and the
|
||||
// pointer is nil, it would panic at runtime in the panicwrap
|
||||
// method. Treat it like nil instead.
|
||||
// Issue 8415.
|
||||
//
|
||||
// This is so people can implement driver.Value on value types and
|
||||
// still use nil pointers to those types to mean nil/NULL, just like
|
||||
// string/*string.
|
||||
//
|
||||
// This function is mirrored in the database/sql/driver package.
|
||||
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
|
||||
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
|
||||
rv.IsNil() &&
|
||||
rv.Type().Elem().Implements(valuerReflectType) {
|
||||
return nil, nil
|
||||
}
|
||||
return vr.Value()
|
||||
}
|
100
helper_test.go
100
helper_test.go
@ -1,45 +1,21 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
var defaultConnTestRunner pgxtest.ConnTestRunner
|
||||
|
||||
func init() {
|
||||
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
|
||||
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
return config
|
||||
}
|
||||
}
|
||||
|
||||
func mustConnectString(t testing.TB, connString string) *pgx.Conn {
|
||||
conn, err := pgx.Connect(context.Background(), connString)
|
||||
func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
|
||||
conn, err := pgx.Connect(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to establish connection: %v", err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig {
|
||||
config, err := pgx.ParseConfig(connString)
|
||||
require.Nil(t, err)
|
||||
return config
|
||||
}
|
||||
|
||||
func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn {
|
||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
||||
func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.ReplicationConn {
|
||||
conn, err := pgx.ReplicationConnect(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to establish connection: %v", err)
|
||||
}
|
||||
@ -47,25 +23,32 @@ func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn {
|
||||
}
|
||||
|
||||
func closeConn(t testing.TB, conn *pgx.Conn) {
|
||||
err := conn.Close(context.Background())
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Close unexpectedly failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) {
|
||||
func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Close unexpectedly failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) {
|
||||
var err error
|
||||
if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
|
||||
if commandTag, err = conn.Exec(sql, arguments...); err != nil {
|
||||
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Do a simple query to ensure the connection is still usable
|
||||
func ensureConnValid(t testing.TB, conn *pgx.Conn) {
|
||||
func ensureConnValid(t *testing.T, conn *pgx.Conn) {
|
||||
var sum, rowCount int32
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
||||
rows, err := conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: %v", err)
|
||||
}
|
||||
@ -79,7 +62,7 @@ func ensureConnValid(t testing.TB, conn *pgx.Conn) {
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("conn.Query failed: %v", rows.Err())
|
||||
t.Fatalf("conn.Query failed: %v", err)
|
||||
}
|
||||
|
||||
if rowCount != 10 {
|
||||
@ -89,50 +72,3 @@ func ensureConnValid(t testing.TB, conn *pgx.Conn) {
|
||||
t.Error("Wrong values returned")
|
||||
}
|
||||
}
|
||||
|
||||
func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
|
||||
if !assert.NotNil(t, expected) {
|
||||
return
|
||||
}
|
||||
if !assert.NotNil(t, actual) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
|
||||
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
|
||||
assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
|
||||
assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
|
||||
assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName)
|
||||
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
||||
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
||||
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
||||
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
||||
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
||||
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
|
||||
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||
|
||||
// Can't test function equality, so just test that they are set or not.
|
||||
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
|
||||
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
|
||||
|
||||
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
||||
if expected.TLSConfig != nil {
|
||||
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
|
||||
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
|
||||
}
|
||||
}
|
||||
|
||||
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
|
||||
for i := range expected.Fallbacks {
|
||||
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
|
||||
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
|
||||
|
||||
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
|
||||
if expected.Fallbacks[i].TLSConfig != nil {
|
||||
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
|
||||
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,70 +0,0 @@
|
||||
// Package iobufpool implements a global segregated-fit pool of buffers for IO.
|
||||
//
|
||||
// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid
|
||||
// an allocation is purposely not documented. https://github.com/golang/go/issues/16323
|
||||
package iobufpool
|
||||
|
||||
import "sync"
|
||||
|
||||
const minPoolExpOf2 = 8
|
||||
|
||||
var pools [18]*sync.Pool
|
||||
|
||||
func init() {
|
||||
for i := range pools {
|
||||
bufLen := 1 << (minPoolExpOf2 + i)
|
||||
pools[i] = &sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, bufLen)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get gets a []byte of len size with cap <= size*2.
|
||||
func Get(size int) *[]byte {
|
||||
i := getPoolIdx(size)
|
||||
if i >= len(pools) {
|
||||
buf := make([]byte, size)
|
||||
return &buf
|
||||
}
|
||||
|
||||
ptrBuf := (pools[i].Get().(*[]byte))
|
||||
*ptrBuf = (*ptrBuf)[:size]
|
||||
|
||||
return ptrBuf
|
||||
}
|
||||
|
||||
func getPoolIdx(size int) int {
|
||||
size--
|
||||
size >>= minPoolExpOf2
|
||||
i := 0
|
||||
for size > 0 {
|
||||
size >>= 1
|
||||
i++
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// Put returns buf to the pool.
|
||||
func Put(buf *[]byte) {
|
||||
i := putPoolIdx(cap(*buf))
|
||||
if i < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
pools[i].Put(buf)
|
||||
}
|
||||
|
||||
func putPoolIdx(size int) int {
|
||||
minPoolSize := 1 << minPoolExpOf2
|
||||
for i := range pools {
|
||||
if size == minPoolSize<<i {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
@ -1,36 +0,0 @@
|
||||
package iobufpool
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPoolIdx(t *testing.T) {
|
||||
tests := []struct {
|
||||
size int
|
||||
expected int
|
||||
}{
|
||||
{size: 0, expected: 0},
|
||||
{size: 1, expected: 0},
|
||||
{size: 255, expected: 0},
|
||||
{size: 256, expected: 0},
|
||||
{size: 257, expected: 1},
|
||||
{size: 511, expected: 1},
|
||||
{size: 512, expected: 1},
|
||||
{size: 513, expected: 2},
|
||||
{size: 1023, expected: 2},
|
||||
{size: 1024, expected: 2},
|
||||
{size: 1025, expected: 3},
|
||||
{size: 2047, expected: 3},
|
||||
{size: 2048, expected: 3},
|
||||
{size: 2049, expected: 4},
|
||||
{size: 8388607, expected: 15},
|
||||
{size: 8388608, expected: 15},
|
||||
{size: 8388609, expected: 16},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
idx := getPoolIdx(tt.size)
|
||||
assert.Equalf(t, tt.expected, idx, "size: %d", tt.size)
|
||||
}
|
||||
}
|
@ -1,85 +0,0 @@
|
||||
package iobufpool_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetCap(t *testing.T) {
|
||||
tests := []struct {
|
||||
requestedLen int
|
||||
expectedCap int
|
||||
}{
|
||||
{requestedLen: 0, expectedCap: 256},
|
||||
{requestedLen: 128, expectedCap: 256},
|
||||
{requestedLen: 255, expectedCap: 256},
|
||||
{requestedLen: 256, expectedCap: 256},
|
||||
{requestedLen: 257, expectedCap: 512},
|
||||
{requestedLen: 511, expectedCap: 512},
|
||||
{requestedLen: 512, expectedCap: 512},
|
||||
{requestedLen: 513, expectedCap: 1024},
|
||||
{requestedLen: 1023, expectedCap: 1024},
|
||||
{requestedLen: 1024, expectedCap: 1024},
|
||||
{requestedLen: 33554431, expectedCap: 33554432},
|
||||
{requestedLen: 33554432, expectedCap: 33554432},
|
||||
|
||||
// Above 32 MiB skip the pool and allocate exactly the requested size.
|
||||
{requestedLen: 33554433, expectedCap: 33554433},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
buf := iobufpool.Get(tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(*buf), "bad len for requestedLen: %d", len(*buf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(*buf), "bad cap for requestedLen: %d", tt.requestedLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutHandlesWrongSizedBuffers(t *testing.T) {
|
||||
for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} {
|
||||
putBuf := make([]byte, putBufSize)
|
||||
iobufpool.Put(&putBuf)
|
||||
|
||||
tests := []struct {
|
||||
requestedLen int
|
||||
expectedCap int
|
||||
}{
|
||||
{requestedLen: 0, expectedCap: 256},
|
||||
{requestedLen: 128, expectedCap: 256},
|
||||
{requestedLen: 255, expectedCap: 256},
|
||||
{requestedLen: 256, expectedCap: 256},
|
||||
{requestedLen: 257, expectedCap: 512},
|
||||
{requestedLen: 511, expectedCap: 512},
|
||||
{requestedLen: 512, expectedCap: 512},
|
||||
{requestedLen: 513, expectedCap: 1024},
|
||||
{requestedLen: 1023, expectedCap: 1024},
|
||||
{requestedLen: 1024, expectedCap: 1024},
|
||||
{requestedLen: 33554431, expectedCap: 33554432},
|
||||
{requestedLen: 33554432, expectedCap: 33554432},
|
||||
|
||||
// Above 32 MiB skip the pool and allocate exactly the requested size.
|
||||
{requestedLen: 33554433, expectedCap: 33554433},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
getBuf := iobufpool.Get(tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(*getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(*getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutGetBufferReuse(t *testing.T) {
|
||||
// There is no way to guarantee a buffer will be reused. It should be, but a GC between the Put and the Get will cause
|
||||
// it not to be. So try many times.
|
||||
for i := 0; i < 100000; i++ {
|
||||
buf := iobufpool.Get(4)
|
||||
(*buf)[0] = 1
|
||||
iobufpool.Put(buf)
|
||||
buf = iobufpool.Get(4)
|
||||
if (*buf)[0] == 1 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Error("buffer was never reused")
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
# pgio
|
||||
|
||||
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||
|
||||
pgio provides functions for appending integers to a []byte while doing byte
|
||||
order conversion.
|
@ -1,136 +0,0 @@
|
||||
// Package pgmock provides the ability to mock a PostgreSQL server.
|
||||
package pgmock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
type Step interface {
|
||||
Step(*pgproto3.Backend) error
|
||||
}
|
||||
|
||||
type Script struct {
|
||||
Steps []Step
|
||||
}
|
||||
|
||||
func (s *Script) Run(backend *pgproto3.Backend) error {
|
||||
for _, step := range s.Steps {
|
||||
err := step.Step(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Script) Step(backend *pgproto3.Backend) error {
|
||||
return s.Run(backend)
|
||||
}
|
||||
|
||||
type expectMessageStep struct {
|
||||
want pgproto3.FrontendMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.Receive()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type expectStartupMessageStep struct {
|
||||
want *pgproto3.StartupMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.ReceiveStartupMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExpectMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, false)
|
||||
}
|
||||
|
||||
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, true)
|
||||
}
|
||||
|
||||
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
|
||||
if want, ok := want.(*pgproto3.StartupMessage); ok {
|
||||
return &expectStartupMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
return &expectMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
type sendMessageStep struct {
|
||||
msg pgproto3.BackendMessage
|
||||
}
|
||||
|
||||
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
backend.Send(e.msg)
|
||||
return backend.Flush()
|
||||
}
|
||||
|
||||
func SendMessage(msg pgproto3.BackendMessage) Step {
|
||||
return &sendMessageStep{msg: msg}
|
||||
}
|
||||
|
||||
type waitForCloseMessageStep struct{}
|
||||
|
||||
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
for {
|
||||
msg, err := backend.Receive()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := msg.(*pgproto3.Terminate); ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WaitForClose() Step {
|
||||
return &waitForCloseMessageStep{}
|
||||
}
|
||||
|
||||
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
||||
return []Step{
|
||||
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||
SendMessage(&pgproto3.AuthenticationOk{}),
|
||||
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
}
|
||||
}
|
@ -1,91 +0,0 @@
|
||||
package pgmock_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgmock"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestScript(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
|
||||
Fields: []pgproto3.FieldDescription{
|
||||
{
|
||||
Name: []byte("?column?"),
|
||||
TableOID: 0,
|
||||
TableAttributeNumber: 0,
|
||||
DataTypeOID: 23,
|
||||
DataTypeSize: 4,
|
||||
TypeModifier: -1,
|
||||
Format: 0,
|
||||
},
|
||||
},
|
||||
}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{
|
||||
Values: [][]byte{[]byte("42")},
|
||||
}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
||||
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{}))
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(time.Second))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = script.Run(pgproto3.NewBackend(conn, conn))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
host, port, _ := strings.Cut(ln.Addr().String(), ":")
|
||||
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
pgConn, err := pgconn.Connect(ctx, connStr)
|
||||
require.NoError(t, err)
|
||||
results, err := pgConn.Exec(ctx, "select 42").ReadAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, results, 1)
|
||||
assert.Nil(t, results[0].Err)
|
||||
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
||||
assert.Len(t, results[0].Rows, 1)
|
||||
assert.Equal(t, "42", string(results[0].Rows[0][0]))
|
||||
|
||||
pgConn.Close(ctx)
|
||||
|
||||
assert.NoError(t, <-serverErrChan)
|
||||
}
|
@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
||||
if [ "$current_branch" == "HEAD" ]; then
|
||||
current_branch=$(git rev-parse HEAD)
|
||||
fi
|
||||
|
||||
restore_branch() {
|
||||
echo "Restoring original branch/commit: $current_branch"
|
||||
git checkout "$current_branch"
|
||||
}
|
||||
trap restore_branch EXIT
|
||||
|
||||
# Check if there are uncommitted changes
|
||||
if ! git diff --quiet || ! git diff --cached --quiet; then
|
||||
echo "There are uncommitted changes. Please commit or stash them before running this script."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure that at least one commit argument is passed
|
||||
if [ "$#" -lt 1 ]; then
|
||||
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
commits=("$@")
|
||||
benchmarks_dir=benchmarks
|
||||
|
||||
if ! mkdir -p "${benchmarks_dir}"; then
|
||||
echo "Unable to create dir for benchmarks data"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Benchmark results
|
||||
bench_files=()
|
||||
|
||||
# Run benchmark for each listed commit
|
||||
for i in "${!commits[@]}"; do
|
||||
commit="${commits[i]}"
|
||||
git checkout "$commit" || {
|
||||
echo "Failed to checkout $commit"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Sanitized commmit message
|
||||
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
|
||||
|
||||
# Benchmark data will go there
|
||||
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
|
||||
|
||||
if ! go test -bench=. -count=10 >"$bench_file"; then
|
||||
echo "Benchmarking failed for commit $commit"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
bench_files+=("$bench_file")
|
||||
done
|
||||
|
||||
# go install golang.org/x/perf/cmd/benchstat[@latest]
|
||||
benchstat "${bench_files[@]}"
|
@ -3,209 +3,97 @@ package sanitize
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Part is either a string or an int. A string is raw SQL. An int is a
|
||||
// argument placeholder.
|
||||
type Part any
|
||||
type Part interface{}
|
||||
|
||||
type Query struct {
|
||||
Parts []Part
|
||||
}
|
||||
|
||||
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
|
||||
// character. utf8.RuneError is not an error if it is also width 3.
|
||||
//
|
||||
// https://github.com/jackc/pgx/issues/1380
|
||||
const replacementcharacterwidth = 3
|
||||
|
||||
const maxBufSize = 16384 // 16 Ki
|
||||
|
||||
var bufPool = &pool[*bytes.Buffer]{
|
||||
new: func() *bytes.Buffer {
|
||||
return &bytes.Buffer{}
|
||||
},
|
||||
reset: func(b *bytes.Buffer) bool {
|
||||
n := b.Len()
|
||||
b.Reset()
|
||||
return n < maxBufSize
|
||||
},
|
||||
}
|
||||
|
||||
var null = []byte("null")
|
||||
|
||||
func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
func (q *Query) Sanitize(args ...interface{}) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := bufPool.get()
|
||||
defer bufPool.put(buf)
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
for _, part := range q.Parts {
|
||||
var str string
|
||||
switch part := part.(type) {
|
||||
case string:
|
||||
buf.WriteString(part)
|
||||
str = part
|
||||
case int:
|
||||
argIdx := part - 1
|
||||
var p []byte
|
||||
if argIdx < 0 {
|
||||
return "", fmt.Errorf("first sql argument must be > 0")
|
||||
}
|
||||
|
||||
if argIdx >= len(args) {
|
||||
return "", fmt.Errorf("insufficient arguments")
|
||||
return "", errors.Errorf("insufficient arguments")
|
||||
}
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
buf.WriteByte(' ')
|
||||
|
||||
arg := args[argIdx]
|
||||
switch arg := arg.(type) {
|
||||
case nil:
|
||||
p = null
|
||||
str = "null"
|
||||
case int64:
|
||||
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
|
||||
str = strconv.FormatInt(arg, 10)
|
||||
case float64:
|
||||
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
|
||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case bool:
|
||||
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
|
||||
str = strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
p = QuoteBytes(buf.AvailableBuffer(), arg)
|
||||
str = QuoteBytes(arg)
|
||||
case string:
|
||||
p = QuoteString(buf.AvailableBuffer(), arg)
|
||||
str = QuoteString(arg)
|
||||
case time.Time:
|
||||
p = arg.Truncate(time.Microsecond).
|
||||
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
default:
|
||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||
return "", errors.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
|
||||
buf.Write(p)
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
buf.WriteByte(' ')
|
||||
default:
|
||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||
return "", errors.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
for i, used := range argUse {
|
||||
if !used {
|
||||
return "", fmt.Errorf("unused argument: %d", i)
|
||||
return "", errors.Errorf("unused argument: %d", i)
|
||||
}
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
query := &Query{}
|
||||
query.init(sql)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
var sqlLexerPool = &pool[*sqlLexer]{
|
||||
new: func() *sqlLexer {
|
||||
return &sqlLexer{}
|
||||
},
|
||||
reset: func(sl *sqlLexer) bool {
|
||||
*sl = sqlLexer{}
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
func (q *Query) init(sql string) {
|
||||
parts := q.Parts[:0]
|
||||
if parts == nil {
|
||||
// dirty, but fast heuristic to preallocate for ~90% usecases
|
||||
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
|
||||
parts = make([]Part, 0, n)
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
}
|
||||
|
||||
l := sqlLexerPool.get()
|
||||
defer sqlLexerPool.put(l)
|
||||
|
||||
l.src = sql
|
||||
l.stateFn = rawState
|
||||
l.parts = parts
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
q.Parts = l.parts
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func QuoteString(dst []byte, str string) []byte {
|
||||
const quote = '\''
|
||||
|
||||
// Preallocate space for the worst case scenario
|
||||
dst = slices.Grow(dst, len(str)*2+2)
|
||||
|
||||
// Add opening quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
// Iterate through the string without allocating
|
||||
for i := 0; i < len(str); i++ {
|
||||
if str[i] == quote {
|
||||
dst = append(dst, quote, quote)
|
||||
} else {
|
||||
dst = append(dst, str[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Add closing quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
return dst
|
||||
func QuoteString(str string) string {
|
||||
return "'" + strings.Replace(str, "'", "''", -1) + "'"
|
||||
}
|
||||
|
||||
func QuoteBytes(dst, buf []byte) []byte {
|
||||
if len(buf) == 0 {
|
||||
return append(dst, `'\x'`...)
|
||||
}
|
||||
|
||||
// Calculate required length
|
||||
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
|
||||
|
||||
// Ensure dst has enough capacity
|
||||
if cap(dst)-len(dst) < requiredLen {
|
||||
newDst := make([]byte, len(dst), len(dst)+requiredLen)
|
||||
copy(newDst, dst)
|
||||
dst = newDst
|
||||
}
|
||||
|
||||
// Record original length and extend slice
|
||||
origLen := len(dst)
|
||||
dst = dst[:origLen+requiredLen]
|
||||
|
||||
// Add prefix
|
||||
dst[origLen] = '\''
|
||||
dst[origLen+1] = '\\'
|
||||
dst[origLen+2] = 'x'
|
||||
|
||||
// Encode bytes directly into dst
|
||||
hex.Encode(dst[origLen+3:len(dst)-1], buf)
|
||||
|
||||
// Add suffix
|
||||
dst[len(dst)-1] = '\''
|
||||
|
||||
return dst
|
||||
func QuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []Part
|
||||
}
|
||||
@ -237,26 +125,12 @@ func rawState(l *sqlLexer) stateFn {
|
||||
l.start = l.pos
|
||||
return placeholderState
|
||||
}
|
||||
case '-':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '-' {
|
||||
l.pos += width
|
||||
return oneLineCommentState
|
||||
}
|
||||
case '/':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '*' {
|
||||
l.pos += width
|
||||
return multilineCommentState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -274,13 +148,11 @@ func singleQuoteState(l *sqlLexer) stateFn {
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -298,13 +170,11 @@ func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -346,115 +216,22 @@ func escapeStringState(l *sqlLexer) stateFn {
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func oneLineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\n', '\r':
|
||||
return rawState
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func multilineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '/':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '*' {
|
||||
l.pos += width
|
||||
l.nested++
|
||||
}
|
||||
case '*':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '/' {
|
||||
continue
|
||||
}
|
||||
|
||||
l.pos += width
|
||||
if l.nested == 0 {
|
||||
return rawState
|
||||
}
|
||||
l.nested--
|
||||
|
||||
case utf8.RuneError:
|
||||
if width != replacementcharacterwidth {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var queryPool = &pool[*Query]{
|
||||
new: func() *Query {
|
||||
return &Query{}
|
||||
},
|
||||
reset: func(q *Query) bool {
|
||||
n := len(q.Parts)
|
||||
q.Parts = q.Parts[:0]
|
||||
return n < 64 // drop too large queries
|
||||
},
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...any) (string, error) {
|
||||
query := queryPool.get()
|
||||
query.init(sql)
|
||||
defer queryPool.put(query)
|
||||
|
||||
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return query.Sanitize(args...)
|
||||
}
|
||||
|
||||
type pool[E any] struct {
|
||||
p sync.Pool
|
||||
new func() E
|
||||
reset func(E) bool
|
||||
}
|
||||
|
||||
func (pool *pool[E]) get() E {
|
||||
v, ok := pool.p.Get().(E)
|
||||
if !ok {
|
||||
v = pool.new()
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (p *pool[E]) put(v E) {
|
||||
if p.reset(v) {
|
||||
p.p.Put(v)
|
||||
}
|
||||
}
|
||||
|
@ -1,62 +0,0 @@
|
||||
// sanitize_benchmark_test.go
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
var benchmarkSanitizeResult string
|
||||
|
||||
const benchmarkQuery = "" +
|
||||
`SELECT *
|
||||
FROM "water_containers"
|
||||
WHERE NOT "id" = $1 -- int64
|
||||
AND "tags" NOT IN $2 -- nil
|
||||
AND "volume" > $3 -- float64
|
||||
AND "transportable" = $4 -- bool
|
||||
AND position($5 IN "sign") -- bytes
|
||||
AND "label" LIKE $6 -- string
|
||||
AND "created_at" > $7; -- time.Time`
|
||||
|
||||
var benchmarkArgs = []any{
|
||||
int64(12345),
|
||||
nil,
|
||||
float64(500),
|
||||
true,
|
||||
[]byte("8BADF00D"),
|
||||
"kombucha's han'dy awokowa",
|
||||
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
func BenchmarkSanitize(b *testing.B) {
|
||||
query, err := sanitize.NewQuery(benchmarkQuery)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create query: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize query: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var benchmarkNewSQLResult string
|
||||
|
||||
func BenchmarkSanitizeSQL(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
var err error
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize SQL: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
func FuzzQuoteString(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add("new\nline")
|
||||
f.Add("sample text")
|
||||
f.Add("sample q'u'o't'e's")
|
||||
f.Add("select 'quoted $42', $1")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
got := string(sanitize.QuoteString([]byte(prefix), input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzQuoteBytes(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add([]byte(nil))
|
||||
f.Add([]byte("\n"))
|
||||
f.Add([]byte("sample text"))
|
||||
f.Add([]byte("sample q'u'o't'e's"))
|
||||
f.Add([]byte("select 'quoted $42', $1"))
|
||||
|
||||
f.Fuzz(func(t *testing.T, input []byte) {
|
||||
got := string(sanitize.QuoteBytes([]byte(prefix), input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
@ -1,12 +1,9 @@
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
"github.com/jackc/pgx/internal/sanitize"
|
||||
)
|
||||
|
||||
func TestNewQuery(t *testing.T) {
|
||||
@ -62,44 +59,6 @@ func TestNewQuery(t *testing.T) {
|
||||
sql: `select e'escape string\' $42', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
|
||||
},
|
||||
{
|
||||
sql: `select /* a baby's toy */ 'barbie', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}},
|
||||
},
|
||||
{
|
||||
sql: `select /* *_* */ $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}},
|
||||
},
|
||||
{
|
||||
sql: `select 42 /* /* /* 42 */ */ */, $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select -- a baby's toy\n'barbie', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 42 -- is a Deep Thought's favorite number",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Deep Thought's favorite number"}},
|
||||
},
|
||||
{
|
||||
sql: "select 42, -- \\nis a Deep Thought's favorite number\n$1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\n", 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}},
|
||||
},
|
||||
{
|
||||
// https://github.com/jackc/pgx/issues/1380
|
||||
sql: "select 'hello w<>rld'",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello w<>rld'"}},
|
||||
},
|
||||
{
|
||||
// Unterminated quoted string
|
||||
sql: "select 'hello world",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successTests {
|
||||
@ -123,68 +82,53 @@ func TestNewQuery(t *testing.T) {
|
||||
func TestQuerySanitize(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
query sanitize.Query
|
||||
args []any
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
args: []any{},
|
||||
args: []interface{}{},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{int64(42)},
|
||||
expected: `select 42 `,
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{float64(1.23)},
|
||||
expected: `select 1.23 `,
|
||||
args: []interface{}{float64(1.23)},
|
||||
expected: `select 1.23`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{true},
|
||||
expected: `select true `,
|
||||
args: []interface{}{true},
|
||||
expected: `select true`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{[]byte{0, 1, 2, 3, 255}},
|
||||
expected: `select '\x00010203ff' `,
|
||||
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
|
||||
expected: `select '\x00010203ff'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{nil},
|
||||
expected: `select null `,
|
||||
args: []interface{}{nil},
|
||||
expected: `select null`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{"foobar"},
|
||||
expected: `select 'foobar' `,
|
||||
args: []interface{}{"foobar"},
|
||||
expected: `select 'foobar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{"foo'bar"},
|
||||
expected: `select 'foo''bar' `,
|
||||
args: []interface{}{"foo'bar"},
|
||||
expected: `select 'foo''bar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{`foo\'bar`},
|
||||
expected: `select 'foo\''bar' `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
|
||||
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
|
||||
expected: `insert '2020-03-01 23:59:59.999999Z' `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
|
||||
args: []any{int64(-1)},
|
||||
expected: `select 1- -1 `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
|
||||
args: []any{float64(-1)},
|
||||
expected: `select 1- -1 `,
|
||||
args: []interface{}{`foo\'bar`},
|
||||
expected: `select 'foo\''bar'`,
|
||||
},
|
||||
}
|
||||
|
||||
@ -202,22 +146,22 @@ func TestQuerySanitize(t *testing.T) {
|
||||
|
||||
errorTests := []struct {
|
||||
query sanitize.Query
|
||||
args []any
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
|
||||
args: []any{int64(42)},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `insufficient arguments`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
|
||||
args: []any{int64(42)},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `unused argument: 0`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{42},
|
||||
args: []interface{}{42},
|
||||
expected: `invalid arg type: int`,
|
||||
},
|
||||
}
|
||||
@ -229,55 +173,3 @@ func TestQuerySanitize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
tc := func(name, input string) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteString(nil, input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("empty", "")
|
||||
tc("text", "abcd")
|
||||
tc("with quotes", `one's hat is always a cat`)
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteString(str string) string {
|
||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||
}
|
||||
|
||||
func TestQuoteBytes(t *testing.T) {
|
||||
tc := func(name string, input []byte) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteBytes(nil, input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("nil", nil)
|
||||
tc("empty", []byte{})
|
||||
tc("text", []byte("abcd"))
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
@ -1,111 +0,0 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
|
||||
type LRUCache struct {
|
||||
cap int
|
||||
m map[string]*list.Element
|
||||
l *list.List
|
||||
invalidStmts []*pgconn.StatementDescription
|
||||
}
|
||||
|
||||
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
|
||||
func NewLRUCache(cap int) *LRUCache {
|
||||
return &LRUCache{
|
||||
cap: cap,
|
||||
m: make(map[string]*list.Element),
|
||||
l: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the statement description for sql. Returns nil if not found.
|
||||
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
|
||||
if el, ok := c.m[key]; ok {
|
||||
c.l.MoveToFront(el)
|
||||
return el.Value.(*pgconn.StatementDescription)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
|
||||
// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
|
||||
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
|
||||
if sd.SQL == "" {
|
||||
panic("cannot store statement description with empty SQL")
|
||||
}
|
||||
|
||||
if _, present := c.m[sd.SQL]; present {
|
||||
return
|
||||
}
|
||||
|
||||
// The statement may have been invalidated but not yet handled. Do not readd it to the cache.
|
||||
for _, invalidSD := range c.invalidStmts {
|
||||
if invalidSD.SQL == sd.SQL {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if c.l.Len() == c.cap {
|
||||
c.invalidateOldest()
|
||||
}
|
||||
|
||||
el := c.l.PushFront(sd)
|
||||
c.m[sd.SQL] = el
|
||||
}
|
||||
|
||||
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||
func (c *LRUCache) Invalidate(sql string) {
|
||||
if el, ok := c.m[sql]; ok {
|
||||
delete(c.m, sql)
|
||||
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||
c.l.Remove(el)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
func (c *LRUCache) InvalidateAll() {
|
||||
el := c.l.Front()
|
||||
for el != nil {
|
||||
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||
el = el.Next()
|
||||
}
|
||||
|
||||
c.m = make(map[string]*list.Element)
|
||||
c.l = list.New()
|
||||
}
|
||||
|
||||
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
|
||||
return c.invalidStmts
|
||||
}
|
||||
|
||||
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||
// never seen by the call to GetInvalidated.
|
||||
func (c *LRUCache) RemoveInvalidated() {
|
||||
c.invalidStmts = nil
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *LRUCache) Len() int {
|
||||
return c.l.Len()
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *LRUCache) Cap() int {
|
||||
return c.cap
|
||||
}
|
||||
|
||||
func (c *LRUCache) invalidateOldest() {
|
||||
oldest := c.l.Back()
|
||||
sd := oldest.Value.(*pgconn.StatementDescription)
|
||||
c.invalidStmts = append(c.invalidStmts, sd)
|
||||
delete(c.m, sd.SQL)
|
||||
c.l.Remove(oldest)
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
// Package stmtcache is a cache for statement descriptions.
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// StatementName returns a statement name that will be stable for sql across multiple connections and program
|
||||
// executions.
|
||||
func StatementName(sql string) string {
|
||||
digest := sha256.Sum256([]byte(sql))
|
||||
return "stmtcache_" + hex.EncodeToString(digest[0:24])
|
||||
}
|
||||
|
||||
// Cache caches statement descriptions.
|
||||
type Cache interface {
|
||||
// Get returns the statement description for sql. Returns nil if not found.
|
||||
Get(sql string) *pgconn.StatementDescription
|
||||
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||
Put(sd *pgconn.StatementDescription)
|
||||
|
||||
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||
Invalidate(sql string)
|
||||
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
InvalidateAll()
|
||||
|
||||
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||
GetInvalidated() []*pgconn.StatementDescription
|
||||
|
||||
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||
// never seen by the call to GetInvalidated.
|
||||
RemoveInvalidated()
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
Len() int
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
Cap() int
|
||||
}
|
@ -1,77 +0,0 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// UnlimitedCache implements Cache with no capacity limit.
|
||||
type UnlimitedCache struct {
|
||||
m map[string]*pgconn.StatementDescription
|
||||
invalidStmts []*pgconn.StatementDescription
|
||||
}
|
||||
|
||||
// NewUnlimitedCache creates a new UnlimitedCache.
|
||||
func NewUnlimitedCache() *UnlimitedCache {
|
||||
return &UnlimitedCache{
|
||||
m: make(map[string]*pgconn.StatementDescription),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the statement description for sql. Returns nil if not found.
|
||||
func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
|
||||
return c.m[sql]
|
||||
}
|
||||
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||
func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
|
||||
if sd.SQL == "" {
|
||||
panic("cannot store statement description with empty SQL")
|
||||
}
|
||||
|
||||
if _, present := c.m[sd.SQL]; present {
|
||||
return
|
||||
}
|
||||
|
||||
c.m[sd.SQL] = sd
|
||||
}
|
||||
|
||||
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||
func (c *UnlimitedCache) Invalidate(sql string) {
|
||||
if sd, ok := c.m[sql]; ok {
|
||||
delete(c.m, sql)
|
||||
c.invalidStmts = append(c.invalidStmts, sd)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
func (c *UnlimitedCache) InvalidateAll() {
|
||||
for _, sd := range c.m {
|
||||
c.invalidStmts = append(c.invalidStmts, sd)
|
||||
}
|
||||
|
||||
c.m = make(map[string]*pgconn.StatementDescription)
|
||||
}
|
||||
|
||||
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
|
||||
return c.invalidStmts
|
||||
}
|
||||
|
||||
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||
// never seen by the call to GetInvalidated.
|
||||
func (c *UnlimitedCache) RemoveInvalidated() {
|
||||
c.invalidStmts = nil
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *UnlimitedCache) Len() int {
|
||||
return len(c.m)
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *UnlimitedCache) Cap() int {
|
||||
return math.MaxInt
|
||||
}
|
208
large_objects.go
208
large_objects.go
@ -1,24 +1,56 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
|
||||
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
|
||||
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
|
||||
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
|
||||
|
||||
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
|
||||
// was created.
|
||||
// LargeObjects is a structure used to access the large objects API. It is only
|
||||
// valid within the transaction where it was created.
|
||||
//
|
||||
// For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html
|
||||
type LargeObjects struct {
|
||||
tx Tx
|
||||
// Has64 is true if the server is capable of working with 64-bit numbers
|
||||
Has64 bool
|
||||
fp *fastpath
|
||||
}
|
||||
|
||||
const largeObjectFns = `select proname, oid from pg_catalog.pg_proc
|
||||
where proname in (
|
||||
'lo_open',
|
||||
'lo_close',
|
||||
'lo_create',
|
||||
'lo_unlink',
|
||||
'lo_lseek',
|
||||
'lo_lseek64',
|
||||
'lo_tell',
|
||||
'lo_tell64',
|
||||
'lo_truncate',
|
||||
'lo_truncate64',
|
||||
'loread',
|
||||
'lowrite')
|
||||
and pronamespace = (select oid from pg_catalog.pg_namespace where nspname = 'pg_catalog')`
|
||||
|
||||
// LargeObjects returns a LargeObjects instance for the transaction.
|
||||
func (tx *Tx) LargeObjects() (*LargeObjects, error) {
|
||||
if tx.conn.fp == nil {
|
||||
tx.conn.fp = newFastpath(tx.conn)
|
||||
}
|
||||
if _, exists := tx.conn.fp.fns["lo_open"]; !exists {
|
||||
res, err := tx.Query(largeObjectFns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.conn.fp.addFunctions(res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
lo := &LargeObjects{fp: tx.conn.fp}
|
||||
_, lo.Has64 = lo.fp.fns["lo_lseek64"]
|
||||
|
||||
return lo, nil
|
||||
}
|
||||
|
||||
type LargeObjectMode int32
|
||||
@ -28,134 +60,90 @@ const (
|
||||
LargeObjectModeRead LargeObjectMode = 0x40000
|
||||
)
|
||||
|
||||
// Create creates a new large object. If oid is zero, the server assigns an unused OID.
|
||||
func (o *LargeObjects) Create(ctx context.Context, oid uint32) (uint32, error) {
|
||||
err := o.tx.QueryRow(ctx, "select lo_create($1)", oid).Scan(&oid)
|
||||
return oid, err
|
||||
// Create creates a new large object. If id is zero, the server assigns an
|
||||
// unused OID.
|
||||
func (o *LargeObjects) Create(id pgtype.OID) (pgtype.OID, error) {
|
||||
newOID, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))}))
|
||||
return pgtype.OID(newOID), err
|
||||
}
|
||||
|
||||
// Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large
|
||||
// object.
|
||||
func (o *LargeObjects) Open(ctx context.Context, oid uint32, mode LargeObjectMode) (*LargeObject, error) {
|
||||
var fd int32
|
||||
err := o.tx.QueryRow(ctx, "select lo_open($1, $2)", oid, mode).Scan(&fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LargeObject{fd: fd, tx: o.tx, ctx: ctx}, nil
|
||||
// Open opens an existing large object with the given mode.
|
||||
func (o *LargeObjects) Open(oid pgtype.OID, mode LargeObjectMode) (*LargeObject, error) {
|
||||
fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))}))
|
||||
return &LargeObject{fd: fd, lo: o}, err
|
||||
}
|
||||
|
||||
// Unlink removes a large object from the database.
|
||||
func (o *LargeObjects) Unlink(ctx context.Context, oid uint32) error {
|
||||
var result int32
|
||||
err := o.tx.QueryRow(ctx, "select lo_unlink($1)", oid).Scan(&result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if result != 1 {
|
||||
return errors.New("failed to remove large object")
|
||||
}
|
||||
|
||||
return nil
|
||||
func (o *LargeObjects) Unlink(oid pgtype.OID) error {
|
||||
_, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))})
|
||||
return err
|
||||
}
|
||||
|
||||
// A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized
|
||||
// in. It uses the context it was initialized with for all operations. It implements these interfaces:
|
||||
// A LargeObject is a large object stored on the server. It is only valid within
|
||||
// the transaction that it was initialized in. It implements these interfaces:
|
||||
//
|
||||
// io.Writer
|
||||
// io.Reader
|
||||
// io.Seeker
|
||||
// io.Closer
|
||||
// io.Writer
|
||||
// io.Reader
|
||||
// io.Seeker
|
||||
// io.Closer
|
||||
type LargeObject struct {
|
||||
ctx context.Context
|
||||
tx Tx
|
||||
fd int32
|
||||
fd int32
|
||||
lo *LargeObjects
|
||||
}
|
||||
|
||||
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
|
||||
// Write writes p to the large object and returns the number of bytes written
|
||||
// and an error if not all of p was written.
|
||||
func (o *LargeObject) Write(p []byte) (int, error) {
|
||||
nTotal := 0
|
||||
for {
|
||||
expected := len(p) - nTotal
|
||||
if expected == 0 {
|
||||
break
|
||||
} else if expected > maxLargeObjectMessageLength {
|
||||
expected = maxLargeObjectMessageLength
|
||||
}
|
||||
|
||||
var n int
|
||||
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
|
||||
if err != nil {
|
||||
return nTotal, err
|
||||
}
|
||||
|
||||
if n < 0 {
|
||||
return nTotal, errors.New("failed to write to large object")
|
||||
}
|
||||
|
||||
nTotal += n
|
||||
|
||||
if n < expected {
|
||||
return nTotal, errors.New("short write to large object")
|
||||
} else if n > expected {
|
||||
return nTotal, errors.New("invalid write to large object")
|
||||
}
|
||||
}
|
||||
|
||||
return nTotal, nil
|
||||
n, err := fpInt32(o.lo.fp.CallFn("lowrite", []fpArg{fpIntArg(o.fd), p}))
|
||||
return int(n), err
|
||||
}
|
||||
|
||||
// Read reads up to len(p) bytes into p returning the number of bytes read.
|
||||
func (o *LargeObject) Read(p []byte) (int, error) {
|
||||
nTotal := 0
|
||||
for {
|
||||
expected := len(p) - nTotal
|
||||
if expected == 0 {
|
||||
break
|
||||
} else if expected > maxLargeObjectMessageLength {
|
||||
expected = maxLargeObjectMessageLength
|
||||
}
|
||||
|
||||
res := pgtype.PreallocBytes(p[nTotal:])
|
||||
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
|
||||
// We compute expected so that it always fits into p, so it should never happen
|
||||
// that PreallocBytes's ScanBytes had to allocate a new slice.
|
||||
nTotal += len(res)
|
||||
if err != nil {
|
||||
return nTotal, err
|
||||
}
|
||||
|
||||
if len(res) < expected {
|
||||
return nTotal, io.EOF
|
||||
} else if len(res) > expected {
|
||||
return nTotal, errors.New("invalid read of large object")
|
||||
}
|
||||
res, err := o.lo.fp.CallFn("loread", []fpArg{fpIntArg(o.fd), fpIntArg(int32(len(p)))})
|
||||
if len(res) < len(p) {
|
||||
err = io.EOF
|
||||
}
|
||||
|
||||
return nTotal, nil
|
||||
return copy(p, res), err
|
||||
}
|
||||
|
||||
// Seek moves the current location pointer to the new location specified by offset.
|
||||
func (o *LargeObject) Seek(offset int64, whence int) (n int64, err error) {
|
||||
err = o.tx.QueryRow(o.ctx, "select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n)
|
||||
return n, err
|
||||
if o.lo.Has64 {
|
||||
n, err = fpInt64(o.lo.fp.CallFn("lo_lseek64", []fpArg{fpIntArg(o.fd), fpInt64Arg(offset), fpIntArg(int32(whence))}))
|
||||
} else {
|
||||
var n32 int32
|
||||
n32, err = fpInt32(o.lo.fp.CallFn("lo_lseek", []fpArg{fpIntArg(o.fd), fpIntArg(int32(offset)), fpIntArg(int32(whence))}))
|
||||
n = int64(n32)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Tell returns the current read or write location of the large object descriptor.
|
||||
// Tell returns the current read or write location of the large object
|
||||
// descriptor.
|
||||
func (o *LargeObject) Tell() (n int64, err error) {
|
||||
err = o.tx.QueryRow(o.ctx, "select lo_tell64($1)", o.fd).Scan(&n)
|
||||
return n, err
|
||||
if o.lo.Has64 {
|
||||
n, err = fpInt64(o.lo.fp.CallFn("lo_tell64", []fpArg{fpIntArg(o.fd)}))
|
||||
} else {
|
||||
var n32 int32
|
||||
n32, err = fpInt32(o.lo.fp.CallFn("lo_tell", []fpArg{fpIntArg(o.fd)}))
|
||||
n = int64(n32)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Truncate the large object to size.
|
||||
// Trunctes the large object to size.
|
||||
func (o *LargeObject) Truncate(size int64) (err error) {
|
||||
_, err = o.tx.Exec(o.ctx, "select lo_truncate64($1, $2)", o.fd, size)
|
||||
return err
|
||||
if o.lo.Has64 {
|
||||
_, err = o.lo.fp.CallFn("lo_truncate64", []fpArg{fpIntArg(o.fd), fpInt64Arg(size)})
|
||||
} else {
|
||||
_, err = o.lo.fp.CallFn("lo_truncate", []fpArg{fpIntArg(o.fd), fpIntArg(int32(size))})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Close the large object descriptor.
|
||||
// Close closees the large object descriptor.
|
||||
func (o *LargeObject) Close() error {
|
||||
_, err := o.tx.Exec(o.ctx, "select lo_close($1)", o.fd)
|
||||
_, err := o.lo.fp.CallFn("lo_close", []fpArg{fpIntArg(o.fd)})
|
||||
return err
|
||||
}
|
||||
|
@ -1,20 +0,0 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable
|
||||
// to the given length for the duration of the test.
|
||||
//
|
||||
// Tests using this helper should not use t.Parallel().
|
||||
func SetMaxLargeObjectMessageLength(t *testing.T, length int) {
|
||||
t.Helper()
|
||||
|
||||
original := maxLargeObjectMessageLength
|
||||
t.Cleanup(func() {
|
||||
maxLargeObjectMessageLength = original
|
||||
})
|
||||
|
||||
maxLargeObjectMessageLength = length
|
||||
}
|
@ -1,77 +1,36 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
func TestLargeObjects(t *testing.T) {
|
||||
// We use a very short limit to test chunking logic.
|
||||
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
conn, err := pgx.Connect(*defaultConnConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testLargeObjects(t, ctx, tx)
|
||||
}
|
||||
|
||||
func TestLargeObjectsSimpleProtocol(t *testing.T) {
|
||||
// We use a very short limit to test chunking logic.
|
||||
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
lo, err := tx.LargeObjects()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
||||
|
||||
conn, err := pgx.ConnectConfig(ctx, config)
|
||||
id, err := lo.Create(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testLargeObjects(t, ctx, tx)
|
||||
}
|
||||
|
||||
func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) {
|
||||
lo := tx.LargeObjects()
|
||||
|
||||
id, err := lo.Create(ctx, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
obj, err := lo.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite)
|
||||
obj, err := lo.Open(id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -150,44 +109,41 @@ func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = lo.Unlink(ctx, id)
|
||||
err = lo.Unlink(id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = lo.Open(ctx, id, pgx.LargeObjectModeRead)
|
||||
if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" {
|
||||
_, err = lo.Open(id, pgx.LargeObjectModeRead)
|
||||
if e, ok := err.(pgx.PgError); !ok || e.Code != "42704" {
|
||||
t.Errorf("Expected undefined_object error (42704), got %#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
||||
// We use a very short limit to test chunking logic.
|
||||
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
conn, err := pgx.Connect(*defaultConnConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lo := tx.LargeObjects()
|
||||
|
||||
id, err := lo.Create(ctx, 0)
|
||||
lo, err := tx.LargeObjects()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
obj, err := lo.Open(ctx, id, pgx.LargeObjectModeWrite)
|
||||
id, err := lo.Create(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
obj, err := lo.Open(id, pgx.LargeObjectModeWrite)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -201,29 +157,32 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Commit the first transaction
|
||||
err = tx.Commit(ctx)
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// IMPORTANT: Use the same connection for another query
|
||||
query := `select n from generate_series(1,10) n`
|
||||
rows, err := conn.Query(ctx, query)
|
||||
rows, err := conn.Query(query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Start a new transaction
|
||||
tx2, err := conn.Begin(ctx)
|
||||
tx2, err := conn.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lo2 := tx2.LargeObjects()
|
||||
lo2, err := tx2.LargeObjects()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Reopen the large object in the new transaction
|
||||
obj2, err := lo2.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite)
|
||||
obj2, err := lo2.Open(id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -294,13 +253,13 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = lo2.Unlink(ctx, id)
|
||||
err = lo2.Unlink(id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = lo2.Open(ctx, id, pgx.LargeObjectModeRead)
|
||||
if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" {
|
||||
_, err = lo2.Open(id, pgx.LargeObjectModeRead)
|
||||
if e, ok := err.(pgx.PgError); !ok || e.Code != "42704" {
|
||||
t.Errorf("Expected undefined_object error (42704), got %#v", err)
|
||||
}
|
||||
}
|
||||
|
47
log/log15adapter/adapter.go
Normal file
47
log/log15adapter/adapter.go
Normal file
@ -0,0 +1,47 @@
|
||||
// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger
|
||||
// log.
|
||||
package log15adapter
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
// Log15Logger interface defines the subset of
|
||||
// github.com/inconshreveable/log15.Logger that this adapter uses.
|
||||
type Log15Logger interface {
|
||||
Debug(msg string, ctx ...interface{})
|
||||
Info(msg string, ctx ...interface{})
|
||||
Warn(msg string, ctx ...interface{})
|
||||
Error(msg string, ctx ...interface{})
|
||||
Crit(msg string, ctx ...interface{})
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
l Log15Logger
|
||||
}
|
||||
|
||||
func NewLogger(l Log15Logger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logArgs := make([]interface{}, 0, len(data))
|
||||
for k, v := range data {
|
||||
logArgs = append(logArgs, k, v)
|
||||
}
|
||||
|
||||
switch level {
|
||||
case pgx.LogLevelTrace:
|
||||
l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...)
|
||||
case pgx.LogLevelDebug:
|
||||
l.l.Debug(msg, logArgs...)
|
||||
case pgx.LogLevelInfo:
|
||||
l.l.Info(msg, logArgs...)
|
||||
case pgx.LogLevelWarn:
|
||||
l.l.Warn(msg, logArgs...)
|
||||
case pgx.LogLevelError:
|
||||
l.l.Error(msg, logArgs...)
|
||||
default:
|
||||
l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...)
|
||||
}
|
||||
}
|
40
log/logrusadapter/adapter.go
Normal file
40
log/logrusadapter/adapter.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Package logrusadapter provides a logger that writes to a github.com/sirupsen/logrus.Logger
|
||||
// log.
|
||||
package logrusadapter
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
l logrus.FieldLogger
|
||||
}
|
||||
|
||||
func NewLogger(l logrus.FieldLogger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
var logger logrus.FieldLogger
|
||||
if data != nil {
|
||||
logger = l.l.WithFields(data)
|
||||
} else {
|
||||
logger = l.l
|
||||
}
|
||||
|
||||
switch level {
|
||||
case pgx.LogLevelTrace:
|
||||
logger.WithField("PGX_LOG_LEVEL", level).Debug(msg)
|
||||
case pgx.LogLevelDebug:
|
||||
logger.Debug(msg)
|
||||
case pgx.LogLevelInfo:
|
||||
logger.Info(msg)
|
||||
case pgx.LogLevelWarn:
|
||||
logger.Warn(msg)
|
||||
case pgx.LogLevelError:
|
||||
logger.Error(msg)
|
||||
default:
|
||||
logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg)
|
||||
}
|
||||
}
|
@ -3,16 +3,15 @@
|
||||
package testingadapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/tracelog"
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
// TestingLogger interface defines the subset of testing.TB methods used by this
|
||||
// adapter.
|
||||
type TestingLogger interface {
|
||||
Log(args ...any)
|
||||
Log(args ...interface{})
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
@ -23,8 +22,8 @@ func NewLogger(l TestingLogger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) {
|
||||
logArgs := make([]any, 0, 2+len(data))
|
||||
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logArgs := make([]interface{}, 0, 2+len(data))
|
||||
logArgs = append(logArgs, level, msg)
|
||||
for k, v := range data {
|
||||
logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v))
|
||||
|
40
log/zapadapter/adapter.go
Normal file
40
log/zapadapter/adapter.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Package zapadapter provides a logger that writes to a go.uber.org/zap.Logger.
|
||||
package zapadapter
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewLogger(logger *zap.Logger) *Logger {
|
||||
return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))}
|
||||
}
|
||||
|
||||
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
fields := make([]zapcore.Field, len(data))
|
||||
i := 0
|
||||
for k, v := range data {
|
||||
fields[i] = zap.Reflect(k, v)
|
||||
i++
|
||||
}
|
||||
|
||||
switch level {
|
||||
case pgx.LogLevelTrace:
|
||||
pl.logger.Debug(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...)
|
||||
case pgx.LogLevelDebug:
|
||||
pl.logger.Debug(msg, fields...)
|
||||
case pgx.LogLevelInfo:
|
||||
pl.logger.Info(msg, fields...)
|
||||
case pgx.LogLevelWarn:
|
||||
pl.logger.Warn(msg, fields...)
|
||||
case pgx.LogLevelError:
|
||||
pl.logger.Error(msg, fields...)
|
||||
default:
|
||||
pl.logger.Error(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...)
|
||||
}
|
||||
}
|
40
log/zerologadapter/adapter.go
Normal file
40
log/zerologadapter/adapter.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Package zerologadapter provides a logger that writes to a github.com/rs/zerolog.
|
||||
package zerologadapter
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx
|
||||
// logging fascade as output.
|
||||
func NewLogger(logger zerolog.Logger) *Logger {
|
||||
return &Logger{
|
||||
logger: logger.With().Str("module", "pgx").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
var zlevel zerolog.Level
|
||||
switch level {
|
||||
case pgx.LogLevelNone:
|
||||
zlevel = zerolog.NoLevel
|
||||
case pgx.LogLevelError:
|
||||
zlevel = zerolog.ErrorLevel
|
||||
case pgx.LogLevelWarn:
|
||||
zlevel = zerolog.WarnLevel
|
||||
case pgx.LogLevelInfo:
|
||||
zlevel = zerolog.InfoLevel
|
||||
case pgx.LogLevelDebug:
|
||||
zlevel = zerolog.DebugLevel
|
||||
default:
|
||||
zlevel = zerolog.DebugLevel
|
||||
}
|
||||
|
||||
pgxlog := pl.logger.With().Fields(data).Logger()
|
||||
pgxlog.WithLevel(zlevel).Msg(msg)
|
||||
}
|
98
logger.go
Normal file
98
logger.go
Normal file
@ -0,0 +1,98 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// The values for log levels are chosen such that the zero value means that no
|
||||
// log level was specified.
|
||||
const (
|
||||
LogLevelTrace = 6
|
||||
LogLevelDebug = 5
|
||||
LogLevelInfo = 4
|
||||
LogLevelWarn = 3
|
||||
LogLevelError = 2
|
||||
LogLevelNone = 1
|
||||
)
|
||||
|
||||
// LogLevel represents the pgx logging level. See LogLevel* constants for
|
||||
// possible values.
|
||||
type LogLevel int
|
||||
|
||||
func (ll LogLevel) String() string {
|
||||
switch ll {
|
||||
case LogLevelTrace:
|
||||
return "trace"
|
||||
case LogLevelDebug:
|
||||
return "debug"
|
||||
case LogLevelInfo:
|
||||
return "info"
|
||||
case LogLevelWarn:
|
||||
return "warn"
|
||||
case LogLevelError:
|
||||
return "error"
|
||||
case LogLevelNone:
|
||||
return "none"
|
||||
default:
|
||||
return fmt.Sprintf("invalid level %d", ll)
|
||||
}
|
||||
}
|
||||
|
||||
// Logger is the interface used to get logging from pgx internals.
|
||||
type Logger interface {
|
||||
// Log a message at the given level with data key/value pairs. data may be nil.
|
||||
Log(level LogLevel, msg string, data map[string]interface{})
|
||||
}
|
||||
|
||||
// LogLevelFromString converts log level string to constant
|
||||
//
|
||||
// Valid levels:
|
||||
// trace
|
||||
// debug
|
||||
// info
|
||||
// warn
|
||||
// error
|
||||
// none
|
||||
func LogLevelFromString(s string) (LogLevel, error) {
|
||||
switch s {
|
||||
case "trace":
|
||||
return LogLevelTrace, nil
|
||||
case "debug":
|
||||
return LogLevelDebug, nil
|
||||
case "info":
|
||||
return LogLevelInfo, nil
|
||||
case "warn":
|
||||
return LogLevelWarn, nil
|
||||
case "error":
|
||||
return LogLevelError, nil
|
||||
case "none":
|
||||
return LogLevelNone, nil
|
||||
default:
|
||||
return 0, errors.New("invalid log level")
|
||||
}
|
||||
}
|
||||
|
||||
func logQueryArgs(args []interface{}) []interface{} {
|
||||
logArgs := make([]interface{}, 0, len(args))
|
||||
|
||||
for _, a := range args {
|
||||
switch v := a.(type) {
|
||||
case []byte:
|
||||
if len(v) < 64 {
|
||||
a = hex.EncodeToString(v)
|
||||
} else {
|
||||
a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64)
|
||||
}
|
||||
case string:
|
||||
if len(v) > 64 {
|
||||
a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64)
|
||||
}
|
||||
}
|
||||
logArgs = append(logArgs, a)
|
||||
}
|
||||
|
||||
return logArgs
|
||||
}
|
240
messages.go
Normal file
240
messages.go
Normal file
@ -0,0 +1,240 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
const (
|
||||
copyData = 'd'
|
||||
copyFail = 'f'
|
||||
copyDone = 'c'
|
||||
varHeaderSize = 4
|
||||
)
|
||||
|
||||
type FieldDescription struct {
|
||||
Name string
|
||||
Table pgtype.OID
|
||||
AttributeNumber uint16
|
||||
DataType pgtype.OID
|
||||
DataTypeSize int16
|
||||
DataTypeName string
|
||||
Modifier int32
|
||||
FormatCode int16
|
||||
}
|
||||
|
||||
func (fd FieldDescription) Length() (int64, bool) {
|
||||
switch fd.DataType {
|
||||
case pgtype.TextOID, pgtype.ByteaOID:
|
||||
return math.MaxInt64, true
|
||||
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
|
||||
return int64(fd.Modifier - varHeaderSize), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) {
|
||||
switch fd.DataType {
|
||||
case pgtype.NumericOID:
|
||||
mod := fd.Modifier - varHeaderSize
|
||||
precision = int64((mod >> 16) & 0xffff)
|
||||
scale = int64(mod & 0xffff)
|
||||
return precision, scale, true
|
||||
default:
|
||||
return 0, 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (fd FieldDescription) Type() reflect.Type {
|
||||
switch fd.DataType {
|
||||
case pgtype.Float8OID:
|
||||
return reflect.TypeOf(float64(0))
|
||||
case pgtype.Float4OID:
|
||||
return reflect.TypeOf(float32(0))
|
||||
case pgtype.Int8OID:
|
||||
return reflect.TypeOf(int64(0))
|
||||
case pgtype.Int4OID:
|
||||
return reflect.TypeOf(int32(0))
|
||||
case pgtype.Int2OID:
|
||||
return reflect.TypeOf(int16(0))
|
||||
case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID:
|
||||
return reflect.TypeOf("")
|
||||
case pgtype.BoolOID:
|
||||
return reflect.TypeOf(false)
|
||||
case pgtype.NumericOID:
|
||||
return reflect.TypeOf(float64(0))
|
||||
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
|
||||
return reflect.TypeOf(time.Time{})
|
||||
case pgtype.ByteaOID:
|
||||
return reflect.TypeOf([]byte(nil))
|
||||
default:
|
||||
return reflect.TypeOf(new(interface{})).Elem()
|
||||
}
|
||||
}
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
||||
// detailed field description.
|
||||
type PgError struct {
|
||||
Severity string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
}
|
||||
|
||||
func (pe PgError) Error() string {
|
||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
// Notice represents a notice response message reported by the PostgreSQL
|
||||
// server. Be aware that this is distinct from LISTEN/NOTIFY notification.
|
||||
type Notice PgError
|
||||
|
||||
// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it.
|
||||
func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.OID) []byte {
|
||||
buf = append(buf, 'P')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, query...)
|
||||
buf = append(buf, 0)
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(parameterOIDs)))
|
||||
for _, oid := range parameterOIDs {
|
||||
buf = pgio.AppendUint32(buf, uint32(oid))
|
||||
}
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it.
|
||||
func appendDescribe(buf []byte, objectType byte, name string) []byte {
|
||||
buf = append(buf, 'D')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, objectType)
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it.
|
||||
func appendSync(buf []byte) []byte {
|
||||
buf = append(buf, 'S')
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it.
|
||||
func appendBind(
|
||||
buf []byte,
|
||||
destinationPortal,
|
||||
preparedStatement string,
|
||||
connInfo *pgtype.ConnInfo,
|
||||
parameterOIDs []pgtype.OID,
|
||||
arguments []interface{},
|
||||
resultFormatCodes []int16,
|
||||
) ([]byte, error) {
|
||||
buf = append(buf, 'B')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, destinationPortal...)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, preparedStatement...)
|
||||
buf = append(buf, 0)
|
||||
|
||||
var err error
|
||||
arguments, err = convertDriverValuers(arguments)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(parameterOIDs)))
|
||||
for i, oid := range parameterOIDs {
|
||||
buf = pgio.AppendInt16(buf, chooseParameterFormatCode(connInfo, oid, arguments[i]))
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(arguments)))
|
||||
for i, oid := range parameterOIDs {
|
||||
var err error
|
||||
buf, err = encodePreparedStatementArgument(connInfo, buf, oid, arguments[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes)))
|
||||
for _, fc := range resultFormatCodes {
|
||||
buf = pgio.AppendInt16(buf, fc)
|
||||
}
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func convertDriverValuers(args []interface{}) ([]interface{}, error) {
|
||||
for i, arg := range args {
|
||||
switch arg := arg.(type) {
|
||||
case pgtype.BinaryEncoder:
|
||||
case pgtype.TextEncoder:
|
||||
case driver.Valuer:
|
||||
v, err := callValuerValue(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args[i] = v
|
||||
}
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it.
|
||||
func appendExecute(buf []byte, portal string, maxRows uint32) []byte {
|
||||
buf = append(buf, 'E')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
buf = append(buf, portal...)
|
||||
buf = append(buf, 0)
|
||||
buf = pgio.AppendUint32(buf, maxRows)
|
||||
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it.
|
||||
func appendQuery(buf []byte, query string) []byte {
|
||||
buf = append(buf, 'Q')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, query...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
@ -1,152 +0,0 @@
|
||||
// Package multitracer provides a Tracer that can combine several tracers into one.
|
||||
package multitracer
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// Tracer can combine several tracers into one.
|
||||
// You can use New to automatically split tracers by interface.
|
||||
type Tracer struct {
|
||||
QueryTracers []pgx.QueryTracer
|
||||
BatchTracers []pgx.BatchTracer
|
||||
CopyFromTracers []pgx.CopyFromTracer
|
||||
PrepareTracers []pgx.PrepareTracer
|
||||
ConnectTracers []pgx.ConnectTracer
|
||||
PoolAcquireTracers []pgxpool.AcquireTracer
|
||||
PoolReleaseTracers []pgxpool.ReleaseTracer
|
||||
}
|
||||
|
||||
// New returns new Tracer from tracers with automatically split tracers by interface.
|
||||
func New(tracers ...pgx.QueryTracer) *Tracer {
|
||||
var t Tracer
|
||||
|
||||
for _, tracer := range tracers {
|
||||
t.QueryTracers = append(t.QueryTracers, tracer)
|
||||
|
||||
if batchTracer, ok := tracer.(pgx.BatchTracer); ok {
|
||||
t.BatchTracers = append(t.BatchTracers, batchTracer)
|
||||
}
|
||||
|
||||
if copyFromTracer, ok := tracer.(pgx.CopyFromTracer); ok {
|
||||
t.CopyFromTracers = append(t.CopyFromTracers, copyFromTracer)
|
||||
}
|
||||
|
||||
if prepareTracer, ok := tracer.(pgx.PrepareTracer); ok {
|
||||
t.PrepareTracers = append(t.PrepareTracers, prepareTracer)
|
||||
}
|
||||
|
||||
if connectTracer, ok := tracer.(pgx.ConnectTracer); ok {
|
||||
t.ConnectTracers = append(t.ConnectTracers, connectTracer)
|
||||
}
|
||||
|
||||
if poolAcquireTracer, ok := tracer.(pgxpool.AcquireTracer); ok {
|
||||
t.PoolAcquireTracers = append(t.PoolAcquireTracers, poolAcquireTracer)
|
||||
}
|
||||
|
||||
if poolReleaseTracer, ok := tracer.(pgxpool.ReleaseTracer); ok {
|
||||
t.PoolReleaseTracers = append(t.PoolReleaseTracers, poolReleaseTracer)
|
||||
}
|
||||
}
|
||||
|
||||
return &t
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
for _, tracer := range t.QueryTracers {
|
||||
ctx = tracer.TraceQueryStart(ctx, conn, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
for _, tracer := range t.QueryTracers {
|
||||
tracer.TraceQueryEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
for _, tracer := range t.BatchTracers {
|
||||
ctx = tracer.TraceBatchStart(ctx, conn, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
for _, tracer := range t.BatchTracers {
|
||||
tracer.TraceBatchQuery(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
for _, tracer := range t.BatchTracers {
|
||||
tracer.TraceBatchEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
|
||||
for _, tracer := range t.CopyFromTracers {
|
||||
ctx = tracer.TraceCopyFromStart(ctx, conn, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
|
||||
for _, tracer := range t.CopyFromTracers {
|
||||
tracer.TraceCopyFromEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
|
||||
for _, tracer := range t.PrepareTracers {
|
||||
ctx = tracer.TracePrepareStart(ctx, conn, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
|
||||
for _, tracer := range t.PrepareTracers {
|
||||
tracer.TracePrepareEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
|
||||
for _, tracer := range t.ConnectTracers {
|
||||
ctx = tracer.TraceConnectStart(ctx, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
|
||||
for _, tracer := range t.ConnectTracers {
|
||||
tracer.TraceConnectEnd(ctx, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
|
||||
for _, tracer := range t.PoolAcquireTracers {
|
||||
ctx = tracer.TraceAcquireStart(ctx, pool, data)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
|
||||
for _, tracer := range t.PoolAcquireTracers {
|
||||
tracer.TraceAcquireEnd(ctx, pool, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
|
||||
for _, tracer := range t.PoolReleaseTracers {
|
||||
tracer.TraceRelease(pool, data)
|
||||
}
|
||||
}
|
@ -1,115 +0,0 @@
|
||||
package multitracer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/multitracer"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testFullTracer struct{}
|
||||
|
||||
func (tt *testFullTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
|
||||
}
|
||||
|
||||
func (tt *testFullTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
|
||||
}
|
||||
|
||||
type testCopyTracer struct{}
|
||||
|
||||
func (tt *testCopyTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testCopyTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
}
|
||||
|
||||
func (tt *testCopyTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testCopyTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fullTracer := &testFullTracer{}
|
||||
copyTracer := &testCopyTracer{}
|
||||
|
||||
mt := multitracer.New(fullTracer, copyTracer)
|
||||
require.Equal(
|
||||
t,
|
||||
&multitracer.Tracer{
|
||||
QueryTracers: []pgx.QueryTracer{
|
||||
fullTracer,
|
||||
copyTracer,
|
||||
},
|
||||
BatchTracers: []pgx.BatchTracer{
|
||||
fullTracer,
|
||||
},
|
||||
CopyFromTracers: []pgx.CopyFromTracer{
|
||||
fullTracer,
|
||||
copyTracer,
|
||||
},
|
||||
PrepareTracers: []pgx.PrepareTracer{
|
||||
fullTracer,
|
||||
},
|
||||
ConnectTracers: []pgx.ConnectTracer{
|
||||
fullTracer,
|
||||
},
|
||||
PoolAcquireTracers: []pgxpool.AcquireTracer{
|
||||
fullTracer,
|
||||
},
|
||||
PoolReleaseTracers: []pgxpool.ReleaseTracer{
|
||||
fullTracer,
|
||||
},
|
||||
},
|
||||
mt,
|
||||
)
|
||||
}
|
295
named_args.go
295
named_args.go
@ -1,295 +0,0 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
|
||||
// ordinal placeholder and construct the appropriate arguments.
|
||||
//
|
||||
// For example, the following two queries are equivalent:
|
||||
//
|
||||
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
|
||||
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
|
||||
//
|
||||
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
|
||||
// letters, numbers, or underscores.
|
||||
type NamedArgs map[string]any
|
||||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||
return rewriteQuery(na, sql, false)
|
||||
}
|
||||
|
||||
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
|
||||
// named arguments that the sql query uses, and no extra arguments.
|
||||
type StrictNamedArgs map[string]any
|
||||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||
return rewriteQuery(sna, sql, true)
|
||||
}
|
||||
|
||||
type namedArg string
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []any
|
||||
|
||||
nameToOrdinal map[namedArg]int
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
nameToOrdinal: make(map[namedArg]int, len(na)),
|
||||
}
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
for _, p := range l.parts {
|
||||
switch p := p.(type) {
|
||||
case string:
|
||||
sb.WriteString(p)
|
||||
case namedArg:
|
||||
sb.WriteRune('$')
|
||||
sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
|
||||
}
|
||||
}
|
||||
|
||||
newArgs = make([]any, len(l.nameToOrdinal))
|
||||
for name, ordinal := range l.nameToOrdinal {
|
||||
var found bool
|
||||
newArgs[ordinal-1], found = na[string(name)]
|
||||
if isStrict && !found {
|
||||
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
|
||||
}
|
||||
}
|
||||
|
||||
if isStrict {
|
||||
for name := range na {
|
||||
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
|
||||
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), newArgs, nil
|
||||
}
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case 'e', 'E':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '\'' {
|
||||
l.pos += width
|
||||
return escapeStringState
|
||||
}
|
||||
case '\'':
|
||||
return singleQuoteState
|
||||
case '"':
|
||||
return doubleQuoteState
|
||||
case '@':
|
||||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if isLetter(nextRune) || nextRune == '_' {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||
}
|
||||
l.start = l.pos
|
||||
return namedArgState
|
||||
}
|
||||
case '-':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '-' {
|
||||
l.pos += width
|
||||
return oneLineCommentState
|
||||
}
|
||||
case '/':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '*' {
|
||||
l.pos += width
|
||||
return multilineCommentState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isLetter(r rune) bool {
|
||||
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
|
||||
}
|
||||
|
||||
func namedArgState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
if r == utf8.RuneError {
|
||||
if l.pos-l.start > 0 {
|
||||
na := namedArg(l.src[l.start:l.pos])
|
||||
if _, found := l.nameToOrdinal[na]; !found {
|
||||
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||
}
|
||||
l.parts = append(l.parts, na)
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
} else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') {
|
||||
l.pos -= width
|
||||
na := namedArg(l.src[l.start:l.pos])
|
||||
if _, found := l.nameToOrdinal[na]; !found {
|
||||
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||
}
|
||||
l.parts = append(l.parts, namedArg(na))
|
||||
l.start = l.pos
|
||||
return rawState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func singleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '"' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func escapeStringState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func oneLineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\n', '\r':
|
||||
return rawState
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func multilineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '/':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '*' {
|
||||
l.pos += width
|
||||
l.nested++
|
||||
}
|
||||
case '*':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '/' {
|
||||
continue
|
||||
}
|
||||
|
||||
l.pos += width
|
||||
if l.nested == 0 {
|
||||
return rawState
|
||||
}
|
||||
l.nested--
|
||||
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
@ -1,162 +0,0 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i, tt := range []struct {
|
||||
sql string
|
||||
args []any
|
||||
namedArgs pgx.NamedArgs
|
||||
expectedSQL string
|
||||
expectedArgs []any
|
||||
}{
|
||||
{
|
||||
sql: "select * from users where id = @id",
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: "select * from users where id = $1",
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: "select * from t where foo < @abc and baz = @def and bar < @abc",
|
||||
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
|
||||
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
|
||||
expectedArgs: []any{int32(42), int32(1)},
|
||||
},
|
||||
{
|
||||
sql: "select @a::int, @b::text",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "select $1::int, $2::text",
|
||||
expectedArgs: []any{int32(42), "foo"},
|
||||
},
|
||||
{
|
||||
sql: "select @Abc::int, @b_4::text, @_c::int",
|
||||
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)},
|
||||
expectedSQL: "select $1::int, $2::text, $3::int",
|
||||
expectedArgs: []any{int32(42), "foo", int32(1)},
|
||||
},
|
||||
{
|
||||
sql: "at end @",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "at end @",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "ignores without valid character after @ foo bar",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "ignores without valid character after @ foo bar",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "name cannot start with number @1 foo bar",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "name cannot start with number @1 foo bar",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: `select *, '@foo' as "@bar" from users where id = @id`,
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: `select * -- @foo
|
||||
from users -- @single line comments
|
||||
where id = @id;`,
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: `select * -- @foo
|
||||
from users -- @single line comments
|
||||
where id = $1;`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: `select * /* @multi line
|
||||
@comment
|
||||
*/
|
||||
/* /* with @nesting */ */
|
||||
from users
|
||||
where id = @id;`,
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: `select * /* @multi line
|
||||
@comment
|
||||
*/
|
||||
/* /* with @nesting */ */
|
||||
from users
|
||||
where id = $1;`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: "extra provided argument",
|
||||
namedArgs: pgx.NamedArgs{"extra": int32(1)},
|
||||
expectedSQL: "extra provided argument",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "@missing argument",
|
||||
namedArgs: pgx.NamedArgs{},
|
||||
expectedSQL: "$1 argument",
|
||||
expectedArgs: []any{nil},
|
||||
},
|
||||
|
||||
// test comments and quotes
|
||||
} {
|
||||
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
|
||||
require.NoError(t, err)
|
||||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrictNamedArgsRewriteQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i, tt := range []struct {
|
||||
sql string
|
||||
namedArgs pgx.StrictNamedArgs
|
||||
expectedSQL string
|
||||
expectedArgs []any
|
||||
isExpectedError bool
|
||||
}{
|
||||
{
|
||||
sql: "no arguments",
|
||||
namedArgs: pgx.StrictNamedArgs{},
|
||||
expectedSQL: "no arguments",
|
||||
expectedArgs: []any{},
|
||||
isExpectedError: false,
|
||||
},
|
||||
{
|
||||
sql: "@all @matches",
|
||||
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
|
||||
expectedSQL: "$1 $2",
|
||||
expectedArgs: []any{int32(1), int32(2)},
|
||||
isExpectedError: false,
|
||||
},
|
||||
{
|
||||
sql: "extra provided argument",
|
||||
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
|
||||
isExpectedError: true,
|
||||
},
|
||||
{
|
||||
sql: "@missing argument",
|
||||
namedArgs: pgx.StrictNamedArgs{},
|
||||
isExpectedError: true,
|
||||
},
|
||||
} {
|
||||
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
|
||||
if tt.isExpectedError {
|
||||
assert.Errorf(t, err, "%d", i)
|
||||
} else {
|
||||
require.NoErrorf(t, err, "%d", i)
|
||||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,75 +0,0 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPgbouncerStatementCacheDescribe(t *testing.T) {
|
||||
connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING")
|
||||
if connString == "" {
|
||||
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING")
|
||||
}
|
||||
|
||||
config := mustParseConfig(t, connString)
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||
config.DescriptionCacheCapacity = 1024
|
||||
|
||||
testPgbouncer(t, config, 10, 100)
|
||||
}
|
||||
|
||||
func TestPgbouncerSimpleProtocol(t *testing.T) {
|
||||
connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING")
|
||||
if connString == "" {
|
||||
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING")
|
||||
}
|
||||
|
||||
config := mustParseConfig(t, connString)
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
||||
|
||||
testPgbouncer(t, config, 10, 100)
|
||||
}
|
||||
|
||||
func testPgbouncer(t *testing.T, config *pgx.ConnConfig, workers, iterations int) {
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
go func() {
|
||||
defer func() { doneChan <- struct{}{} }()
|
||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
||||
require.Nil(t, err)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
var i32 int32
|
||||
var i64 int64
|
||||
var f32 float32
|
||||
var s string
|
||||
var s2 string
|
||||
err = conn.QueryRow(context.Background(), "select 1::int4, 2::int8, 3::float4, 'hi'::text").Scan(&i32, &i64, &f32, &s)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int32(1), i32)
|
||||
assert.Equal(t, int64(2), i64)
|
||||
assert.Equal(t, float32(3), f32)
|
||||
assert.Equal(t, "hi", s)
|
||||
|
||||
err = conn.QueryRow(context.Background(), "select 1::int8, 2::float4, 'bye'::text, 4::int4, 'whatever'::text").Scan(&i64, &f32, &s, &i32, &s2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), i64)
|
||||
assert.Equal(t, float32(2), f32)
|
||||
assert.Equal(t, "bye", s)
|
||||
assert.Equal(t, int32(4), i32)
|
||||
assert.Equal(t, "whatever", s2)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
<-doneChan
|
||||
}
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
# pgconn
|
||||
|
||||
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
|
||||
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
|
||||
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
|
||||
low-level access to PostgreSQL functionality.
|
||||
|
||||
## Example Usage
|
||||
|
||||
```go
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
log.Fatalln("pgconn failed to connect:", err)
|
||||
}
|
||||
defer pgConn.Close(context.Background())
|
||||
|
||||
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
|
||||
for result.NextRow() {
|
||||
fmt.Println("User 123 has email:", string(result.Values()[0]))
|
||||
}
|
||||
_, err = result.Close()
|
||||
if err != nil {
|
||||
log.Fatalln("failed reading result:", err)
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
See CONTRIBUTING.md for setup instructions.
|
@ -1,73 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkCommandTagRowsAffected(b *testing.B) {
|
||||
benchmarks := []struct {
|
||||
commandTag string
|
||||
rowsAffected int64
|
||||
}{
|
||||
{"UPDATE 1", 1},
|
||||
{"UPDATE 123456789", 123456789},
|
||||
{"INSERT 0 1", 1},
|
||||
{"INSERT 0 123456789", 123456789},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
ct := CommandTag{s: bm.commandTag}
|
||||
b.Run(bm.commandTag, func(b *testing.B) {
|
||||
var n int64
|
||||
for i := 0; i < b.N; i++ {
|
||||
n = ct.RowsAffected()
|
||||
}
|
||||
if n != bm.rowsAffected {
|
||||
b.Errorf("expected %d got %d", bm.rowsAffected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCommandTagTypeFromString(b *testing.B) {
|
||||
ct := CommandTag{s: "UPDATE 1"}
|
||||
|
||||
var update bool
|
||||
for i := 0; i < b.N; i++ {
|
||||
update = strings.HasPrefix(ct.String(), "UPDATE")
|
||||
}
|
||||
if !update {
|
||||
b.Error("expected update")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCommandTagInsert(b *testing.B) {
|
||||
benchmarks := []struct {
|
||||
commandTag string
|
||||
is bool
|
||||
}{
|
||||
{"INSERT 1", true},
|
||||
{"INSERT 1234567890", true},
|
||||
{"UPDATE 1", false},
|
||||
{"UPDATE 1234567890", false},
|
||||
{"DELETE 1", false},
|
||||
{"DELETE 1234567890", false},
|
||||
{"SELECT 1", false},
|
||||
{"SELECT 1234567890", false},
|
||||
{"UNKNOWN 1234567890", false},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
ct := CommandTag{s: bm.commandTag}
|
||||
b.Run(bm.commandTag, func(b *testing.B) {
|
||||
var is bool
|
||||
for i := 0; i < b.N; i++ {
|
||||
is = ct.Insert()
|
||||
}
|
||||
if is != bm.is {
|
||||
b.Errorf("expected %v got %v", bm.is, is)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,250 +0,0 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func BenchmarkConnect(b *testing.B) {
|
||||
benchmarks := []struct {
|
||||
name string
|
||||
env string
|
||||
}{
|
||||
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
||||
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
bm := bm
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
connString := os.Getenv(bm.env)
|
||||
if connString == "" {
|
||||
b.Skipf("Skipping due to missing environment variable %v", bm.env)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pgconn.Connect(context.Background(), connString)
|
||||
require.Nil(b, err)
|
||||
|
||||
err = conn.Close(context.Background())
|
||||
require.Nil(b, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExec(b *testing.B) {
|
||||
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||
benchmarks := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
}{
|
||||
// Using an empty context other than context.Background() to compare
|
||||
// performance
|
||||
{"background context", context.Background()},
|
||||
{"empty context", context.TODO()},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
bm := bm
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
|
||||
|
||||
for mrr.NextResult() {
|
||||
rr := mrr.ResultReader()
|
||||
|
||||
rowCount := 0
|
||||
for rr.NextRow() {
|
||||
rowCount++
|
||||
if len(rr.Values()) != len(expectedValues) {
|
||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||
}
|
||||
for i := range rr.Values() {
|
||||
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = rr.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if rowCount != 1 {
|
||||
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||
}
|
||||
}
|
||||
|
||||
err := mrr.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecPossibleToCancel(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
|
||||
|
||||
for mrr.NextResult() {
|
||||
rr := mrr.ResultReader()
|
||||
|
||||
rowCount := 0
|
||||
for rr.NextRow() {
|
||||
rowCount++
|
||||
if len(rr.Values()) != len(expectedValues) {
|
||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||
}
|
||||
for i := range rr.Values() {
|
||||
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = rr.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if rowCount != 1 {
|
||||
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||
}
|
||||
}
|
||||
|
||||
err := mrr.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecPrepared(b *testing.B) {
|
||||
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||
|
||||
benchmarks := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
}{
|
||||
// Using an empty context other than context.Background() to compare
|
||||
// performance
|
||||
{"background context", context.Background()},
|
||||
{"empty context", context.TODO()},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
bm := bm
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
require.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil)
|
||||
|
||||
rowCount := 0
|
||||
for rr.NextRow() {
|
||||
rowCount++
|
||||
if len(rr.Values()) != len(expectedValues) {
|
||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||
}
|
||||
for i := range rr.Values() {
|
||||
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = rr.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if rowCount != 1 {
|
||||
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
require.Nil(b, err)
|
||||
|
||||
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil)
|
||||
|
||||
rowCount := 0
|
||||
for rr.NextRow() {
|
||||
rowCount += 1
|
||||
if len(rr.Values()) != len(expectedValues) {
|
||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||
}
|
||||
for i := range rr.Values() {
|
||||
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = rr.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if rowCount != 1 {
|
||||
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
|
||||
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
// require.Nil(b, err)
|
||||
// defer closeConn(b, conn)
|
||||
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
// defer cancel()
|
||||
|
||||
// b.ResetTimer()
|
||||
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// conn.ChanToSetDeadline().Watch(ctx)
|
||||
// conn.ChanToSetDeadline().Ignore()
|
||||
// }
|
||||
// }
|
953
pgconn/config.go
953
pgconn/config.go
@ -1,953 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgpassfile"
|
||||
"github.com/jackc/pgservicefile"
|
||||
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
type (
|
||||
AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
GetSSLPasswordFunc func(ctx context.Context) string
|
||||
)
|
||||
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
|
||||
// manually initialized Config will cause ConnectConfig to panic.
|
||||
type Config struct {
|
||||
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
ConnectTimeout time.Duration
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||
BuildFrontend BuildFrontendFunc
|
||||
|
||||
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
|
||||
// when a context passed to a PgConn method is canceled.
|
||||
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
|
||||
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
KerberosSrvName string
|
||||
KerberosSpn string
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
|
||||
|
||||
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
ValidateConnect ValidateConnectFunc
|
||||
|
||||
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
|
||||
// or prepare statements). If this returns an error the connection attempt fails.
|
||||
AfterConnect AfterConnectFunc
|
||||
|
||||
// OnNotice is a callback function called when a notice response is received.
|
||||
OnNotice NoticeHandler
|
||||
|
||||
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||
OnNotification NotificationHandler
|
||||
|
||||
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
|
||||
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
|
||||
// that you close on FATAL errors by returning false.
|
||||
OnPgError PgErrorHandler
|
||||
|
||||
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||
}
|
||||
|
||||
// ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
|
||||
type ParseConfigOptions struct {
|
||||
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function
|
||||
// PQsetSSLKeyPassHook_OpenSSL.
|
||||
GetSSLPassword GetSSLPasswordFunc
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the config that is safe to use and modify.
|
||||
// The only exception is the TLSConfig field:
|
||||
// according to the tls.Config docs it must not be modified after creation.
|
||||
func (c *Config) Copy() *Config {
|
||||
newConf := new(Config)
|
||||
*newConf = *c
|
||||
if newConf.TLSConfig != nil {
|
||||
newConf.TLSConfig = c.TLSConfig.Clone()
|
||||
}
|
||||
if newConf.RuntimeParams != nil {
|
||||
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
|
||||
for k, v := range c.RuntimeParams {
|
||||
newConf.RuntimeParams[k] = v
|
||||
}
|
||||
}
|
||||
if newConf.Fallbacks != nil {
|
||||
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
|
||||
for i, fallback := range c.Fallbacks {
|
||||
newFallback := new(FallbackConfig)
|
||||
*newFallback = *fallback
|
||||
if newFallback.TLSConfig != nil {
|
||||
newFallback.TLSConfig = fallback.TLSConfig.Clone()
|
||||
}
|
||||
newConf.Fallbacks[i] = newFallback
|
||||
}
|
||||
}
|
||||
return newConf
|
||||
}
|
||||
|
||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||
type FallbackConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
}
|
||||
|
||||
// connectOneConfig is the configuration for a single attempt to connect to a single host.
|
||||
type connectOneConfig struct {
|
||||
network string
|
||||
address string
|
||||
originalHostname string // original hostname before resolving
|
||||
tlsConfig *tls.Config // nil disables TLS
|
||||
}
|
||||
|
||||
// isAbsolutePath checks if the provided value is an absolute path either
|
||||
// beginning with a forward slash (as on Linux-based systems) or with a capital
|
||||
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
|
||||
func isAbsolutePath(path string) bool {
|
||||
isWindowsPath := func(p string) bool {
|
||||
if len(p) < 3 {
|
||||
return false
|
||||
}
|
||||
drive := p[0]
|
||||
colon := p[1]
|
||||
backslash := p[2]
|
||||
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(path, "/") || isWindowsPath(path)
|
||||
}
|
||||
|
||||
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||
// net.Dial.
|
||||
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
if isAbsolutePath(host) {
|
||||
network = "unix"
|
||||
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||
} else {
|
||||
network = "tcp"
|
||||
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
}
|
||||
return network, address
|
||||
}
|
||||
|
||||
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
|
||||
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
|
||||
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See
|
||||
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty
|
||||
// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||
//
|
||||
// # Example Keyword/Value
|
||||
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||
//
|
||||
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
|
||||
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
|
||||
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
|
||||
// not be modified individually. They should all be modified or all left unchanged.
|
||||
//
|
||||
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
||||
// values that will be tried in order. This can be used as part of a high availability system. See
|
||||
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||
//
|
||||
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
|
||||
// via database URL or keyword/value:
|
||||
//
|
||||
// PGHOST
|
||||
// PGPORT
|
||||
// PGDATABASE
|
||||
// PGUSER
|
||||
// PGPASSWORD
|
||||
// PGPASSFILE
|
||||
// PGSERVICE
|
||||
// PGSERVICEFILE
|
||||
// PGSSLMODE
|
||||
// PGSSLCERT
|
||||
// PGSSLKEY
|
||||
// PGSSLROOTCERT
|
||||
// PGSSLPASSWORD
|
||||
// PGOPTIONS
|
||||
// PGAPPNAME
|
||||
// PGCONNECT_TIMEOUT
|
||||
// PGTARGETSESSIONATTRS
|
||||
// PGTZ
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables.
|
||||
//
|
||||
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
||||
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
||||
//
|
||||
// Important Security Notes:
|
||||
//
|
||||
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
||||
// not set.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
||||
// security each sslmode provides.
|
||||
//
|
||||
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
|
||||
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
|
||||
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
|
||||
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
|
||||
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
|
||||
// TLSConfig.
|
||||
//
|
||||
// Other known differences with libpq:
|
||||
//
|
||||
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
|
||||
// does not.
|
||||
//
|
||||
// In addition, ParseConfig accepts the following options:
|
||||
//
|
||||
// - servicefile.
|
||||
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
||||
// part of the connection string.
|
||||
func ParseConfig(connString string) (*Config, error) {
|
||||
var parseConfigOptions ParseConfigOptions
|
||||
return ParseConfigWithOptions(connString, parseConfigOptions)
|
||||
}
|
||||
|
||||
// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
|
||||
// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
|
||||
// get the SSL password.
|
||||
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
|
||||
defaultSettings := defaultSettings()
|
||||
envSettings := parseEnvSettings()
|
||||
|
||||
connStringSettings := make(map[string]string)
|
||||
if connString != "" {
|
||||
var err error
|
||||
// connString may be a database URL or in PostgreSQL keyword/value format
|
||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||
connStringSettings, err = parseURLSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
|
||||
}
|
||||
} else {
|
||||
connStringSettings, err = parseKeywordValueSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
|
||||
if service, present := settings["service"]; present {
|
||||
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
|
||||
if err != nil {
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
|
||||
}
|
||||
|
||||
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
createdByParseConfig: true,
|
||||
Database: settings["database"],
|
||||
User: settings["user"],
|
||||
Password: settings["password"],
|
||||
RuntimeParams: make(map[string]string),
|
||||
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
|
||||
return pgproto3.NewFrontend(r, w)
|
||||
},
|
||||
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
|
||||
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
|
||||
},
|
||||
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
|
||||
// we want to automatically close any fatal errors
|
||||
if strings.EqualFold(pgErr.Severity, "FATAL") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
|
||||
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
|
||||
if err != nil {
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
|
||||
}
|
||||
config.ConnectTimeout = connectTimeout
|
||||
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
|
||||
} else {
|
||||
defaultDialer := makeDefaultDialer()
|
||||
config.DialFunc = defaultDialer.DialContext
|
||||
}
|
||||
|
||||
config.LookupFunc = makeDefaultResolver().LookupHost
|
||||
|
||||
notRuntimeParams := map[string]struct{}{
|
||||
"host": {},
|
||||
"port": {},
|
||||
"database": {},
|
||||
"user": {},
|
||||
"password": {},
|
||||
"passfile": {},
|
||||
"connect_timeout": {},
|
||||
"sslmode": {},
|
||||
"sslkey": {},
|
||||
"sslcert": {},
|
||||
"sslrootcert": {},
|
||||
"sslnegotiation": {},
|
||||
"sslpassword": {},
|
||||
"sslsni": {},
|
||||
"krbspn": {},
|
||||
"krbsrvname": {},
|
||||
"target_session_attrs": {},
|
||||
"service": {},
|
||||
"servicefile": {},
|
||||
}
|
||||
|
||||
// Adding kerberos configuration
|
||||
if _, present := settings["krbsrvname"]; present {
|
||||
config.KerberosSrvName = settings["krbsrvname"]
|
||||
}
|
||||
if _, present := settings["krbspn"]; present {
|
||||
config.KerberosSpn = settings["krbspn"]
|
||||
}
|
||||
|
||||
for k, v := range settings {
|
||||
if _, present := notRuntimeParams[k]; present {
|
||||
continue
|
||||
}
|
||||
config.RuntimeParams[k] = v
|
||||
}
|
||||
|
||||
fallbacks := []*FallbackConfig{}
|
||||
|
||||
hosts := strings.Split(settings["host"], ",")
|
||||
ports := strings.Split(settings["port"], ",")
|
||||
|
||||
for i, host := range hosts {
|
||||
var portStr string
|
||||
if i < len(ports) {
|
||||
portStr = ports[i]
|
||||
} else {
|
||||
portStr = ports[0]
|
||||
}
|
||||
|
||||
port, err := parsePort(portStr)
|
||||
if err != nil {
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
|
||||
}
|
||||
|
||||
var tlsConfigs []*tls.Config
|
||||
|
||||
// Ignore TLS settings if Unix domain socket like libpq
|
||||
if network, _ := NetworkAddress(host, port); network == "unix" {
|
||||
tlsConfigs = append(tlsConfigs, nil)
|
||||
} else {
|
||||
var err error
|
||||
tlsConfigs, err = configTLS(settings, host, options)
|
||||
if err != nil {
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tlsConfig := range tlsConfigs {
|
||||
fallbacks = append(fallbacks, &FallbackConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
config.Host = fallbacks[0].Host
|
||||
config.Port = fallbacks[0].Port
|
||||
config.TLSConfig = fallbacks[0].TLSConfig
|
||||
config.Fallbacks = fallbacks[1:]
|
||||
config.SSLNegotiation = settings["sslnegotiation"]
|
||||
|
||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||
if err == nil {
|
||||
if config.Password == "" {
|
||||
host := config.Host
|
||||
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
||||
}
|
||||
}
|
||||
|
||||
switch tsa := settings["target_session_attrs"]; tsa {
|
||||
case "read-write":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||
case "read-only":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
|
||||
case "primary":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
|
||||
case "standby":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
|
||||
case "prefer-standby":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
|
||||
case "any":
|
||||
// do nothing
|
||||
default:
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func mergeSettings(settingSets ...map[string]string) map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
for _, s2 := range settingSets {
|
||||
for k, v := range s2 {
|
||||
settings[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func parseEnvSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
nameMap := map[string]string{
|
||||
"PGHOST": "host",
|
||||
"PGPORT": "port",
|
||||
"PGDATABASE": "database",
|
||||
"PGUSER": "user",
|
||||
"PGPASSWORD": "password",
|
||||
"PGPASSFILE": "passfile",
|
||||
"PGAPPNAME": "application_name",
|
||||
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||
"PGSSLMODE": "sslmode",
|
||||
"PGSSLKEY": "sslkey",
|
||||
"PGSSLCERT": "sslcert",
|
||||
"PGSSLSNI": "sslsni",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
"PGSSLPASSWORD": "sslpassword",
|
||||
"PGSSLNEGOTIATION": "sslnegotiation",
|
||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||
"PGSERVICE": "service",
|
||||
"PGSERVICEFILE": "servicefile",
|
||||
"PGTZ": "timezone",
|
||||
"PGOPTIONS": "options",
|
||||
}
|
||||
|
||||
for envname, realname := range nameMap {
|
||||
value := os.Getenv(envname)
|
||||
if value != "" {
|
||||
settings[realname] = value
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func parseURLSettings(connString string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
parsedURL, err := url.Parse(connString)
|
||||
if err != nil {
|
||||
if urlErr := new(url.Error); errors.As(err, &urlErr) {
|
||||
return nil, urlErr.Err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parsedURL.User != nil {
|
||||
settings["user"] = parsedURL.User.Username()
|
||||
if password, present := parsedURL.User.Password(); present {
|
||||
settings["password"] = password
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||
var hosts []string
|
||||
var ports []string
|
||||
for _, host := range strings.Split(parsedURL.Host, ",") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
if isIPOnly(host) {
|
||||
hosts = append(hosts, strings.Trim(host, "[]"))
|
||||
continue
|
||||
}
|
||||
h, p, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
|
||||
}
|
||||
if h != "" {
|
||||
hosts = append(hosts, h)
|
||||
}
|
||||
if p != "" {
|
||||
ports = append(ports, p)
|
||||
}
|
||||
}
|
||||
if len(hosts) > 0 {
|
||||
settings["host"] = strings.Join(hosts, ",")
|
||||
}
|
||||
if len(ports) > 0 {
|
||||
settings["port"] = strings.Join(ports, ",")
|
||||
}
|
||||
|
||||
database := strings.TrimLeft(parsedURL.Path, "/")
|
||||
if database != "" {
|
||||
settings["database"] = database
|
||||
}
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
for k, v := range parsedURL.Query() {
|
||||
if k2, present := nameMap[k]; present {
|
||||
k = k2
|
||||
}
|
||||
|
||||
settings[k] = v[0]
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func isIPOnly(host string) bool {
|
||||
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
|
||||
}
|
||||
|
||||
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||
|
||||
func parseKeywordValueSettings(s string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
for len(s) > 0 {
|
||||
var key, val string
|
||||
eqIdx := strings.IndexRune(s, '=')
|
||||
if eqIdx < 0 {
|
||||
return nil, errors.New("invalid keyword/value")
|
||||
}
|
||||
|
||||
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
||||
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
|
||||
if len(s) == 0 {
|
||||
} else if s[0] != '\'' {
|
||||
end := 0
|
||||
for ; end < len(s); end++ {
|
||||
if asciiSpace[s[end]] == 1 {
|
||||
break
|
||||
}
|
||||
if s[end] == '\\' {
|
||||
end++
|
||||
if end == len(s) {
|
||||
return nil, errors.New("invalid backslash")
|
||||
}
|
||||
}
|
||||
}
|
||||
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||
if end == len(s) {
|
||||
s = ""
|
||||
} else {
|
||||
s = s[end+1:]
|
||||
}
|
||||
} else { // quoted string
|
||||
s = s[1:]
|
||||
end := 0
|
||||
for ; end < len(s); end++ {
|
||||
if s[end] == '\'' {
|
||||
break
|
||||
}
|
||||
if s[end] == '\\' {
|
||||
end++
|
||||
}
|
||||
}
|
||||
if end == len(s) {
|
||||
return nil, errors.New("unterminated quoted string in connection info string")
|
||||
}
|
||||
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||
if end == len(s) {
|
||||
s = ""
|
||||
} else {
|
||||
s = s[end+1:]
|
||||
}
|
||||
}
|
||||
|
||||
if k, ok := nameMap[key]; ok {
|
||||
key = k
|
||||
}
|
||||
|
||||
if key == "" {
|
||||
return nil, errors.New("invalid keyword/value")
|
||||
}
|
||||
|
||||
settings[key] = val
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
|
||||
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
|
||||
}
|
||||
|
||||
service, err := servicefile.GetService(serviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to find service: %v", serviceName)
|
||||
}
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
settings := make(map[string]string, len(service.Settings))
|
||||
for k, v := range service.Settings {
|
||||
if k2, present := nameMap[k]; present {
|
||||
k = k2
|
||||
}
|
||||
settings[k] = v
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||
// "prefer" allow fallback.
|
||||
func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
|
||||
host := thisHost
|
||||
sslmode := settings["sslmode"]
|
||||
sslrootcert := settings["sslrootcert"]
|
||||
sslcert := settings["sslcert"]
|
||||
sslkey := settings["sslkey"]
|
||||
sslpassword := settings["sslpassword"]
|
||||
sslsni := settings["sslsni"]
|
||||
sslnegotiation := settings["sslnegotiation"]
|
||||
|
||||
// Match libpq default behavior
|
||||
if sslmode == "" {
|
||||
sslmode = "prefer"
|
||||
}
|
||||
if sslsni == "" {
|
||||
sslsni = "1"
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
if sslnegotiation == "direct" {
|
||||
tlsConfig.NextProtos = []string{"postgresql"}
|
||||
if sslmode == "prefer" {
|
||||
sslmode = "require"
|
||||
}
|
||||
}
|
||||
|
||||
if sslrootcert != "" {
|
||||
var caCertPool *x509.CertPool
|
||||
|
||||
if sslrootcert == "system" {
|
||||
var err error
|
||||
|
||||
caCertPool, err = x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load system certificate pool: %w", err)
|
||||
}
|
||||
|
||||
sslmode = "verify-full"
|
||||
} else {
|
||||
caCertPool = x509.NewCertPool()
|
||||
|
||||
caPath := sslrootcert
|
||||
caCert, err := os.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read CA file: %w", err)
|
||||
}
|
||||
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, errors.New("unable to add CA to cert pool")
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
|
||||
switch sslmode {
|
||||
case "disable":
|
||||
return []*tls.Config{nil}, nil
|
||||
case "allow", "prefer":
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
case "require":
|
||||
// According to PostgreSQL documentation, if a root CA file exists,
|
||||
// the behavior of sslmode=require should be the same as that of verify-ca
|
||||
//
|
||||
// See https://www.postgresql.org/docs/current/libpq-ssl.html
|
||||
if sslrootcert != "" {
|
||||
goto nextCase
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
break
|
||||
nextCase:
|
||||
fallthrough
|
||||
case "verify-ca":
|
||||
// Don't perform the default certificate verification because it
|
||||
// will verify the hostname. Instead, verify the server's
|
||||
// certificate chain ourselves in VerifyPeerCertificate and
|
||||
// ignore the server name. This emulates libpq's verify-ca
|
||||
// behavior.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
|
||||
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
|
||||
// for more info.
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
|
||||
certs := make([]*x509.Certificate, len(certificates))
|
||||
for i, asn1Data := range certificates {
|
||||
cert, err := x509.ParseCertificate(asn1Data)
|
||||
if err != nil {
|
||||
return errors.New("failed to parse certificate from server: " + err.Error())
|
||||
}
|
||||
certs[i] = cert
|
||||
}
|
||||
|
||||
// Leave DNSName empty to skip hostname verification.
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: tlsConfig.RootCAs,
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
// Skip the first cert because it's the leaf. All others
|
||||
// are intermediates.
|
||||
for _, cert := range certs[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
_, err := certs[0].Verify(opts)
|
||||
return err
|
||||
}
|
||||
case "verify-full":
|
||||
tlsConfig.ServerName = host
|
||||
default:
|
||||
return nil, errors.New("sslmode is invalid")
|
||||
}
|
||||
|
||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||
}
|
||||
|
||||
if sslcert != "" && sslkey != "" {
|
||||
buf, err := os.ReadFile(sslkey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
||||
}
|
||||
block, _ := pem.Decode(buf)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode sslkey")
|
||||
}
|
||||
var pemKey []byte
|
||||
var decryptedKey []byte
|
||||
var decryptedError error
|
||||
// If PEM is encrypted, attempt to decrypt using pass phrase
|
||||
if x509.IsEncryptedPEMBlock(block) {
|
||||
// Attempt decryption with pass phrase
|
||||
// NOTE: only supports RSA (PKCS#1)
|
||||
if sslpassword != "" {
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
}
|
||||
// if sslpassword not provided or has decryption error when use it
|
||||
// try to find sslpassword with callback function
|
||||
if sslpassword == "" || decryptedError != nil {
|
||||
if parseConfigOptions.GetSSLPassword != nil {
|
||||
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
||||
}
|
||||
if sslpassword == "" {
|
||||
return nil, fmt.Errorf("unable to find sslpassword")
|
||||
}
|
||||
}
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
// Should we also provide warning for PKCS#1 needed?
|
||||
if decryptedError != nil {
|
||||
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
||||
}
|
||||
|
||||
pemBytes := pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: decryptedKey,
|
||||
}
|
||||
pemKey = pem.EncodeToMemory(&pemBytes)
|
||||
} else {
|
||||
pemKey = pem.EncodeToMemory(block)
|
||||
}
|
||||
certfile, err := os.ReadFile(sslcert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read cert: %w", err)
|
||||
}
|
||||
cert, err := tls.X509KeyPair(certfile, pemKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load cert: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// Set Server Name Indication (SNI), if enabled by connection parameters.
|
||||
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
|
||||
// or IPv6).
|
||||
if sslsni == "1" && net.ParseIP(host) == nil {
|
||||
tlsConfig.ServerName = host
|
||||
}
|
||||
|
||||
switch sslmode {
|
||||
case "allow":
|
||||
return []*tls.Config{nil, tlsConfig}, nil
|
||||
case "prefer":
|
||||
return []*tls.Config{tlsConfig, nil}, nil
|
||||
case "require", "verify-ca", "verify-full":
|
||||
return []*tls.Config{tlsConfig}, nil
|
||||
default:
|
||||
panic("BUG: bad sslmode should already have been caught")
|
||||
}
|
||||
}
|
||||
|
||||
func parsePort(s string) (uint16, error) {
|
||||
port, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if port < 1 || port > math.MaxUint16 {
|
||||
return 0, errors.New("outside range")
|
||||
}
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
func makeDefaultDialer() *net.Dialer {
|
||||
// rely on GOLANG KeepAlive settings
|
||||
return &net.Dialer{}
|
||||
}
|
||||
|
||||
func makeDefaultResolver() *net.Resolver {
|
||||
return net.DefaultResolver
|
||||
}
|
||||
|
||||
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
|
||||
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if timeout < 0 {
|
||||
return 0, errors.New("negative timeout")
|
||||
}
|
||||
return time.Duration(timeout) * time.Second, nil
|
||||
}
|
||||
|
||||
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
|
||||
d := makeDefaultDialer()
|
||||
d.Timeout = timeout
|
||||
return d.DialContext
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-write.
|
||||
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||
result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result[0].Rows[0][0]) == "on" {
|
||||
return errors.New("read only connection")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-only.
|
||||
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
|
||||
result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result[0].Rows[0][0]) != "on" {
|
||||
return errors.New("connection is not read only")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=standby.
|
||||
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
|
||||
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result[0].Rows[0][0]) != "t" {
|
||||
return errors.New("server is not in hot standby mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=primary.
|
||||
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
|
||||
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result[0].Rows[0][0]) == "t" {
|
||||
return errors.New("server is in standby mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=prefer-standby.
|
||||
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
|
||||
result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(result[0].Rows[0][0]) != "t" {
|
||||
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,80 +0,0 @@
|
||||
package ctxwatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||
// time.
|
||||
type ContextWatcher struct {
|
||||
handler Handler
|
||||
unwatchChan chan struct{}
|
||||
|
||||
lock sync.Mutex
|
||||
watchInProgress bool
|
||||
onCancelWasCalled bool
|
||||
}
|
||||
|
||||
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
||||
// onCancel called.
|
||||
func NewContextWatcher(handler Handler) *ContextWatcher {
|
||||
cw := &ContextWatcher{
|
||||
handler: handler,
|
||||
unwatchChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
return cw
|
||||
}
|
||||
|
||||
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
|
||||
func (cw *ContextWatcher) Watch(ctx context.Context) {
|
||||
cw.lock.Lock()
|
||||
defer cw.lock.Unlock()
|
||||
|
||||
if cw.watchInProgress {
|
||||
panic("Watch already in progress")
|
||||
}
|
||||
|
||||
cw.onCancelWasCalled = false
|
||||
|
||||
if ctx.Done() != nil {
|
||||
cw.watchInProgress = true
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cw.handler.HandleCancel(ctx)
|
||||
cw.onCancelWasCalled = true
|
||||
<-cw.unwatchChan
|
||||
case <-cw.unwatchChan:
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
cw.watchInProgress = false
|
||||
}
|
||||
}
|
||||
|
||||
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
|
||||
// called then onUnwatchAfterCancel will also be called.
|
||||
func (cw *ContextWatcher) Unwatch() {
|
||||
cw.lock.Lock()
|
||||
defer cw.lock.Unlock()
|
||||
|
||||
if cw.watchInProgress {
|
||||
cw.unwatchChan <- struct{}{}
|
||||
if cw.onCancelWasCalled {
|
||||
cw.handler.HandleUnwatchAfterCancel()
|
||||
}
|
||||
cw.watchInProgress = false
|
||||
}
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
|
||||
// context that was canceled.
|
||||
HandleCancel(canceledCtx context.Context)
|
||||
|
||||
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
|
||||
HandleUnwatchAfterCancel()
|
||||
}
|
@ -1,185 +0,0 @@
|
||||
package ctxwatch_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testHandler struct {
|
||||
handleCancel func(context.Context)
|
||||
handleUnwatchAfterCancel func()
|
||||
}
|
||||
|
||||
func (h *testHandler) HandleCancel(ctx context.Context) {
|
||||
h.handleCancel(ctx)
|
||||
}
|
||||
|
||||
func (h *testHandler) HandleUnwatchAfterCancel() {
|
||||
h.handleUnwatchAfterCancel()
|
||||
}
|
||||
|
||||
func TestContextWatcherContextCancelled(t *testing.T) {
|
||||
canceledChan := make(chan struct{})
|
||||
cleanupCalled := false
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
canceledChan <- struct{}{}
|
||||
}, handleUnwatchAfterCancel: func() {
|
||||
cleanupCalled = true
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cw.Watch(ctx)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-canceledChan:
|
||||
case <-time.NewTimer(time.Second).C:
|
||||
t.Fatal("Timed out waiting for cancel func to be called")
|
||||
}
|
||||
|
||||
cw.Unwatch()
|
||||
|
||||
require.True(t, cleanupCalled, "Cleanup func was not called")
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
t.Error("cancel func should not have been called")
|
||||
}, handleUnwatchAfterCancel: func() {
|
||||
t.Error("cleanup func should not have been called")
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cw.Watch(ctx)
|
||||
cw.Unwatch()
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
cw.Watch(ctx)
|
||||
defer cw.Unwatch()
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
cw.Unwatch() // unwatch when not / never watching
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
cw.Watch(ctx)
|
||||
cw.Unwatch()
|
||||
cw.Unwatch() // double unwatch
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
cw.Watch(ctx)
|
||||
|
||||
go cw.Unwatch()
|
||||
go cw.Unwatch()
|
||||
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func TestContextWatcherStress(t *testing.T) {
|
||||
var cancelFuncCalls int64
|
||||
var cleanupFuncCalls int64
|
||||
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
atomic.AddInt64(&cancelFuncCalls, 1)
|
||||
}, handleUnwatchAfterCancel: func() {
|
||||
atomic.AddInt64(&cleanupFuncCalls, 1)
|
||||
},
|
||||
})
|
||||
|
||||
cycleCount := 100000
|
||||
|
||||
for i := 0; i < cycleCount; i++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cw.Watch(ctx)
|
||||
if i%2 == 0 {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
|
||||
if i%333 == 0 {
|
||||
// on Windows Sleep takes more time than expected so we try to get here less frequently to avoid
|
||||
// the CI takes a long time
|
||||
time.Sleep(time.Nanosecond)
|
||||
}
|
||||
|
||||
cw.Unwatch()
|
||||
if i%2 == 1 {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
|
||||
actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)
|
||||
|
||||
if actualCancelFuncCalls == 0 {
|
||||
t.Fatal("actualCancelFuncCalls == 0")
|
||||
}
|
||||
|
||||
maxCancelFuncCalls := int64(cycleCount) / 2
|
||||
if actualCancelFuncCalls > maxCancelFuncCalls {
|
||||
t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
|
||||
}
|
||||
|
||||
if actualCancelFuncCalls != actualCleanupFuncCalls {
|
||||
t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
cw.Watch(context.Background())
|
||||
cw.Unwatch()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContextWatcherCancelled(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cw.Watch(ctx)
|
||||
cancel()
|
||||
cw.Unwatch()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContextWatcherCancellable(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
cw.Watch(ctx)
|
||||
cw.Unwatch()
|
||||
}
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func defaultSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
settings["host"] = defaultHost()
|
||||
settings["port"] = "5432"
|
||||
|
||||
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||
// OS. The client application will simply have to specify the user in that
|
||||
// case (which they typically will be doing anyway).
|
||||
user, err := user.Current()
|
||||
if err == nil {
|
||||
settings["user"] = user.Username
|
||||
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
|
||||
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
|
||||
if _, err := os.Stat(sslcert); err == nil {
|
||||
if _, err := os.Stat(sslkey); err == nil {
|
||||
// Both the cert and key must be present to use them, or do not use either
|
||||
settings["sslcert"] = sslcert
|
||||
settings["sslkey"] = sslkey
|
||||
}
|
||||
}
|
||||
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
|
||||
if _, err := os.Stat(sslrootcert); err == nil {
|
||||
settings["sslrootcert"] = sslrootcert
|
||||
}
|
||||
}
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||
// checks the existence of common locations.
|
||||
func defaultHost() string {
|
||||
candidatePaths := []string{
|
||||
"/var/run/postgresql", // Debian
|
||||
"/private/tmp", // OSX - homebrew
|
||||
"/tmp", // standard PostgreSQL
|
||||
}
|
||||
|
||||
for _, path := range candidatePaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
return "localhost"
|
||||
}
|
@ -1,57 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func defaultSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
settings["host"] = defaultHost()
|
||||
settings["port"] = "5432"
|
||||
|
||||
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||
// OS. The client application will simply have to specify the user in that
|
||||
// case (which they typically will be doing anyway).
|
||||
user, err := user.Current()
|
||||
appData := os.Getenv("APPDATA")
|
||||
if err == nil {
|
||||
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
|
||||
// but the libpq default is just the `user` portion, so we strip off the first part.
|
||||
username := user.Username
|
||||
if strings.Contains(username, "\\") {
|
||||
username = username[strings.LastIndex(username, "\\")+1:]
|
||||
}
|
||||
|
||||
settings["user"] = username
|
||||
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
|
||||
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
|
||||
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
|
||||
if _, err := os.Stat(sslcert); err == nil {
|
||||
if _, err := os.Stat(sslkey); err == nil {
|
||||
// Both the cert and key must be present to use them, or do not use either
|
||||
settings["sslcert"] = sslcert
|
||||
settings["sslkey"] = sslkey
|
||||
}
|
||||
}
|
||||
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
|
||||
if _, err := os.Stat(sslrootcert); err == nil {
|
||||
settings["sslrootcert"] = sslrootcert
|
||||
}
|
||||
}
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||
// checks the existence of common locations.
|
||||
func defaultHost() string {
|
||||
return "localhost"
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
// Package pgconn is a low-level PostgreSQL database driver.
|
||||
/*
|
||||
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
|
||||
nearly the same level is the C library libpq.
|
||||
|
||||
Establishing a Connection
|
||||
|
||||
Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the
|
||||
environment for libpq style environment variables.
|
||||
|
||||
Executing a Query
|
||||
|
||||
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
|
||||
reads all rows into memory.
|
||||
|
||||
Executing Multiple Queries in a Single Round Trip
|
||||
|
||||
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
|
||||
result. The ReadAll method reads all query results into memory.
|
||||
|
||||
Pipeline Mode
|
||||
|
||||
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of
|
||||
exactly how many and when network round trips occur.
|
||||
|
||||
Context Support
|
||||
|
||||
All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the
|
||||
method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can
|
||||
be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior.
|
||||
This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is
|
||||
a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before
|
||||
interrupting the query in such a way as to close the connection.
|
||||
|
||||
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
|
||||
client to abort.
|
||||
*/
|
||||
package pgconn
|
256
pgconn/errors.go
256
pgconn/errors.go
@ -1,256 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
|
||||
func SafeToRetry(err error) bool {
|
||||
var retryableErr interface{ SafeToRetry() bool }
|
||||
if errors.As(err, &retryableErr) {
|
||||
return retryableErr.SafeToRetry()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
|
||||
// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
|
||||
func Timeout(err error) bool {
|
||||
var timeoutErr *errTimeout
|
||||
return errors.As(err, &timeoutErr)
|
||||
}
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/current/static/protocol-error-fields.html for
|
||||
// detailed field description.
|
||||
type PgError struct {
|
||||
Severity string
|
||||
SeverityUnlocalized string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
}
|
||||
|
||||
func (pe *PgError) Error() string {
|
||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
// SQLState returns the SQLState of the error.
|
||||
func (pe *PgError) SQLState() string {
|
||||
return pe.Code
|
||||
}
|
||||
|
||||
// ConnectError is the error returned when a connection attempt fails.
|
||||
type ConnectError struct {
|
||||
Config *Config // The configuration that was used in the connection attempt.
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *ConnectError) Error() string {
|
||||
prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database)
|
||||
details := e.err.Error()
|
||||
if strings.Contains(details, "\n") {
|
||||
return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t")
|
||||
} else {
|
||||
return prefix + " " + details
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConnectError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type perDialConnectError struct {
|
||||
address string
|
||||
originalHostname string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *perDialConnectError) Error() string {
|
||||
return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *perDialConnectError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type connLockError struct {
|
||||
status string
|
||||
}
|
||||
|
||||
func (e *connLockError) SafeToRetry() bool {
|
||||
return true // a lock failure by definition happens before the connection is used.
|
||||
}
|
||||
|
||||
func (e *connLockError) Error() string {
|
||||
return e.status
|
||||
}
|
||||
|
||||
// ParseConfigError is the error returned when a connection string cannot be parsed.
|
||||
type ParseConfigError struct {
|
||||
ConnString string // The connection string that could not be parsed.
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func NewParseConfigError(conn, msg string, err error) error {
|
||||
return &ParseConfigError{
|
||||
ConnString: conn,
|
||||
msg: msg,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ParseConfigError) Error() string {
|
||||
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
|
||||
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
|
||||
// allow access to the original string if desired and Unwrap would allow access to the underlying error.
|
||||
connString := redactPW(e.ConnString)
|
||||
if e.err == nil {
|
||||
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
|
||||
}
|
||||
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *ParseConfigError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func normalizeTimeoutError(ctx context.Context, err error) error {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
if ctx.Err() == context.Canceled {
|
||||
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
|
||||
return context.Canceled
|
||||
} else if ctx.Err() == context.DeadlineExceeded {
|
||||
return &errTimeout{err: ctx.Err()}
|
||||
} else {
|
||||
return &errTimeout{err: netErr}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type pgconnError struct {
|
||||
msg string
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *pgconnError) Error() string {
|
||||
if e.msg == "" {
|
||||
return e.err.Error()
|
||||
}
|
||||
if e.err == nil {
|
||||
return e.msg
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *pgconnError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *pgconnError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
|
||||
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
|
||||
type errTimeout struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *errTimeout) Error() string {
|
||||
return fmt.Sprintf("timeout: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *errTimeout) SafeToRetry() bool {
|
||||
return SafeToRetry(e.err)
|
||||
}
|
||||
|
||||
func (e *errTimeout) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type contextAlreadyDoneError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Error() string {
|
||||
return fmt.Sprintf("context already done: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) SafeToRetry() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
|
||||
func newContextAlreadyDoneError(ctx context.Context) (err error) {
|
||||
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
|
||||
}
|
||||
|
||||
func redactPW(connString string) string {
|
||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||
if u, err := url.Parse(connString); err == nil {
|
||||
return redactURL(u)
|
||||
}
|
||||
}
|
||||
quotedKV := regexp.MustCompile(`password='[^']*'`)
|
||||
connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
plainKV := regexp.MustCompile(`password=[^ ]*`)
|
||||
connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
|
||||
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
|
||||
return connString
|
||||
}
|
||||
|
||||
func redactURL(u *url.URL) string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
if _, pwSet := u.User.Password(); pwSet {
|
||||
u.User = url.UserPassword(u.User.Username(), "xxxxx")
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
type NotPreferredError struct {
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) Error() string {
|
||||
return fmt.Sprintf("standby server not found: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
@ -1,54 +0,0 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfigError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedMsg string
|
||||
}{
|
||||
{
|
||||
name: "url with password",
|
||||
err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil),
|
||||
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg",
|
||||
},
|
||||
{
|
||||
name: "keyword/value with password unquoted",
|
||||
err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil),
|
||||
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||
},
|
||||
{
|
||||
name: "keyword/value with password quoted",
|
||||
err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil),
|
||||
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||
},
|
||||
{
|
||||
name: "weird url",
|
||||
err: pgconn.NewParseConfigError("postgresql://foo::password@host:1:", "msg", nil),
|
||||
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg",
|
||||
},
|
||||
{
|
||||
name: "weird url with slash in password",
|
||||
err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil),
|
||||
expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg",
|
||||
},
|
||||
{
|
||||
name: "url without password",
|
||||
err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil),
|
||||
expectedMsg: "cannot parse `postgresql://other@host/db`: msg",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.EqualError(t, tt.err, tt.expectedMsg)
|
||||
})
|
||||
}
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
// File export_test exports some methods for better testing.
|
||||
|
||||
package pgconn
|
@ -1,36 +0,0 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func closeConn(t testing.TB, conn *pgconn.PgConn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, conn.Close(ctx))
|
||||
select {
|
||||
case <-conn.CleanupDone():
|
||||
case <-time.After(30 * time.Second):
|
||||
t.Fatal("Connection cleanup exceeded maximum time")
|
||||
}
|
||||
}
|
||||
|
||||
// Do a simple query to ensure the connection is still usable
|
||||
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
|
||||
cancel()
|
||||
|
||||
require.Nil(t, result.Err)
|
||||
assert.Equal(t, 3, len(result.Rows))
|
||||
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||
assert.Equal(t, "2", string(result.Rows[1][0]))
|
||||
assert.Equal(t, "3", string(result.Rows[2][0]))
|
||||
}
|
@ -1,139 +0,0 @@
|
||||
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
|
||||
package bgreader
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusStopped = iota
|
||||
StatusRunning
|
||||
StatusStopping
|
||||
)
|
||||
|
||||
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
|
||||
type BGReader struct {
|
||||
r io.Reader
|
||||
|
||||
cond *sync.Cond
|
||||
status int32
|
||||
readResults []readResult
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
buf *[]byte
|
||||
err error
|
||||
}
|
||||
|
||||
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
|
||||
// reader will stop automatically when the underlying reader returns an error.
|
||||
func (r *BGReader) Start() {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
switch r.status {
|
||||
case StatusStopped:
|
||||
r.status = StatusRunning
|
||||
go r.bgRead()
|
||||
case StatusRunning:
|
||||
// no-op
|
||||
case StatusStopping:
|
||||
r.status = StatusRunning
|
||||
}
|
||||
}
|
||||
|
||||
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
|
||||
// background reader is not running.
|
||||
func (r *BGReader) Stop() {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
switch r.status {
|
||||
case StatusStopped:
|
||||
// no-op
|
||||
case StatusRunning:
|
||||
r.status = StatusStopping
|
||||
case StatusStopping:
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current status of the background reader.
|
||||
func (r *BGReader) Status() int32 {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
return r.status
|
||||
}
|
||||
|
||||
func (r *BGReader) bgRead() {
|
||||
keepReading := true
|
||||
for keepReading {
|
||||
buf := iobufpool.Get(8192)
|
||||
n, err := r.r.Read(*buf)
|
||||
*buf = (*buf)[:n]
|
||||
|
||||
r.cond.L.Lock()
|
||||
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
|
||||
if r.status == StatusStopping || err != nil {
|
||||
r.status = StatusStopped
|
||||
keepReading = false
|
||||
}
|
||||
r.cond.L.Unlock()
|
||||
r.cond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (r *BGReader) Read(p []byte) (int, error) {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
if len(r.readResults) > 0 {
|
||||
return r.readFromReadResults(p)
|
||||
}
|
||||
|
||||
// There are no unread background read results and the background reader is stopped.
|
||||
if r.status == StatusStopped {
|
||||
return r.r.Read(p)
|
||||
}
|
||||
|
||||
// Wait for results from the background reader
|
||||
for len(r.readResults) == 0 {
|
||||
r.cond.Wait()
|
||||
}
|
||||
return r.readFromReadResults(p)
|
||||
}
|
||||
|
||||
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
|
||||
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
|
||||
buf := r.readResults[0].buf
|
||||
var err error
|
||||
|
||||
n := copy(p, *buf)
|
||||
if n == len(*buf) {
|
||||
err = r.readResults[0].err
|
||||
iobufpool.Put(buf)
|
||||
if len(r.readResults) == 1 {
|
||||
r.readResults = nil
|
||||
} else {
|
||||
r.readResults = r.readResults[1:]
|
||||
}
|
||||
} else {
|
||||
*buf = (*buf)[n:]
|
||||
r.readResults[0].buf = buf
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func New(r io.Reader) *BGReader {
|
||||
return &BGReader{
|
||||
r: r,
|
||||
cond: &sync.Cond{
|
||||
L: &sync.Mutex{},
|
||||
},
|
||||
}
|
||||
}
|
@ -1,140 +0,0 @@
|
||||
package bgreader_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBGReaderReadWhenStopped(t *testing.T) {
|
||||
r := bytes.NewReader([]byte("foo bar baz"))
|
||||
bgr := bgreader.New(r)
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foo bar baz"), buf)
|
||||
}
|
||||
|
||||
func TestBGReaderReadWhenStarted(t *testing.T) {
|
||||
r := bytes.NewReader([]byte("foo bar baz"))
|
||||
bgr := bgreader.New(r)
|
||||
bgr.Start()
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foo bar baz"), buf)
|
||||
}
|
||||
|
||||
type mockReadFunc func(p []byte) (int, error)
|
||||
|
||||
type mockReader struct {
|
||||
readFuncs []mockReadFunc
|
||||
}
|
||||
|
||||
func (r *mockReader) Read(p []byte) (int, error) {
|
||||
if len(r.readFuncs) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
fn := r.readFuncs[0]
|
||||
r.readFuncs = r.readFuncs[1:]
|
||||
|
||||
return fn(p)
|
||||
}
|
||||
|
||||
func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), nil },
|
||||
},
|
||||
}
|
||||
bgr := bgreader.New(rr)
|
||||
bgr.Start()
|
||||
buf := make([]byte, 3)
|
||||
n, err := bgr.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
require.Equal(t, []byte("foo"), buf)
|
||||
}
|
||||
|
||||
func TestBGReaderErrorWhenStarted(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
|
||||
},
|
||||
}
|
||||
|
||||
bgr := bgreader.New(rr)
|
||||
bgr.Start()
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.Equal(t, []byte("foobarbaz"), buf)
|
||||
require.EqualError(t, err, "oops")
|
||||
}
|
||||
|
||||
func TestBGReaderErrorWhenStopped(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
|
||||
},
|
||||
}
|
||||
|
||||
bgr := bgreader.New(rr)
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.Equal(t, []byte("foobarbaz"), buf)
|
||||
require.EqualError(t, err, "oops")
|
||||
}
|
||||
|
||||
type numberReader struct {
|
||||
v uint8
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
func (nr *numberReader) Read(p []byte) (int, error) {
|
||||
n := nr.rng.Intn(len(p))
|
||||
for i := 0; i < n; i++ {
|
||||
p[i] = nr.v
|
||||
nr.v++
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and
|
||||
// stopping the background worker from other goroutines.
|
||||
func TestBGReaderStress(t *testing.T) {
|
||||
nr := &numberReader{rng: rand.New(rand.NewSource(0))}
|
||||
bgr := bgreader.New(nr)
|
||||
|
||||
bytesRead := 0
|
||||
var expected uint8
|
||||
buf := make([]byte, 10_000)
|
||||
rng := rand.New(rand.NewSource(0))
|
||||
|
||||
for bytesRead < 1_000_000 {
|
||||
randomNumber := rng.Intn(100)
|
||||
switch {
|
||||
case randomNumber < 10:
|
||||
go bgr.Start()
|
||||
case randomNumber < 20:
|
||||
go bgr.Stop()
|
||||
default:
|
||||
n, err := bgr.Read(buf)
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < n; i++ {
|
||||
require.Equal(t, expected, buf[i])
|
||||
expected++
|
||||
}
|
||||
bytesRead += n
|
||||
}
|
||||
}
|
||||
}
|
100
pgconn/krb5.go
100
pgconn/krb5.go
@ -1,100 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
// NewGSSFunc creates a GSS authentication provider, for use with
|
||||
// RegisterGSSProvider.
|
||||
type NewGSSFunc func() (GSS, error)
|
||||
|
||||
var newGSS NewGSSFunc
|
||||
|
||||
// RegisterGSSProvider registers a GSS authentication provider. For example, if
|
||||
// you need to use Kerberos to authenticate with your server, add this to your
|
||||
// main package:
|
||||
//
|
||||
// import "github.com/otan/gopgkrb5"
|
||||
//
|
||||
// func init() {
|
||||
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
|
||||
// }
|
||||
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
|
||||
newGSS = newGSSArg
|
||||
}
|
||||
|
||||
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
||||
type GSS interface {
|
||||
GetInitToken(host, service string) ([]byte, error)
|
||||
GetInitTokenFromSPN(spn string) ([]byte, error)
|
||||
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
||||
}
|
||||
|
||||
func (c *PgConn) gssAuth() error {
|
||||
if newGSS == nil {
|
||||
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
|
||||
}
|
||||
cli, err := newGSS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var nextData []byte
|
||||
if c.config.KerberosSpn != "" {
|
||||
// Use the supplied SPN if provided.
|
||||
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
|
||||
} else {
|
||||
// Allow the kerberos service name to be overridden
|
||||
service := "postgres"
|
||||
if c.config.KerberosSrvName != "" {
|
||||
service = c.config.KerberosSrvName
|
||||
}
|
||||
nextData, err = cli.GetInitToken(c.config.Host, service)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
gssResponse := &pgproto3.GSSResponse{
|
||||
Data: nextData,
|
||||
}
|
||||
c.frontend.Send(gssResponse)
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.rxGSSContinue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var done bool
|
||||
done, nextData, err = cli.Continue(resp.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationGSSContinue:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
|
||||
}
|
2504
pgconn/pgconn.go
2504
pgconn/pgconn.go
File diff suppressed because it is too large
Load Diff
@ -1,41 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCommandTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
commandTag CommandTag
|
||||
rowsAffected int64
|
||||
isInsert bool
|
||||
isUpdate bool
|
||||
isDelete bool
|
||||
isSelect bool
|
||||
}{
|
||||
{commandTag: CommandTag{s: "INSERT 0 5"}, rowsAffected: 5, isInsert: true},
|
||||
{commandTag: CommandTag{s: "UPDATE 0"}, rowsAffected: 0, isUpdate: true},
|
||||
{commandTag: CommandTag{s: "UPDATE 1"}, rowsAffected: 1, isUpdate: true},
|
||||
{commandTag: CommandTag{s: "DELETE 0"}, rowsAffected: 0, isDelete: true},
|
||||
{commandTag: CommandTag{s: "DELETE 1"}, rowsAffected: 1, isDelete: true},
|
||||
{commandTag: CommandTag{s: "DELETE 1234567890"}, rowsAffected: 1234567890, isDelete: true},
|
||||
{commandTag: CommandTag{s: "SELECT 1"}, rowsAffected: 1, isSelect: true},
|
||||
{commandTag: CommandTag{s: "SELECT 99999999999"}, rowsAffected: 99999999999, isSelect: true},
|
||||
{commandTag: CommandTag{s: "CREATE TABLE"}, rowsAffected: 0},
|
||||
{commandTag: CommandTag{s: "ALTER TABLE"}, rowsAffected: 0},
|
||||
{commandTag: CommandTag{s: "DROP TABLE"}, rowsAffected: 0},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
ct := tt.commandTag
|
||||
assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
|
||||
assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
|
||||
assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
|
||||
assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
|
||||
assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
|
||||
}
|
||||
}
|
@ -1,90 +0,0 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnStress(t *testing.T) {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
actionCount := 10000
|
||||
if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" {
|
||||
stressFactor, err := strconv.ParseInt(s, 10, 64)
|
||||
require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR")
|
||||
actionCount *= int(stressFactor)
|
||||
}
|
||||
|
||||
setupStressDB(t, pgConn)
|
||||
|
||||
actions := []struct {
|
||||
name string
|
||||
fn func(*pgconn.PgConn) error
|
||||
}{
|
||||
{"Exec Select", stressExecSelect},
|
||||
{"ExecParams Select", stressExecParamsSelect},
|
||||
{"Batch", stressBatch},
|
||||
}
|
||||
|
||||
for i := 0; i < actionCount; i++ {
|
||||
action := actions[rand.Intn(len(actions))]
|
||||
err := action.fn(pgConn)
|
||||
require.Nilf(t, err, "%d: %s", i, action.name)
|
||||
}
|
||||
|
||||
// Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled.
|
||||
numGoroutine := runtime.NumGoroutine()
|
||||
require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine)
|
||||
}
|
||||
|
||||
func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
|
||||
_, err := pgConn.Exec(context.Background(), `
|
||||
create temporary table widgets(
|
||||
id serial primary key,
|
||||
name varchar not null,
|
||||
description text,
|
||||
creation_time timestamptz default now()
|
||||
);
|
||||
|
||||
insert into widgets(name, description) values
|
||||
('Foo', 'bar'),
|
||||
('baz', 'Something really long Something really long Something really long Something really long Something really long'),
|
||||
('a', 'b')`).ReadAll()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func stressExecSelect(pgConn *pgconn.PgConn) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := pgConn.Exec(ctx, "select * from widgets").ReadAll()
|
||||
return err
|
||||
}
|
||||
|
||||
func stressExecParamsSelect(pgConn *pgconn.PgConn) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
|
||||
return result.Err
|
||||
}
|
||||
|
||||
func stressBatch(pgConn *pgconn.PgConn) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
|
||||
batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
|
||||
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||
return err
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user