Compare commits

...

84 Commits

Author SHA1 Message Date
Jack Christensen 0f77a2d028
Merge pull request #2293 from divyam234/master
feat: add support for direct sslnegotiation
2025-03-31 08:13:19 -05:00
divyam234 ddd966f09f
update 2025-03-31 15:06:55 +02:00
divyam234 924834b5b4
add pgmock tests 2025-03-31 15:02:07 +02:00
divyam234 9b15554c51
respect sslmode set by user 2025-03-30 16:35:43 +02:00
divyam234 037e4cf9a2
feat: add support for direct sslnegotiation 2025-03-30 16:21:52 +02:00
Jack Christensen 04bcc0219d Add v5.7.4 to changelog 2025-03-24 20:04:45 -05:00
Jack Christensen 0e0a7d8344
Merge pull request #2288 from felix-roehrich/fr/fix-plan-scan
Revert change in `if` from #2236.
2025-03-24 19:46:22 -05:00
Felix Röhrich 63422c7d6c revert change in if 2025-03-24 15:01:50 +01:00
Jack Christensen 5c1fbf4806 Update changelog for v5.7.3 2025-03-21 21:02:03 -05:00
Jack Christensen 05fe5f8b05 Explain seemingly redundant rows.Close() in CollectOneRow
fixes https://github.com/jackc/pgx/issues/2282
2025-03-21 20:33:32 -05:00
Jack Christensen 70c9a147a2
Merge pull request #2279 from djahandarie/min-idle-conns
Add MinIdleConns
2025-03-21 20:25:19 -05:00
Darius Jahandarie 6603ddfbe4
add MinIdleConns 2025-03-15 19:14:26 +09:00
Jack Christensen 70f7cad222 Add link to https://github.com/Arlandaren/pgxWrappy 2025-02-28 20:59:28 -06:00
Jack Christensen 6bf1b0b1b9 Add database/sql to overview of scanning 2025-02-22 08:42:26 -06:00
Jack Christensen 14bda65a0c Correct pgtype docs 2025-02-22 08:34:31 -06:00
Jack Christensen 9e3c4fb40f
Merge pull request #2257 from felix-roehrich/fr/change-connect-logic
Change connection logic to be more forgiving
2025-02-19 07:36:35 -06:00
Felix Röhrich 05e72a5ab1 make connection logic more forgiving 2025-02-17 21:24:38 +01:00
Jack Christensen 47d631e34b Added missed change to v5.7.2 changelog 2025-02-08 10:21:39 -06:00
Jack Christensen 58b05f567c Add https://github.com/nikolayk812/pgx-outbox to README.md
fixes https://github.com/jackc/pgx/issues/2239
2025-01-25 08:59:52 -06:00
Jack Christensen dcb7193669
Merge pull request #2236 from felix-roehrich/fr/fix-plan-scan
Alternative implementation for JSONCodec.PlanScan
2025-01-25 08:56:38 -06:00
Jack Christensen 1abf7d9050
Merge pull request #2240 from bonnefoa/fix-watch-panic
Unwatch and close connection on a batch write error
2025-01-25 08:38:33 -06:00
Jack Christensen b5efc90a32
Merge pull request #2028 from jackc/enable-composite-tests-on-cockroachdb
Enable composite tests on cockroachdb
2025-01-25 08:22:32 -06:00
Jack Christensen a26c93551f Skip TestCompositeCodecTranscodeStructWrapperForTable 2025-01-25 08:15:40 -06:00
Jack Christensen 2100e1da46 Use latest version of CockroachDB for CI 2025-01-25 08:04:42 -06:00
Jack Christensen 2d21a2b80d
Merge pull request #2228 from jackc/fix-xml-decode-value
XMLCodec: fix DecodeValue to return a []byte
2025-01-25 07:24:30 -06:00
Jack Christensen 5f33ee5f07 Call out []byte in QueryExecModeSimpleProtocol documentation
https://github.com/jackc/pgx/issues/2231
2025-01-25 07:15:02 -06:00
Anthonin Bonnefoy 228cfffc20 Unwatch and close connection on a batch write error
Previously, a conn.Write would simply unlock pgconn, leaving the
connection as Idle and reusable while the multiResultReader would be
closed. From this state, calling multiResultReader.Close won't try to
receiveMessage and thus won't unwatch and close the connection since it
is already closed. This leaves the connection "open" and the next time
it's used, a "Watch already in progress" panic could be triggered.

This patch fixes the issue by unwatching and closing the connection on a
batch write error. The same was done on Sync.Encode error even if the
path is unreachable as Sync.Error never returns an error.
2025-01-24 08:49:07 +01:00
Felix Röhrich a5353af354 rework JSONCodec.PlanScan 2025-01-22 22:35:35 +01:00
Jack Christensen 0bc29e3000
Merge pull request #2225 from logicbomb/improve-error-message
Include the field name in error messages when scanning structs
2025-01-18 10:41:13 -06:00
Jack Christensen 9cce05944a
Merge pull request #2216 from pconstantinou/master
Timestamp incorrectly adds 'Z' when serializing from JSON to indicate GMT, fixes bug #2215
2025-01-18 10:17:43 -06:00
Jason Turim 9c0ad690a9 Include the field name in error messages when scanning structs 2025-01-11 14:31:24 -05:00
Jack Christensen 03f08abda3 Fix in Unmarshal function rather than DecodeValue
This preserves backwards compatibility in the unlikely event someone is
using an alternative XML unmarshaler that does support unmarshalling
into *any.
2025-01-11 11:26:46 -06:00
Jack Christensen 2c1b1c389a
Merge pull request #2200 from zenkovev/flush_request_in_pipeline
add flush request in pipeline
2025-01-11 11:15:36 -06:00
Jack Christensen 329cb45913 XMLCodec: fix DecodeValue to return a []byte
Previously, DecodeValue would always return nil with the default
Unmarshal function.

fixes https://github.com/jackc/pgx/issues/2227
2025-01-11 10:55:48 -06:00
zenkovev c96a55f8c0 private const for pipelineRequestType 2025-01-11 19:54:18 +03:00
Jack Christensen e87760682f Update oldest supported Go version to 1.22 2025-01-11 07:49:50 -06:00
Jack Christensen f681632c68 Drop PG 12 support and add PG 17 to CI 2025-01-11 07:49:26 -06:00
Phil Constantinou 3c640a44b6 Making the tests a little cleaner and clear 2025-01-06 09:24:55 -08:00
zenkovev de3f868c1d pipeline queue for client requests 2025-01-06 13:54:48 +03:00
Phil Constantinou 5424d3c873 Return error and make sure they are unit tested 2025-01-05 19:45:45 -08:00
Phil Constantinou 42d3d00734 Parse as a UTC time 2025-01-05 19:19:17 -08:00
Phil Constantinou cdc672cf3f Make JSON output confirm to ISO8601 timestamp without a timezone 2025-01-05 13:05:51 -08:00
Phil Constantinou 52e2858629 Added unit test and fixed typo 2025-01-02 13:36:33 -08:00
Phil Constantinou e352784fed Add Z only if needed. 2025-01-02 12:50:29 -08:00
Jack Christensen c2175fe46e
Merge pull request #2213 from moukoublen/fix_2204
Fix #2204
2024-12-30 20:35:41 -06:00
Jack Christensen 659823f8f3 Add link to github.com/amirsalarsafaei/sqlc-pgx-monitoring
fixes https://github.com/jackc/pgx/issues/2212
2024-12-30 20:27:10 -06:00
Jack Christensen ca04098fab
Merge pull request #2136 from ninedraft/optimize-sanitize
Reduce SQL sanitizer allocations
2024-12-30 20:24:13 -06:00
Jack Christensen 4ff0a454e0
Merge pull request #2211 from EinoPlasma/master
Fixes for Method Comment and Typo in Test Function Name
2024-12-30 20:12:22 -06:00
Jack Christensen 00b86ca3db
Merge pull request #2208 from vamshiaruru/feat/expose_empty_acquire_wait_time_from_puddle
Expose puddle.Pool's EmptyAcquireWaitTime in pgxpool's Stats
2024-12-30 20:03:51 -06:00
Kostas Stamatakis 61a0227241
simplify test 2024-12-30 23:15:46 +02:00
Kostas Stamatakis 2190a8e0d1
cleanup and add test for json codec 2024-12-30 23:09:19 +02:00
Kostas Stamatakis 6e9fa42fef
fix #2204 2024-12-30 22:54:42 +02:00
EinoPlasma 6d9e6a726e Fix typo in test function name 2024-12-29 21:03:38 +08:00
EinoPlasma 02e387ea64 Fix method comment in PasswordMessage 2024-12-29 20:59:24 +08:00
merlin e452f80b1d
TestErrNoRows: remove bad test case 2024-12-28 13:39:01 +02:00
merlin da0315d1a4
optimisations of quote functions by @sean- 2024-12-28 13:31:09 +02:00
merlin 120c89fe0d
fix preallocations of quoted string 2024-12-28 13:31:09 +02:00
merlin 057937db27
add prefix to quoters tests 2024-12-28 13:31:09 +02:00
merlin 47cbd8edb8
drop too large values from memory pools 2024-12-28 13:31:09 +02:00
merlin 90a77b13b2
add docs to sanitize tests 2024-12-28 13:31:08 +02:00
merlin 59d6aa87b9
rework QuoteString and QuoteBytes as append-style 2024-12-28 13:31:08 +02:00
merlin 39ffc8b7a4
add lexer and query pools
use lexer pool
2024-12-28 13:31:08 +02:00
merlin c4c1076d28
add FuzzQuoteString and FuzzQuoteBytes 2024-12-28 13:31:08 +02:00
merlin 4293b25262
decrease number of samples in go benchmark 2024-12-28 13:31:08 +02:00
merlin ea1e13a660
quoteString 2024-12-28 13:31:08 +02:00
merlin 58d4c0c94f
quoteBytes
check new quoteBytes
2024-12-28 13:31:08 +02:00
merlin 1752f7b4c1
docs 2024-12-28 13:31:08 +02:00
merlin ee718a110d
append AvailableBuffer 2024-12-28 13:31:08 +02:00
merlin 546ad2f4e2
shared bytestring 2024-12-28 13:31:08 +02:00
merlin efc2c9ff44
buf pool 2024-12-28 13:31:08 +02:00
merlin aabed18db8
add benchmark tool
fix benchmmark script

fix benchmark script
2024-12-28 13:31:08 +02:00
merlin afa974fb05
base case
make benchmark more extensive

add quote to string

add BenchmarkSanitizeSQL
2024-12-28 13:31:08 +02:00
Vamshi Aruru 12b37f3218 Expose puddle.Pool's EmptyAcquireWaitTime in pgxpool's Stats
Addresses: https://github.com/jackc/pgx/issues/2205
2024-12-26 13:46:49 +05:30
Jack Christensen bcf3fbd780
Merge pull request #2206 from alexandear/refactor-impossible-cond
Refactor Conn.LoadTypes by removing redundant check
2024-12-24 11:14:17 -06:00
Jack Christensen f7c3d190ad
Merge pull request #2203 from martinyonatann/chore/check-array-and-remove-imposible-condition
check array just using `len` and remove `imposible condition`
2024-12-24 11:10:45 -06:00
Jack Christensen 473a241b96
Merge pull request #2202 from martinyonatann/chore/remove-unused-parameter
remove unused func and parameter
2024-12-24 09:32:07 -06:00
Oleksandr Redko 311f72afdc Refactor Conn.LoadTypes by removing redundant check 2024-12-24 12:58:15 +02:00
martinpasaribu 877111ceeb
check array just using len and remove imposible condition 2024-12-22 23:57:28 +07:00
martinpasaribu dc3aea06b5
remove unused func and parameter 2024-12-22 23:48:08 +07:00
Jack Christensen e5d321f920
Merge pull request #2197 from alexandear/fix-generated-hdr
Update comments in generated code to align with Go standards
2024-12-21 12:40:23 -06:00
Oleksandr Redko 17cd36818c Update comments in generated code to align with Go standards 2024-12-21 20:21:32 +02:00
zenkovev 76593f37f7 add flush request in pipeline 2024-12-17 11:49:13 +03:00
Jack Christensen 29751194ef Test composites on CockroachDB 2024-05-25 07:49:00 -05:00
Jack Christensen c1f4cbb5cd Upgrade CockroachDB on CI 2024-05-25 07:48:47 -05:00
41 changed files with 1539 additions and 244 deletions

View File

@ -14,18 +14,8 @@ jobs:
strategy:
matrix:
go-version: ["1.22", "1.23"]
pg-version: [12, 13, 14, 15, 16, cockroachdb]
pg-version: [13, 14, 15, 16, 17, cockroachdb]
include:
- pg-version: 12
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 13
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
@ -66,6 +56,16 @@ jobs:
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 17
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: cockroachdb
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"

View File

@ -1,3 +1,20 @@
# 5.7.4 (March 24, 2025)
* Fix / revert change to scanning JSON `null` (Felix Röhrich)
# 5.7.3 (March 21, 2025)
* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32)
* Improve SQL sanitizer performance (ninedraft)
* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich)
* Fix Values() for xml type always returning nil instead of []byte
* Add ability to send Flush message in pipeline mode (zenkovev)
* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou)
* Better error messages when scanning structs (logicbomb)
* Fix handling of error on batch write (bonnefoa)
* Match libpq's connection fallback behavior more closely (felix-roehrich)
* Add MinIdleConns to pgxpool (djahandarie)
# 5.7.2 (December 21, 2024)
* Fix prepared statement already exists on batch prepare failure
@ -9,6 +26,7 @@
* Implement pgtype.UUID.String() (Konstantin Grachev)
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
* Update golang.org/x/crypto
* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo)
# 5.7.1 (September 10, 2024)

View File

@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.
## Supported Go and PostgreSQL Versions
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.22 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
## Version Policy
@ -172,3 +172,15 @@ Supports, structs, maps, slices and custom mapping functions.
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
Code first migration library for native pgx (no database/sql abstraction).
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.
### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox)
Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver.
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
Simplifies working with the pgx library, providing convenient scanning of nested structures.

View File

@ -2,7 +2,7 @@ require "erb"
rule '.go' => '.go.erb' do |task|
erb = ERB.new(File.read(task.source))
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
sh "goimports", "-w", task.name
end

View File

@ -42,8 +42,8 @@ fi
if [[ "${PGVERSION-}" =~ ^cockroach ]]
then
wget -qO- https://binaries.cockroachdb.com/cockroach-v23.1.3.linux-amd64.tgz | tar xvz
sudo mv cockroach-v23.1.3.linux-amd64/cockroach /usr/local/bin/
wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz
sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/
cockroach start-single-node --insecure --background --listen-addr=localhost
cockroach sql --insecure -e 'create database pgx_test'
fi

21
conn.go
View File

@ -420,7 +420,7 @@ func (c *Conn) IsClosed() bool {
return c.pgConn.IsClosed()
}
func (c *Conn) die(err error) {
func (c *Conn) die() {
if c.IsClosed() {
return
}
@ -588,14 +588,6 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
return result.CommandTag, result.Err
}
type unknownArgumentTypeQueryExecModeExecError struct {
arg any
}
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
}
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
err := c.eqb.Build(c.typeMap, nil, args)
if err != nil {
@ -661,11 +653,12 @@ const (
// should implement pgtype.Int64Valuer.
QueryExecModeExec
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. Queries
// are executed in a single round trip. Type mappings can be registered with pgtype.Map.RegisterDefaultPgType. Queries
// will be rejected that have arguments that are unregistered or ambiguous. e.g. A map[string]string may have the
// PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a map[string]string directly as an
// argument. This mode cannot.
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is
// especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used
// instead for text type values including json and jsonb. Type mappings can be registered with
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous.
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a
// map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip.
//
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes
// the warning regarding differences in text format and binary format encoding with user defined types. There may be

View File

@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
}

View File

@ -161,7 +161,7 @@ type derivedTypeInfo struct {
// The result of this call can be passed into RegisterTypes to complete the process.
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
m := c.TypeMap()
if typeNames == nil || len(typeNames) == 0 {
if len(typeNames) == 0 {
return nil, fmt.Errorf("No type names were supplied.")
}
@ -169,13 +169,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
// the SQL not support recent structures such as multirange
serverVersion, _ := serverVersion(c)
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
var rows Rows
var err error
if typeNames == nil {
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
} else {
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
}
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
if err != nil {
return nil, fmt.Errorf("While generating load types query: %w", err)
}
@ -232,15 +226,15 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
default:
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
}
if type_ != nil {
m.RegisterType(type_)
if ti.NspName != "" {
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
m.RegisterType(nspType)
result = append(result, nspType)
}
result = append(result, type_)
// the type_ is imposible to be null
m.RegisterType(type_)
if ti.NspName != "" {
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
m.RegisterType(nspType)
result = append(result, nspType)
}
result = append(result, type_)
}
return result, nil
}

View File

@ -0,0 +1,60 @@
#!/usr/bin/env bash
current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi
restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT
# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi
# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi
commits=("$@")
benchmarks_dir=benchmarks
if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi
# Benchmark results
bench_files=()
# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}
# Sanitized commmit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi
bench_files+=("$bench_file")
done
# go install golang.org/x/perf/cmd/benchstat[@latest]
benchstat "${bench_files[@]}"

View File

@ -4,8 +4,10 @@ import (
"bytes"
"encoding/hex"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
@ -24,18 +26,33 @@ type Query struct {
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3
const maxBufSize = 16384 // 16 Ki
var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}
var null = []byte("null")
func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
buf := bufPool.get()
defer bufPool.put(buf)
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
buf.WriteString(part)
case int:
argIdx := part - 1
var p []byte
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}
@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) {
if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
p = null
case int64:
str = strconv.FormatInt(arg, 10)
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
case []byte:
str = QuoteBytes(arg)
p = QuoteBytes(buf.AvailableBuffer(), arg)
case string:
str = QuoteString(arg)
p = QuoteString(buf.AvailableBuffer(), arg)
case time.Time:
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
buf.Write(p)
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = " " + str + " "
buf.WriteByte(' ')
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
}
func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
query := &Query{}
query.init(sql)
return query, nil
}
var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}
func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
}
l := sqlLexerPool.get()
defer sqlLexerPool.put(l)
l.src = sql
l.stateFn = rawState
l.parts = parts
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
q.Parts = l.parts
}
func QuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
func QuoteString(dst []byte, str string) []byte {
const quote = '\''
// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)
// Add opening quote
dst = append(dst, quote)
// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}
// Add closing quote
dst = append(dst, quote)
return dst
}
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}
// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}
// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]
// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'
// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)
// Add suffix
dst[len(dst)-1] = '\''
return dst
}
type sqlLexer struct {
@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
}
}
var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...any) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
query := queryPool.get()
query.init(sql)
defer queryPool.put(query)
return query.Sanitize(args...)
}
type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}
func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}
return v
}
func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}

View File

@ -0,0 +1,62 @@
// sanitize_benchmark_test.go
package sanitize_test
import (
"testing"
"time"
"github.com/jackc/pgx/v5/internal/sanitize"
)
var benchmarkSanitizeResult string
const benchmarkQuery = "" +
`SELECT *
FROM "water_containers"
WHERE NOT "id" = $1 -- int64
AND "tags" NOT IN $2 -- nil
AND "volume" > $3 -- float64
AND "transportable" = $4 -- bool
AND position($5 IN "sign") -- bytes
AND "label" LIKE $6 -- string
AND "created_at" > $7; -- time.Time`
var benchmarkArgs = []any{
int64(12345),
nil,
float64(500),
true,
[]byte("8BADF00D"),
"kombucha's han'dy awokowa",
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
}
func BenchmarkSanitize(b *testing.B) {
query, err := sanitize.NewQuery(benchmarkQuery)
if err != nil {
b.Fatalf("failed to create query: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize query: %v", err)
}
}
}
var benchmarkNewSQLResult string
func BenchmarkSanitizeSQL(b *testing.B) {
b.ReportAllocs()
var err error
for i := 0; i < b.N; i++ {
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize SQL: %v", err)
}
}
}

View File

@ -0,0 +1,55 @@
package sanitize_test
import (
"strings"
"testing"
"github.com/jackc/pgx/v5/internal/sanitize"
)
func FuzzQuoteString(f *testing.F) {
const prefix = "prefix"
f.Add("new\nline")
f.Add("sample text")
f.Add("sample q'u'o't'e's")
f.Add("select 'quoted $42', $1")
f.Fuzz(func(t *testing.T, input string) {
got := string(sanitize.QuoteString([]byte(prefix), input))
want := oldQuoteString(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}
func FuzzQuoteBytes(f *testing.F) {
const prefix = "prefix"
f.Add([]byte(nil))
f.Add([]byte("\n"))
f.Add([]byte("sample text"))
f.Add([]byte("sample q'u'o't'e's"))
f.Add([]byte("select 'quoted $42', $1"))
f.Fuzz(func(t *testing.T, input []byte) {
got := string(sanitize.QuoteBytes([]byte(prefix), input))
want := oldQuoteBytes(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}

View File

@ -1,6 +1,8 @@
package sanitize_test
import (
"encoding/hex"
"strings"
"testing"
"time"
@ -227,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
}
}
}
func TestQuoteString(t *testing.T) {
tc := func(name, input string) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteString(nil, input))
want := oldQuoteString(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("empty", "")
tc("text", "abcd")
tc("with quotes", `one's hat is always a cat`)
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}
func TestQuoteBytes(t *testing.T) {
tc := func(name string, input []byte) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteBytes(nil, input))
want := oldQuoteBytes(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("nil", nil)
tc("empty", []byte{})
tc("text", []byte("abcd"))
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}

View File

@ -51,6 +51,8 @@ type Config struct {
KerberosSpn string
Fallbacks []*FallbackConfig
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
@ -318,6 +320,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"sslnegotiation": {},
"sslpassword": {},
"sslsni": {},
"krbspn": {},
@ -386,6 +389,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
config.Port = fallbacks[0].Port
config.TLSConfig = fallbacks[0].TLSConfig
config.Fallbacks = fallbacks[1:]
config.SSLNegotiation = settings["sslnegotiation"]
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
if err == nil {
@ -449,6 +453,7 @@ func parseEnvSettings() map[string]string {
"PGSSLSNI": "sslsni",
"PGSSLROOTCERT": "sslrootcert",
"PGSSLPASSWORD": "sslpassword",
"PGSSLNEGOTIATION": "sslnegotiation",
"PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service",
"PGSERVICEFILE": "servicefile",
@ -646,6 +651,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
sslkey := settings["sslkey"]
sslpassword := settings["sslpassword"]
sslsni := settings["sslsni"]
sslnegotiation := settings["sslnegotiation"]
// Match libpq default behavior
if sslmode == "" {
@ -657,6 +663,13 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
tlsConfig := &tls.Config{}
if sslnegotiation == "direct" {
tlsConfig.NextProtos = []string{"postgresql"}
if sslmode == "prefer" {
sslmode = "require"
}
}
if sslrootcert != "" {
var caCertPool *x509.CertPool

View File

@ -49,7 +49,7 @@ func TestContextWatcherContextCancelled(t *testing.T) {
require.True(t, cleanupCalled, "Cleanup func was not called")
}
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
t.Error("cancel func should not have been called")

View File

@ -1,6 +1,7 @@
package pgconn
import (
"container/list"
"context"
"crypto/md5"
"crypto/tls"
@ -267,12 +268,15 @@ func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*
var pgErr *PgError
if errors.As(err, &pgErr) {
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
// pgx will try next host even if libpq does not in certain cases (see #2246)
// consider change for the next major version
const ERRCODE_INVALID_PASSWORD = "28P01"
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
// auth failed due to invalid password, db does not exist or user has no permission
if pgErr.Code == ERRCODE_INVALID_PASSWORD ||
pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil ||
pgErr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
return nil, allErrors
@ -321,7 +325,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
if connectConfig.tlsConfig != nil {
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
pgConn.contextWatcher.Watch(ctx)
tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig)
var (
tlsConn net.Conn
err error
)
if config.SSLNegotiation == "direct" {
tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig)
} else {
tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig)
}
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
pgConn.conn.Close()
@ -1408,9 +1420,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct {
pgConn *PgConn
ctx context.Context
pipeline *Pipeline
pgConn *PgConn
ctx context.Context
rr *ResultReader
@ -1443,12 +1454,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
mrr.closed = true
if mrr.pipeline != nil {
mrr.pipeline.expectedReadyForQueryCount--
} else {
mrr.pgConn.contextWatcher.Unwatch()
mrr.pgConn.unlock()
}
mrr.pgConn.contextWatcher.Unwatch()
mrr.pgConn.unlock()
case *pgproto3.ErrorResponse:
mrr.err = ErrorResponseToPgError(msg)
}
@ -1672,7 +1679,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
case *pgproto3.EmptyQueryResponse:
rr.concludeCommand(CommandTag{}, nil)
case *pgproto3.ErrorResponse:
rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
pgErr := ErrorResponseToPgError(msg)
if rr.pipeline != nil {
rr.pipeline.state.HandleError(pgErr)
}
rr.concludeCommand(CommandTag{}, pgErr)
}
return msg, nil
@ -1773,9 +1784,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
pgConn.contextWatcher.Unwatch()
multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
pgConn.asyncClose()
return multiResult
}
@ -1783,9 +1795,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
defer pgConn.exitPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf)
if err != nil {
pgConn.contextWatcher.Unwatch()
multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
multiResult.closed = true
multiResult.err = err
pgConn.unlock()
pgConn.asyncClose()
return multiResult
}
@ -1999,9 +2012,7 @@ type Pipeline struct {
conn *PgConn
ctx context.Context
expectedReadyForQueryCount int
pendingSync bool
state pipelineState
err error
closed bool
}
@ -2012,6 +2023,122 @@ type PipelineSync struct{}
// CloseComplete is returned by GetResults when a CloseComplete message is received.
type CloseComplete struct{}
type pipelineRequestType int
const (
pipelineNil pipelineRequestType = iota
pipelinePrepare
pipelineQueryParams
pipelineQueryPrepared
pipelineDeallocate
pipelineSyncRequest
pipelineFlushRequest
)
type pipelineRequestEvent struct {
RequestType pipelineRequestType
WasSentToServer bool
BeforeFlushOrSync bool
}
type pipelineState struct {
requestEventQueue list.List
lastRequestType pipelineRequestType
pgErr *PgError
expectedReadyForQueryCount int
}
func (s *pipelineState) Init() {
s.requestEventQueue.Init()
s.lastRequestType = pipelineNil
}
func (s *pipelineState) RegisterSendingToServer() {
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
val := elem.Value.(pipelineRequestEvent)
if val.WasSentToServer {
return
}
val.WasSentToServer = true
elem.Value = val
}
}
func (s *pipelineState) registerFlushingBufferOnServer() {
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
val := elem.Value.(pipelineRequestEvent)
if val.BeforeFlushOrSync {
return
}
val.BeforeFlushOrSync = true
elem.Value = val
}
}
func (s *pipelineState) PushBackRequestType(req pipelineRequestType) {
if req == pipelineNil {
return
}
if req != pipelineFlushRequest {
s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req})
}
if req == pipelineFlushRequest || req == pipelineSyncRequest {
s.registerFlushingBufferOnServer()
}
s.lastRequestType = req
if req == pipelineSyncRequest {
s.expectedReadyForQueryCount++
}
}
func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType {
for {
elem := s.requestEventQueue.Front()
if elem == nil {
return pipelineNil
}
val := elem.Value.(pipelineRequestEvent)
if !(val.WasSentToServer && val.BeforeFlushOrSync) {
return pipelineNil
}
s.requestEventQueue.Remove(elem)
if val.RequestType == pipelineSyncRequest {
s.pgErr = nil
}
if s.pgErr == nil {
return val.RequestType
}
}
}
func (s *pipelineState) HandleError(err *PgError) {
s.pgErr = err
}
func (s *pipelineState) HandleReadyForQuery() {
s.expectedReadyForQueryCount--
}
func (s *pipelineState) PendingSync() bool {
var notPendingSync bool
if elem := s.requestEventQueue.Back(); elem != nil {
val := elem.Value.(pipelineRequestEvent)
notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer
} else {
notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil)
}
return !notPendingSync
}
func (s *pipelineState) ExpectedReadyForQuery() int {
return s.expectedReadyForQueryCount
}
// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
@ -2020,16 +2147,21 @@ type CloseComplete struct{}
// Prefer ExecBatch when only sending one group of queries at once.
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
if err := pgConn.lock(); err != nil {
return &Pipeline{
pipeline := &Pipeline{
closed: true,
err: err,
}
pipeline.state.Init()
return pipeline
}
pgConn.pipeline = Pipeline{
conn: pgConn,
ctx: ctx,
}
pgConn.pipeline.state.Init()
pipeline := &pgConn.pipeline
if ctx != context.Background() {
@ -2052,10 +2184,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
if p.closed {
return
}
p.pendingSync = true
p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
p.state.PushBackRequestType(pipelinePrepare)
}
// SendDeallocate deallocates a prepared statement.
@ -2063,9 +2195,9 @@ func (p *Pipeline) SendDeallocate(name string) {
if p.closed {
return
}
p.pendingSync = true
p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
p.state.PushBackRequestType(pipelineDeallocate)
}
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
@ -2073,12 +2205,12 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [
if p.closed {
return
}
p.pendingSync = true
p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
p.conn.frontend.SendExecute(&pgproto3.Execute{})
p.state.PushBackRequestType(pipelineQueryParams)
}
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
@ -2086,11 +2218,42 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para
if p.closed {
return
}
p.pendingSync = true
p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
p.conn.frontend.SendExecute(&pgproto3.Execute{})
p.state.PushBackRequestType(pipelineQueryPrepared)
}
// SendFlushRequest sends a request for the server to flush its output buffer.
//
// The server flushes its output buffer automatically as a result of Sync being called,
// or on any request when not in pipeline mode; this function is useful to cause the server
// to flush its output buffer in pipeline mode without establishing a synchronization point.
// Note that the request is not itself flushed to the server automatically; use Flush if
// necessary. This copies the behavior of libpq PQsendFlushRequest.
func (p *Pipeline) SendFlushRequest() {
if p.closed {
return
}
p.conn.frontend.Send(&pgproto3.Flush{})
p.state.PushBackRequestType(pipelineFlushRequest)
}
// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message
// without flushing the send buffer. This serves as the delimiter of an implicit
// transaction and an error recovery point.
//
// Note that the request is not itself flushed to the server automatically; use Flush if
// necessary. This copies the behavior of libpq PQsendPipelineSync.
func (p *Pipeline) SendPipelineSync() {
if p.closed {
return
}
p.conn.frontend.SendSync(&pgproto3.Sync{})
p.state.PushBackRequestType(pipelineSyncRequest)
}
// Flush flushes the queued requests without establishing a synchronization point.
@ -2115,28 +2278,14 @@ func (p *Pipeline) Flush() error {
return err
}
p.state.RegisterSendingToServer()
return nil
}
// Sync establishes a synchronization point and flushes the queued requests.
func (p *Pipeline) Sync() error {
if p.closed {
if p.err != nil {
return p.err
}
return errors.New("pipeline closed")
}
p.conn.frontend.SendSync(&pgproto3.Sync{})
err := p.Flush()
if err != nil {
return err
}
p.pendingSync = false
p.expectedReadyForQueryCount++
return nil
p.SendPipelineSync()
return p.Flush()
}
// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
@ -2150,7 +2299,7 @@ func (p *Pipeline) GetResults() (results any, err error) {
return nil, errors.New("pipeline closed")
}
if p.expectedReadyForQueryCount == 0 {
if p.state.ExtractFrontRequestType() == pipelineNil {
return nil, nil
}
@ -2195,13 +2344,13 @@ func (p *Pipeline) getResults() (results any, err error) {
case *pgproto3.CloseComplete:
return &CloseComplete{}, nil
case *pgproto3.ReadyForQuery:
p.expectedReadyForQueryCount--
p.state.HandleReadyForQuery()
return &PipelineSync{}, nil
case *pgproto3.ErrorResponse:
pgErr := ErrorResponseToPgError(msg)
p.state.HandleError(pgErr)
return nil, pgErr
}
}
}
@ -2231,6 +2380,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
// These should never happen here. But don't take chances that could lead to a deadlock.
case *pgproto3.ErrorResponse:
pgErr := ErrorResponseToPgError(msg)
p.state.HandleError(pgErr)
return nil, pgErr
case *pgproto3.CommandComplete:
p.conn.asyncClose()
@ -2250,7 +2400,7 @@ func (p *Pipeline) Close() error {
p.closed = true
if p.pendingSync {
if p.state.PendingSync() {
p.conn.asyncClose()
p.err = errors.New("pipeline has unsynced requests")
p.conn.contextWatcher.Unwatch()
@ -2259,7 +2409,7 @@ func (p *Pipeline) Close() error {
return p.err
}
for p.expectedReadyForQueryCount > 0 {
for p.state.ExpectedReadyForQuery() > 0 {
_, err := p.getResults()
if err != nil {
p.err = err

View File

@ -14,6 +14,7 @@ import (
"os"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@ -1420,6 +1421,52 @@ func TestConnExecBatch(t *testing.T) {
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
}
type mockConnection struct {
net.Conn
writeLatency *time.Duration
}
func (m mockConnection) Write(b []byte) (n int, err error) {
time.Sleep(*m.writeLatency)
return m.Conn.Write(b)
}
func TestConnExecBatchWriteError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
var mockConn mockConnection
writeLatency := 0 * time.Second
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := net.Dial(network, address)
mockConn = mockConnection{conn, &writeLatency}
return mockConn, err
}
pgConn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer closeConn(t, pgConn)
batch := &pgconn.Batch{}
pgConn.Conn()
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel2()
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
writeLatency = 2 * time.Second
mrr := pgConn.ExecBatch(ctx2, batch)
err = mrr.Close()
require.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
require.True(t, pgConn.IsClosed())
}
func TestConnExecBatchDeferredError(t *testing.T) {
t.Parallel()
@ -3105,6 +3152,344 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestPipelineFlushForSingleRequests(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(ctx)
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
sd, ok := results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, "msg", string(sd.Fields[0].Name))
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "hello", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
pipeline.SendDeallocate("ps")
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.CloseComplete)
require.Truef(t, ok, "expected CloseComplete, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "1", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Sync()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelineFlushForRequestSeries(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(ctx)
pipeline.SendPrepare("ps", "select $1::bigint as num", nil)
err = pipeline.Sync()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
sd, ok := results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, "num", string(sd.Fields[0].Name))
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("1")}, nil, nil)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("2")}, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "1", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "2", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("3")}, nil, nil)
err = pipeline.Flush()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("4")}, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "3", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "4", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("5")}, nil, nil)
pipeline.SendFlushRequest()
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("6")}, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "5", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "6", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Sync()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelineFlushWithError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(ctx)
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "1", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
var pgErr *pgconn.PgError
require.ErrorAs(t, readResult.Err, &pgErr)
require.Equal(t, "22012", pgErr.Code)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
pipeline.SendPipelineSync()
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "5", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Sync()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelineCloseReadsUnreadResults(t *testing.T) {
t.Parallel()
@ -3435,6 +3820,173 @@ func TestSNISupport(t *testing.T) {
}
}
func TestConnectWithDirectSSLNegotiation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
connString string
expectDirectNego bool
}{
{
name: "Default negotiation (postgres)",
connString: "sslmode=require",
expectDirectNego: false,
},
{
name: "Direct negotiation",
connString: "sslmode=require sslnegotiation=direct",
expectDirectNego: true,
},
{
name: "Explicit postgres negotiation",
connString: "sslmode=require sslnegotiation=postgres",
expectDirectNego: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
_, port, err := net.SplitHostPort(ln.Addr().String())
require.NoError(t, err)
var directNegoObserved atomic.Bool
serverErrCh := make(chan error, 1)
go func() {
defer close(serverErrCh)
conn, err := ln.Accept()
if err != nil {
serverErrCh <- fmt.Errorf("accept error: %w", err)
return
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
firstByte := make([]byte, 1)
_, err = conn.Read(firstByte)
if err != nil {
serverErrCh <- fmt.Errorf("read first byte error: %w", err)
return
}
// Check if TLS Client Hello (direct) or PostgreSQL SSLRequest
isDirect := firstByte[0] >= 20 && firstByte[0] <= 23
directNegoObserved.Store(isDirect)
var tlsConn *tls.Conn
if !isDirect {
// Handle standard PostgreSQL SSL negotiation
// Read the rest of the SSL request message
sslRequestRemainder := make([]byte, 7)
_, err = io.ReadFull(conn, sslRequestRemainder)
if err != nil {
serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err)
return
}
// Send SSL acceptance response
_, err = conn.Write([]byte("S"))
if err != nil {
serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err)
return
}
// Setup TLS server without needing to reuse the first byte
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
if err != nil {
serverErrCh <- fmt.Errorf("cert error: %w", err)
return
}
tlsConn = tls.Server(conn, &tls.Config{
Certificates: []tls.Certificate{cert},
})
} else {
// Handle direct TLS negotiation
// Setup TLS server with the first byte already read
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
if err != nil {
serverErrCh <- fmt.Errorf("cert error: %w", err)
return
}
// Use a wrapper to inject the first byte back into the TLS handshake
bufConn := &prefixConn{
Conn: conn,
prefixData: firstByte,
}
tlsConn = tls.Server(bufConn, &tls.Config{
Certificates: []tls.Certificate{cert},
})
}
// Complete TLS handshake
if err := tlsConn.Handshake(); err != nil {
serverErrCh <- fmt.Errorf("TLS handshake error: %w", err)
return
}
defer tlsConn.Close()
err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn))
if err != nil {
serverErrCh <- fmt.Errorf("pgmock run error: %w", err)
return
}
}()
connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1",
tt.connString, port)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
conn, err := pgconn.Connect(ctx, connStr)
require.NoError(t, err)
defer conn.Close(ctx)
err = <-serverErrCh
require.NoError(t, err)
require.Equal(t, tt.expectDirectNego, directNegoObserved.Load())
})
}
}
// prefixConn implements a net.Conn that prepends some data to the first Read
type prefixConn struct {
net.Conn
prefixData []byte
prefixConsumed bool
}
func (c *prefixConn) Read(b []byte) (n int, err error) {
if !c.prefixConsumed && len(c.prefixData) > 0 {
n = copy(b, c.prefixData)
c.prefixData = c.prefixData[n:]
c.prefixConsumed = len(c.prefixData) == 0
return n, nil
}
return c.Conn.Read(b)
}
// https://github.com/jackc/pgx/issues/1920
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
t.Parallel()

View File

@ -12,7 +12,7 @@ type PasswordMessage struct {
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*PasswordMessage) Frontend() {}
// Frontend identifies this message as an authentication response.
// InitialResponse identifies this message as an authentication response.
func (*PasswordMessage) InitialResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message

View File

@ -11,8 +11,6 @@ import (
)
func TestCompositeCodecTranscode(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists ct_test;
@ -91,8 +89,6 @@ func (p *point3d) ScanIndex(i int) any {
}
func TestCompositeCodecTranscodeStruct(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists point3d;
@ -128,8 +124,6 @@ create type point3d as (
}
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists point3d;
@ -169,8 +163,6 @@ create type point3d as (
}
func TestCompositeCodecDecodeValue(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists point3d;
@ -214,7 +206,7 @@ create type point3d as (
//
// https://github.com/jackc/pgx/issues/1576
func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
skipCockroachDB(t, "Server does not support composite types from table definitions")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

View File

@ -53,8 +53,8 @@ similar fashion to database/sql. The second is to use a pointer to a pointer.
return err
}
When using nullable pgtype types as parameters for queries, one has to remember
to explicitly set their Valid field to true, otherwise the parameter's value will be NULL.
When using nullable pgtype types as parameters for queries, one has to remember to explicitly set their Valid field to
true, otherwise the parameter's value will be NULL.
JSON Support
@ -159,11 +159,16 @@ example_child_records_test.go for an example.
Overview of Scanning Implementation
The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID
from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for
scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are
interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner and
PointValuer interfaces.
The first step is to use the OID to lookup the correct Codec. The Map will call the Codec's PlanScan method to get a
plan for scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types
are interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner
and PointValuer interfaces.
If a Go value is not supported directly by a Codec then Map will try see if it is a sql.Scanner. If is then that
interface will be used to scan the value. Most sql.Scanners require the input to be in the text format (e.g. UUIDs and
numeric). However, pgx will typically have received the value in the binary format. In this case the binary value will be
parsed, reencoded as text, and then passed to the sql.Scanner. This may incur additional overhead for query results with
a large number of affected values.
If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again.
For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that

View File

@ -1,4 +1,5 @@
// Do not edit. Generated from pgtype/int.go.erb
// Code generated from pgtype/int.go.erb. DO NOT EDIT.
package pgtype
import (

View File

@ -1,4 +1,5 @@
// Do not edit. Generated from pgtype/int_test.go.erb
// Code generated from pgtype/int_test.go.erb. DO NOT EDIT.
package pgtype_test
import (

View File

@ -1,3 +1,5 @@
// Code generated from pgtype/integration_benchmark_test.go.erb. DO NOT EDIT.
package pgtype_test
import (

View File

@ -71,6 +71,27 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco
}
}
// JSON needs its on scan plan for pointers to handle 'null'::json(b).
// Consider making pointerPointerScanPlan more flexible in the future.
type jsonPointerScanPlan struct {
next ScanPlan
}
func (p jsonPointerScanPlan) Scan(src []byte, dst any) error {
el := reflect.ValueOf(dst).Elem()
if src == nil || string(src) == "null" {
el.SetZero()
return nil
}
el.Set(reflect.New(el.Type().Elem()))
if p.next != nil {
return p.next.Scan(src, el.Interface())
}
return nil
}
type encodePlanJSONCodecEitherFormatString struct{}
func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) {
@ -117,58 +138,36 @@ func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (
return buf, nil
}
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch target.(type) {
case *string:
return scanPlanAnyToString{}
case **string:
// This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better
// solution would be.
//
// https://github.com/jackc/pgx/issues/1470 -- **string
// https://github.com/jackc/pgx/issues/1691 -- ** anything else
if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok {
if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil {
if _, failed := nextPlan.(*scanPlanFail); !failed {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
}
}
}
case *[]byte:
return scanPlanJSONToByteSlice{}
case BytesScanner:
return scanPlanBinaryBytesToBytesScanner{}
}
// Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
//
// https://github.com/jackc/pgx/issues/1418
if isSQLScanner(target) {
return &scanPlanSQLScanner{formatCode: format}
}
return &scanPlanJSONToJSONUnmarshal{
unmarshal: c.Unmarshal,
}
func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan {
return c.planScan(m, oid, formatCode, target, 0)
}
// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner).
//
// https://github.com/jackc/pgx/issues/2146
func isSQLScanner(v any) bool {
val := reflect.ValueOf(v)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
return true
}
val = val.Elem()
// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b),
// so we need to duplicate the logic here.
func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan {
if depth > 8 {
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
}
switch target.(type) {
case *string:
return &scanPlanAnyToString{}
case *[]byte:
return &scanPlanJSONToByteSlice{}
case BytesScanner:
return &scanPlanBinaryBytesToBytesScanner{}
case sql.Scanner:
return &scanPlanSQLScanner{formatCode: formatCode}
}
rv := reflect.ValueOf(target)
if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer {
var plan jsonPointerScanPlan
plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1)
return plan
} else {
return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal}
}
return false
}
type scanPlanAnyToString struct{}
@ -212,7 +211,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
return fmt.Errorf("cannot scan NULL into %T", dst)
}
elem := reflect.ValueOf(dst).Elem()
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Pointer || v.IsNil() {
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
}
elem := v.Elem()
elem.Set(reflect.Zero(elem.Type()))
return s.unmarshal(src, dst)

View File

@ -6,6 +6,7 @@ import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
"testing"
@ -48,6 +49,7 @@ func TestJSONCodec(t *testing.T) {
Age int `json:"age"`
}
var str string
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{
{nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))},
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
@ -65,6 +67,9 @@ func TestJSONCodec(t *testing.T) {
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
// Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204)
{NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }},
})
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
@ -136,6 +141,27 @@ func (i Issue2146) Value() (driver.Value, error) {
return string(b), err
}
type NonPointerJSONScanner struct {
V *string
}
func (i NonPointerJSONScanner) Scan(src any) error {
switch c := src.(type) {
case string:
*i.V = c
case []byte:
*i.V = string(c)
default:
return errors.New("unknown source type")
}
return nil
}
func (i NonPointerJSONScanner) Value() (driver.Value, error) {
return i.V, nil
}
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
@ -166,11 +192,15 @@ func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
// A string cannot scan a NULL.
str := "foobar"
err = conn.QueryRow(ctx, "select null::json").Scan(&str)
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
fieldName := "json"
if conn.PgConn().ParameterStatus("crdb_version") != "" {
fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb.
}
require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *string", fieldName))
// A non-string cannot scan a NULL.
err = conn.QueryRow(ctx, "select null::json").Scan(&n)
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int")
require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *int", fieldName))
})
}
@ -267,7 +297,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
},
})
}
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
@ -278,3 +309,54 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
}},
})
}
func TestJSONCodecScanToNonPointerValues(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
n := 44
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
require.Error(t, err)
var i *int
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
require.Error(t, err)
m := 0
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
require.NoError(t, err)
require.Equal(t, 42, m)
})
}
func TestJSONCodecScanNull(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
var dest struct{}
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot scan NULL into *struct {}")
err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&dest)
require.NoError(t, err)
var destPointer *struct{}
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&destPointer)
require.NoError(t, err)
require.Nil(t, destPointer)
err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&destPointer)
require.NoError(t, err)
require.Nil(t, destPointer)
var raw json.RawMessage
require.NoError(t, conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&raw))
require.Equal(t, json.RawMessage("null"), raw)
})
}
func TestJSONCodecScanNullToPointerToSQLScanner(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
var dest *Issue2146
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
require.NoError(t, err)
require.Nil(t, dest)
})
}

View File

@ -66,11 +66,11 @@ func TestJSONBCodecUnmarshalSQLNull(t *testing.T) {
// A string cannot scan a NULL.
str := "foobar"
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&str)
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *string")
// A non-string cannot scan a NULL.
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&n)
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int")
require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *int")
})
}

View File

@ -396,11 +396,7 @@ type scanPlanSQLScanner struct {
}
func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
scanner := getSQLScanner(dst)
if scanner == nil {
return fmt.Errorf("cannot scan into %T", dst)
}
scanner := dst.(sql.Scanner)
if src == nil {
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
@ -413,21 +409,6 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
}
}
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
func getSQLScanner(target any) sql.Scanner {
val := reflect.ValueOf(target)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
return val.Interface().(sql.Scanner)
}
val = val.Elem()
}
return nil
}
type scanPlanString struct{}
func (scanPlanString) Scan(src []byte, dst any) error {

View File

@ -91,7 +91,25 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
defaultMap.RegisterType(&Type{Name: "xid8", OID: XID8OID, Codec: Uint64Codec{}})
defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{Marshal: xml.Marshal, Unmarshal: xml.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{
Marshal: xml.Marshal,
// xml.Unmarshal does not support unmarshalling into *any. However, XMLCodec.DecodeValue calls Unmarshal with a
// *any. Wrap xml.Marshal with a function that copies the data into a new byte slice in this case. Not implementing
// directly in XMLCodec.DecodeValue to allow for the unlikely possibility that someone uses an alternative XML
// unmarshaler that does support unmarshalling into *any.
//
// https://github.com/jackc/pgx/issues/2227
// https://github.com/jackc/pgx/pull/2228
Unmarshal: func(data []byte, v any) error {
if v, ok := v.(*any); ok {
dstBuf := make([]byte, len(data))
copy(dstBuf, data)
*v = dstBuf
return nil
}
return xml.Unmarshal(data, v)
},
}})
// Range types
defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}})

View File

@ -12,6 +12,7 @@ import (
)
const pgTimestampFormat = "2006-01-02 15:04:05.999999999"
const jsonISO8601 = "2006-01-02T15:04:05.999999999"
type TimestampScanner interface {
ScanTimestamp(v Timestamp) error
@ -76,7 +77,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) {
switch ts.InfinityModifier {
case Finite:
s = ts.Time.Format(time.RFC3339Nano)
s = ts.Time.Format(jsonISO8601)
case Infinity:
s = "infinity"
case NegativeInfinity:
@ -104,15 +105,23 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error {
case "-infinity":
*ts = Timestamp{Valid: true, InfinityModifier: -Infinity}
default:
// PostgreSQL uses ISO 8601 wihout timezone for to_json function and casting from a string to timestampt
tim, err := time.Parse(time.RFC3339Nano, *s+"Z")
if err != nil {
return err
// Parse time with or without timezonr
tss := *s
// PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestampt
tim, err := time.Parse(time.RFC3339Nano, tss)
if err == nil {
*ts = Timestamp{Time: tim, Valid: true}
return nil
}
*ts = Timestamp{Time: tim, Valid: true}
tim, err = time.ParseInLocation(jsonISO8601, tss, time.UTC)
if err == nil {
*ts = Timestamp{Time: tim, Valid: true}
return nil
}
ts.Valid = false
return fmt.Errorf("cannot unmarshal %s to timestamp with layout %s or %s (%w)",
*s, time.RFC3339Nano, jsonISO8601, err)
}
return nil
}

View File

@ -2,12 +2,14 @@ package pgtype_test
import (
"context"
"encoding/json"
"testing"
"time"
pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -100,13 +102,24 @@ func TestTimestampCodecDecodeTextInvalid(t *testing.T) {
}
func TestTimestampMarshalJSON(t *testing.T) {
tsStruct := struct {
TS pgtype.Timestamp `json:"ts"`
}{}
tm := time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC)
tsString := "\"" + tm.Format("2006-01-02T15:04:05") + "\"" // `"2012-03-29T10:05:45"`
var pgt pgtype.Timestamp
_ = pgt.Scan(tm)
successfulTests := []struct {
source pgtype.Timestamp
result string
}{
{source: pgtype.Timestamp{}, result: "null"},
{source: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC), Valid: true}, result: "\"2012-03-29T10:05:45Z\""},
{source: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.UTC), Valid: true}, result: "\"2012-03-29T10:05:45.555Z\""},
{source: pgtype.Timestamp{Time: tm, Valid: true}, result: tsString},
{source: pgt, result: tsString},
{source: pgtype.Timestamp{Time: tm.Add(time.Second * 555 / 1000), Valid: true}, result: `"2012-03-29T10:05:45.555"`},
{source: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""},
{source: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""},
}
@ -116,12 +129,32 @@ func TestTimestampMarshalJSON(t *testing.T) {
t.Errorf("%d: %v", i, err)
}
if string(r) != tt.result {
if !assert.Equal(t, tt.result, string(r)) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r))
}
tsStruct.TS = tt.source
b, err := json.Marshal(tsStruct)
assert.NoErrorf(t, err, "failed to marshal %v %s", tt.source, err)
t2 := tsStruct
t2.TS = pgtype.Timestamp{} // Clear out the value so that we can compare after unmarshalling
err = json.Unmarshal(b, &t2)
assert.NoErrorf(t, err, "failed to unmarshal %v with %s", tt.source, err)
assert.True(t, tsStruct.TS.Time.Unix() == t2.TS.Time.Unix())
}
}
func TestTimestampUnmarshalJSONErrors(t *testing.T) {
tsStruct := struct {
TS pgtype.Timestamp `json:"ts"`
}{}
goodJson1 := []byte(`{"ts":"2012-03-29T10:05:45"}`)
assert.NoError(t, json.Unmarshal(goodJson1, &tsStruct))
goodJson2 := []byte(`{"ts":"2012-03-29T10:05:45Z"}`)
assert.NoError(t, json.Unmarshal(goodJson2, &tsStruct))
badJson := []byte(`{"ts":"2012-03-29"}`)
assert.Error(t, json.Unmarshal(badJson, &tsStruct))
}
func TestTimestampUnmarshalJSON(t *testing.T) {
successfulTests := []struct {
source string

View File

@ -79,7 +79,7 @@ func TestXMLCodecUnmarshalSQLNull(t *testing.T) {
// A string cannot scan a NULL.
str := "foobar"
err = conn.QueryRow(ctx, "select null::xml").Scan(&str)
assert.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
assert.EqualError(t, err, "can't scan into dest[0] (col: xml): cannot scan NULL into *string")
})
}
@ -97,3 +97,32 @@ func TestXMLCodecPointerToPointerToString(t *testing.T) {
require.Nil(t, s)
})
}
func TestXMLCodecDecodeValue(t *testing.T) {
skipCockroachDB(t, "CockroachDB does not support XML.")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
for _, tt := range []struct {
sql string
expected any
}{
{
sql: `select '<foo>bar</foo>'::xml`,
expected: []byte("<foo>bar</foo>"),
},
} {
t.Run(tt.sql, func(t *testing.T) {
rows, err := conn.Query(ctx, tt.sql)
require.NoError(t, err)
for rows.Next() {
values, err := rows.Values()
require.NoError(t, err)
require.Len(t, values, 1)
require.Equal(t, tt.expected, values[0])
}
require.NoError(t, rows.Err())
})
}
})
}

View File

@ -1,4 +1,5 @@
// Do not edit. Generated from pgtype/zeronull/int.go.erb
// Code generated from pgtype/zeronull/int.go.erb. DO NOT EDIT.
package zeronull
import (

View File

@ -1,4 +1,5 @@
// Do not edit. Generated from pgtype/zeronull/int_test.go.erb
// Code generated from pgtype/zeronull/int_test.go.erb. DO NOT EDIT.
package zeronull_test
import (

View File

@ -147,6 +147,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName
assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName)
assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName)
assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName)
assert.Equalf(t, expected.MinIdleConns, actual.MinIdleConns, "%s - MinIdleConns", testName)
assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName)
assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName)

View File

@ -17,6 +17,7 @@ import (
var defaultMaxConns = int32(4)
var defaultMinConns = int32(0)
var defaultMinIdleConns = int32(0)
var defaultMaxConnLifetime = time.Hour
var defaultMaxConnIdleTime = time.Minute * 30
var defaultHealthCheckPeriod = time.Minute
@ -87,6 +88,7 @@ type Pool struct {
afterRelease func(*pgx.Conn) bool
beforeClose func(*pgx.Conn)
minConns int32
minIdleConns int32
maxConns int32
maxConnLifetime time.Duration
maxConnLifetimeJitter time.Duration
@ -144,6 +146,13 @@ type Config struct {
// to create new connections.
MinConns int32
// MinIdleConns is the minimum number of idle connections in the pool. You can increase this to ensure that
// there are always idle connections available. This can help reduce tail latencies during request processing,
// as you can avoid the latency of establishing a new connection while handling requests. It is superior
// to MinConns for this purpose.
// Similar to MinConns, the pool might temporarily dip below MinIdleConns after connection closes.
MinIdleConns int32
// HealthCheckPeriod is the duration between checks of the health of idle connections.
HealthCheckPeriod time.Duration
@ -189,6 +198,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
afterRelease: config.AfterRelease,
beforeClose: config.BeforeClose,
minConns: config.MinConns,
minIdleConns: config.MinIdleConns,
maxConns: config.MaxConns,
maxConnLifetime: config.MaxConnLifetime,
maxConnLifetimeJitter: config.MaxConnLifetimeJitter,
@ -271,7 +281,8 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
}
go func() {
p.createIdleResources(ctx, int(p.minConns))
targetIdleResources := max(int(p.minConns), int(p.minIdleConns))
p.createIdleResources(ctx, targetIdleResources)
p.backgroundHealthCheck()
}()
@ -334,6 +345,17 @@ func ParseConfig(connString string) (*Config, error) {
config.MinConns = defaultMinConns
}
if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_idle_conns"]; ok {
delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns")
n, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return nil, fmt.Errorf("cannot parse pool_min_idle_conns: %w", err)
}
config.MinIdleConns = int32(n)
} else {
config.MinIdleConns = defaultMinIdleConns
}
if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok {
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
d, err := time.ParseDuration(s)
@ -472,7 +494,9 @@ func (p *Pool) checkMinConns() error {
// TotalConns can include ones that are being destroyed but we should have
// sleep(500ms) around all of the destroys to help prevent that from throwing
// off this check
toCreate := p.minConns - p.Stat().TotalConns()
// Create the number of connections needed to get to both minConns and minIdleConns
toCreate := max(p.minConns-p.Stat().TotalConns(), p.minIdleConns-p.Stat().IdleConns())
if toCreate > 0 {
return p.createIdleResources(context.Background(), int(toCreate))
}

View File

@ -43,10 +43,11 @@ func TestConnectConfig(t *testing.T) {
func TestParseConfigExtractsPoolArguments(t *testing.T) {
t.Parallel()
config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1")
config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1 pool_min_idle_conns=2")
assert.NoError(t, err)
assert.EqualValues(t, 42, config.MaxConns)
assert.EqualValues(t, 1, config.MinConns)
assert.EqualValues(t, 2, config.MinIdleConns)
assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_max_conns")
assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_min_conns")
}

View File

@ -82,3 +82,10 @@ func (s *Stat) MaxLifetimeDestroyCount() int64 {
func (s *Stat) MaxIdleDestroyCount() int64 {
return s.idleDestroyCount
}
// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires
// from the pool for a resource to be released or constructed because the pool was
// empty.
func (s *Stat) EmptyAcquireWaitTime() time.Duration {
return s.s.EmptyAcquireWaitTime()
}

View File

@ -420,7 +420,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) {
t.Fatal("Expected Rows to have an error after an improper read but it didn't")
}
if rows.Err().Error() != "can't scan into dest[0]: cannot scan int4 (OID 23) in binary format into *time.Time" {
if rows.Err().Error() != "can't scan into dest[0] (col: n): cannot scan int4 (OID 23) in binary format into *time.Time" {
t.Fatalf("Expected different Rows.Err(): %v", rows.Err())
}

13
rows.go
View File

@ -272,7 +272,7 @@ func (rows *baseRows) Scan(dest ...any) error {
err := rows.scanPlans[i].Scan(values[i], dst)
if err != nil {
err = ScanArgError{ColumnIndex: i, Err: err}
err = ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err}
rows.fatal(err)
return err
}
@ -334,11 +334,16 @@ func (rows *baseRows) Conn() *Conn {
type ScanArgError struct {
ColumnIndex int
FieldName string
Err error
}
func (e ScanArgError) Error() string {
return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
if e.FieldName == "?column?" { // Don't include the fieldname if it's unknown
return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
}
return fmt.Sprintf("can't scan into dest[%d] (col: %s): %v", e.ColumnIndex, e.FieldName, e.Err)
}
func (e ScanArgError) Unwrap() error {
@ -366,7 +371,7 @@ func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, v
err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
if err != nil {
return ScanArgError{ColumnIndex: i, Err: err}
return ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err}
}
}
@ -468,6 +473,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
return value, err
}
// The defer rows.Close() won't have executed yet. If the query returned more than one row, rows would still be open.
// rows.Close() must be called before rows.Err() so we explicitly call it here.
rows.Close()
return value, rows.Err()
}

5
tx.go
View File

@ -3,7 +3,6 @@ package pgx
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
@ -103,7 +102,7 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
if err != nil {
// begin should never fail unless there is an underlying connection issue or
// a context timeout. In either case, the connection is possibly broken.
c.die(errors.New("failed to begin transaction"))
c.die()
return nil, err
}
@ -216,7 +215,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error {
tx.closed = true
if err != nil {
// A rollback failure leaves the connection in an undefined state
tx.conn.die(fmt.Errorf("rollback failed: %w", err))
tx.conn.die()
return err
}

View File

@ -3,6 +3,7 @@ package pgx_test
import (
"bytes"
"context"
"fmt"
"net"
"os"
"reflect"
@ -215,7 +216,12 @@ func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typena
input := []int{1, 2, 234432}
var output []int16
err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" {
fieldName := typename
if conn.PgConn().ParameterStatus("crdb_version") != "" && typename == "json" {
fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb.
}
expectedMessage := fmt.Sprintf("can't scan into dest[0] (col: %s): json: cannot unmarshal number 234432 into Go value of type int16", fieldName)
if err == nil || err.Error() != expectedMessage {
t.Errorf("%s: Expected *json.UnmarshalTypeError, but got %v", typename, err)
}
}